Load detection json for subsequent classification
Search before asking
- [x] I have searched the Pytorch-Wildlife issues and found no similar bug report.
Question
I have in the past used batch_image_detection() to detect animals on a large set of images. The output has then been stored using the save_detection_json() function.
Now I would like to use the AI4GAmazonRainforest Model for classifying the images. Is there a way to load the json file and then directly start with the DetectionCrops() and batch_image_classification() functions?
Thanks a lot for your help.
Additional
No response
Hi @MattB-SF!
Typically, we load the detection crops for the classification model using the detection results variable (Check our Custom Weight Loading demo as an example). However, once the detections are saved as a a json, the variable is saved as a dictionary, which is incompatible with the pw_data.DetectionCrops() class. One solution to this issue is to modify the "Detection Crops class" to load the detections as a dictionary instead. Here is a modified version that I have used in the past:
# Adapted from data/dataset.py
import os
import json
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import supervision as sv
class DetectionCrops(Dataset):
def __init__(self, detection_results, transform=None, path_head=None, animal_cls_id=0):
"""
Args:
detection_results (dict): Your loaded JSON file
transform (callable, optional): Transformations (if needed)
path_head (str, optional): Path prefix where the images are present
animal_cls_id (int): Class ID that represents animals.
"""
self.detection_results = detection_results
self.transform = transform
self.path_head = path_head
self.animal_cls_id = animal_cls_id # Defines which class represents animals.
self.img_ids = []
self.xyxys = []
self.load_detection_results()
def load_detection_results(self):
"""
Load detection results and filter for animal detections.
"""
for det in self.detection_results["annotations"]:
img_id = det["img_id"]
bboxes = np.array(det["bbox"])
categories = np.array(det["category"])
mask = categories == self.animal_cls_id
filtered_bboxes = bboxes[mask]
for bbox in filtered_bboxes:
self.img_ids.append(img_id)
self.xyxys.append(bbox)
def __getitem__(self, idx):
"""
Retrieves an image from the dataset.
Parameters:
idx (int): Index of the image to retrieve.
Returns:
tuple: Contains the cropped image and the image's path.
"""
img_id = self.img_ids[idx]
xyxy = self.xyxys[idx]
img_path = os.path.join(self.path_head, img_id) if self.path_head else img_id
img = sv.crop_image(np.array(Image.open(img_path).convert("RGB")), xyxy=xyxy)
if self.transform:
img = self.transform(Image.fromarray(img))
return img, img_path
def __len__(self):
return len(self.img_ids)
If you follow the custom weights demo, you should replace the DetectionCrops class with your custom version:
with open("detections.json", "r") as f:
detection_results = json.load(f)
dataset = DetectionCrops(detection_results, path_head="/path/to/images", animal_cls_id=1)
Please let us know if this works. We will add it as an additional utility class if successful.
Sincerely,
Andrés.