TinySAM icon indicating copy to clipboard operation
TinySAM copied to clipboard

Hello, questions of training on customised dataset?

Open NTUZZH opened this issue 1 year ago • 4 comments

Hi thanks for such an amazing work!

I have noticed the authors haven't revealed the training code, if in this way, could anyone please can guide how to modify the TinySAM working on the customised dataset? The customised dataset is ready with COCO-seg format!

Thanks with appreciation.

John

NTUZZH avatar Jan 21 '25 09:01 NTUZZH

Specially, how to fine-tune the original TinySAM onto new tasks with few-shot scenarios.

NTUZZH avatar Jan 23 '25 07:01 NTUZZH

I'm intrested as well!

MauroAndretta avatar May 09 '25 13:05 MauroAndretta

I'm intrested as well!

Hi, I tried to use LoRa to fine tune the image encoder while keep other parameters frozen, it works and the training code is written by customization.

NTUZZH avatar May 12 '25 07:05 NTUZZH

Could you please share the code you used for training the model.

I've done something like:

num_epochs = 100
tiny_sam.to(DEVICE)

tiny_sam.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):

      # preparing the batch as expected by the model 
      batched_input = [{
         'image': batch["pixel_values"][0].to(DEVICE),
         'input_boxes': batch["input_boxes"][0].to(DEVICE),
      }]

      # forward pass
      outputs = tiny_sam(batched_input)

      # compute loss
      predicted_masks = outputs.pred_masks.squeeze(1)
      ground_truth_masks = batch["ground_truth_mask"].float().to(DEVICE)
      loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

      # backward pass (compute gradients of parameters w.r.t. loss)
      optimizer.zero_grad()
      loss.backward()

      # optimize
      optimizer.step()
      epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

but I got this error

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[488], line 16
     10 batched_input = [{
     11    'image': batch["pixel_values"][0].to(DEVICE),
     12    'input_boxes': batch["input_boxes"][0].to(DEVICE),
     13 }]
     15 # forward pass
---> 16 outputs = tiny_sam(batched_input)
     18 # compute loss
     19 predicted_masks = outputs.pred_masks.squeeze(1)

File c:\Users\clinica_medica\anaconda3\envs\model_training\Lib\site-packages\torch\nn\modules\module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File c:\Users\clinica_medica\anaconda3\envs\model_training\Lib\site-packages\torch\nn\modules\module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File c:\Users\clinica_medica\anaconda3\envs\model_training\Lib\site-packages\torch\utils\_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File c:\Users\clinica_medica\Desktop\SEGMENTAZIONE - QUIPU\dl-image-segmentation\notebooks\model_training\TinySAM\tinysam\modeling\sam.py:118, in Sam.forward(self, batched_input)
    104 sparse_embeddings, dense_embeddings = self.prompt_encoder(
    105     points=points,
    106     boxes=image_record.get("boxes", None),
    107     masks=image_record.get("mask_inputs", None),
    108 )
    109 low_res_masks, iou_predictions = self.mask_decoder(
    110     image_embeddings=curr_embedding.unsqueeze(0),
    111     image_pe=self.prompt_encoder.get_dense_pe(),
    112     sparse_prompt_embeddings=sparse_embeddings,
    113     dense_prompt_embeddings=dense_embeddings,
    114 )
    115 masks = self.postprocess_masks(
    116     low_res_masks,
    117     input_size=image_record["image"].shape[-2:],
--> 118     original_size=image_record["original_size"],
    119 )
    120 masks = masks > self.mask_threshold
    121 outputs.append(
    122     {
    123         "masks": masks,
   (...)    126     }
    127 )

KeyError: 'original_size'

MauroAndretta avatar May 12 '25 14:05 MauroAndretta