-
Notifications
You must be signed in to change notification settings - Fork 2
/
blip2_detect.py
106 lines (79 loc) · 3.43 KB
/
blip2_detect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import numpy as np
import os
import pandas as pd
from dataset import ImageCaptioningDataset
from torch.utils.data import DataLoader
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from peft import LoraConfig, get_peft_model
import argparse
# Set random seed for PyTorch
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set random seed for NumPy
np.random.seed(RANDOM_SEED)
def collate_fn(batch):
# pad the input_ids and attention_mask
processed_batch = {}
for key in batch[0].keys():
if key != "text":
processed_batch[key] = torch.stack([example[key] for example in batch])
else:
text_inputs = processor.tokenizer(
[example["text"] for example in batch], padding=True, return_tensors="pt"
)
processed_batch["input_ids"] = text_inputs["input_ids"]
processed_batch["attention_mask"] = text_inputs["attention_mask"]
return processed_batch
# Model Creation and Initialisation
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="auto", load_in_8bit=True)
# Low Rank Adaptation Technique Set
# LoraConfig
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=["q_proj", "k_proj"]
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
# Main Body
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
parser = argparse.ArgumentParser(description="Fine-Tune BLIP-2 for Diffusion-based Generated Images Detection.")
parser.add_argument('--dataset', default='./data/train.csv', type=str,
help='Path to the training CSV file')
parser.add_argument('--epochs', default=20, type=int,
help='Number of training epochs.')
parser.add_argument('--lr', default=5e-5, type=float,
help='The learning rate for training (default: 5e-5).')
parser.add_argument('--save_path', type=str, default='./SaveFineTune',
help='Path to save trained model.')
opt = parser.parse_args()
if not os.path.exists(opt.save_path):
os.makedirs(opt.save_path)
data = pd.read_csv(opt.dataset)
train_dataset = ImageCaptioningDataset(data, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=32, collate_fn=collate_fn)
print(f'Training environnement with : {device}')
optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
model.train()
for epoch in range(opt.epochs):
print("Epoch:", epoch)
for idx, batch in enumerate(train_dataloader):
input_ids = batch.pop("input_ids").to(device)
pixel_values = batch.pop("pixel_values").to(device, torch.float16)
outputs = model(input_ids=input_ids,
pixel_values=pixel_values,
labels=input_ids)
loss = outputs.loss
print("Loss:", loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
model.save_pretrained(opt.save_path)