PUDM icon indicating copy to clipboard operation
PUDM copied to clipboard

[email protected]

Open kizzyk opened this issue 1 year ago • 4 comments

pointnet2/util.py 中报错 B, N, D = x0.shape diffusion_steps = torch.randint(T, size=(B, 1, 1)).cuda() # t ~ U[T] z = std_normal(x0.shape) # xt = sqrt(at_) * X0 + sqrt(1-at_) * z ==> q(xt|x0) xt = torch.sqrt(Alpha_bar[diffusion_steps]) * x0 + torch.sqrt(1 - Alpha_bar[diffusion_steps]) * z i = midpoint_interpolate(condition.permute(0, 2, 1)).permute(0, 2, 1) xt = torch.cat([xt, i], dim=-1)

第520行 xt = torch.cat([xt, i], dim=-1) 这里xt维度是 [B,N,D] (是由x0变换而来,而B, N, D = x0.shape显示x0维度是[B,N,D]),i维度是[B,rN,D] , 根据论文可知,目的是拼接成[B,N+rN,D]的张量,但是这里操作的维度是dim=-1,报错了,不知道是我设置的数据格式,还是代码的错误呢,修改成dim=1,就不会报错了。

第二个问题,参数 in_fea_dim "pointnet_config": { "model_name": "pointnet", "in_fea_dim": 3,

会报错,改成"in_fea_dim": 0 之后就不会报错了,有可能是我的数据格式问题,能否提供一下您输入到此函数时候x0,label,condition的维度呢 def training_loss( net, loss_fn, x0, diffusion_hyperparams, label=None, condition=None, alpha=1.0, gamma=None ): 我的x0和condition都是[B,2048,3]的 希望收到您的答复!感谢!

kizzyk avatar Oct 11 '24 03:10 kizzyk

你好,代码本身是没有错误的,因为很多人都可以跑通的(参考其他close的问题)。需要改维度,说明输入数据的维度有问题。 我认为你应该是在训练期间碰到错误的是吗,因为涉及到了x0,输入的condition维度应该是[B,N,3],而x0(Ground Truth)是[B,rN,3],label的维度是[3,](参考pointnet2/dataloader/dataset_loader.py)。 x0意味着ground truth,也就是上采样之后的稠密点云,维度应该是[B,rN,3];xt意味着通过diffusion公式加噪x0而来的(xt~q(xt|x0)),因此维度是[B,rN,3];condition意味着需要上采样的点云,也就是稀疏点云,维度应该是[B,N,3];label意味着稠密点云x0与稀疏点云condition之间的倍率关系,维度[r-1,](下标从0开始,比如差距4倍,那么输入是[3,]);i意味着插值点云,从condition通过中点插值而来,维度应该是[B,rN,3]。

同时,xt = torch.cat([xt, i], dim=-1),这一行代码意味着在最后一个维度进行拼接,比如[B,rN,3]与[B,rN,3]在最后一个通道拼接,那么应该是[B,rN,3+3]

QWTforGithub avatar Oct 11 '24 05:10 QWTforGithub

请注意,在训练时,你的数据输入,只需要两个:condition [B,N,3]和x0 [B,rN,3]。label会自动生成(参考pointnet2/dataloader/dataset_loader.py)。

QWTforGithub avatar Oct 11 '24 05:10 QWTforGithub

请注意,在训练时,你的数据输入,只需要两个:condition [B,N,3]和x0 [B,rN,3]。label会自动生成(参考pointnet2/dataloader/dataset_loader.py)。

很高兴收到你的答复,看了你说了,我现在明白了,是我之前的输入有问题,感谢解惑!

kizzyk avatar Oct 11 '24 05:10 kizzyk

请注意,在训练时,你的数据输入,只需要两个:condition [B,N,3]和x0 [B,rN,3]。label会自动生成(参考pointnet2/dataloader/dataset_loader.py)。

很高兴收到你的答复,看了你说了,我现在明白了,是我之前的输入有问题,感谢解惑!

不客气。

QWTforGithub avatar Oct 11 '24 06:10 QWTforGithub