torchdiffeq icon indicating copy to clipboard operation
torchdiffeq copied to clipboard

Can the gradient pass through ode?

Open dgbsg opened this issue 4 years ago • 0 comments

Using the output value of neural network as the parameter of differential equation, can the gradient be conducted in this case? As shown below. `class ODENet(nn.Module): class ODENet(nn.Module): def init(self,n_input,n_hidden,n_output): super(ODENet,self).init() self.hidden=torch.nn.Linear(n_input,n_hidden,bias=False) self.output=torch.nn.Linear(n_hidden,n_output,bias=False) def forward(self,x): x=self.hidden(x) x=torch.tanh(x) x=self.output(x) return x

NN_model=ODENet() def CME_train(t,y): NN_out = NN_model(y.T) ....... du=torch.mm(drift_matric,y) return du for i in range(10): CME_result = torch.Tensor(odeint(CME_train, z0.unsqueeze(-1), torch.linspace(0, 51, 51))) optimizer.zero_grad() loss=loss_func(CME_result,true) loss.backward() optimizer.step()`

dgbsg avatar Dec 02 '21 09:12 dgbsg