How to train gray pictures?
When the image_dataset loads pictures, it will automatically convert the pictrue to rgb. However, i just want to keep the original type
To train the model for 1 channel, you need to go to the image_datasets.py file and comment out line 97: arr = np.array(pil_image.convert("RGB")). This line forces the input channel to be RGB.
Then add this 2 lines after the line you commented: arr = np.array(pil_image) arr = arr.reshape((self.resolution, self.resolution, 1))
I made the following changes, and it worked:
In image_datasets.py,
# arr = np.array(pil_image.convert("RGB")) # commenting this line
arr = np.array(pil_image) # add this line
arr = arr.reshape((arr.shape[0], arr.shape[1], 1)) # add this line
In script_util.py, change the in_channels and out_channels variables being sent to UNetModel. Specifically:
return UNetModel(
in_channels=1,
model_channels=num_channels,
out_channels=(1 if not learn_sigma else 2),