Skip to content

Commit

Permalink
new file: HuggingFace/image_captioner.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Daethyra committed Oct 7, 2023
1 parent 911dc4e commit 38217f6
Showing 1 changed file with 87 additions and 0 deletions.
87 changes: 87 additions & 0 deletions HuggingFace/image_captioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import logging
import csv
from datetime import datetime
from dotenv import load_dotenv
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration

class ImageCaptioner:
"""A class for generating captions for images using the BlipForConditionalGeneration model.
Attributes:
processor (BlipProcessor): Processor for image and text data.
model (BlipForConditionalGeneration): The captioning model.
"""

def __init__(self, model_name: str = "Salesforce/blip-image-captioning-base"):
"""Initializes the ImageCaptioner with a specific model.
Args:
model_name (str): The name of the model to be loaded.
"""
self.processor = BlipProcessor.from_pretrained(model_name)
self.model = BlipForConditionalGeneration.from_pretrained(model_name)
logging.basicConfig(level=logging.INFO)

def load_image(self, image_path: str):
"""Loads an image from a specified path and converts it to RGB format.
Args:
image_path (str): The path to the image file.
Returns:
PIL.Image.Image or None: The loaded image or None if loading failed.
"""
try:
return Image.open(image_path).convert('RGB')
except Exception as e:
logging.error(f"Failed to load image: {e}")
return None

def generate_caption(self, raw_image: Image.Image, text: str = None) -> str:
"""Generates a caption for the given image. An optional text can be provided to condition the captioning.
Args:
raw_image (Image.Image): The image for which to generate a caption.
text (str, optional): Optional text to condition the captioning.
Returns:
str or None: The generated caption or None if captioning failed.
"""
try:
inputs = self.processor(raw_image, text, return_tensors="pt") if text else self.processor(raw_image, return_tensors="pt")
out = self.model.generate(**inputs)
return self.processor.decode(out[0], skip_special_tokens=True)
except Exception as e:
logging.error(f"Failed to generate caption: {e}")
return None

def save_to_csv(self, image_name: str, caption: str, file_name: str = None):
"""Saves the image name and the generated caption to a CSV file.
Args:
image_name (str): The name of the image file.
caption (str): The generated caption.
file_name (str, optional): The name of the CSV file. Defaults to a timestamp-based name.
"""
if file_name is None:
file_name = f"captions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
with open(file_name, 'a', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([image_name, caption])

if __name__ == "__main__":
load_dotenv()
image_folder = os.getenv('IMAGE_FOLDER', 'images')
base_name = os.getenv('BASE_NAME', 'your_image_name_here.jpg') # Replace with the actual image name at runtime
ending_caption = os.getenv('ENDING_CAPTION', "AI generated Artwork by Daethyra using DallE") # Ending caption for conditional captioning

# The following lines are commented out for potential future use
# base_name = 'your_image_name_here.jpg'
# ending_caption = "AI generated Artwork by Daethyra using DallE"

image_path = os.path.join(image_folder, base_name)

captioner = ImageCaptioner()
raw_image = captioner.load_image(image_path)

if raw_image:
unconditional_caption = captioner.generate_caption(raw_image)
captioner.save_to_csv(base_name, unconditional_caption)

conditional_caption = captioner.generate_caption(raw_image, ending_caption)
captioner.save_to_csv(base_name, conditional_caption)

0 comments on commit 38217f6

Please sign in to comment.