transformers.js icon indicating copy to clipboard operation
transformers.js copied to clipboard

SAM add support for box inputs

Open xenova opened this issue 2 years ago • 3 comments

Code:

import { SamModel, AutoProcessor, RawImage } from '@xenova/transformers';

const model = await SamModel.from_pretrained('Xenova/slimsam-77-uniform', {
    revision: 'boxes',
});
const processor = await AutoProcessor.from_pretrained('Xenova/slimsam-77-uniform', {
    revision: 'boxes',
});

const img_url = 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png';
const raw_image = await RawImage.read(img_url);
const input_boxes = [[[650, 900, 1000, 1250]]];

const inputs = await processor(raw_image, null, null, input_boxes);
const outputs = await model(inputs);

const masks = await processor.post_process_masks(outputs.pred_masks, inputs.original_sizes, inputs.reshaped_input_sizes);

Inputs: image

Visualization:

const image = RawImage.fromTensor(masks[0][0].mul(255));
image.save('img.png');

image

xenova avatar Feb 04 '24 23:02 xenova

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Thanks for all your work Xenova! I'm looking forward to this - very useful when segmenting parts of a larger whole, particularly for things which are rare in SA-1B.

In case it's helpful to anyone, here's the code snippet modified to have the embeddings step separate (so you can re-use it):

Code
const model = await SamModel.from_pretrained('Xenova/slimsam-77-uniform', {
    revision: 'boxes',
});
const processor = await AutoProcessor.from_pretrained('Xenova/slimsam-77-uniform', {
    revision: 'boxes',
});

const img_url = 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png';
const raw_image = await RawImage.read(img_url);
const input_boxes = [[[650, 900, 1000, 1250]]];

let imageInputs = await processor(raw_image);
let imageEmbeddings = await model.get_image_embeddings(imageInputs);
const outputs = await model({
    ...imageEmbeddings,
    input_points: null,
    input_labels: null,
    input_boxes: processor.reshape_input_points(
        input_boxes,
        imageInputs.original_sizes,
        imageInputs.reshaped_input_sizes,
        true)});

masks = await processor.post_process_masks(outputs.pred_masks, imageInputs.original_sizes, imageInputs.reshaped_input_sizes);

A couple questions (I can move these to a separate issue if that would be preferable):

  • Is it possible to pass batches of bounding boxes in the input_boxes array? The nestedness and variable name suggests you could, and I think it works with points (although I haven't explored that much), but it crashes if I try to pass in more than one bbox.
  • post_process_masks accounts for the vast majority (~80%) of the runtime of my script. Is this expected, or is there something I'm doing wrong/can do to speed it up? Here's a performance profile from Firefox. I'm running that segmentHold function a bunch of times in a loop because I have many instances to segment.

jwlarocque avatar Feb 22 '24 02:02 jwlarocque

Edit: Fixed memory leaks in snippet

I put together a version of post_process_masks with higher throughput. It doesn't cover the breadth of use cases that the built in one does and requires a heavy dependency in opencv.js but is suited to my use case. Also converts the mask to contours but that could be skipped.

Code
// note: takes only a single mask, e.g. outputs.pred_masks[0][0][best_mask_index]
function post_process_mask(mask, original_size, reshaped_input_size, pad_size) {
    // mask: [256, 256]
    let mat_a = cv.matFromArray(256, 256, cv.CV_32FC1, mask.data);
    // upscale mask to padded size
    let padded_size = new cv.Size(pad_size.height, pad_size.width);
    let mat_b = new cv.Mat();
    cv.resize(mat_a, mat_b, padded_size, cv.INTER_LINEAR);
    // crop mask
    let roi = new cv.Rect(0, 0, reshaped_input_size[1], reshaped_input_size[0]);
    mat_a = mat_b.roi(roi);
    // downscale mask
    let downscaled_size = new cv.Size(original_size[1], original_size[0]);
    cv.resize(mat_a, mat_b, downscaled_size, cv.INTER_LINEAR);

    // note: You could stop here and skip the contour detection if you didn't need it
    //       (the opencv interpolation is still faster)
    // To 8 bit for contour detection
    mat_b.convertTo(mat_a, cv.CV_8UC1);
    // erode away stray pixels
    let M = cv.Mat.ones(5, 5, cv.CV_8UC1);
    let anchor = new cv.Point(-1, -1)
    cv.erode(mat_a, mat_b, M, anchor, 1, cv.BORDER_CONSTANT, cv.morphologyDefaultBorderValue());
    cv.dilate(mat_b, mat_a, M, anchor, 2, cv.BORDER_CONSTANT, cv.morphologyDefaultBorderValue());
    // find contours (external only)
    let contours = new cv.MatVector();
    let hierarchy = new cv.Mat();
    cv.findContours(mat_a, contours, hierarchy, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE);
    let tmp = new cv.Mat();
    let contourArray = [];
    for (let i = 0; i < contours.size(); i++)
        let contour = contours.get(i);
        let epsilon = 0.002 * cv.arcLength(contour, true);
        cv.approxPolyDP(contour, tmp, epsilon, true);
        contourArray.push(Array.from(tmp.data32S));
    }
    // free memory again
    tmp.delete()
    mat_a.delete();
    mat_b.delete()
    M.delete();
    contours.delete();
    hierarchy.delete();

    return contourArray;
}

(Also, sorry for clogging up the PR with these semi-relevant comments.)

jwlarocque avatar Feb 27 '24 01:02 jwlarocque

Revisiting this now. Thanks @jwlarocque for your testing!

Is it possible to pass batches of bounding boxes in the input_boxes array? The nestedness and variable name suggests you could, and I think it works with points (although I haven't explored that much), but it crashes if I try to pass in more than one bbox.

Could you provide example code for which it crashes? Also, what error message do you get?

post_process_masks accounts for the vast majority (~80%) of the runtime of my script. Is this expected, or is there something I'm doing wrong/can do to speed it up? Here's a performance profile from Firefox. I'm running that segmentHold function a bunch of times in a loop because I have many instances to segment.

Indeed this is a known limitation of the current approach (which does the post-processing in JavaScript), and is not easily solvable without introducing a massive dependency like opencv.js. A better solution would be to implement this in WebGPU, but that will only be supported after v3.

xenova avatar Mar 20 '24 14:03 xenova