DSFD-Pytorch-Inference
DSFD-Pytorch-Inference copied to clipboard
Incorrect face-dection-url / workaround
The link to the face_detector model is broken. Ideally this should be updated, in the meantime I'd like to share my monkey-patch, where the referenced model can be downloaded at this repo FaceDetection-DSFD
import torch
import numpy as np
import typing
from .face_ssd import SSD
from .config import resnet152_model_config
from ..base import Detector
from ..build import DETECTOR_REGISTRY
import os
@DETECTOR_REGISTRY.register_module
class DSFDDetector(Detector):
def __init__(
self, *args, **kwargs):
super().__init__(*args, **kwargs)
face_model_path = os.getcwd() + f"/models/WIDERFace_DSFD_RES152.pth"
state_dict = torch.load(face_model_path, map_location=torch.device('cpu'))
self.net = SSD(resnet152_model_config)
self.net.load_state_dict(state_dict)
self.net.eval()
self.net = self.net.to(self.device)
@torch.no_grad()
def _detect(self, x: torch.Tensor,) -> typing.List[np.ndarray]:
"""Batched detect
Args:
image (np.ndarray): shape [N, H, W, 3]
Returns:
boxes: list of length N with shape [num_boxes, 5] per element
"""
# Expects BGR
x = x[:, [2, 1, 0], :, :]
with torch.cuda.amp.autocast(enabled=self.fp16_inference):
boxes = self.net(
x, self.confidence_threshold, self.nms_iou_threshold
)
return boxes