CPCGAN
CPCGAN copied to clipboard
how to use
您好,我在尝试使用GAN来做3D点云的生成,参考了您的论文,想请教您一个问题。我直接使用MLP来实现一个简单的Discriminator和Generator。来做无条件的3d点云生成,不过多次训练效果都非常不理想,不知道是否是这个思路就不对。非常期待您的指导。参考您的代码我修改后如下:
import torch
from torch.nn import Conv1d, Sequential, ConvTranspose1d,Linear, Dropout,LeakyReLU, Module, BatchNorm1d, AdaptiveMaxPool1d,AdaptiveAvgPool1d, Upsample, ConvTranspose1d,Tanh,Sigmoid, BCELoss, MaxPool1d, LayerNorm, ConvTranspose1d
from torch.optim import Adam
import matplotlib.pyplot as plt
from pytorch3d.datasets import ShapeNetCore,collate_batched_meshes
from pytorch3d.ops import sample_points_from_meshes
from torch.utils.data import DataLoader
import torch.nn.functional as F
from pytorch3d.structures import Meshes
device = torch.device('cuda:0')
#Discriminator
class Discriminator(Module):
def __init__(self, input_dim):
super(Discriminator, self).__init__()
# Bx3xN -> Bx1024xN
self.layer_1_0 = Conv1d(in_channels=input_dim,out_channels=64, kernel_size=1, stride=1)
self.layer_1_1 = Conv1d(in_channels=64,out_channels=128, kernel_size=1, stride=1)
self.layer_1_2 = Conv1d(in_channels=128,out_channels=256, kernel_size=1, stride=1)
self.layer_1_3 = Conv1d(in_channels=256,out_channels=512, kernel_size=1, stride=1)
self.layer_1_4 = Conv1d(in_channels=512,out_channels=1024, kernel_size=1, stride=1)
self.layer_1_A = LeakyReLU(0.2,inplace=True)
# Bx1024xN -> Bx1024x1
self.layer_2 = Sequential(
AdaptiveMaxPool1d(1),
# AdaptiveAvgPool1d(1)
)
#Bx1024 -> Bx1
self.layer_3 = Sequential(
Linear(in_features=1024,out_features=1024),
Linear(in_features=1024,out_features=512),
Linear(in_features=512,out_features=512),
Linear(in_features=512,out_features=1),
)
def forward(self, x):
b_size,C,L = x.shape
out = self.layer_1_0(x)
out = self.layer_1_A(out)
out = self.layer_1_1(out)
out = self.layer_1_A(out)
out = self.layer_1_2(out)
out = self.layer_1_A(out)
out = self.layer_1_3(out)
out = self.layer_1_A(out)
out = self.layer_1_4(out)
out = self.layer_1_A(out)
out = self.layer_2(out)
out = out.squeeze(-1) # (B,1024)
out = self.layer_3(out) #(B,1)
return out
# Generator
class Generator(Module):
def __init__(self, input_dim, input_k):
super(Generator, self).__init__()
self.layer_1_0 = Linear(in_features=input_dim, out_features=128)
self.layer_1_1 = Linear(in_features=128, out_features=256)
self.layer_1_2 = Linear(in_features=256, out_features=256)
self.layer_1_3 = Linear(in_features=256, out_features=256)
self.layer_1_4 = Linear(in_features=256, out_features=512)
self.layer_1_5 = Linear(in_features=512, out_features=512)
self.layer_1_A = LeakyReLU(0.2, inplace=True)
out_f = int((2048/input_k)*3)
self.layer_2 = Sequential(
Linear(in_features=512,out_features=out_f),
)
def forward(self, x):
b_size,_,_ = x.shape
x = x.transpose(1,2)
out = self.layer_1_0(x)
out = self.layer_1_A(out)
out = self.layer_1_1(out)
out = self.layer_1_A(out)
out = self.layer_1_2(out)
out = self.layer_1_A(out)
out = self.layer_1_3(out)
out = self.layer_1_A(out)
out = self.layer_1_4(out)
out = self.layer_1_A(out)
out = self.layer_1_5(out)
out = self.layer_1_A(out)
out = self.layer_2(out)
out = out.view(b_size,-1,3)
out = out.transpose(1,2)
return out
#GP
def cal_gradient_penalty(D, xr, xf):
# print(xr.shape)
# print(xf.shape)
b_size, _, L = xr.shape
# print(L)
#随机取一个点
alpha = torch.rand(b_size,1, 1, device=device,requires_grad=True).cuda()
# 取interpolation
interpolates = xr + alpha * (xf - xr)
# set it to require grad info
disc_interpolates = D(interpolates)
grads = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates).cuda(),
create_graph=True, retain_graph=True, only_inputs=True)[0]
# print(grads.shape)
grads = grads.contiguous().view(b_size,-1)
gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()
return gp
#以下参数设置
real_label = 1.
fake_label = 0.
# 学习率
lr_D = 1e-4
lr_G = 1e-4
# Adam 一阶超参数
beta1 = 0.
beta2 = 0.99
# 采样点数
sample_number = 2048
generator_feature = 96
generator_number = 64
batch_size = 64
# 只取一个分类来训练,04099429 = ROCKET
shapenet_dataset = ShapeNetCore(data_dir="/mnt/ShapeNetCore.v2",version=2,load_textures=True,synsets=['04099429'])
dataloader = DataLoader(dataset=shapenet_dataset,batch_size=batch_size, collate_fn=collate_batched_meshes, num_workers=8, pin_memory=True, drop_last=True, shuffle=True)
netD=Discriminator(3).to(device)
netG=Generator(generator_feature,generator_number).to(device)
# Setup Adam optimizers for both G and D
optimizerD = Adam(netD.parameters(), lr=lr_D, betas=(beta1, beta2))
optimizerG = Adam(netG.parameters(), lr=lr_G, betas=(beta1, beta2))
#训练
for epoch in range(200):
netG.train()
# 训练4次D
for data in dataloader:
# 获取real样本
mesh = data['mesh']
real_data = sample_points_from_meshes(mesh, sample_number).to(device) #B x N x C
real_data = real_data.transpose(2,1).contiguous() # B x C x N
for _ in range(2):
# 获取fake样本
noise = torch.randn((generator_feature,generator_number), dtype=torch.float,device=device) # C x N
noise = noise.expand(batch_size,generator_feature,generator_number) # B x C x N
netD.zero_grad()
fake_data = netG(noise) #B x C x N
# 判断real
output_real = netD(real_data)
# label = torch.full(output_real.shape, real_label, dtype=torch.float, device=device)
# print(f"------ output_real {output_real.shape}")
# loss_real = criterion(output_real,label)
# loss_real.backward()
# 判断fake
# fake_data = fake_data.transpose(1,2).contiguous()
# label.fill_(fake_label)
output_fake = netD(fake_data)
# loss_fake = criterion(output_fake,label)
# loss_fake.backward()
# gp = cal_gradient_penalty(netD,real_data,fake_data)
# d_loss = output_fake.mean() - output_real.mean() + gp
# d_loss.backward()
# 计算gp
gp = cal_gradient_penalty(netD,real_data,fake_data)
d_loss = output_fake.mean() - output_real.mean() + gp * 10
d_loss.backward()
# d_loss = loss_real.mean()+loss_fake.mean()
optimizerD.step()
for _ in range(1):
#训练G
netG.zero_grad()
# noise = torch.randn((batch_size,generator_feature,generator_number), dtype=torch.float,device=device)
noise = torch.randn((generator_feature,generator_number), dtype=torch.float,device=device)
noise = noise.expand(batch_size,generator_feature,generator_number)
fake_data = netG(noise)
# fake_data = fake_data.transpose(1,2).contiguous()
# label.fill_(real_label)
output_fake_real = netD(fake_data)
# loss_fake_real = criterion(output_fake_real,label)
# loss_fake_real.backward()
# g_loss = loss_fake_real.mean()
g_loss = -(output_fake_real.mean())
g_loss.backward()
optimizerG.step()
print(f"epoch = {epoch}, loss_d = {d_loss}, errorG = {g_loss}")
torch.save(netD.state_dict(),"3DGAN-D-09.pt")
torch.save(netG.state_dict(),"3DGAN-G-09.pt")