Replies: 1 comment
-
The following code solves my problem def predict(data, model, model_name, batch_size=32, n_fold=0, device='cuda', tta=False):
model = model.to(device) # load the model into the GPU
model.load_state_dict(torch.load(os.path.join(model_name + str(n_fold), 'checkpoint.pth')))
model.eval()
with torch.no_grad():
inferer = SlidingWindowInferer(roi_size=(128, 128), sw_batch_size=batch_size, overlap=0.5, mode="gaussian", progress=True, sw_device=device, device=torch.device('cpu'))
outputs = inferer(data, model)
outputs = torch.softmax(outputs, dim=1)
if tta:
tta_list = [Flip(spatial_axis=0), Flip(spatial_axis=1), Compose([Flip(spatial_axis=0), Flip(spatial_axis=1)])]
tta_res = [outputs]
for aug in tta_list:
with torch.no_grad():
inferer = SlidingWindowInferer(roi_size=(128, 128), sw_batch_size=batch_size, overlap=0.5, mode="gaussian", progress=True, sw_device=device, device=torch.device('cpu'))
transformed_data = aug(data[0]).unsqueeze(0)
outputs = inferer(transformed_data, model)
outputs = aug.inverse(outputs[0]).unsqueeze(0)
outputs = torch.softmax(outputs, dim=1)
tta_res.append(outputs)
gc.collect()
outputs = torch.stack(tta_res, dim=0).mean(dim=0)
return outputs |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Beta Was this translation helpful? Give feedback.
All reactions