facer
facer copied to clipboard
Best Solution for Face Parsing (Segmentation)
Demo
- Input
- Output
Install
pip install torch pyfacer==0.04
Code
import torch
import facer
import numpy as np
from PIL import Image
def save_result(mask, output_path):
palette = np.array(
[
[0, 0, 0], # "background"
[255, 153, 51], # "neck"
[204, 0, 0], # "face"
[0, 204, 0], # "cloth"
[102, 51, 0], # "rr"
[255, 0, 0], # "lr"
[0, 255, 255], # "rb"
[255, 204, 204], # "lb"
[51, 51, 255], # "re"
[204, 0, 204], # "le"
[76, 153, 0], # "nose"
[102, 204, 0], # "imouth"
[0, 0, 153], # "llip"
[255, 255, 0], # "ulip"
[0, 0, 204], # "hair"
[204, 204, 0], # "eyeg"
[255, 51, 153], # "hat"
[0, 204, 204], # "earr"
[0, 51, 0], # "neck_l"
],
dtype=np.uint8,
)
mask = mask.squeeze(0).cpu().byte().numpy()
mask = Image.fromarray(mask, mode="P")
mask.putpalette(palette.flatten())
mask.save(output_path)
def inference(input_path, output_path):
device = "cuda" if torch.cuda.is_available() else "cpu"
image = facer.hwc2bchw(facer.read_hwc(input_path)).to(
device
) # image: 1 x 3 x h x w
model = facer.face_parser(
"farl/celebm/448",
device,
model_path="https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt",
)
with torch.inference_mode():
logits, _ = model.net(image / 255.0)
mask = logits.argmax(dim=1)
save_result(mask, output_path)
if __name__ == "__main__":
inference("input.jpg", "output.png")
Thank your for the code snippet, this behavior now is better supported in pyfacer==0.0.5
import torch.nn.functional as F
# on already face aligned image, for example FFHQ dataset or CelebA aligned image
images = facer.hwc2bchw(facer.read_hwc('data/ffhq_15723.jpg')).to(device=device) # image: 1 x 3 x h x w
facer.show_bchw(images)
# do some preprocessing
images = images.to(dtype=torch.float32) / 255.0 # [0, 1]
images = F.interpolate(images, size=(448, 448), mode='bilinear', align_corners=False) # as it's already face aligned, directly resize to 448x448
seg_logits, seg_preds, label_names = face_parser.forward_warped(images)
print(label_names)
plt.imshow(seg_preds[0].cpu().numpy())
checkout sample for more details.