Skip to content

Commit

Permalink
Organizing HF subdir and finalizing the integratable_captioner
Browse files Browse the repository at this point in the history
	renamed:    HuggingFace/v10.py -> HuggingFace/integrated_captioner.py/integratable_captioner.py
	renamed:    HuggingFace/V8.py -> HuggingFace/integrated_captioner.py/rough-drafts/V8.py
	renamed:    HuggingFace/v1.py -> HuggingFace/integrated_captioner.py/rough-drafts/v1.py
	new file:   HuggingFace/integrated_captioner.py/rough-drafts/v10.py
	renamed:    HuggingFace/v3.py -> HuggingFace/integrated_captioner.py/rough-drafts/v3.py
	renamed:    HuggingFace/v4.py -> HuggingFace/integrated_captioner.py/rough-drafts/v4.py
	renamed:    HuggingFace/v5.py -> HuggingFace/integrated_captioner.py/rough-drafts/v5.py
	renamed:    HuggingFace/v6.py -> HuggingFace/integrated_captioner.py/rough-drafts/v6.py
	renamed:    HuggingFace/v7.py -> HuggingFace/integrated_captioner.py/rough-drafts/v7.py
	renamed:    HuggingFace/v9.py -> HuggingFace/integrated_captioner.py/rough-drafts/v9.py
	new file:   HuggingFace/integrated_captioner.py/template-config.json
	new file:   HuggingFace/integrated_captioner.py/template.env
  • Loading branch information
Daethyra committed Oct 8, 2023
1 parent 09b1d84 commit e400d86
Show file tree
Hide file tree
Showing 12 changed files with 250 additions and 0 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
235 changes: 235 additions & 0 deletions HuggingFace/integrated_captioner.py/rough-drafts/v10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import os
import logging
import csv
import json
from datetime import datetime
from dotenv import load_dotenv
import asyncio
import torch
from PIL import Image, UnidentifiedImageError
from transformers import BlipProcessor, BlipForConditionalGeneration, PreTrainedModel

# Initialize logging at the beginning of the script
logging_level = os.getenv('LOGGING_LEVEL', 'INFO').upper()
logging.basicConfig(level=getattr(logging, logging_level, logging.INFO))

class ImageCaptioner:
"""
A class for generating captions for images using the BlipForConditionalGeneration model.
Attributes:
processor (BlipProcessor): Processor for image and text data.
model (BlipForConditionalGeneration): The captioning model.
is_initialized (bool): Flag indicating successful initialization.
caption_cache (dict): Cache for storing generated captions.
device (str): The device (CPU or GPU) on which the model will run.
"""

def __init__(self, model_name: str = "Salesforce/blip-image-captioning-base"):
"""
Initializes the ImageCaptioner with a specific model and additional features like caching and device selection.
Args:
model_name (str): The name of the model to be loaded.
"""
self.is_initialized = True
self.caption_cache = {}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
try:
self.processor = BlipProcessor.from_pretrained(model_name)
self.model = BlipForConditionalGeneration.from_pretrained(model_name).to(self.device)
logging.info("Successfully loaded model and processor.")
except Exception as e:
logging.error(f"Failed to load model and processor: {e}")
self.is_initialized = False
raise

def load_image(self, image_path: str) -> Image.Image:
"""
Loads an image from a specified path and converts it to RGB format with enhanced error handling.
Args:
image_path (str): The path to the image file.
Returns:
PIL.Image.Image or None: The loaded image or None if loading failed.
"""
try:
return Image.open(image_path).convert('RGB')
except UnidentifiedImageError as e:
logging.error(f"Failed to load image: {e}")
return None

async def generate_caption(self, raw_image: Image.Image, text: str = None) -> str:
"""
Generates a caption for the given image asynchronously with added features like caching and device selection.
Args:
raw_image (Image.Image): The image for which to generate a caption.
text (str, optional): Optional text to condition the captioning.
Returns:
str or None: The generated caption or None if captioning failed.
"""
try:
# Check if this image has been processed before
cache_key = f"{id(raw_image)}_{text}"
if cache_key in self.caption_cache:
return self.caption_cache[cache_key]

inputs = self.processor(raw_image, text, return_tensors="pt").to(self.device) if text else self.processor(raw_image, return_tensors="pt").to(self.device)
out = self.model.generate(**inputs)
caption = self.processor.batch_decode(out, skip_special_tokens=True)[0]

# Store the generated caption in cache
self.caption_cache[cache_key] = caption

return caption
except Exception as e:
logging.error(f"Failed to generate caption: {e}")
return None

def save_to_csv(self, image_name: str, caption: str, file_name: str = None, csvfile=None):
"""
Saves the image name and the generated caption to a CSV file, supporting both file name and file object inputs.
Args:
image_name (str): The name of the image file.
caption (str): The generated caption.
file_name (str, optional): The name of the CSV file. Defaults to a timestamp-based name.
csvfile (file object, optional): The CSV file to write to. Takes precedence over file_name if provided.
"""
if csvfile is None:
if file_name is None:
file_name = f"captions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
with open(file_name, 'a', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([image_name, caption])
if csvfile is not None and file_name is not None:
csvfile.close()

class ConfigurationManager:
"""
A class for managing configuration settings for the ImageCaptioner.
Attributes:
config (dict): The configuration settings.
"""

def __init__(self):
"""
Initializes the ConfigurationManager and loads settings from a JSON file and environment variables.
"""
self.config = self.load_config()

def load_config(self) -> dict:
"""
Loads and validates configuration settings from a JSON file and environment variables.
Returns:
dict: The loaded and validated configuration settings.
"""
# Initialize with default values
config_updated = False
config = {
'IMAGE_FOLDER': 'images',
'BASE_NAME': 'your_image_name_here.jpg',
'ENDING_CAPTION': "AI generated Artwork by Daethyra using DallE"
}

# Try to load settings from configuration file
try:
with open('config.json', 'r') as f:
file_config = json.load(f)
config.update(file_config)
except FileNotFoundError:
logging.error("Configuration file config.json not found.")
except json.JSONDecodeError as e:
logging.error(f"Failed to parse configuration file: {e}")
except Exception as e:
logging.error(f"An unknown error occurred while loading the configuration file: {e}")

# Validate the loaded settings
self.validate_config(config)

# Fallback to environment variables and offer to update the JSON configuration
for key in config.keys():
env_value = os.getenv(key, None)
if env_value:
logging.info(f"Falling back to environment variable for {key}: {env_value}")
config[key] = env_value

# Offering to update the JSON configuration file with new settings
if config_updated:
try:
with open('config.json', 'w') as f:
json.dump(config, f, indent=4)
except Exception as e:
logging.error(f"Failed to update configuration file: {e}")

return config

def validate_config(self, config: dict):
"""
Validates the loaded configuration settings.
Args:
config (dict): The loaded configuration settings.
"""
if not config.get('IMAGE_FOLDER'):
logging.error("The IMAGE_FOLDER is missing or invalid.")

if not config.get('BASE_NAME'):
logging.error("The BASE_NAME is missing or invalid.")

if not config.get('ENDING_CAPTION'):
logging.error("The ENDING_CAPTION is missing or invalid.")

async def main() -> None:
"""
Asynchronous main function to initialize and run the image captioning pipeline.
This function performs the following tasks:
1. Load environment variables.
2. Initialize the configuration manager.
3. Initialize the ImageCaptioner.
4. List all image files in the configured directory.
5. Loop through each image file to generate and save both unconditional and conditional captions.
"""
# Load environment variables from a .env file
load_dotenv()

# Initialize the configuration manager to load and manage settings
config_manager = ConfigurationManager()
config = config_manager.config

# Initialize the ImageCaptioner with the specified model
captioner = ImageCaptioner()

# Get a list of all image files in the specified directory
image_files = list_image_files(config['IMAGE_FOLDER'])

# Default to using the conditional captioning logic
use_conditional_caption = config.get('USE_CONDITIONAL_CAPTION', True)

# Loop through each image file in the directory
for image_file in image_files:
raw_image = captioner.load_image(image_file)

try:
if raw_image:
# If the user has opted for conditional captions, generate and save them.
if use_conditional_caption:
caption = await captioner.generate_caption(raw_image, config['ENDING_CAPTION'])
else:
# Fallback to unconditional caption if the conditional caption is not selected.
caption = await captioner.generate_caption(raw_image)

# Save the chosen caption to a CSV file.
captioner.save_to_csv(os.path.basename(image_file), caption)

except Exception as e:
logging.error(f"An unexpected error occurred: {e}")

if __name__ == "__main__":
asyncio.run(main())
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
7 changes: 7 additions & 0 deletions HuggingFace/integrated_captioner.py/template-config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"IMAGE_FOLDER": "images",
"BASE_NAME": "your_image_name_here.jpg",
"ENDING_CAPTION": "AI generated Artwork by Daethyra using DallE",
"MODEL_NAME": "Salesforce/blip-image-captioning-base",
"USE_CONDITIONAL_CAPTION": true
}
8 changes: 8 additions & 0 deletions HuggingFace/integrated_captioner.py/template.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Logging level for the application
LOGGING_LEVEL=INFO

# Pretrained model name
MODEL_NAME=Salesforce/blip-image-captioning-base

# Whether to use conditional captioning logic
USE_CONDITIONAL_CAPTION=true

0 comments on commit e400d86

Please sign in to comment.