Training on CIFAR10
Hello,
Thank you for this excellent repository!
Do you have any suggestions of changes to make to train BYOL on the CIFAR10 dataset?
The way I am doing this (in main.py) (I am also training my own custom models, but I do not think that is too relevant)
DATASET='CIFAR10' # Can change to STL10
if DATASET=='STL10':
train_dataset = datasets.STL10('/workspace/STLDataset', split='train+unlabeled', download=True,
transform=MultiViewDataInjector([data_transform, data_transform]))
elif DATASET=='CIFAR10':
train_dataset = datasets.CIFAR10('/workspace/CIFAR10Dataset', train=True, download=True,
transform=MultiViewDataInjector([data_transform, data_transform]))
else:
print("Error, dataset not supported, choose CIFAR10 or STL10")
exit(0)
I also change the config to have: input_shape: (32,32,3).
Further, I may not have taken a very deep look into this code-base, but how do we produce the 'STL10 Top 1' accuracies(75.2%) after training the model on the self-supervised task? Do we take the trained model and fine-tune on the STL10 supervised dataset? I assume that code is not included in this library?
Thank you!
Hi Akhauriyash, you can just modify the input shape and name of the dataset. I am testing with the model but it doesn't work well with CIFAR10, ~ 54% top1 accuracy and I wonder the config is the same or different on learning rate? Thank you!