segment-anything
segment-anything copied to clipboard
How to store embedding and then use it in the SamPredictor Class?
I would like to calculate embedding in a machine, then use it in another to do predictions. I have been trying to save it using
predictor = SamPredictor(sam) predictor.set_image(img) embedding = predictor.get_image_embedding().detach().cpu().numpy() with open(pickle_file_name, "wb") as f: pickle.dump(embedding, f)
then I upload the pickle file in another script but I don't know how to insert embedding again in the predictor avoiding to calculate embedding again with set_image(img).
I think u can do it in a simple way like this. :)
class SamPredictor:
def save_image_embedding(self, path):
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before embedding saving.")
res = {
'original_size': self.original_size,
'input_size': self.input_size,
'features': self.features,
'is_image_set': True,
}
torch.save(res, path)
def load_image_embedding(self, path):
res = torch.load(path, self.device)
for k, v in res.items():
setattr(self, k, v)
我认为你可以用这样简单的方式做到这一点。:)
class SamPredictor: def save_image_embedding(self, path): if not self.is_image_set: raise RuntimeError("An image must be set with .set_image(...) before embedding saving.") res = { 'original_size': self.original_size, 'input_size': self.input_size, 'features': self.features, 'is_image_set': True, } torch.save(res, path) def load_image_embedding(self, path): res = torch.load(path, self.device) for k, v in res.items(): setattr(self, k, v)
You are my God !