Hello, questions of training on customised dataset?
Hi thanks for such an amazing work!
I have noticed the authors haven't revealed the training code, if in this way, could anyone please can guide how to modify the TinySAM working on the customised dataset? The customised dataset is ready with COCO-seg format!
Thanks with appreciation.
John
Specially, how to fine-tune the original TinySAM onto new tasks with few-shot scenarios.
I'm intrested as well!
I'm intrested as well!
Hi, I tried to use LoRa to fine tune the image encoder while keep other parameters frozen, it works and the training code is written by customization.
Could you please share the code you used for training the model.
I've done something like:
num_epochs = 100
tiny_sam.to(DEVICE)
tiny_sam.train()
for epoch in range(num_epochs):
epoch_losses = []
for batch in tqdm(train_dataloader):
# preparing the batch as expected by the model
batched_input = [{
'image': batch["pixel_values"][0].to(DEVICE),
'input_boxes': batch["input_boxes"][0].to(DEVICE),
}]
# forward pass
outputs = tiny_sam(batched_input)
# compute loss
predicted_masks = outputs.pred_masks.squeeze(1)
ground_truth_masks = batch["ground_truth_mask"].float().to(DEVICE)
loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
# backward pass (compute gradients of parameters w.r.t. loss)
optimizer.zero_grad()
loss.backward()
# optimize
optimizer.step()
epoch_losses.append(loss.item())
print(f'EPOCH: {epoch}')
print(f'Mean loss: {mean(epoch_losses)}')
but I got this error
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[488], line 16
10 batched_input = [{
11 'image': batch["pixel_values"][0].to(DEVICE),
12 'input_boxes': batch["input_boxes"][0].to(DEVICE),
13 }]
15 # forward pass
---> 16 outputs = tiny_sam(batched_input)
18 # compute loss
19 predicted_masks = outputs.pred_masks.squeeze(1)
File c:\Users\clinica_medica\anaconda3\envs\model_training\Lib\site-packages\torch\nn\modules\module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File c:\Users\clinica_medica\anaconda3\envs\model_training\Lib\site-packages\torch\nn\modules\module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File c:\Users\clinica_medica\anaconda3\envs\model_training\Lib\site-packages\torch\utils\_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File c:\Users\clinica_medica\Desktop\SEGMENTAZIONE - QUIPU\dl-image-segmentation\notebooks\model_training\TinySAM\tinysam\modeling\sam.py:118, in Sam.forward(self, batched_input)
104 sparse_embeddings, dense_embeddings = self.prompt_encoder(
105 points=points,
106 boxes=image_record.get("boxes", None),
107 masks=image_record.get("mask_inputs", None),
108 )
109 low_res_masks, iou_predictions = self.mask_decoder(
110 image_embeddings=curr_embedding.unsqueeze(0),
111 image_pe=self.prompt_encoder.get_dense_pe(),
112 sparse_prompt_embeddings=sparse_embeddings,
113 dense_prompt_embeddings=dense_embeddings,
114 )
115 masks = self.postprocess_masks(
116 low_res_masks,
117 input_size=image_record["image"].shape[-2:],
--> 118 original_size=image_record["original_size"],
119 )
120 masks = masks > self.mask_threshold
121 outputs.append(
122 {
123 "masks": masks,
(...) 126 }
127 )
KeyError: 'original_size'