Skip to content

Commit

Permalink
updated tta to increase batch size in tta
Browse files Browse the repository at this point in the history
  • Loading branch information
KMarshallX committed Oct 6, 2023
1 parent 62ac1ce commit b76639a
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*.ipynb
train_zero.py
train_test.py
test_time_adaptation_test.py
/__pycache__
/data
/utils/__pycache__
Expand Down
3 changes: 3 additions & 0 deletions config/adapt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,7 @@
# Resource optimization flag. 0: intermediate files are saved, 1: intermediate files are deleted
adapt_parser.add_argument('--resource', type=int, default=0, help=argparse.SUPPRESS)

# batch size multiplier
adapt_parser.add_argument('--batch_mul', type=int, default=4, help=argparse.SUPPRESS)

args = adapt_parser.parse_args()
12 changes: 9 additions & 3 deletions test_time_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Test time adpatation module
Editor: Marshall Xu
Last Edited: 10/08/2023
Last Edited: 08/10/2023
"""

import os
Expand Down Expand Up @@ -125,10 +125,16 @@
aug_item = aug_utils(patch_size, aug_mode)

print("Finetuning procedure starts!")
print(f"\nIn this test, the batch size is {6*args.batch_mul}\n")
# training loop
for epoch in tqdm(range(epoch_num)):
image, label = next(iter(data_loader))
image_batch, label_batch = aug_item(image, label)
for j in range(1, args.batch_mul):
image, label = next(iter(data_loader))
image_batch_temp, label_batch_temp = aug_item(image, label)
image_batch = torch.cat((image_batch, image_batch_temp), 0)
label_batch = torch.cat((label_batch, label_batch_temp), 0)
image_batch, label_batch = image_batch.to(device), label_batch.to(device)

optimizer.zero_grad()
Expand All @@ -144,9 +150,9 @@
# Learning rate shceduler
scheduler.step(loss)

# TODO: debug message, delete this
current_lr = optimizer.param_groups[0]['lr']
tqdm.write(f'Epoch: [{epoch+1}/{epoch_num}], Loss: {loss.item(): .4f}, Current learning rate: {current_lr: .8f}')

tqdm.write(f'Epoch: [{epoch+1}/{epoch_num}], Loss: {loss.item(): .4f}, Current learning rate: {current_lr : .8f}')

file_name = processed_data_list[i].split('.')[0]
out_mo_name = os.path.join(out_mo_path, file_name)
Expand Down
1 change: 1 addition & 0 deletions utils/unet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ class RandomCrop3D():
"""
def __init__(self, img_sz, exp_sz):
h, w, d = img_sz
# test 0925, constraint the higher bound of the crop size to be 128
crop_h = torch.randint(32, h, (1,)).item()
crop_w = torch.randint(32, w, (1,)).item()
crop_d = torch.randint(32, d, (1,)).item()
Expand Down

0 comments on commit b76639a

Please sign in to comment.