sam3 icon indicating copy to clipboard operation
sam3 copied to clipboard

No object detected in inference after fine-tuning on custom dataset

Open simoneriggi opened this issue 4 months ago • 5 comments

Dear all, I am fine-tuning SAM3 on a custom astronomical dataset formatted in a COCO format. The dataset contains 5 object classes (custom naming), segmentation masks/bboxes and I have added a "noun_phrase" for each annotation. I adapted the roboflow train configuration file, making these changes:

  • set dataset & log paths
  • enable segmentation (loss & metrics)
  • adapted collator to perform runs with gradient accumulation >1
  • add sam3.train.transforms.segmentation.DecodeRle in validation transforms
  • set Slurm job parameters (batch=8, gradacc=8, 4 A100 GPUs)
  • comment some roboflow settings (supercategory, task array)
  • I have added in the code the possibility to freeze the backbone or other components (freeze_cfg config, commented out below)

I fine-tuned for 20 epochs both with all model components free (840M trainable pars) and also with backbone frozen (32.7M trainable pars). I set training configuration (learning rate, etc) to roboflow defaults. Below, I attach the training all loss and eval metrics in the two runs:

--> Full fine-tuning Image Meters: {'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP': 0.6391414438821105, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_50': 0.8176683410036266, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_75': 0.7222655357848473, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_small': 0.6245980936200656, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_medium': 0.5391393291034561, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_large': 0.95, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_maxDets@1': 0.575984978215531, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_maxDets@10': 0.7667850931014953, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_maxDets@100': 0.790897661111249, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_small': 0.7862146778761182, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_medium': 0.686418844156647, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_large': 0.95, 'Losses/val_all_loss': 0, 'Losses/val_default_loss': 0, 'Losses/val_roboflow100_core_loss': 0.0, 'Trainer/where': 0.9997907949790795, 'Trainer/epoch': 19, 'Trainer/steps_val': 97090}

--> Backbone frozen Image Meters: {'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP': 0.5608939549867914, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_50': 0.758768676001815, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_75': 0.6316128247121221, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_small': 0.5385709234878986, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_medium': 0.4265914293746788, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_large': 0.95, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_maxDets@1': 0.5454540602304996, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_maxDets@10': 0.7218796950563996, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_maxDets@100': 0.7489443844002187, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_small': 0.744400135157136, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_medium': 0.6562509361033513, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_large': 0.95, 'Losses/val_all_loss': 0, 'Losses/val_default_loss': 0, 'Losses/val_roboflow100_core_loss': 0.0, 'Trainer/where': 0.9997907949790795, 'Trainer/epoch': 19, 'Trainer/steps_val': 97090}

As far as I understood from metrics and loss, the model is indeed learning something, although for sure I need to train more and with optimized parameters. Now I would like to run inference on a single image using the fine-tuning checkpoint and the example script https://github.com/facebookresearch/sam3/blob/main/examples/sam3_image_predictor_example.ipynb. However, when I run the inference script on train/eval images using the same noun_phrase prompt and a low confidence score (0.1) no objects are detected.

When I load the model I see a log saying that many model component keys are missing:

loaded [RUN DIR]/checkpoints/checkpoint.pt and found missing and/or unexpected keys: missing_keys=['backbone.vision_backbone.trunk.pos_embed', 'backbone.vision_backbone.trunk.patch_embed.proj.weight', 'backbone.vision_backbone.trunk.blocks.0.norm1.weight', 'backbone.vision_backbone.trunk.blocks.0.norm1.bias', 'backbone.vision_backbone.trunk.blocks.0.attn.freqs_cis', 'backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight', 'backbone.vision_backbone.trunk.blocks.0.attn.qkv.bias', 'backbone.vision_backbone.trunk.blocks.0.attn.proj.weight', 'backbone.vision_backbone.trunk.blocks.0.attn.proj.bias', 'backbone.vision_backbone.trunk.blocks.0.norm2.weight', 'backbone.vision_backbone.trunk.blocks.0.norm2.bias', 'backbone.vision_backbone.trunk.blocks.0.mlp.fc1.weight', ... ...

Could someone give me some hint on what I am doing wrong? Is it a matter of fine-tuning (category embedding, hyperparameters, etc) or how I do the inference (input data normalization/transform) or both?

Thanks a lot.

PS: My config file and inference script are reported below:

CONFIG FILE

# @package _global_
defaults:
  - _self_

# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
  roboflow_vl_100_root: [DATASET ROOT DIR]
  experiment_log_dir: [RUN LOG DIR]
  bpe_path: [SAM PATH]/sam3/assets/bpe_simple_vocab_16e6.txt.gz

#freeze_cfg:
#  backbone: true
#  backbone_blocks: 0      # e.g. freeze first N blocks; 0 = none
#  text_encoder: true
#  transformer: false
#  segmentation_head: false

# Roboflow dataset configuration
roboflow_train:
  num_images: null # Note: This is the number of images used for training. If null, all images are used.

  # Training transforms pipeline
  train_transforms:
    - _target_: sam3.train.transforms.basic_for_api.ComposeAPI
      transforms:
        - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
          query_filter:
            _target_: sam3.train.transforms.filter_query_transforms.FilterCrowds
        - _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox
          box_noise_std: 0.1
          box_noise_max: 20
        - _target_: sam3.train.transforms.segmentation.DecodeRle
        - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
          sizes:
            _target_: sam3.train.transforms.basic.get_random_resize_scales
            size: ${scratch.resolution}
            min_size: 480
            rounded: false
          max_size:
            _target_: sam3.train.transforms.basic.get_random_resize_max_size
            size: ${scratch.resolution}
          square: true
          consistent_transform: ${scratch.consistent_transform}
        - _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI
          size: ${scratch.resolution}
          consistent_transform: ${scratch.consistent_transform}
        - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
        - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
          query_filter:
            _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
        - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
          mean: ${scratch.train_norm_mean}
          std: ${scratch.train_norm_std}
        - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
          query_filter:
            _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
    - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
      query_filter:
        _target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut
        max_num_objects: ${scratch.max_ann_per_img}

  # Validation transforms pipeline
  val_transforms:
    - _target_: sam3.train.transforms.basic_for_api.ComposeAPI
      transforms:
        # 1) Decode COCO RLE/poly into mask tensors
        - _target_: sam3.train.transforms.segmentation.DecodeRle

        # 2) Resize image + masks
        - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
          sizes: ${scratch.resolution}
          max_size:
            _target_: sam3.train.transforms.basic.get_random_resize_max_size
            size: ${scratch.resolution}
          square: true
          consistent_transform: False

        # 3) Convert to torch tensors
        - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI

        # 4) Normalize
        - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
          mean: ${scratch.train_norm_mean}
          std: ${scratch.train_norm_std}

  # NOTE: Loss to be used for training in case of segmentation
  loss:
     _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
     matcher: ${scratch.matcher}
     o2m_weight: 2.0
     o2m_matcher:
       _target_: sam3.train.matcher.BinaryOneToManyMatcher
       alpha: 0.3
       threshold: 0.4
       topk: 4
     use_o2m_matcher_on_o2m_aux: false
     loss_fns_find:
       - _target_: sam3.train.loss.loss_fns.Boxes
         weight_dict:
           loss_bbox: 5.0
           loss_giou: 2.0
       - _target_: sam3.train.loss.loss_fns.IABCEMdetr
         weak_loss: False
         weight_dict:
           loss_ce: 20.0 # Another option is 100.0
           presence_loss: 20.0
         pos_weight: 10.0 # Another option is 5.0
         alpha: 0.25
         gamma: 2
         use_presence: True  # Change
         pos_focal: false
         pad_n_queries: 200
         pad_scale_pos: 1.0
       - _target_: sam3.train.loss.loss_fns.Masks
         focal_alpha: 0.25
         focal_gamma: 2.0
         weight_dict:
           loss_mask: 200.0
           loss_dice: 10.0
         compute_aux: false
     loss_fn_semantic_seg:
       #_target_: sam3.losses.loss_fns.SemanticSegCriterion
       _target_: sam3.train.loss.loss_fns.SemanticSegCriterion
       presence_head: True
       presence_loss: False  # Change
       focal: True
       focal_alpha: 0.6
       focal_gamma: 2.0
       downsample: False
       weight_dict:
         loss_semantic_seg: 20.0
         loss_semantic_presence: 1.0
         loss_semantic_dice: 30.0
     scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}

# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
  enable_segmentation: True # NOTE: This is the number of queries used for segmentation
  # Model parameters
  d_model: 256
  pos_embed:
    _target_: sam3.model.position_encoding.PositionEmbeddingSine
    num_pos_feats: ${scratch.d_model}
    normalize: true
    scale: null
    temperature: 10000

  # Box processing
  use_presence_eval: True
  original_box_postprocessor:
    _target_: sam3.eval.postprocessors.PostProcessImage
    max_dets_per_img: -1  # infinite detections
    use_original_ids: true
    use_original_sizes_box: true
    use_presence: ${scratch.use_presence_eval}

  # Matcher configuration
  matcher:
    _target_: sam3.train.matcher.BinaryHungarianMatcherV2
    focal: true  # with `focal: true` it is equivalent to BinaryFocalHungarianMatcher
    cost_class: 2.0
    cost_bbox: 5.0
    cost_giou: 2.0
    alpha: 0.25
    gamma: 2
    stable: False
  scale_by_find_batch_size: True

  # Image processing parameters
  resolution: 1008
  consistent_transform: False
  max_ann_per_img: 200

  # Normalization parameters
  train_norm_mean: [0.5, 0.5, 0.5]
  train_norm_std: [0.5, 0.5, 0.5]
  val_norm_mean: [0.5, 0.5, 0.5]
  val_norm_std: [0.5, 0.5, 0.5]

  # Training parameters
  num_train_workers: 10
  num_val_workers: 0
  max_data_epochs: 20
  target_epoch_size: 1500
  hybrid_repeats: 1
  context_length: 2
  gather_pred_via_filesys: false

  # Learning rate and scheduler parameters
  lr_scale: 0.1
  lr_transformer: ${times:8e-4,${scratch.lr_scale}}
  lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
  lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
  lrd_vision_backbone: 0.9
  wd: 0.1
  scheduler_timescale: 20
  scheduler_warmup: 20
  scheduler_cooldown: 20

  val_batch_size: 1
  collate_fn_val:
    _target_: sam3.train.data.collator.collate_fn_api
    _partial_: true
    repeats: ${scratch.hybrid_repeats}
    dict_key: roboflow100
    with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!

  gradient_accumulation_steps: 8
  train_batch_size: 8

  ## USE THIS WITH GRAD ACC=1
  #collate_fn:
  #  _target_: sam3.train.data.collator.collate_fn_api
  #  _partial_: true
  #  repeats: ${scratch.hybrid_repeats}
  #  dict_key: all
  #  with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!

  ## USE THIS WITH GRAD ACC>1
  collate_fn:
    _target_: sam3.train.data.collator.collate_fn_api_with_chunking
    _partial_: true
    repeats: ${scratch.hybrid_repeats}
    dict_key: all
    with_seg_masks: ${scratch.enable_segmentation}
    num_chunks: ${scratch.gradient_accumulation_steps}

# ============================================================================
# Trainer Configuration
# ============================================================================

trainer:

  _target_: sam3.train.trainer.Trainer
  skip_saving_ckpts: false
  empty_gpu_mem_cache_after_eval: True
  skip_first_val: True
  max_epochs: 20
  accelerator: cuda
  seed_value: 123
  val_epoch_freq: 1
  mode: train
  gradient_accumulation_steps: ${scratch.gradient_accumulation_steps}

  distributed:
    backend: nccl
    find_unused_parameters: True
    gradient_as_bucket_view: True

  loss:
    all: ${roboflow_train.loss}
    default:
      _target_: sam3.train.loss.sam3_loss.DummyLoss
    
  data:

    train:
      _target_: sam3.train.data.torch_dataset.TorchDataset
      dataset:
        _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
        limit_ids: ${roboflow_train.num_images}
        transforms: ${roboflow_train.train_transforms}
        load_segmentation: ${scratch.enable_segmentation}
        max_ann_per_img: 500000
        multiplier: 1
        max_train_queries: 50000
        max_val_queries: 50000
        training: true
        use_caching: False
        img_folder: ${paths.roboflow_vl_100_root}
        ann_file: ${paths.roboflow_vl_100_root}/dataset_sam_train.json

      shuffle: True
      batch_size: ${scratch.train_batch_size}
      num_workers: ${scratch.num_train_workers}
      pin_memory: True
      drop_last: True
      collate_fn: ${scratch.collate_fn}

    val:
      _target_: sam3.train.data.torch_dataset.TorchDataset
      dataset:
        _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
        load_segmentation: ${scratch.enable_segmentation}
        coco_json_loader:
          _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
          include_negatives: true
          category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU.
          _partial_: true
        img_folder: ${paths.roboflow_vl_100_root}
        ann_file: ${paths.roboflow_vl_100_root}/dataset_sam_val.json
        transforms: ${roboflow_train.val_transforms}
        max_ann_per_img: 100000
        multiplier: 1
        training: false

      shuffle: False
      batch_size: ${scratch.val_batch_size}
      num_workers: ${scratch.num_val_workers}
      pin_memory: True
      drop_last: False
      collate_fn: ${scratch.collate_fn_val}

  model:
    _target_: sam3.model_builder.build_sam3_image_model
    bpe_path: ${paths.bpe_path}
    device: cpus
    eval_mode: false
    enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
    checkpoint_path: [HF_HOME]/models/huggingface/hub/models--facebook--sam3/snapshots/3c879f39826c281e95690f02c7821c4de09afae7/sam3.pt

  freeze_cfg: ${freeze_cfg}

  meters:
    val:
      roboflow100:
        detection:
          _target_: sam3.eval.coco_writer.PredictionDumper
          iou_type: "bbox"
          dump_dir: ${launcher.experiment_log_dir}/dumps
          merge_predictions: True
          postprocessor: ${scratch.original_box_postprocessor}
          gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
          maxdets: 100
          pred_file_evaluators:
            - _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
              gt_path: ${paths.roboflow_vl_100_root}/dataset_sam_val.json
              tide: False
              iou_type: "bbox"

  optim:
    amp:
      enabled: True
      amp_dtype: bfloat16

    optimizer:
      _target_: torch.optim.AdamW

    gradient_clip:
      _target_: sam3.train.optim.optimizer.GradientClipper
      max_norm: 0.1
      norm_type: 2

    param_group_modifiers:
      - _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
        _partial_: True
        layer_decay_value: ${scratch.lrd_vision_backbone}
        apply_to: 'backbone.vision_backbone.trunk'
        overrides:
          - pattern: '*pos_embed*'
            value: 1.0

    options:
      lr:
        - scheduler:  # transformer and class_embed
            _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
            base_lr: ${scratch.lr_transformer}
            timescale: ${scratch.scheduler_timescale}
            warmup_steps: ${scratch.scheduler_warmup}
            cooldown_steps: ${scratch.scheduler_cooldown}
        - scheduler:
            _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
            base_lr: ${scratch.lr_vision_backbone}
            timescale: ${scratch.scheduler_timescale}
            warmup_steps: ${scratch.scheduler_warmup}
            cooldown_steps: ${scratch.scheduler_cooldown}
          param_names:
            - 'backbone.vision_backbone.*'
        - scheduler:
            _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
            base_lr: ${scratch.lr_language_backbone}
            timescale: ${scratch.scheduler_timescale}
            warmup_steps: ${scratch.scheduler_warmup}
            cooldown_steps: ${scratch.scheduler_cooldown}
          param_names:
            - 'backbone.language_backbone.*'

      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: ${scratch.wd}
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
            - '*bias*'
          module_cls_names: ['torch.nn.LayerNorm']

  checkpoint:
    save_dir: ${launcher.experiment_log_dir}/checkpoints
    save_freq: 0  # 0 only last checkpoint is saved.

  logging:
    tensorboard_writer:
      _target_: sam3.train.utils.logger.make_tensorboard_logger
      log_dir: ${launcher.experiment_log_dir}/tensorboard
      flush_secs: 120
      should_log: True
    wandb_writer: null
    log_dir: ${launcher.experiment_log_dir}/logs
    log_freq: 10

# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================

launcher:
  num_nodes: 1
  gpus_per_node: 4
  experiment_log_dir: ${paths.experiment_log_dir}

submitit:
  use_cluster: True
  account: XXX
  partition: XXX
  qos: XXX
  timeout_hour: 96
  name: sam3
  cpus_per_task: 8
  port_range: [10000, 65000]
  
# ============================================================================
# Available Roboflow Supercategories (for reference)
# ============================================================================

all_roboflow_supercategories:
  - -grccs
  - zebrasatasturias
  ...
  ...

INFERENCE SCRIPT

import os
import sys
import matplotlib.pyplot as plt
import numpy as np

import sam3
from PIL import Image
from sam3 import build_sam3_image_model
from sam3.model.box_ops import box_xywh_to_cxcywh
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results
from sam3.train.transforms.basic_for_api import ComposeAPI, RandomResizeAPI, ToTensorAPI, NormalizeAPI
from sam3.model.position_encoding import PositionEmbeddingSine
from sam3.eval.postprocessors import PostProcessImage

import torch
import torchvision
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

#########################
##   MAIN
#########################
sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
print("sam3_root")
print(sam3_root)
device = "cuda" if torch.cuda.is_available() else "cpu"

# - Build model
print("Loading model ...")
bpe_path = f"{sam3_root}/assets/bpe_simple_vocab_16e6.txt.gz"
checkpoint_path= "[RUN DIR]/checkpoints/checkpoint.pt"

model= build_sam3_image_model(
    bpe_path=bpe_path,
    device=device,
    eval_mode=True,
    checkpoint_path=checkpoint_path,
    load_from_HF=False,
    enable_segmentation=True,
    enable_inst_interactivity=False,
    compile=False,
)

# - Load image
print("Loading image ...")
image_path= "sidelobe0001.png"
image = Image.open(image_path).convert('RGB')
width, height = image.size
print(f"Image width={width}, height={height}")

# - Transform image
print("Transforming image ...")
resize_size= 1008
processor = Sam3Processor(model, resolution=resize_size, confidence_threshold=0.0)
inference_state = processor.set_image(image) ## Looking at the code, image is resized inside set_image method

# - Inference
prompt= "spurious source, imaging artefact, sidelobe"
processor.reset_all_prompts(inference_state)
inference_state= processor.set_text_prompt(state=inference_state, prompt=prompt)

masks, boxes, scores = inference_state["masks"], inference_state["boxes"], inference_state["scores"]

# - Draw results
img0 = Image.open(image_path)
plot_results(img0, inference_state)
plt.show()

simoneriggi avatar Dec 10 '25 12:12 simoneriggi

I have the same question

noperfect-zhy avatar Dec 10 '25 14:12 noperfect-zhy

I have fixed the missing keys issue when loading the model using the suggestion reported in this post: #270. Basically, the original model weights have layers prefixed as "detector." while the fine-tuned weights don't, so we need to rename layers. I repost the code here:

import torch
from collections import OrderedDict

checkpoint = "[RUN DIR]/checkpoints/checkpoint.pt"

wrapped_model = torch.load(checkpoint, map_location="cpu")
model = wrapped_model["model"]
new_state_dict = OrderedDict(("detector." + k, v) for k, v in model.items())
torch.save(new_state_dict, checkpoint.replace(".pt", "_converted.pt"))

When I use the checkpoint_converted.pt weights in the inference script I do not have warnings and I see detected objects. However, on different images I always see 200 objects with very small scores (~10^-5-10^-6). I post the original image and the detection plot.

Image Image

Any hint? Should I perform a deeper fine-tuning run or is there any other check/fix or post-processing I can try?

simoneriggi avatar Dec 10 '25 15:12 simoneriggi

Hey Simon,

thank you for sharing your code, this is really helpful!, just a suggestion, can you try confidence threshold of 0.5 instead? maybe it can give you only the high confidence boxes (I tried it for my case and worked!)

MohammedAdelFahmi avatar Dec 10 '25 18:12 MohammedAdelFahmi

Hi @MohammedAdelFahmi, if I increase the confidence threshold to 0.5, I do not get any detections, as the scores of the 200 detected objects are in the range [0,0.03]. If I apply a threshold at 0.02, this is what I get:

Image

This is going in the right direction, e.g. the top confidence score objects are correlated with the radio galaxy position, but I need to understand why the scores are so low and if this is a matter of underfitting. This plot was obtained with backbone frozen. With full fine-tuning, the confidence scores are even lower.

simoneriggi avatar Dec 10 '25 19:12 simoneriggi

@simoneriggi I think your prompt is not correct, try to make a list of prompts and run them one by one example:

prompts = ['spurious source', ', imaging artefact', 'sidelobe'] for prompt in prompts: # continue inference code here

MohammedAdelFahmi avatar Dec 10 '25 20:12 MohammedAdelFahmi

@simoneriggi Hi Simone, were you able to resolve the issue with diminishing confidence scores? I’m encountering the exact same problem with my custom dataset.

denizcan-pointr avatar Jan 05 '26 08:01 denizcan-pointr