How does this provide the same gradient as a larger batch size?
Looking through the code, I notice that there are mini-batches consisting of just negative examples that appear to be ignored entirely. If the code ignores certain combinations, how does using GradCache do the same thing as running larger batches on larger GPUs?
I also ran an experiment where I developed an image-text contrastive learning example with a batch size of 64. I tested using the batch size of 64 directly, and tested using GradCache with a mini-batch of 16. The batch size of 64 directly had a much better performance via linear eval than using GradCache.
I am not sure which part of the code you are confused about. I'd suggest first reading the math derivations in our paper to see if it can help you get a clearer picture of the grad cache approach.
The math derivations in the paper make sense to me, but the code does not seem to match unless I am mistaken. There should be a method of handling subsets with just negative values (as these should be included within s in S and t in T if I am not mistaken), but looking through the code I do not see where this is the case.
print(f"Using CUDA: {is_available()}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class HuggingFaceImageEncoder(Module):
""""Wrapper for HuggingFace pretrained CLIP image encoder"""
def __init__(self, projection_dim):
super().__init__()
# ViT_dim is 768 for 'openai/clip-vit-base-patch32'
self.model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
self.proj = Linear(768, projection_dim, bias=False)
def forward(self, image_lst):
image_out = self.model(image_lst.to(device)).pooler_output
return self.proj(image_out)
class HuggingFaceTextEncoder(Module):
""""Wrapper for HuggingFace pretrained CLIP image encoder"""
def __init__(self, projection_dim):
super().__init__()
# bert_dim is 512 for 'openai/clip-vit-base-patch32'
self.model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
self.proj = Linear(512, projection_dim, bias=False)
def forward(self, **kwargs):
for key in kwargs:
kwargs[key] = kwargs[key].to(device)
text_out = self.model(**kwargs).pooler_output
return self.proj(text_out)
# Load the image encoder
img_encoder = HuggingFaceImageEncoder(512)
# Load the text encoder
text_encoder = HuggingFaceTextEncoder(512)
loss_fn = SimpleContrastiveLoss()
#Custom representation function that enables GradCache to be used for multimodal learning.
def rep_fn(v):
try:
return v.pooler_output
except:
return v
gc = GradCache(
models = [img_encoder, text_encoder],
chunk_sizes = 16,
loss_fn = loss_fn,
get_rep_fn = rep_fn
)
optimizer = Adam(list(img_encoder.parameters()) + list(text_encoder.parameters()))
dataloader = get_clip_dataloader(64)
def one_epoch():
loss = []
for batch in dataloader:
optimizer.zero_grad()
cur_loss = gc(batch['image'], batch['caption'], reduction='mean').item()
loss.append(cur_loss)
optimizer.step()
for _ in range(10):
one_epoch()
I just want to confirm that this is how one should do img-text contrastive learning using this API.
What are subsets with just negative values?
Say we choose a batch size of 64, and a sub-batch size of 16. This means we split each input into 64/16 = 4 sub batches. For contrastive learning, we need to evaluate all combinations of inputs in order to obtain the full benefits of negative sampling, a total of 4*4 = 16 sub-batch computations in this case. However, it is my understanding that GradCache only looks at the sub batches along the diagonal, in this case computing 4 sub-batch computations. If this is the case, I don't see how GradCache provides the benefits of the additional negative samples that larger batch sizes give.
It seems to essentially compute the constrastive loss assuming a small batch size, but caches the gradient until we process all samples in the batch and then uses the optimizer to update weights.
GradCache compute loss based on the full cartesian product between S and T as long as the provided loss does so. All representations are cached, from which the full batch's gradient cache is computed.