AlexNet-PyTorch
AlexNet-PyTorch copied to clipboard
The __getitem__ function may have some problem.
the original __getitem__ function.
def __getitem__(self, batch_index: int) -> [torch.Tensor, int]:
image_dir, image_name = self.image_file_paths[batch_index].split(self.delimiter)[-2:]
# Read a batch of image data
if image_name.split(".")[-1].lower() in IMG_EXTENSIONS:
image = cv2.imread(self.image_file_paths[batch_index])
target = self.class_to_idx[image_dir]
else:
raise ValueError(f"Unsupported image extensions, Only support `{IMG_EXTENSIONS}`, "
"please check the image file extensions.")
# BGR to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# OpenCV convert PIL
image = Image.fromarray(image)
# Data preprocess
image = self.pre_transform(image)
# Convert image data into Tensor stream format (PyTorch).
# Note: The range of input and output is between [0, 1]
tensor = imgproc.image_to_tensor(image, False, False)
# Data postprocess
tensor = self.post_transform(tensor)
return {"image": tensor, "target": target}
But according to pytorch doc,
the accurate form should be:
def __getitem__(self, batch_index: int) :
image_dir, image_name = self.image_file_paths[batch_index].split(self.delimiter)[-2:]
# Read a batch of image data
if image_name.split(".")[-1].lower() in IMG_EXTENSIONS:
image = cv2.imread(self.image_file_paths[batch_index])
target = self.class_to_idx[image_dir]
else:
raise ValueError(f"Unsupported image extensions, Only support `{IMG_EXTENSIONS}`, "
"please check the image file extensions.")
# BGR to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# OpenCV convert PIL
image = Image.fromarray(image)
# Data preprocess
image = self.pre_transform(image)
# Convert image data into Tensor stream format (PyTorch).
# Note: The range of input and output is between [0, 1]
tensor = imgproc.image_to_tensor(image, False, False)
# Data postprocess
tensor = self.post_transform(tensor)
return tensor, target
PS: I didn't run the code to experiment this idea.