torchdiffeq icon indicating copy to clipboard operation
torchdiffeq copied to clipboard

Can't replicate O(1) memory with Adjoint method

Open petercmh01 opened this issue 1 year ago • 0 comments

Hi, I tried to use the following code to test the effects of using adjoint method

import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint as odeint
import time

class LargeODEFunc(nn.Module):
    def __init__(self, state_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, t, y):
        return self.net(y)

state_dim = 1024
hidden_dim = 512
output_dim = 1024

ode_func = LargeODEFunc(state_dim, hidden_dim, output_dim).cuda()

y0 = torch.randn(1, state_dim).cuda()
t_dense = torch.linspace(0, 1, 1000).cuda()
t_sparse = torch.linspace(0, 1, 10).cuda()

optimizer = torch.optim.Adam(ode_func.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

def print_cuda_memory_usage():
    print(f"Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"Memory Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

def train_step(all_time_steps, label):
    optimizer.zero_grad()
    torch.cuda.reset_peak_memory_stats()

    print("\nAfter optimizer.zero_grad() and before odeint()")
    print_cuda_memory_usage()

    y_ode = odeint(ode_func, y0, all_time_steps)[-1]
    
    target = torch.randn_like(y_ode).cuda()
    loss = loss_fn(y_ode, target)
    loss.backward()
    print("\nAfter loss_backward()")
    print_cuda_memory_usage()
    optimizer.step()
    torch.cuda.synchronize()
    optimizer.zero_grad()
    #print(f"Peak Memory = {peak_mem:.2f}MB, Decoded Steps = {len(all_time_steps)}")

print("Training with dense time steps:")
for epoch in range(5):
    train_step(t_dense, "dense")

As comparison, I seen to have lower memory consumption with the normal odeint instead of odeint_adjoint. Can anyone see if there is a problem in my design?

petercmh01 avatar Dec 23 '24 14:12 petercmh01