Super_Selfish
Super_Selfish copied to clipboard
BYOL
Hi authors,
Thank you for the great library. I was trying to use it to reproduce the results from BYOL paper. But I ran into some problems:
- While I was trying to use resnet as my backbone, I encountered a matrix mismatch error. Could you please share some use cases where you use different network as backbone?
from torchvision.models import resnet18
supervisor = BYOLSupervisor(train_dataset, backbone=resnet18(pretrained=True)).to(device)
supervisor.supervise(lr=lr, epochs=epochs, batch_size=batch_size,
name= supervisor_name + 'resnet')

- I have followed the example in tests.py and trained my network for two batches. Does it appear to you normal that the accuracy is 0.1? Or have I missed something?
from super_selfish.supervisors import BYOLSupervisor, LabelSupervisor
from super_selfish.models import CombinedNet, Classification
supervisor_name = 'byol'
lr = 1e-2
epochs = 2
batch_size = 32
device = 'cuda'
from torchvision.transforms import Compose, Resize, ToTensor
from torch.utils.data import Subset
train_dataset = STL10(root='/content/data/',
split='unlabeled',
download=False,
transform=Compose([Resize((225, 225)), ToTensor()]))
train_dataset = Subset(train_dataset, range(20000))
val_dataset = STL10(root='/content/data/',
split='train',
download=False,
transform=Compose([Resize((225, 225)), ToTensor()]))
test_dataset = STL10(root='/content/data/',
split='test',
download=False,
transform=Compose([Resize((225, 225)), ToTensor()]))
supervisor = BYOLSupervisor(train_dataset).to(device)
# loss is **-3.287230** after two batches!
supervisor.supervise(lr=lr, epochs=epochs, batch_size=batch_size,
name= supervisor_name, pretrained=False)
# loading the self-supervision network
supervisor = BYOLSupervisor(train_dataset).to(device)
supervisor._load_pretrained('byol', True)
# trained with a supervised network
combined = CombinedNet(supervisor.get_backbone(),
Classification(layers=[3136, 256, 128, 10])).to(device)
supervisor = LabelSupervisor(combined, val_dataset)
supervisor.supervise(lr=lr, epochs=epochs,
batch_size=batch_size, name="store/finetuned_" + supervisor_name, pretrained=False)
from super_selfish.utils import test
# Accuracy is **0.101880**
test(combined, test_dataset)
Thank you!