tnt
tnt copied to clipboard
Separate grad scaler test out from test_app_state_mixin
Summary:
Currently test_app_state_mixin is failing on OSS CI:
self = <framework.test_auto_unit.TestAutoUnit testMethod=test_app_state_mixin>
def test_app_state_mixin(self) -> None:
"""
Test that app_state, tracked_optimizers, tracked_lr_schedulers are set as expected with AutoUnit
"""
my_module = torch.nn.Linear(2, 2)
auto_unit = DummyAutoUnit(
module=my_module,
precision="fp16",
)
self.assertEqual(auto_unit.tracked_modules()["module"], my_module)
self.assertTrue(
isinstance(
auto_unit.tracked_misc_statefuls()["grad_scaler"],
torch.cuda.amp.GradScaler,
)
)
for key in ("module", "optimizer", "lr_scheduler", "grad_scaler"):
> self.assertTrue(key in auto_unit.app_state())
E AssertionError: False is not true
tests/framework/test_auto_unit.py:69: AssertionError
https://github.com/pytorch/tnt/actions/runs/5321328919/jobs/9636295932
Differential Revision: D46870935
This pull request was exported from Phabricator. Differential Revision: D46870935