-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
226 lines (189 loc) · 9.48 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import sys
from utils.NiftiPromptDataset import NiftiPromptDataset
import utils.NiftiDataset as NiftiDataset
from torch.utils.data import DataLoader, DistributedSampler
from options.train_options import TrainOptions
import time
from collections import OrderedDict
from models import create_model
from utils.visualizer import Visualizer
from utils import util
from test import inference
import torch
import random
import numpy as np
import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.distributed import init_process_group
import warnings
# Ignore all user warnings
warnings.filterwarnings("ignore", category=UserWarning)
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
def DDP_step_up():
init_process_group(backend='nccl')
DDP_flag= False
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_gpu_memory():
torch.cuda.synchronize() # Make sure all operations are done
return torch.cuda.memory_allocated() / 1e9
def DDP_reduce_loss_dict(loss_dict, world_size):
"""
Reduce the loss dictionary across all processes so that all of them have the averaged results.
"""
with torch.no_grad():
loss_names = []
all_losses = []
for k in sorted(loss_dict.keys()):
loss_names.append(k)
all_losses.append(loss_dict[k])
# Stack all losses into a single tensor and reduce
all_losses = torch.stack(all_losses, 0)
dist.reduce(all_losses, dst=0)
if dist.get_rank() == 0:
# Only the master process will get the correct results
all_losses /= world_size
reduced_losses = {name: loss for name, loss in zip(loss_names, all_losses)}
else:
# Other processes get dummy data
reduced_losses = {name: torch.tensor(0) for name in loss_names}
return reduced_losses
if __name__ == '__main__':
# seed = 50
# seed_everything(seed)
# ----- Loading the init options -----
opt = TrainOptions().parse() # get training options
print("--------------------------------------------->:",os.environ.get('CUDA_VISIBLE_DEVICES'))
if DDP_flag is True:
DDP_step_up()
local_rank = int(os.environ["LOCAL_RANK"])
# world_size = dist.get_world_size()
# world_size = dist.get_world_size()
world_size = dist.get_world_size()
print(torch.cuda.is_available())
opt.device = torch.device(local_rank)
print('----------------->world_size: ', world_size)
print('----------------->local_rank: ', local_rank)
else:
opt.device = torch.device(int(opt.gpu_ids[0]))
print('opt.device : ', opt.device)
# Check for CUDA's availability
cuda_available = torch.cuda.is_available()
# Print whether CUDA is available
print(f"CUDA is available: {cuda_available}")
# If CUDA is available, print the number and name of available GPUs
if cuda_available:
num_gpus = torch.cuda.device_count()
print(f"Number of available GPUs: {num_gpus}")
for i in range(num_gpus):
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
print("No GPUs detected.")
# -----------------------------------------------------
model = create_model(opt) # creation of the model
model.setup(opt)
if opt.epoch_count > 1:
model.load_networks(opt.which_epoch)
visualizer = Visualizer(opt)
total_steps = 0
if DDP_flag is True:
model = DistributedDataParallel(model, device_ids=[local_rank])
# ----- Transformation and Augmentation process for the data -----
min_pixel = int(opt.min_pixel * ((opt.patch_size[0] * opt.patch_size[1] * opt.patch_size[2]) / 50))
trainTransforms = [
NiftiDataset.Padding(
(opt.patch_size[0], opt.patch_size[1], opt.patch_size[2])),
NiftiDataset.RandomCrop(
(opt.patch_size[0], opt.patch_size[1], opt.patch_size[2]), opt.drop_ratio, min_pixel),
]
train_set = NiftiPromptDataset(opt.data_path, path_A=opt.generate_path, path_B=opt.input_path,
which_direction='AtoB', transforms=trainTransforms, shuffle_labels=False, train=True,
add_graph=False, label2mask=True if opt.model in ["medddpm", "medddpmtext", "medddpmtextcross", "medddpmvisualprompt"] else False)
print('lenght train list:', len(train_set))
# Check if the dataset has a 'collate_fn' method
if hasattr(train_set, 'collate_fn'):
# If it does, use it in the DataLoader
train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=True, collate_fn=train_set.collate_fn)
else:
# If it does not, use the DataLoader without specifying a custom collate_fn
if DDP_flag is False:
train_loader = DataLoader(train_set, batch_size=opt.batch_size,shuffle=True, num_workers=opt.workers, pin_memory=True)
else:
sampler = DistributedSampler(train_set, shuffle=True)
train_loader = DataLoader( train_set, batch_size=opt.batch_size, sampler=sampler, num_workers=opt.workers,pin_memory=True)
initial_memory = get_gpu_memory()
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
for i, data in enumerate(train_loader):
iter_start_time = time.time()
if total_steps % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
visualizer.reset()
total_steps += opt.batch_size
epoch_iter += opt.batch_size
model.set_input(data)
model.optimize_parameters(step=total_steps)
memory_used_by_batch = get_gpu_memory() - initial_memory
losses = model.get_current_losses()
t = (time.time() - iter_start_time) / opt.batch_size
# The `visualizer` is a helper class that provides functions for visualizing and printing
# the current losses and results during training. It is used to display and save images
# and loss values for monitoring the training progress.
if DDP_flag is True and local_rank == 0:
for key in losses.keys():
losses = DDP_reduce_loss_dict(losses, world_size=world_size)
visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)
print('I finished (epoch %d, total_steps %d/)' %
(epoch, total_steps))
if total_steps % opt.print_freq == 0 or (DDP_flag is True and local_rank == 0):
losses = model.get_current_losses()
t = (time.time() - iter_start_time) / opt.batch_size
visualizer.print_current_losses(
epoch, epoch_iter, losses, t, t_data)
# ------- visualisation of 3D slicer -----------
z_slice = 31
generated = model.get_current_visuals(epoch=epoch)
generated = generated['fake_B']
print("generated shape:", generated.shape)
if opt.segmentation == 1:
if data['label'].shape[1] > 1:
input_label = util.tensor2label(
data['label'][0, :1, :, :, z_slice], 5)
else:
input_label = util.tensor2label(
data['label'][0, :, :, :, z_slice], 5)
else:
input_label = util.tensor2im(data['label'][0, :, :, :, z_slice])
visuals = OrderedDict([('input_label', input_label),
# ('input_label', util.tensor2label(data[1][0, :, :, :, z_slice], 5)),
('synthesized_image', util.tensor2im( # `generated` is a
# variable that stores the
# output of the model's
# forward pass. It
# represents the
# synthesized image
# generated by the model.
generated.data[0, :, :, :, z_slice])),
('real_image', util.tensor2im(data['image'][0, :, :, :, z_slice]))])
visualizer.display_current_results(visuals, epoch, total_steps)
# ------- visualisation of 3D slicier -----------
if (total_steps % opt.save_latest_freq == 0) or (DDP_flag is True and local_rank == 0):
print('saving the latest model (epoch %d, total_steps %d)' %
(epoch, total_steps))
model.save_networks('latest')
print(f'------------------> Memory_used_by_batch: {memory_used_by_batch}GB')
iter_data_time = time.time()
if (epoch % opt.save_epoch_freq == 0) or (DDP_flag is True and local_rank == 0):
print('saving the model at the end of epoch %d, iters %d' %
(epoch, total_steps))
model.save_networks('latest')
model.save_networks(epoch)
print('End of epoch %d / %d \t Time Taken: %d sec' %
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
model.update_learning_rate()