-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Organizing HF subdir and finalizing the integratable_captioner
renamed: HuggingFace/v10.py -> HuggingFace/integrated_captioner.py/integratable_captioner.py renamed: HuggingFace/V8.py -> HuggingFace/integrated_captioner.py/rough-drafts/V8.py renamed: HuggingFace/v1.py -> HuggingFace/integrated_captioner.py/rough-drafts/v1.py new file: HuggingFace/integrated_captioner.py/rough-drafts/v10.py renamed: HuggingFace/v3.py -> HuggingFace/integrated_captioner.py/rough-drafts/v3.py renamed: HuggingFace/v4.py -> HuggingFace/integrated_captioner.py/rough-drafts/v4.py renamed: HuggingFace/v5.py -> HuggingFace/integrated_captioner.py/rough-drafts/v5.py renamed: HuggingFace/v6.py -> HuggingFace/integrated_captioner.py/rough-drafts/v6.py renamed: HuggingFace/v7.py -> HuggingFace/integrated_captioner.py/rough-drafts/v7.py renamed: HuggingFace/v9.py -> HuggingFace/integrated_captioner.py/rough-drafts/v9.py new file: HuggingFace/integrated_captioner.py/template-config.json new file: HuggingFace/integrated_captioner.py/template.env
- Loading branch information
Showing
12 changed files
with
250 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
File renamed without changes.
235 changes: 235 additions & 0 deletions
235
HuggingFace/integrated_captioner.py/rough-drafts/v10.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
import os | ||
import logging | ||
import csv | ||
import json | ||
from datetime import datetime | ||
from dotenv import load_dotenv | ||
import asyncio | ||
import torch | ||
from PIL import Image, UnidentifiedImageError | ||
from transformers import BlipProcessor, BlipForConditionalGeneration, PreTrainedModel | ||
|
||
# Initialize logging at the beginning of the script | ||
logging_level = os.getenv('LOGGING_LEVEL', 'INFO').upper() | ||
logging.basicConfig(level=getattr(logging, logging_level, logging.INFO)) | ||
|
||
class ImageCaptioner: | ||
""" | ||
A class for generating captions for images using the BlipForConditionalGeneration model. | ||
Attributes: | ||
processor (BlipProcessor): Processor for image and text data. | ||
model (BlipForConditionalGeneration): The captioning model. | ||
is_initialized (bool): Flag indicating successful initialization. | ||
caption_cache (dict): Cache for storing generated captions. | ||
device (str): The device (CPU or GPU) on which the model will run. | ||
""" | ||
|
||
def __init__(self, model_name: str = "Salesforce/blip-image-captioning-base"): | ||
""" | ||
Initializes the ImageCaptioner with a specific model and additional features like caching and device selection. | ||
Args: | ||
model_name (str): The name of the model to be loaded. | ||
""" | ||
self.is_initialized = True | ||
self.caption_cache = {} | ||
self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
try: | ||
self.processor = BlipProcessor.from_pretrained(model_name) | ||
self.model = BlipForConditionalGeneration.from_pretrained(model_name).to(self.device) | ||
logging.info("Successfully loaded model and processor.") | ||
except Exception as e: | ||
logging.error(f"Failed to load model and processor: {e}") | ||
self.is_initialized = False | ||
raise | ||
|
||
def load_image(self, image_path: str) -> Image.Image: | ||
""" | ||
Loads an image from a specified path and converts it to RGB format with enhanced error handling. | ||
Args: | ||
image_path (str): The path to the image file. | ||
Returns: | ||
PIL.Image.Image or None: The loaded image or None if loading failed. | ||
""" | ||
try: | ||
return Image.open(image_path).convert('RGB') | ||
except UnidentifiedImageError as e: | ||
logging.error(f"Failed to load image: {e}") | ||
return None | ||
|
||
async def generate_caption(self, raw_image: Image.Image, text: str = None) -> str: | ||
""" | ||
Generates a caption for the given image asynchronously with added features like caching and device selection. | ||
Args: | ||
raw_image (Image.Image): The image for which to generate a caption. | ||
text (str, optional): Optional text to condition the captioning. | ||
Returns: | ||
str or None: The generated caption or None if captioning failed. | ||
""" | ||
try: | ||
# Check if this image has been processed before | ||
cache_key = f"{id(raw_image)}_{text}" | ||
if cache_key in self.caption_cache: | ||
return self.caption_cache[cache_key] | ||
|
||
inputs = self.processor(raw_image, text, return_tensors="pt").to(self.device) if text else self.processor(raw_image, return_tensors="pt").to(self.device) | ||
out = self.model.generate(**inputs) | ||
caption = self.processor.batch_decode(out, skip_special_tokens=True)[0] | ||
|
||
# Store the generated caption in cache | ||
self.caption_cache[cache_key] = caption | ||
|
||
return caption | ||
except Exception as e: | ||
logging.error(f"Failed to generate caption: {e}") | ||
return None | ||
|
||
def save_to_csv(self, image_name: str, caption: str, file_name: str = None, csvfile=None): | ||
""" | ||
Saves the image name and the generated caption to a CSV file, supporting both file name and file object inputs. | ||
Args: | ||
image_name (str): The name of the image file. | ||
caption (str): The generated caption. | ||
file_name (str, optional): The name of the CSV file. Defaults to a timestamp-based name. | ||
csvfile (file object, optional): The CSV file to write to. Takes precedence over file_name if provided. | ||
""" | ||
if csvfile is None: | ||
if file_name is None: | ||
file_name = f"captions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" | ||
with open(file_name, 'a', newline='') as csvfile: | ||
writer = csv.writer(csvfile) | ||
writer.writerow([image_name, caption]) | ||
if csvfile is not None and file_name is not None: | ||
csvfile.close() | ||
|
||
class ConfigurationManager: | ||
""" | ||
A class for managing configuration settings for the ImageCaptioner. | ||
Attributes: | ||
config (dict): The configuration settings. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Initializes the ConfigurationManager and loads settings from a JSON file and environment variables. | ||
""" | ||
self.config = self.load_config() | ||
|
||
def load_config(self) -> dict: | ||
""" | ||
Loads and validates configuration settings from a JSON file and environment variables. | ||
Returns: | ||
dict: The loaded and validated configuration settings. | ||
""" | ||
# Initialize with default values | ||
config_updated = False | ||
config = { | ||
'IMAGE_FOLDER': 'images', | ||
'BASE_NAME': 'your_image_name_here.jpg', | ||
'ENDING_CAPTION': "AI generated Artwork by Daethyra using DallE" | ||
} | ||
|
||
# Try to load settings from configuration file | ||
try: | ||
with open('config.json', 'r') as f: | ||
file_config = json.load(f) | ||
config.update(file_config) | ||
except FileNotFoundError: | ||
logging.error("Configuration file config.json not found.") | ||
except json.JSONDecodeError as e: | ||
logging.error(f"Failed to parse configuration file: {e}") | ||
except Exception as e: | ||
logging.error(f"An unknown error occurred while loading the configuration file: {e}") | ||
|
||
# Validate the loaded settings | ||
self.validate_config(config) | ||
|
||
# Fallback to environment variables and offer to update the JSON configuration | ||
for key in config.keys(): | ||
env_value = os.getenv(key, None) | ||
if env_value: | ||
logging.info(f"Falling back to environment variable for {key}: {env_value}") | ||
config[key] = env_value | ||
|
||
# Offering to update the JSON configuration file with new settings | ||
if config_updated: | ||
try: | ||
with open('config.json', 'w') as f: | ||
json.dump(config, f, indent=4) | ||
except Exception as e: | ||
logging.error(f"Failed to update configuration file: {e}") | ||
|
||
return config | ||
|
||
def validate_config(self, config: dict): | ||
""" | ||
Validates the loaded configuration settings. | ||
Args: | ||
config (dict): The loaded configuration settings. | ||
""" | ||
if not config.get('IMAGE_FOLDER'): | ||
logging.error("The IMAGE_FOLDER is missing or invalid.") | ||
|
||
if not config.get('BASE_NAME'): | ||
logging.error("The BASE_NAME is missing or invalid.") | ||
|
||
if not config.get('ENDING_CAPTION'): | ||
logging.error("The ENDING_CAPTION is missing or invalid.") | ||
|
||
async def main() -> None: | ||
""" | ||
Asynchronous main function to initialize and run the image captioning pipeline. | ||
This function performs the following tasks: | ||
1. Load environment variables. | ||
2. Initialize the configuration manager. | ||
3. Initialize the ImageCaptioner. | ||
4. List all image files in the configured directory. | ||
5. Loop through each image file to generate and save both unconditional and conditional captions. | ||
""" | ||
# Load environment variables from a .env file | ||
load_dotenv() | ||
|
||
# Initialize the configuration manager to load and manage settings | ||
config_manager = ConfigurationManager() | ||
config = config_manager.config | ||
|
||
# Initialize the ImageCaptioner with the specified model | ||
captioner = ImageCaptioner() | ||
|
||
# Get a list of all image files in the specified directory | ||
image_files = list_image_files(config['IMAGE_FOLDER']) | ||
|
||
# Default to using the conditional captioning logic | ||
use_conditional_caption = config.get('USE_CONDITIONAL_CAPTION', True) | ||
|
||
# Loop through each image file in the directory | ||
for image_file in image_files: | ||
raw_image = captioner.load_image(image_file) | ||
|
||
try: | ||
if raw_image: | ||
# If the user has opted for conditional captions, generate and save them. | ||
if use_conditional_caption: | ||
caption = await captioner.generate_caption(raw_image, config['ENDING_CAPTION']) | ||
else: | ||
# Fallback to unconditional caption if the conditional caption is not selected. | ||
caption = await captioner.generate_caption(raw_image) | ||
|
||
# Save the chosen caption to a CSV file. | ||
captioner.save_to_csv(os.path.basename(image_file), caption) | ||
|
||
except Exception as e: | ||
logging.error(f"An unexpected error occurred: {e}") | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"IMAGE_FOLDER": "images", | ||
"BASE_NAME": "your_image_name_here.jpg", | ||
"ENDING_CAPTION": "AI generated Artwork by Daethyra using DallE", | ||
"MODEL_NAME": "Salesforce/blip-image-captioning-base", | ||
"USE_CONDITIONAL_CAPTION": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Logging level for the application | ||
LOGGING_LEVEL=INFO | ||
|
||
# Pretrained model name | ||
MODEL_NAME=Salesforce/blip-image-captioning-base | ||
|
||
# Whether to use conditional captioning logic | ||
USE_CONDITIONAL_CAPTION=true |