command line arg parsing for Wav2vec example on TPU
A) wav2vec example when executed on TPU via using fairseq cli arguments leads to the following error:
Traceback (most recent call last):
File "/home/sivaibhav/.local/lib/python3.8/site-packages/hydra/core/override_parser/overrides_visitor.py", line 302, in visitFunction
return self.functions.eval(function)
File "/home/sivaibhav/.local/lib/python3.8/site-packages/hydra/_internal/grammar/functions.py", line 34, in eval
raise HydraException(
hydra.errors.HydraException: Unknown function 'InferredW2vConfig'
Available: bool,choice,float,glob,int,interval,range,shuffle,sort,str,tag
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "./train.py", line 14, in <module>
cli_main()
File "/home/sivaibhav/fairseq-dev/fairseq_cli/train.py", line 496, in cli_main
cfg = convert_namespace_to_omegaconf(args)
File "/home/sivaibhav/fairseq-dev/fairseq/dataclass/utils.py", line 389, in convert_namespace_to_omegaconf
composed_cfg = compose("config", overrides=overrides, strict=False)
File "/home/sivaibhav/.local/lib/python3.8/site-packages/hydra/experimental/compose.py", line 31, in compose
cfg = gh.hydra.compose_config(
File "/home/sivaibhav/.local/lib/python3.8/site-packages/hydra/_internal/hydra.py", line 507, in compose_config
cfg = self.config_loader.load_configuration(
File "/home/sivaibhav/.local/lib/python3.8/site-packages/hydra/_internal/config_loader_impl.py", line 151, in load_configuration
return self._load_configuration(
File "/home/sivaibhav/.local/lib/python3.8/site-packages/hydra/_internal/config_loader_impl.py", line 180, in _load_configuration
parsed_overrides = parser.parse_overrides(overrides=overrides)
File "/home/sivaibhav/.local/lib/python3.8/site-packages/hydra/core/override_parser/overrides_parser.py", line 95, in parse_overrides
raise OverrideParseException(
hydra.errors.OverrideParseException: Error parsing override 'task.inferred_w2v_config=InferredW2vConfig(mask_length='${model.mask_length}', mask_prob='${model.mask_prob}', mask_selecti
on='${model.mask_selection}', mask_other='${model.mask_other}', no_mask_overlap='${model.no_mask_overlap}', mask_min_space='${model.mask_min_space}', mask_channel_length='${model.mask_
channel_length}', mask_channel_prob='${model.mask_channel_prob}', mask_channel_selection='${model.mask_channel_selection}', mask_channel_other='${model.mask_channel_other}', no_mask_ch
annel_overlap='${model.no_mask_channel_overlap}', mask_channel_min_space='${model.mask_channel_min_space}', conv_feature_layers='${model.conv_feature_layers}', encoder_embed_dim='${mod
el.encoder_embed_dim}')'
HydraException while evaluating 'InferredW2vConfig(mask_length='${model.mask_length}', mask_prob='${model.mask_prob}', mask_selection='${model.mask_selection}', mask_other='${model.mas
k_other}', no_mask_overlap='${model.no_mask_overlap}', mask_min_space='${model.mask_min_space}', mask_channel_length='${model.mask_channel_length}', mask_channel_prob='${model.mask_cha
nnel_prob}', mask_channel_selection='${model.mask_channel_selection}', mask_channel_other='${model.mask_channel_other}', no_mask_channel_overlap='${model.no_mask_channel_overlap}', mas
k_channel_min_space='${model.mask_channel_min_space}', conv_feature_layers='${model.conv_feature_layers}', encoder_embed_dim='${model.encoder_embed_dim}')': Unknown function 'InferredW
2vConfig'
Available: bool,choice,float,glob,int,interval,range,shuffle,sort,str,tag
Commandline used is:
export OMP_NUM_THREADS=1
python3 ./train.py \
/home/sivaibhav/manifest/ \
--num-batch-buckets 3 \
--tpu \
--max-sentences 4 \
--max-sentences-valid 4 \
--required-batch-size-multiple 4 \
--distributed-world-size 8 \
--distributed-port 12597 \
--update-freq 1 \
--enable-padding \
--log-interval 20 \
--num-workers 6 \
--task audio_pretraining \
--criterion wav2vec \
--arch wav2vec2 \
--log-keys "['prob_perplexity','code_perplexity','temp']" \
--quantize-targets \
--extractor-mode default \
--conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' \
--final-dim 256 \
--latent-vars 320 \
--latent-groups 2 \
--latent-temp '(2,0.5,0.999995)' \
--infonce \
--optimizer adam \
--adam-betas '(0.9,0.98)' \
--adam-eps 1e-06 \
--lr-scheduler polynomial_decay \
--total-num-update 400000 \
--lr 0.0005 \
--warmup-updates 32000 \
--encoder-layerdrop 0 \
--dropout-input 0.0 \
--dropout-features 0.0 \
--feature-grad-mult 0.1 \
--loss-weights '[0.1, 10]' \
--conv-pos 128 \
--conv-pos-groups 16 \
--num-negatives 100 \
--cross-sample-negatives 0 \
--max-sample-size 250000 \
--min-sample-size 32000 \
--dropout 0.0 \
--attention-dropout 0.0 \
--weight-decay 0.01 \
--max-tokens 1400000 \
--skip-invalid-size-inputs-valid-test \
--ddp-backend no_c10d \
--log-format simple \
B) When using hydra config file:
We see the following error:
Traceback (most recent call last):
File "/home/sivaibhav/fairseq/fairseq_cli/hydra_train.py", line 45, in hydra_main
distributed_utils.call_main(cfg, pre_main)
File "/home/sivaibhav/fairseq/fairseq/distributed/utils.py", line 369, in call_main
main(cfg, **kwargs)
File "/home/sivaibhav/fairseq/fairseq_cli/train.py", line 124, in main
task.load_dataset(valid_sub_split, combine=False, epoch=1)
File "/home/sivaibhav/fairseq/fairseq/tasks/audio_pretraining.py", line 245, in load_dataset
if self.cfg.tpu and task_cfg["mask_channel_prob"] == 0.0:
omegaconf.errors.ConfigKeyError: Key 'mask_channel_prob' not in 'AudioPretrainingConfig'
full_key: mask_channel_prob
reference_type=Optional[AudioPretrainingConfig]
object_type=AudioPretrainingConfig
This can be reproduced using:
OMP_NUM_THREADS=1 fairseq-hydra-train task.data=/home/sivaibhav/manifest --config-dir ./examples/wav2vec/config/pretraining --config-name wav2vec2_large_librivox_tpu.yaml
Any update on this issue? I have a similar issue, though when trying to run fairseq.checkpoint_utils.load_model_ensemble_and_task on a wav2vec model that I fine tuned myself with fairseq-hydra-train. My issue looks like this:
omegaconf.errors.ConfigKeyError: Key 'eval_wer' not in 'AudioPretrainingConfig'
full_key: eval_wer
reference_type=Optional[AudioPretrainingConfig]
object_type=AudioPretrainingConfig
Met the same issue when loading multilingual pre-trained wav2vec 2.0 (XLSR) models, and I used the sample code from documentation.
import torch
import fairseq
cp_path = './ckpt/xlsr_53_56k.pt'
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
model = model[0]
model.eval()
wav_input_16khz = torch.randn(1,10000)
z = model.feature_extractor(wav_input_16khz)
c = model.feature_aggregator(z)
Errors show:
omegaconf.errors.ConfigKeyError: Key 'eval_wer' not in 'AudioPretrainingConfig'
full_key: eval_wer
reference_type=Optional[AudioPretrainingConfig]
object_type=AudioPretrainingConfig
Any update on this issue? I have a similar issue, though when trying to run
fairseq.checkpoint_utils.load_model_ensemble_and_taskon a wav2vec model that I fine tuned myself withfairseq-hydra-train. My issue looks like this:omegaconf.errors.ConfigKeyError: Key 'eval_wer' not in 'AudioPretrainingConfig' full_key: eval_wer reference_type=Optional[AudioPretrainingConfig] object_type=AudioPretrainingConfig
@kaleko I've solved this issue manually by dropping some keys in the state_dict. See : https://github.com/facebookresearch/fairseq/issues/4585
@xjtupanda Awesome your code solves the eval_wer key issue for me. However, now I encounter a new one which I haven't been able to solve by dropping keys. Have you encountered this before?
omegaconf.errors.ConfigKeyError: Key 'target_dict' not in 'AudioPretrainingConfig'
full_key: target_dict
reference_type=Optional[AudioPretrainingConfig]
object_type=AudioPretrainingConfig
@kaleko Basically this is because there are some keys missing in the AudioPretrainingConfig and leads to inconsistency. I guess by dropping those keys may solve your problem, but I don't find the key 'target_dict' in my Config. But you can reference this for how to locate those keys and drop them.
Make sure you have installed omegaconf package. pip install omegaconf
- Load the checkpoint file and convert the Omegaconf into a ordinary dict for ease of iteration.
from omegaconf import DictConfig, OmegaConf, open_dict
cp_path = './ckpt/xlsr_53_56k.pt'
cp = torch.load(cp_path)
cfg = DictConfig(cp['cfg'])
- Convert the DictConfig into a ordinary dict object and iterate that to find the wrong key.
dd = OmegaConf.to_container(cfg, resolve=True)
for k,v in dd.items():
if not isinstance(v, dict):
continue
for key, _ in v.items():
if key == 'eval_wer':
print(k)
break
The result shows it's in the sub-dict 'task'. 3. Drop the keys and save the new checkpoint.
with open_dict(cfg):
cfg.task.pop('eval_wer')
cp['cfg'] = cfg
torch.save(cp, './ckpt/xlsr_53_56k_new.pt')
@xjtupanda Great code! This also solves the issue for me. Did you manage to figure out whether dropping keys will have any negative impact on model performance?
@KiriKoppelgaard I suppose not since I was just trying to extract features using pretrained models. But to work around this, in the end I used transformers package and loaded the pretrained model from faecbook's Hugging Face pages, and it worked just fine without any error or warning.