Some Mappings missed when checking for Dict type only
Describe the bug
In the case of masked-language-modeling experiment, inputs to the model (outputs from the tokenizer) are instances of the BatchEncoding class (HF source code) which is a subclass of UserDict.
Unfortunately, a UserDict is not an instance of the Dict class and checks for isinstance(tensors, Dict) will be completely missed.
from collections import UserDict
from typing import Dict
assert not isinstance(UserDict(), Dict)
To Reproduce
-
Run a masked-language-modelling experiment with KD and pytorch 1.11 version. All checks for
isinstance(tensors, Dict)will be missed and result with many exceptions. -
Run a masked-language-modelling experiment with KD and any pytorch version prior to 29 Nov. 2021, with disabled pin_memory feature.
Additional Context
This bug didn't appear in current experiments with pytorch versions before 29 Nov. 2021 because all runs by default use pin_memory = True. And the pin_memory = True had a small issue where input data of any Mapping type would always be converted to Dict, instead of preserving the original type(data). With this issue fixed in newer versions of pytorch, pin_memory = True will indeed return the original type, which in our MLM case is BatchEncoding(UserDict), and this will miss all isinstance(tensors, Dict) checks.
Relevant links:
Reopening because the same issue happens at a different place now.