DiffSHEG icon indicating copy to clipboard operation
DiffSHEG copied to clipboard

[SHOW Visualization] Which part of code to refer

Open zelingao98 opened this issue 1 year ago • 8 comments

Dear author,

Thank you for this awesome work!

I run the inference part of this repo using SHOW dataset, and I only get a bunch of .npz.

However, how to visualize them with visualization tool in TalkSHOW. I mean which part of code should I used to visualize the results?

Best regards

zelingao98 avatar Jun 04 '24 11:06 zelingao98

I try to use TalkSHOW code to visualize data but I get the bad result.

image

Do you know the reason? My code is as follow (from TalkSHOW/scripts/demo.py):


lower_pose = torch.tensor(
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0747, -0.0158, -0.0152, -1.1826512813568115, 0.23866955935955048,
     0.15146760642528534, -1.2604516744613647, -0.3160211145877838,
     -0.1603458970785141, 1.1654603481292725, 0.0, 0.0, 1.2521806955337524, 0.041598282754421234, -0.06312154978513718,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
lower_pose_stand = torch.tensor([
    8.9759e-04, 7.1074e-04, -5.9163e-06, 8.9759e-04, 7.1074e-04, -5.9163e-06,
    3.0747, -0.0158, -0.0152,
    -3.6665e-01, -8.8455e-03, 1.6113e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
    -3.9716e-01, -4.0229e-02, -1.2637e-01,
    7.9163e-01, 6.8519e-02, -1.5091e-01, 7.9163e-01, 6.8519e-02, -1.5091e-01,
    7.8632e-01, -4.3810e-02, 1.4375e-02,
    -1.0675e-01, 1.2635e-01, 1.6711e-02, -1.0675e-01, 1.2635e-01, 1.6711e-02, ])

def part2full(input, stand=False):
    if stand:
        lp = torch.zeros_like(lower_pose)
        lp[6:9] = torch.tensor([3.0747, -0.0158, -0.0152])
        lp = lp.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
    else:
        lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)

    input = torch.cat([input[:, :3],
                       lp[:, :15],
                       input[:, 3:6],
                       lp[:, 15:21],
                       input[:, 6:9],
                       lp[:, 21:27],
                       input[:, 9:12],
                       lp[:, 27:],
                       input[:, 12:]]
                      , dim=1)
    return input

def main():
    # * create smplex model
    zelin_log.info('init smlpx model...')
    dtype = torch.float64
    smplx_path = './visualise/'
    model_params = dict(model_path=smplx_path,
                        model_type='smplx',
                        create_global_orient=True,
                        create_body_pose=True,
                        create_betas=True,
                        num_betas=300,
                        create_left_hand_pose=True,
                        create_right_hand_pose=True,
                        use_pca=False,
                        flat_hand_mean=False,
                        create_expression=True,
                        num_expression_coeffs=100,
                        num_pca_comps=12,
                        create_jaw_pose=True,
                        create_leye_pose=True,
                        create_reye_pose=True,
                        create_transl=False,
                        dtype=dtype,
    )
    smplx_model = smplx.create(**model_params).to(device)
    # * load smplx param
    # this is DiffSHEG output
    pred_smplx = np.load('results/talkshow_88/test_custom_audio/talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree/fixStart10/ckpt_e2599_ddim25_lastStepInterp/pid_1/Forrest_tts.npy')
    pred_smplx = torch.from_numpy(pred_smplx).float().to(device)[0][:100]
    pred_smplx = part2full(pred_smplx, stand=True)
    
    # * pred_smplx size: [n_frames, param_dim]
    import tqdm
    vertices = []
    betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
    for frame_ind in tqdm.tqdm(range(pred_smplx.shape[0]), desc='infer mesh vectices per frame'):
        sample_output: SMPLOutput = smplx_model.forward(
            betas=betas,
            jaw_pose=pred_smplx[frame_ind][0:3].unsqueeze_(dim=0),
            leye_pose=pred_smplx[frame_ind][3:6].unsqueeze_(dim=0),
            reye_pose=pred_smplx[frame_ind][6:9].unsqueeze_(dim=0),
            global_orient=pred_smplx[frame_ind][9:12].unsqueeze_(dim=0),
            body_pose=pred_smplx[frame_ind][12:75].unsqueeze_(dim=0),
            left_hand_pose=pred_smplx[frame_ind][75:120].unsqueeze_(dim=0),
            right_hand_pose=pred_smplx[frame_ind][120:165].unsqueeze_(dim=0),
            expression=pred_smplx[frame_ind][165:265].unsqueeze_(dim=0),
            return_verts=True,
        )
        vertices.append(sample_output.vertices.detach().cpu().numpy().squeeze())
    vertices = np.asarray(vertices)

    print(vertices.shape)

    # * debug Render
    exp_dir = 'exp/speech2smplx'
    os.makedirs(exp_dir, exist_ok=True)
    num_frames = vertices.shape[0]

    # dataset is inverse
    vertices = vertices.reshape(vertices.shape[0], -1, 3)
    vertices[:, :, 1] = -vertices[:, :, 1]
    vertices[:, :, 2] = -vertices[:, :, 2]

    width, height = 800, 1440
    viewport_height = 1440
    z_offset = 1.8

    video_fname = 'demo'
    os.makedirs(f'{exp_dir}/video_frames', exist_ok=True)

    writer = cv2.VideoWriter(f'{exp_dir}/{video_fname}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height), True)
    center = np.mean(vertices[0], axis=0)

    render_helper = pyrender.OffscreenRenderer(viewport_width=800, viewport_height=viewport_height)

    class Struct(object):
        def __init__(self, **kwargs):
            for key, val in kwargs.items():
                setattr(self, key, val)

    path = os.path.join(os.getcwd(), 'visualise/smplx/SMPLX_NEUTRAL.npz')
    model_data = np.load(path, allow_pickle=True)
    data_struct = Struct(**model_data)

    for i_frame in tqdm.tqdm(range(num_frames), desc='render debug image'):
        vectice = vertices[i_frame]
        # todo save vectice as npz
        imgi = render_mesh_helper((vectice, data_struct.f), center, camera='o', r=render_helper, y=0.7, z_offset=z_offset)
        imgi = imgi.astype(np.uint8)
        # save image as frame
        cv2.imwrite(f'{exp_dir}/video_frames/{i_frame:04d}.png', imgi)
        # save image as video
        writer.write(imgi)
    writer.release()

if __name__ == '__main__':
    main()

zelingao98 avatar Jun 04 '24 16:06 zelingao98

Hi James, you may want to pay attention to the code here: https://github.com/JeremyCJM/DiffSHEG/blob/3ebf3058f48cba3da9146afb7623e9ec1ab9e9a5/datasets/show.py#L146. The order of channels for pose should be carefully aligned with the pose in visualization code of TalkSHOW.

JeremyCJM avatar Jun 05 '24 08:06 JeremyCJM

Owner

Hi James, you may want to pay attention to the code here:

https://github.com/JeremyCJM/DiffSHEG/blob/3ebf3058f48cba3da9146afb7623e9ec1ab9e9a5/datasets/show.py#L146

. The order of channels for pose should be carefully aligned with the pose in visualization code of TalkSHOW.

Thank you for reply!

Yes, so what is the the order of channel of these output files? I mean I read these npy files and find that they are [n_frame, 232] where 232 is exactly same as the output of SHOW/TalkSHOW.

The order is important since I need input them for this function to get mesh:

pred_smplx = np.load('Forrest_tts.npy')

sample_output: SMPLOutput = smplx_model.forward(
    betas=betas,
    jaw_pose=pred_smplx[0][0:3].unsqueeze_(dim=0),
    leye_pose=pred_smplx[0][3:6].unsqueeze_(dim=0),
    reye_pose=pred_smplx[0][6:9].unsqueeze_(dim=0),
    global_orient=pred_smplx[0][9:12].unsqueeze_(dim=0),
    body_pose=pred_smplx[0][12:75].unsqueeze_(dim=0),
    left_hand_pose=pred_smplx[0][75:120].unsqueeze_(dim=0),
    right_hand_pose=pred_smplx[0][120:165].unsqueeze_(dim=0),
    expression=pred_smplx[0][165:265].unsqueeze_(dim=0),
    return_verts=True,
)

zelingao98 avatar Jun 06 '24 04:06 zelingao98

Owner

Hi James, you may want to pay attention to the code here: https://github.com/JeremyCJM/DiffSHEG/blob/3ebf3058f48cba3da9146afb7623e9ec1ab9e9a5/datasets/show.py#L146

. The order of channels for pose should be carefully aligned with the pose in visualization code of TalkSHOW.

Thank you for reply!

Yes, so what is the the order of channel of these output files? I mean I read these npy files and find that they are [n_frame, 232] where 232 is exactly same as the output of SHOW/TalkSHOW.

The order is important since I need input them for this function to get mesh:

pred_smplx = np.load('Forrest_tts.npy')

sample_output: SMPLOutput = smplx_model.forward(
    betas=betas,
    jaw_pose=pred_smplx[0][0:3].unsqueeze_(dim=0),
    leye_pose=pred_smplx[0][3:6].unsqueeze_(dim=0),
    reye_pose=pred_smplx[0][6:9].unsqueeze_(dim=0),
    global_orient=pred_smplx[0][9:12].unsqueeze_(dim=0),
    body_pose=pred_smplx[0][12:75].unsqueeze_(dim=0),
    left_hand_pose=pred_smplx[0][75:120].unsqueeze_(dim=0),
    right_hand_pose=pred_smplx[0][120:165].unsqueeze_(dim=0),
    expression=pred_smplx[0][165:265].unsqueeze_(dim=0),
    return_verts=True,
)

Hello, did you render the result correctly?

Mumuwei avatar Jun 06 '24 08:06 Mumuwei

Hi @jameskuma, this is my code to visualize the SHOW results, which is modified from the visualization code in TalkSHOW. Remember to specify the face_path and gesture_path arguments.

import os
import sys

# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
sys.path.append(os.getcwd())

from transformers import Wav2Vec2Processor
from glob import glob

import numpy as np
import json
import smplx as smpl

from nets import *
from trainer.options import parse_args
from data_utils import torch_data
from trainer.config import load_JsonConfig

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
from visualise.rendering import RenderTool

import time


def init_model(model_name, model_path, args, config):
    if model_name == 's2g_face':
        generator = s2g_face(
            args,
            config,
        )
    elif model_name == 's2g_body_vq':
        generator = s2g_body_vq(
            args,
            config,
        )
    elif model_name == 's2g_body_pixel':
        generator = s2g_body_pixel(
            args,
            config,
        )
    else:
        raise NotImplementedError

    model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
    if model_name == 'smplx_S2G':
        generator.generator.load_state_dict(model_ckpt['generator']['generator'])

    elif 'generator' in list(model_ckpt.keys()):
        generator.load_state_dict(model_ckpt['generator'])
    else:
        model_ckpt = {'generator': model_ckpt}
        generator.load_state_dict(model_ckpt)

    return generator


def init_dataloader(data_root, speakers, args, config):
    if data_root.endswith('.csv'):
        raise NotImplementedError
    else:
        data_class = torch_data
    if 'smplx' in config.Model.model_name or 's2g' in config.Model.model_name:
        data_base = torch_data(
            data_root=data_root,
            speakers=speakers,
            split='test',
            limbscaling=False,
            normalization=config.Data.pose.normalization,
            norm_method=config.Data.pose.norm_method,
            split_trans_zero=False,
            num_pre_frames=config.Data.pose.pre_pose_length,
            num_generate_length=config.Data.pose.generate_length,
            num_frames=30,
            aud_feat_win_size=config.Data.aud.aud_feat_win_size,
            aud_feat_dim=config.Data.aud.aud_feat_dim,
            feat_method=config.Data.aud.feat_method,
            smplx=True,
            audio_sr=22000,
            convert_to_6d=config.Data.pose.convert_to_6d,
            expression=config.Data.pose.expression,
            config=config
        )
    else:
        data_base = torch_data(
            data_root=data_root,
            speakers=speakers,
            split='val',
            limbscaling=False,
            normalization=config.Data.pose.normalization,
            norm_method=config.Data.pose.norm_method,
            split_trans_zero=False,
            num_pre_frames=config.Data.pose.pre_pose_length,
            aud_feat_win_size=config.Data.aud.aud_feat_win_size,
            aud_feat_dim=config.Data.aud.aud_feat_dim,
            feat_method=config.Data.aud.feat_method
        )
    if config.Data.pose.normalization:
        norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
        norm_stats = np.load(norm_stats_fn, allow_pickle=True)
        data_base.data_mean = norm_stats[0]
        data_base.data_std = norm_stats[1]
    else:
        norm_stats = None

    data_base.get_dataset()
    infer_set = data_base.all_dataset
    infer_loader = data.DataLoader(data_base.all_dataset, batch_size=1, shuffle=False)

    return infer_set, infer_loader, norm_stats


def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
    vertices_list = []
    poses_list = []
    expression = torch.zeros([1, 50])

    for i in result_list:
        vertices = []
        poses = []
        for j in range(i.shape[0]):
            output = smplx_model(betas=betas,
                                 expression=i[j][165:265].unsqueeze_(dim=0) if exp else expression,
                                 jaw_pose=i[j][0:3].unsqueeze_(dim=0),
                                 leye_pose=i[j][3:6].unsqueeze_(dim=0),
                                 reye_pose=i[j][6:9].unsqueeze_(dim=0),
                                 global_orient=i[j][9:12].unsqueeze_(dim=0),
                                 body_pose=i[j][12:75].unsqueeze_(dim=0),
                                 left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
                                 right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
                                 return_verts=True)
            vertices.append(output.vertices.detach().cpu().numpy().squeeze())
            # pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
            pose = output.body_pose
            poses.append(pose.detach().cpu())
        vertices = np.asarray(vertices)
        vertices_list.append(vertices)
        poses = torch.cat(poses, dim=0)
        poses_list.append(poses)
    if require_pose:
        return vertices_list, poses_list
    else:
        return vertices_list, None


global_orient = torch.tensor([3.0747, -0.0158, -0.0152])


def infer(data_root, g_body, g_face, g_body2, exp_name, infer_loader, infer_set, device, norm_stats, smplx,
          smplx_model, rendertool, args=None, config=None, face_path=None, gesture_path=None):
    am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
    am_sr = 16000
    num_sample = 1
    face = False
    if face:
        body_static = torch.zeros([1, 162], device='cuda')
        body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
    stand = False
    j = 0
    gt_0 = None

    face_list = os.listdir(face_path)
    face_list.sort()

    gesture_list = os.listdir(gesture_path)
    gesture_list.sort()

    for idx, bat in enumerate(infer_loader):
        poses_ = bat['poses'].to(torch.float32).to(device)
        if poses_.shape[-1] == 300:
            # import pdb; pdb.set_trace()
            j = j + 1
            if j > 1000:
                continue
            id = bat['speaker'].to('cuda') - 20
            if config.Data.pose.expression:
                expression = bat['expression'].to(device).to(torch.float32)
                poses = torch.cat([poses_, expression], dim=1)
            else:
                poses = poses_
            cur_wav_file = bat['aud_file'][0]
            npy_file_name = 'visualise/video/' + config.Log.name + '/' + \
                        cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1] + '.npy'

            if os.path.exists(npy_file_name):
                continue
            
            betas = bat['betas'][0].to(torch.float64).to('cuda')
            # betas = torch.zeros([1, 300], dtype=torch.float64).to('cuda')
            gt = poses.to('cuda').squeeze().transpose(1, 0)
            if config.Data.pose.normalization: # false
                gt = denormalize(gt, norm_stats[0], norm_stats[1]).squeeze(dim=0)
            if config.Data.pose.convert_to_6d: # false
                if config.Data.pose.expression:
                    gt_exp = gt[:, -100:]
                    gt = gt[:, :-100]

                gt = gt.reshape(gt.shape[0], -1, 6)

                gt = matrix_to_axis_angle(rotation_6d_to_matrix(gt)).reshape(gt.shape[0], -1)
                gt = torch.cat([gt, gt_exp], -1)
            if face: # false
                gt = torch.cat([gt[:, :3], body_static.repeat(gt.shape[0], 1), gt[:, -100:]], dim=-1)

            result_list = [gt]

            # cur_wav_file = '.\\training_data\\1_song_(Vocals).wav'

            ############################ Prediction ############################
            pred_face = np.load(os.path.join(face_path, face_list[idx]))

            pred_face = torch.tensor(pred_face).squeeze().to('cuda')
            pred_jaw = pred_face[:, :3]
            pred_face = pred_face[:, 3:]

            for i in range(num_sample):
                pred_res = np.load(os.path.join(gesture_path,gesture_list[idx]))
                pred = torch.tensor(pred_res).squeeze().to('cuda')

                if pred.shape[0] < pred_face.shape[0]:
                    repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
                    pred = torch.cat([pred, repeat_frame], dim=0)
                else:
                    pred = pred[:pred_face.shape[0], :]


                # pred = torch.cat([pred, pred_face], dim=-1)
                pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)

                pred = part2full(pred, stand)


                result_list.append(pred)

            vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)

            result_list = [res.to('cpu') for res in result_list]
            dict = np.concatenate(result_list[1:], axis=0)
            file_name = 'visualise/video/' + config.Log.name + '/' + \
                        cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
            np.save(file_name, dict)

            rendertool._render_sequences(cur_wav_file, vertices_list[1:], stand=stand, face=face)


def main():
    parser = parse_args()
    args = parser.parse_args()
    device = torch.device(args.gpu)
    torch.cuda.set_device(device)

    config = load_JsonConfig(args.config_file)

    face_model_name = args.face_model_name
    face_model_path = args.face_model_path
    body_model_name = args.body_model_name
    body_model_path = args.body_model_path
    smplx_path = './visualise/'

    os.environ['smplx_npz_path'] = config.smplx_npz_path
    os.environ['extra_joint_path'] = config.extra_joint_path
    os.environ['j14_regressor_path'] = config.j14_regressor_path

    print('init model...')
    generator = init_model(body_model_name, body_model_path, args, config)
    generator2 = None
    generator_face = init_model(face_model_name, face_model_path, args, config)
    print('init dataloader...')
    infer_set, infer_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)

    print('init smlpx model...')
    dtype = torch.float64
    model_params = dict(model_path=smplx_path,
                        model_type='smplx',
                        create_global_orient=True,
                        create_body_pose=True,
                        create_betas=True,
                        num_betas=300,
                        create_left_hand_pose=True,
                        create_right_hand_pose=True,
                        use_pca=False,
                        flat_hand_mean=False,
                        create_expression=True,
                        num_expression_coeffs=100,
                        num_pca_comps=12,
                        create_jaw_pose=True,
                        create_leye_pose=True,
                        create_reye_pose=True,
                        create_transl=False,
                        # gender='ne',
                        dtype=dtype, )
    smplx_model = smpl.create(**model_params).to('cuda')
    

    if args.rename != None:
        config.Log.name = args.rename
    print('init rendertool...')
    rendertool = RenderTool('visualise/video/' + config.Log.name)
    
    infer(config.Data.data_root, generator, generator_face, generator2, args.exp_name, infer_loader, infer_set, device,
          norm_stats, True, smplx_model, rendertool, args, config, face_path=args.face_path, gesture_path=args.gesture_path)


if __name__ == '__main__':
    main()

JeremyCJM avatar Jul 01 '24 06:07 JeremyCJM

Hello, I get similar results to @jameskuma

I tried to understand if there is a mismatch in parameters in DiffSHEG output and SHOW SMPLX model input but everything seems okay. Has anyone been able to find the right way to render SHOW results?

@JeremyCJM I tried running your code but I cannot figure out what the face_path and gesture_path are since the DiffSHEG model only gives one npy output. Also, not quite sure why it creates a dataset and loader for the whole talkSHOW dataset whilst infering one output. Can you help me use your code for a single inference from the .npy output DiffSHEG gives?

Any help in visualising would be appreciated!

https://github.com/user-attachments/assets/08c0cd0e-f1ed-43eb-94a5-1ed136528e32

Here is my code:

import os
import sys
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
sys.path.append(os.getcwd())

from transformers import Wav2Vec2Processor
from glob import glob

import numpy as np
import json
import smplx as smpl

from nets import *
from trainer.options import parse_args
from data_utils import torch_data
from trainer.config import load_JsonConfig

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
from visualise.rendering import RenderTool

global device
device = 'cpu'

def init_model(model_name, model_path, args, config):
    if model_name == 's2g_face':
        generator = s2g_face(
            args,
            config,
        )
    elif model_name == 's2g_body_vq':
        generator = s2g_body_vq(
            args,
            config,
        )
    elif model_name == 's2g_body_pixel':
        generator = s2g_body_pixel(
            args,
            config,
        )
    elif model_name == 's2g_LS3DCG':
        generator = LS3DCG(
            args,
            config,
        )
    else:
        raise NotImplementedError

    model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
    if model_name == 'smplx_S2G':
        generator.generator.load_state_dict(model_ckpt['generator']['generator'])

    elif 'generator' in list(model_ckpt.keys()):
        generator.load_state_dict(model_ckpt['generator'])
    else:
        model_ckpt = {'generator': model_ckpt}
        generator.load_state_dict(model_ckpt)

    return generator

def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
    vertices_list = []
    poses_list = []
    expression = torch.zeros([1, 50])

    for i in result_list:
        vertices = []
        poses = []
        for j in range(i.shape[0]):
            output = smplx_model(betas=betas,
                                 expression=i[j][165:265].unsqueeze_(dim=0),
                                 jaw_pose=i[j][0:3].unsqueeze_(dim=0),
                                 leye_pose=i[j][3:6].unsqueeze_(dim=0),
                                 reye_pose=i[j][6:9].unsqueeze_(dim=0),
                                 global_orient=i[j][9:12].unsqueeze_(dim=0),
                                 body_pose=i[j][12:75].unsqueeze_(dim=0),
                                 left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
                                 right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
                                 return_verts=True)
            vertices.append(output.vertices.detach().cpu().numpy().squeeze())
            # pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
            pose = output.body_pose
            poses.append(pose.detach().cpu())
        vertices = np.asarray(vertices)
        vertices_list.append(vertices)
        poses = torch.cat(poses, dim=0)
        poses_list.append(poses)
    if require_pose:
        return vertices_list, poses_list
    else:
        return vertices_list, None


global_orient = torch.tensor([3.0747, -0.0158, -0.0152])


def infer(g_body, g_face, smplx_model, rendertool, config, args):
    betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
    am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
    am_sr = 16000
    num_sample = args.num_sample
    cur_wav_file = args.audio_file
    id = args.id
    face = args.only_face
    stand = args.stand
    if face:
        body_static = torch.zeros([1, 162], device=device)
        body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)

    # result_list = []

    # pred_face = g_face.infer_on_audio(cur_wav_file,
    #                                   initial_pose=None,
    #                                   norm_stats=None,
    #                                   w_pre=False,
    #                                   # id=id,
    #                                   frame=None,
    #                                   am=am,
    #                                   am_sr=am_sr
    #                                   )
    # pred_face = torch.tensor(pred_face).squeeze().to(device)
    # # pred_face = torch.zeros([gt.shape[0], 105])

    # if config.Data.pose.convert_to_6d:
    #     pred_jaw = pred_face[:, :6].reshape(pred_face.shape[0], -1, 6)
    #     pred_jaw = matrix_to_axis_angle(rotation_6d_to_matrix(pred_jaw)).reshape(pred_face.shape[0], -1)
    #     pred_face = pred_face[:, 6:]
    # else:
    #     pred_jaw = pred_face[:, :3]
    #     pred_face = pred_face[:, 3:]

    # id = torch.tensor([id], device=device)

    # for i in range(num_sample):
    #     pred_res = g_body.infer_on_audio(cur_wav_file,
    #                                      initial_pose=None,
    #                                      norm_stats=None,
    #                                      txgfile=None,
    #                                      id=id,
    #                                      var=None,
    #                                      fps=30,
    #                                      w_pre=False
    #                                      )
    #     pred = torch.tensor(pred_res).squeeze().to(device)

    #     if pred.shape[0] < pred_face.shape[0]:
    #         repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
    #         pred = torch.cat([pred, repeat_frame], dim=0)
    #     else:
    #         pred = pred[:pred_face.shape[0], :]

    #     body_or_face = False
    #     if pred.shape[1] < 275:
    #         body_or_face = True
    #     if config.Data.pose.convert_to_6d:
    #         pred = pred.reshape(pred.shape[0], -1, 6)
    #         pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
    #         pred = pred.reshape(pred.shape[0], -1)

    #     if config.Model.model_name == 's2g_LS3DCG':
    #         pred = torch.cat([pred[:, :3], pred[:, 103:], pred[:, 3:103]], dim=-1)
    #     else:
    #         pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)

    #     # pred[:, 9:12] = global_orient
    #     pred = part2full(pred, stand)
    #     if face:
    #         pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
    #     # result_list[0] = poses2pred(result_list[0], stand)
    #     # if gt_0 is None:
    #     #     gt_0 = gt
    #     # pred = pred2poses(pred, gt_0)
    #     # result_list[0] = poses2poses(result_list[0], gt_0)

    #     result_list.append(pred)

    result_list = torch.from_numpy(np.load('../DiffSHEG/results/talkshow_88/test_custom_audio/talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree/fixStart10/ckpt_e2599_ddim25_lastStepInterp/pid_4/gesture/Forrest_tts.npy'))
    result_list = part2full(result_list[0], stand=True).unsqueeze(0)
    print(result_list.shape)
    vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)

    result_list = [res.to('cpu') for res in result_list]
    dict = np.concatenate(result_list[:], axis=0)
    file_name = 'visualise/video/' + config.Log.name + '/' + \
                cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
    np.save(file_name, dict)
    rendertool._render_sequences(cur_wav_file, vertices_list, stand=stand, face=face, whole_body=args.whole_body)


def main():
    parser = parse_args()
    args = parser.parse_args()
    # device = torch.device(args.gpu)
    # torch.cuda.set_device(device)


    config = load_JsonConfig(args.config_file)

    face_model_name = args.face_model_name
    face_model_path = args.face_model_path
    body_model_name = args.body_model_name
    body_model_path = args.body_model_path
    smplx_path = './visualise/'

    os.environ['smplx_npz_path'] = config.smplx_npz_path
    os.environ['extra_joint_path'] = config.extra_joint_path
    os.environ['j14_regressor_path'] = config.j14_regressor_path

    print('init model...')
    generator = init_model(body_model_name, body_model_path, args, config)
    generator2 = None
    generator_face = init_model(face_model_name, face_model_path, args, config)

    print('init smlpx model...')
    dtype = torch.float64
    model_params = dict(model_path=smplx_path,
                        model_type='smplx',
                        create_global_orient=True,
                        create_body_pose=True,
                        create_betas=True,
                        num_betas=300,
                        create_left_hand_pose=True,
                        create_right_hand_pose=True,
                        use_pca=False,
                        flat_hand_mean=False,
                        create_expression=True,
                        num_expression_coeffs=100,
                        num_pca_comps=12,
                        create_jaw_pose=True,
                        create_leye_pose=True,
                        create_reye_pose=True,
                        create_transl=False,
                        # gender='ne',
                        dtype=dtype, )
    smplx_model = smpl.create(**model_params).to(device)
    print('init rendertool...')
    rendertool = RenderTool('visualise/video/' + config.Log.name)

    infer(generator, generator_face, smplx_model, rendertool, config, args)


if __name__ == '__main__':
    main()

TashvikDhamija avatar Sep 03 '24 11:09 TashvikDhamija

@TashvikDhamija @jameskuma Guys, did you correctly render the Talkshow result??? Could you share the tricks? Thx

henryham avatar Mar 02 '25 13:03 henryham

in TalkSHOW/data_utils/lower_body.py

def part2full(input, stand=False):
    if stand:
        # lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
        lp = torch.zeros_like(lower_pose)
        lp[6:9] = torch.tensor([3.0747, -0.0158, -0.0152])
        lp = lp.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
    else:
        lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)

    input = torch.cat([input[:, :3],
                       lp[:, :15],
                       input[:, 3:6],
                       lp[:, 15:21],
                       input[:, 6:9],
                       lp[:, 21:27],
                       input[:, 9:12],
                       lp[:, 27:],
                       input[:, 12:]]
                      , dim=1)
    return input

the problem is here: lp should be the first

up1 = gesture[:, 0:3]
up2 = gesture[:, 3:6]
up3 = gesture[:, 6:9]
up4 = gesture[:, 9:39]
hands = gesture[:, 39:129]    
left_hand  = hands[:, :45]
right_hand = hands[:, 45:90]

zeros6 = np.zeros((T,6), np.float32)
if USE_LOWER_PRESET:
    low1 = LOWER_PRESET[0:6][None].repeat(T, axis=0)
    low2 = LOWER_PRESET[6:12][None].repeat(T, axis=0)
    low3 = LOWER_PRESET[12:18][None].repeat(T, axis=0)
    low4 = LOWER_PRESET[18:24][None].repeat(T, axis=0)
else:
    low1 = low2 = low3 = low4 = zeros6

body_pose = np.concatenate([low1, up1, low2, up2, low3, up3, low4, up4], axis=1)  # [T,63]

this is my code, and it works correctly

HOWEVER, the problem I encountered was that his hand was bent at an unusual angle, have anyone encountered yet?

Image

rookie2002 avatar Nov 03 '25 09:11 rookie2002