How to check `is_last_batch` in torchtnt==0.1.0?
🐛 Describe the bug
In the train_step method of a Callback class:
def train_step(self, state: State, data: TrainBatch) -> None:
global_step = state.train_state.progress.num_steps_completed
is_last_batch = state.train_state.is_last_batch # error
Error:
AttributeError:'PhaseState' object has no attribute 'is_last_batch'
train_state._step_output = train_unit.train_step(state, step_input)
Versions
Hi @daniellepintz
Why is is_last_batch removed in #367 ?
This wont pass the test defined in https://github.com/pytorch/tnt/blob/9b3b7b1a3c0cfa8354bd459fe84a46a03b2754f5/tests/framework/test_auto_unit.py#L901
What is the correct way to check is_last_batch? Thank you in advance!
The env:
torchtnt==0.1.0
torcheval==0.0.6
torchsnapshot==0.1.0
### Tasks
CC @ananthsub
Hi @yiminglin-ai, could you describe your use case in more detail?
For context, that field was added to support gradient accumulation in the AutoUnit extension. However, subsequent PRs made it such that we could deduce this information entirely within the AutoUnit. Accordingly, to keep the state as minimal as possible, we removed the is_last_batch attribute from there.
Knowing more information on how you'd like to use this data will help us out!