tnt icon indicating copy to clipboard operation
tnt copied to clipboard

How to check `is_last_batch` in torchtnt==0.1.0?

Open yiminglin-ai opened this issue 2 years ago • 2 comments

🐛 Describe the bug

In the train_step method of a Callback class:

def train_step(self, state: State, data: TrainBatch) -> None:
  global_step = state.train_state.progress.num_steps_completed
  is_last_batch = state.train_state.is_last_batch # error

Error:

AttributeError:'PhaseState' object has no attribute 'is_last_batch'
train_state._step_output = train_unit.train_step(state, step_input)

Versions

Hi @daniellepintz Why is is_last_batch removed in #367 ? This wont pass the test defined in https://github.com/pytorch/tnt/blob/9b3b7b1a3c0cfa8354bd459fe84a46a03b2754f5/tests/framework/test_auto_unit.py#L901 What is the correct way to check is_last_batch? Thank you in advance!

The env:

torchtnt==0.1.0
torcheval==0.0.6
torchsnapshot==0.1.0
### Tasks

yiminglin-ai avatar Jun 14 '23 13:06 yiminglin-ai

CC @ananthsub

yiminglin-ai avatar Jun 14 '23 13:06 yiminglin-ai

Hi @yiminglin-ai, could you describe your use case in more detail?

For context, that field was added to support gradient accumulation in the AutoUnit extension. However, subsequent PRs made it such that we could deduce this information entirely within the AutoUnit. Accordingly, to keep the state as minimal as possible, we removed the is_last_batch attribute from there.

Knowing more information on how you'd like to use this data will help us out!

ananthsub avatar Jun 14 '23 19:06 ananthsub