Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge contrib text2image to src #880

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ dmypy.json
# vscode
.vs
.vscode
# vscode extensions
.history

# Pycharm
.idea
Expand Down
17 changes: 17 additions & 0 deletions configs/accelerate_t2i_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
8 changes: 4 additions & 4 deletions contrib/text2image/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ The `train.json` should be the format as follow:

```json
{
"type": "text-image",
"type": "image_text",
"instances": [
{
"image": "00.jpg",
"images": "00.jpg",
"text": "A photo of a <SKS> dog"
},
...
Expand Down Expand Up @@ -61,8 +61,8 @@ For convenience, we provide a script `finetune_t2i.sh` for fine-tuning. It can b

```bash
bash finetune_t2i.sh \
model_name_or_path=stabilityai/stable-diffusion-2-1 \
dataset_path=data/example
--model_name_or_path "stabilityai/stable-diffusion-2-1" \
--dataset_path "data/example"
```

The `model_name_or_path` is the model name in [huggingface](https://huggingface.co/) or path of the pre-trained model. The `dataset_path` is the path of the dataset, which should be organized as the above tree struct.
Expand Down
12 changes: 6 additions & 6 deletions contrib/text2image/finetune_t2i.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
model_name_or_path=stabilityai/stable-diffusion-2-1
model_type="unet"
dataset_path=data/example
output_dir=output
main_port=29500
output_dir=output_dir
main_process_port=29500
img_size=768

while [[ $# -ge 1 ]]; do
Expand All @@ -25,8 +25,8 @@ while [[ $# -ge 1 ]]; do
output_dir="$2"
shift
;;
-p|--main_port)
main_port="$2"
-p|--main_process_port)
main_process_port="$2"
shift
;;
-i|--img_size)
Expand All @@ -44,13 +44,13 @@ echo "model_name_or_path: ${model_name_or_path}"
echo "model_type: ${model_type}"
echo "dataset_path: ${dataset_path}"
echo "output_dir: ${output_dir}"
echo "main_port: ${main_port}"
echo "main_process_port: ${main_process_port}"
echo "img_size: ${img_size}"


accelerate launch \
--config_file=./accelerate_t2i_config.yaml \
--main_port=${main_port} \
--main_process_port=${main_process_port} \
finetune_t2i.py \
--model_name_or_path=${model_name_or_path} \
--model_type=${model_type} \
Expand Down
4 changes: 2 additions & 2 deletions contrib/text2image/t2i_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, data_args: T2IDatasetArguments):
self.data_file = osp.join(data_args.dataset_path, data_args.train_file)

self.data_dict = json.load(open(self.data_file, "r"))
assert self.data_dict["type"] == "text-image", "The dataset type must be text-image."
assert self.data_dict["type"] == "image_text", "The dataset type must be image_text"

self.data_instances = self.data_dict["instances"]

Expand All @@ -35,7 +35,7 @@ def __len__(self):

def __getitem__(self, idx):
instance = self.data_instances[idx]
image_path = osp.join(self.image_folder, instance["image"])
image_path = osp.join(self.image_folder, instance["images"])
image = Image.open(image_path)
image = image.convert("RGB")

Expand Down
105 changes: 105 additions & 0 deletions examples/finetune_t2i.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import sys
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_MODE"] = "offline"
import shutil
from pathlib import Path
import gc

import torch
from diffusers import (
AutoencoderKL,
UNet2DConditionModel
)
from transformers import (
AutoTokenizer,
CLIPTextModel
)
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from transformers import HfArgumentParser
from peft import LoraConfig

from lmflow.args import (
DiffuserModelArguments,
T2IDatasetArguments,
AutoArguments,
)
from lmflow.datasets import Dataset
from lmflow.pipeline.auto_pipeline import AutoPipeline

def main():
pipeline_name = "diffuser_tuner"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((DiffuserModelArguments, T2IDatasetArguments, PipelineArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses()


logging_dir = Path(pipeline_args.output_dir, pipeline_args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=pipeline_args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
mixed_precision=pipeline_args.mixed_precision,
log_with="wandb",
project_config=accelerator_project_config,
)

if accelerator.is_main_process and pipeline_args.overwrite_output_dir and os.path.exists(pipeline_args.output_dir):
shutil.rmtree(pipeline_args.output_dir)

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_args.model_name_or_path, subfolder="text_encoder").to("cuda")
vae = AutoencoderKL.from_pretrained(model_args.model_name_or_path, subfolder="vae").to("cuda")

# dataset = build_t2i_dataset(data_args, tokenizer, text_encoder, vae)
kwargs = {"tokenizer": tokenizer, "text_encoder": text_encoder, "vae": vae}
dataset = Dataset(data_args, backend="t2i", **kwargs)

del tokenizer, text_encoder, vae
torch.cuda.empty_cache()
gc.collect()

model = None
if model_args.arch_type == "unet":
model = UNet2DConditionModel.from_pretrained(model_args.model_name_or_path, subfolder=model_args.arch_type)
elif model_args.arch_type == "transformer":
raise NotImplementedError("Transformer model is not implemented.")
else:
raise ValueError("The model type is not supported.")
if model_args.use_lora:
accelerator.print(f"Using LoRA of {model_args.lora_target_modules} for training")
model.requires_grad_(False)
lora_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
init_lora_weights="gaussian",
target_modules=model_args.lora_target_modules,
)
model.add_adapter(lora_config)
else:
model.requires_grad_(True)

fintuner = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
)
accelerator.init_trackers("text2image-finetune", config={
"data_args": data_args,
"model_args": model_args,
"pipeline_args": pipeline_args,
})

accelerator.wait_for_everyone()
fintuner.tune(
accelerator=accelerator,
model=model, dataset=dataset
)

if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ pydantic
gradio
accelerate>=0.27.2
einops>=0.6.1
vllm>=0.4.1
vllm>=0.4.1
diffusers>=0.29.2
75 changes: 75 additions & 0 deletions scripts/diffuser/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Diffusers fine-tuning

## Data Preparation

Here is a tree struct of the required data organization. In detail, under a `dataset_path` *example*, by default, an `img` directory is used for image files, and `train.json`, `valid.json` and `test.json` are used for reference of training, validation and testinig data. The `valid.json` and `test.json` are optional. If one is provided and the other is not, the two files will be set as the same.

```bash
data
└── example
├── img
│   ├── 00.jpg
│   ├── 01.jpg
│   ├── 02.jpg
│   ├── 03.jpg
│   └── 04.jpg
├── train.json
├── [valid.json]
└── [test.json]
```

The `train.json` should be the format as follow:

```json
{
"type": "image_text",
"instances": [
{
"images": "00.jpg",
"text": "A photo of a <SKS> dog"
},
...
]
}
```

And the `valid.json` and `test.json` should be the format as follow:

```json
{
"type": "text-only",
"instances": [
{
"text": "A photo of a <SKS> dog in front of Eiffel Tower."
},
...
]
}
```

Here is a specific example of the data [dog_t2i_data_example](https://drive.google.com/drive/folders/106ahvIrXbiuZMBw0NuOTjY0vnM_xXARW?usp=sharing)

## Pretrained Models

The script will automatically download pretrained weights from huggingface. Just pass the correct path of pretrained weight like `stabilityai/stable-diffusion-2-1` to the script arg `model_name_or_path`.

## Finetune

For convenience, we provide a script `run_finetune_t2i.sh` for fine-tuning. It can be used as follow:

```bash
bash scripts/diffuser/run_finetune_t2i.sh \
--model_name_or_path "stabilityai/stable-diffusion-2-1" \
--dataset_path "data/example"
```

The `model_name_or_path` is the model name in [huggingface](https://huggingface.co/) or path of the pre-trained model. The `dataset_path` is the path of the dataset, which should be organized as the above tree struct.

There are also some optional arguments for the script:

- `arch_type`: The type of the model, which can be `unet` or `transformer`. Default is `unet`. (The `transformer` is not supported yet.)
- `output_dir`: The output directory of the fine-tuned model. Default is `output_dir`.
- `main_process_port`: The main port of the server. Default is `29500`.
- `img_size`: The size of the image for fine-tuning, validation and testing. Default is `768`.

For more customization, you can refer to the [run_finetune_t2i.sh](./run_finetune_t2i.sh) and example [finetune_t2i.py](../../examples/finetune_t2i.py).
72 changes: 72 additions & 0 deletions scripts/diffuser/run_finetune_t2i.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Parses arguments
model_name_or_path=stabilityai/stable-diffusion-2-1
arch_type="unet"
dataset_path=data/example
output_dir=output_dir
main_process_port=29500
img_size=768

while [[ $# -ge 1 ]]; do
key="$1"
case ${key} in
-m|--model_name_or_path)
model_name_or_path="$2"
shift
;;
-t|--arch_type)
arch_type="$2"
shift
;;
-d|--dataset_path)
dataset_path="$2"
shift
;;
-o|--output_dir)
output_dir="$2"
shift
;;
-p|--main_process_port)
main_process_port="$2"
shift
;;
-i|--img_size)
img_size="$2"
shift
;;
*)
echo "error: unknown option \"${key}\"" 1>&2
exit 1
esac
shift
done

echo "model_name_or_path: ${model_name_or_path}"
echo "arch_type: ${arch_type}"
echo "dataset_path: ${dataset_path}"
echo "output_dir: ${output_dir}"
echo "main_process_port: ${main_process_port}"
echo "img_size: ${img_size}"


accelerate launch \
--config_file=configs/accelerate_t2i_config.yaml \
--main_process_port=${main_process_port} \
examples/finetune_t2i.py \
--model_name_or_path=${model_name_or_path} \
--arch_type=${arch_type} \
--use_lora=True \
--lora_target_module "to_k" "to_q" "to_v" "to_out.0" "add_k_proj" "add_v_proj" \
--dataset_path=${dataset_path} \
--image_folder="img" \
--image_size=${img_size} \
--train_file="train.json" \
--validation_file="valid.json" \
--test_file="test.json" \
--output_dir=${output_dir} \
--logging_dir="logs" \
--overwrite_output_dir=True \
--mixed_precision="fp16" \
--num_train_epochs=100 \
--train_batch_size=1 \
--learning_rate=1e-4 \
--valid_steps=50
Loading
Loading