Cannot finetune pi0_base model
Hi,
I'm following all the documentation steps to finetune pi0_base on Fractal dataset. I managed to compute the stats (previous step in the documentation), it worked without any problem, but when I run the train script I face the follwoing error:
Traceback (most recent call last):
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 446, in wrapped_fn_impl
param_fn(*args, **kwargs)
File "<@beartype(openpi.training.utils.TrainState) at 0x7f2010728ae0>", line 135, in TrainState
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 295, in __instancecheck__
return cls.__is_instance_beartype__(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefabc.py", line 112, in __is_instance_beartype__
return isinstance(obj, cls.__type_beartype__) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 451, in __type_beartype__
referent = import_module_attr(
^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_util/module/utilmodimport.py", line 295, in import_module_attr
raise exception_cls(exception_message)
beartype.roar.BeartypeCallHintForwardRefException: Forward reference "ArrayTree" unimportable from module "openpi.training.utils".
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 811, in _get_problem_arg
fn(*args, **kwargs)
File "<@beartype(openpi.training.utils.check_single_arg) at 0x7f201072a660>", line 53, in check_single_arg
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 295, in __instancecheck__
return cls.__is_instance_beartype__(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefabc.py", line 112, in __is_instance_beartype__
return isinstance(obj, cls.__type_beartype__) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 451, in __type_beartype__
referent = import_module_attr(
^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_util/module/utilmodimport.py", line 295, in import_module_attr
raise exception_cls(exception_message)
beartype.roar.BeartypeCallHintForwardRefException: Forward reference "ArrayTree" unimportable from module "openpi.training.utils".
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 451, in wrapped_fn_impl
argmsg = _get_problem_arg(
^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 814, in _get_problem_arg
raise TypeCheckError(
jaxtyping.TypeCheckError:
The problem arose whilst typechecking parameter 'opt_state'.
Here is my config file:
TrainConfig(
name="pi0_fractal",
# Here is an example of loading a pi0 model for LoRA fine-tuning.
model=pi0.Pi0Config(action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m"),
data=LeRobotFractalDataConfig(
repo_id="IPEC-COMMUNITY/fractal20220817_data_lerobot",
base_config=DataConfig(
local_files_only=False, # Set to True for local-only datasets.
prompt_from_task=True,
),
),
batch_size=8,
num_workers=64,
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=30_000,
# The freeze filter defines which parameters should be frozen during training.
# We have a convenience function in the model config that returns the default freeze filter
# for the given model config for LoRA finetuning. Just make sure it matches the model config
# you chose above.
freeze_filter=pi0.Pi0Config(
paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m"
).get_freeze_filter(),
# Turn off EMA for LoRA finetuning.
ema_decay=None,
),
Any help is welcome!! Thanks
Hi, I'm new to OpenPi and ran into the same error while trying to fine-tune the pi0_base model on HSR dataset.
Since the error seemed to be caused by type checking of "opt_state", I tried modifying the type definition of the TrainState class in openpi/src/openpi/training/utils.py (starting at line 15) from:
opt_state: optax.OptState
to:
opt_state: Any
This change resolved the error, and I was able to start training successfully.
However, I'm a bit concerned that removing the type check might cause issues later. I’d appreciate it if someone more experienced with OpenPi could confirm whether this workaround is acceptable or if there's a better fix.
Hope this helps.
This was addressed at HEAD.