How to get segmentation result with Test.py about img0060 ~ img0080 on BTCV
In utils/data_utils.py
test_transform need label
https://github.com/Project-MONAI/research-contributions/blob/4af80e1e2dcacfde8255fcf32616184990d5bf40/SwinUNETR/BTCV/utils/data_utils.py#L119
But there is no label for test dataset.
And in test.py
https://github.com/Project-MONAI/research-contributions/blob/4af80e1e2dcacfde8255fcf32616184990d5bf40/SwinUNETR/BTCV/test.py#L91
need label same with data_utils.
So how to get segmentation result for test dataset?
Thank you, whduddhks
Same question, have you solved it now?
Maybe i solved it.
So, this is the utils.data_utils.py
test_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image"]),
transforms.AddChanneld(keys=["image"]),
# transforms.Orientationd(keys=["image"], axcodes="RAS"),
transforms.Spacingd(keys="image", pixdim=(args.space_x, args.space_y, args.space_z), mode="bilinear"),
transforms.ScaleIntensityRanged(
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
),
transforms.ToTensord(keys=["image"]),
]
)
And in line 133
test_files = load_decathlon_datalist(datalist_json, True, "test", base_dir=data_dir)
For Test.py
for img_path in raw:
nii_img = nib.load(img_path)
shape.append(nii_img.get_fdata().shape)
affine.append(nii_img.affine)
with torch.no_grad():
for i, batch in enumerate(val_loader):
val_inputs = batch["image"].cuda()
# original_affine = batch["label_meta_dict"]["affine"][0].numpy()
h, w, d = shape[i]
target_shape = (h, w, d)
img_name = batch["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]
print("Inference on case {}".format(img_name))
val_outputs = sliding_window_inference(
val_inputs, (args.roi_x, args.roi_y, args.roi_z), 4, model, overlap=args.infer_overlap, mode="gaussian"
)
val_outputs = torch.softmax(val_outputs, 1).cpu().numpy()
val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8)[0]
val_outputs = resample_3d(val_outputs, target_shape)
nib.save(
nib.Nifti1Image(val_outputs.astype(np.uint8), affine[i]), os.path.join(output_directory, img_name)
)
I got 61 ~ 80 segmentation result by this.
Thank you
Can I take a look at the complete content of the 'Test.py' file? I don't have any definitions related to 'raw,' so it will result in an error.
Just only can run.Do not sure must be correct In utils.data_utils.py change:
test_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image"]),
transforms.AddChanneld(keys=["image"]),
# transforms.Orientationd(keys=["image"], axcodes="RAS"),
transforms.Spacingd(keys="image", pixdim=(args.space_x, args.space_y, args.space_z), mode="bilinear"),
transforms.ScaleIntensityRanged(
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
),
transforms.ToTensord(keys=["image"]),
]
)
in line 133 change: test_files = load_decathlon_datalist(datalist_json, True, "test", base_dir=data_dir)
In test.py:
add a new function:
import scipy.ndimage
def resample_3d(image, target_shape, mode="nearest"):
zoom_factors = [t / s for t, s in zip(target_shape, image.shape)]
return scipy.ndimage.zoom(image, zoom_factors, order=0 if mode == "nearest" else 1)
and then change the original with torch.no_grad():
change into:
with torch.no_grad():
for i, batch in enumerate(val_loader):
test_inputs = batch["image"].cuda()
original_affine = batch["image_meta_dict"]["affine"][0].numpy()
img_name = batch["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]
target_shape = batch["image_meta_dict"]["spatial_shape"][0].tolist()
print(f"Inference on case {img_name}")
test_outputs = sliding_window_inference(
test_inputs, (args.roi_x, args.roi_y, args.roi_z), 4, model, overlap=args.infer_overlap, mode="gaussian"
)
test_outputs = torch.softmax(test_outputs, 1).cpu().numpy()
test_outputs = np.argmax(test_outputs, axis=1).astype(np.uint8)[0]
test_outputs = resample_3d(test_outputs, target_shape, mode="nearest")
nib.save(
nib.Nifti1Image(test_outputs.astype(np.uint8), original_affine),
os.path.join(output_directory, img_name)
)
print(f"Saved inference result for {img_name}")