Skip to content

Commit

Permalink
Iterated through a ton of versions of the image captioner
Browse files Browse the repository at this point in the history
  • Loading branch information
Daethyra committed Oct 8, 2023
1 parent 6594baa commit d20ea4f
Show file tree
Hide file tree
Showing 8 changed files with 1,014 additions and 13 deletions.
210 changes: 210 additions & 0 deletions HuggingFace/V8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import os
import logging
import csv
import json
from datetime import datetime
from dotenv import load_dotenv
import asyncio
import torch
from PIL import Image, UnidentifiedImageError
from transformers import BlipProcessor, BlipForConditionalGeneration, PreTrainedModel

# Initialize logging at the beginning of the script
logging_level = os.getenv('LOGGING_LEVEL', 'INFO').upper()
logging.basicConfig(level=getattr(logging, logging_level, logging.INFO))

class ImageCaptioner:
"""
A class for generating captions for images using the BlipForConditionalGeneration model.
Attributes:
processor (BlipProcessor): Processor for image and text data.
model (BlipForConditionalGeneration): The captioning model.
is_initialized (bool): Flag indicating successful initialization.
caption_cache (dict): Cache for storing generated captions.
device (str): The device (CPU or GPU) on which the model will run.
"""

def __init__(self, model_name: str = "Salesforce/blip-image-captioning-base"):
"""
Initializes the ImageCaptioner with a specific model and additional features like caching and device selection.
Args:
model_name (str): The name of the model to be loaded.
"""
self.is_initialized = True
self.caption_cache = {}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
try:
self.processor = BlipProcessor.from_pretrained(model_name)
self.model = BlipForConditionalGeneration.from_pretrained(model_name).to(self.device)
logging.info("Successfully loaded model and processor.")
except Exception as e:
logging.error(f"Failed to load model and processor: {e}")
self.is_initialized = False
raise

def load_image(self, image_path: str) -> Image.Image:
"""
Loads an image from a specified path and converts it to RGB format with enhanced error handling.
Args:
image_path (str): The path to the image file.
Returns:
PIL.Image.Image or None: The loaded image or None if loading failed.
"""
try:
return Image.open(image_path).convert('RGB')
except UnidentifiedImageError as e:
logging.error(f"Failed to load image: {e}")
return None

async def generate_caption(self, raw_image: Image.Image, text: str = None) -> str:
"""
Generates a caption for the given image asynchronously with added features like caching and device selection.
Args:
raw_image (Image.Image): The image for which to generate a caption.
text (str, optional): Optional text to condition the captioning.
Returns:
str or None: The generated caption or None if captioning failed.
"""
try:
# Check if this image has been processed before
cache_key = f"{id(raw_image)}_{text}"
if cache_key in self.caption_cache:
return self.caption_cache[cache_key]

inputs = self.processor(raw_image, text, return_tensors="pt").to(self.device) if text else self.processor(raw_image, return_tensors="pt").to(self.device)
out = self.model.generate(**inputs)
caption = self.processor.batch_decode(out, skip_special_tokens=True)[0]

# Store the generated caption in cache
self.caption_cache[cache_key] = caption

return caption
except Exception as e:
logging.error(f"Failed to generate caption: {e}")
return None

def save_to_csv(self, image_name: str, caption: str, file_name: str = None, csvfile=None):
"""
Saves the image name and the generated caption to a CSV file, supporting both file name and file object inputs.
Args:
image_name (str): The name of the image file.
caption (str): The generated caption.
file_name (str, optional): The name of the CSV file. Defaults to a timestamp-based name.
csvfile (file object, optional): The CSV file to write to. Takes precedence over file_name if provided.
"""
if csvfile is None:
if file_name is None:
file_name = f"captions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
with open(file_name, 'a', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([image_name, caption])
if csvfile is not None and file_name is not None:
csvfile.close()

class ConfigurationManager:
"""
A class for managing configuration settings for the ImageCaptioner.
Attributes:
config (dict): The configuration settings.
"""

def __init__(self):
"""
Initializes the ConfigurationManager and loads settings from a JSON file and environment variables.
"""
self.config = self.load_config()

def load_config(self) -> dict:
"""
Loads and validates configuration settings from a JSON file and environment variables.
Returns:
dict: The loaded and validated configuration settings.
"""
# Initialize with default values
config_updated = False
config = {
'IMAGE_FOLDER': 'images',
'BASE_NAME': 'your_image_name_here.jpg',
'ENDING_CAPTION': "AI generated Artwork by Daethyra using DallE"
}

# Try to load settings from configuration file
try:
with open('config.json', 'r') as f:
file_config = json.load(f)
config.update(file_config)
except FileNotFoundError:
logging.error("Configuration file config.json not found.")
except json.JSONDecodeError as e:
logging.error(f"Failed to parse configuration file: {e}")
except Exception as e:
logging.error(f"An unknown error occurred while loading the configuration file: {e}")

# Validate the loaded settings
self.validate_config(config)

# Fallback to environment variables and offer to update the JSON configuration
for key in config.keys():
env_value = os.getenv(key, None)
if env_value:
logging.info(f"Falling back to environment variable for {key}: {env_value}")
config[key] = env_value

# Offering to update the JSON configuration file with new settings
if config_updated:
try:
with open('config.json', 'w') as f:
json.dump(config, f, indent=4)
except Exception as e:
logging.error(f"Failed to update configuration file: {e}")

return config

def validate_config(self, config: dict):
"""
Validates the loaded configuration settings.
Args:
config (dict): The loaded configuration settings.
"""
if not config.get('IMAGE_FOLDER'):
logging.error("The IMAGE_FOLDER is missing or invalid.")

if not config.get('BASE_NAME'):
logging.error("The BASE_NAME is missing or invalid.")

if not config.get('ENDING_CAPTION'):
logging.error("The ENDING_CAPTION is missing or invalid.")

async def main():
load_dotenv()

# Initialize configuration manager
config_manager = ConfigurationManager()
config = config_manager.config

# Remaining logic for running the ImageCaptioner
image_path = os.path.join(config['IMAGE_FOLDER'], config['BASE_NAME'])
captioner = ImageCaptioner()
raw_image = captioner.load_image(image_path)
try:
if raw_image:
unconditional_caption = await captioner.generate_caption(raw_image)
captioner.save_to_csv(config['BASE_NAME'], unconditional_caption)

conditional_caption = await captioner.generate_caption(raw_image, config['ENDING_CAPTION'])
captioner.save_to_csv(config['BASE_NAME'], conditional_caption)
except Exception as e:
logging.error(f"An unexpected error occurred: {e}")

if __name__ == "__main__":
asyncio.run(main())
File renamed without changes.
Empty file added HuggingFace/v10.py
Empty file.
14 changes: 1 addition & 13 deletions HuggingFace/v4.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import os
import logging
import csv
Expand All @@ -10,7 +9,6 @@
from PIL import Image, UnidentifiedImageError
from transformers import BlipProcessor, BlipForConditionalGeneration, PreTrainedModel


class ImageCaptioner:
"""
A class for generating captions for images using the BlipForConditionalGeneration model.
Expand All @@ -23,8 +21,6 @@ class ImageCaptioner:
device (str): The device (CPU or GPU) on which the model will run.
"""



def __init__(self, model_name: str = "Salesforce/blip-image-captioning-base"):
"""
Initializes the ImageCaptioner with a specific model and additional features like caching and device selection.
Expand All @@ -47,8 +43,6 @@ def __init__(self, model_name: str = "Salesforce/blip-image-captioning-base"):
logging_level = os.getenv('LOGGING_LEVEL', 'INFO').upper()
logging.basicConfig(level=getattr(logging, logging_level, logging.INFO))



def load_image(self, image_path: str) -> Image.Image:
"""
Loads an image from a specified path and converts it to RGB format with enhanced error handling.
Expand All @@ -65,8 +59,6 @@ def load_image(self, image_path: str) -> Image.Image:
logging.error(f"Failed to load image: {e}")
return None



async def generate_caption(self, raw_image: Image.Image, text: str = None) -> str:
"""
Generates a caption for the given image asynchronously with added features like caching and device selection.
Expand Down Expand Up @@ -96,8 +88,6 @@ async def generate_caption(self, raw_image: Image.Image, text: str = None) -> st
logging.error(f"Failed to generate caption: {e}")
return None



def save_to_csv(self, image_name: str, caption: str, file_name: str = None, csvfile=None):
"""
Saves the image name and the generated caption to a CSV file, supporting both file name and file object inputs.
Expand All @@ -119,10 +109,8 @@ def save_to_csv(self, image_name: str, caption: str, file_name: str = None, csvf
if csvfile is not None and file_name is not None:
csvfile.close()


# Remaining code for running the script can go here.

async def main():
# Instantiate environment variables
load_dotenv()

# Load settings from configuration file
Expand Down
Loading

0 comments on commit d20ea4f

Please sign in to comment.