openpi icon indicating copy to clipboard operation
openpi copied to clipboard

Cannot finetune pi0_base model

Open pablovalle opened this issue 8 months ago • 1 comments

Hi,

I'm following all the documentation steps to finetune pi0_base on Fractal dataset. I managed to compute the stats (previous step in the documentation), it worked without any problem, but when I run the train script I face the follwoing error:

Traceback (most recent call last):
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 446, in wrapped_fn_impl
    param_fn(*args, **kwargs)
  File "<@beartype(openpi.training.utils.TrainState) at 0x7f2010728ae0>", line 135, in TrainState
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 295, in __instancecheck__
    return cls.__is_instance_beartype__(obj)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefabc.py", line 112, in __is_instance_beartype__
    return isinstance(obj, cls.__type_beartype__)  # type: ignore[arg-type]
                           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 451, in __type_beartype__
    referent = import_module_attr(
               ^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_util/module/utilmodimport.py", line 295, in import_module_attr
    raise exception_cls(exception_message)
beartype.roar.BeartypeCallHintForwardRefException: Forward reference "ArrayTree" unimportable from module "openpi.training.utils".

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 811, in _get_problem_arg
    fn(*args, **kwargs)
  File "<@beartype(openpi.training.utils.check_single_arg) at 0x7f201072a660>", line 53, in check_single_arg
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 295, in __instancecheck__
    return cls.__is_instance_beartype__(obj)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefabc.py", line 112, in __is_instance_beartype__
    return isinstance(obj, cls.__type_beartype__)  # type: ignore[arg-type]
                           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_check/forward/reference/fwdrefmeta.py", line 451, in __type_beartype__
    referent = import_module_attr(
               ^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/beartype/_util/module/utilmodimport.py", line 295, in import_module_attr
    raise exception_cls(exception_message)
beartype.roar.BeartypeCallHintForwardRefException: Forward reference "ArrayTree" unimportable from module "openpi.training.utils".

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 451, in wrapped_fn_impl
    argmsg = _get_problem_arg(
             ^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Desktop/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 814, in _get_problem_arg
    raise TypeCheckError(
jaxtyping.TypeCheckError:
The problem arose whilst typechecking parameter 'opt_state'.

Here is my config file:


TrainConfig(
        name="pi0_fractal",
        # Here is an example of loading a pi0 model for LoRA fine-tuning.
        model=pi0.Pi0Config(action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m"),
        data=LeRobotFractalDataConfig(
            repo_id="IPEC-COMMUNITY/fractal20220817_data_lerobot",
            base_config=DataConfig(
                local_files_only=False,  # Set to True for local-only datasets.
                prompt_from_task=True,
            ),
        ),
        batch_size=8,
        num_workers=64,
        weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
        num_train_steps=30_000,
        # The freeze filter defines which parameters should be frozen during training.
        # We have a convenience function in the model config that returns the default freeze filter
        # for the given model config for LoRA finetuning. Just make sure it matches the model config
        # you chose above.
        freeze_filter=pi0.Pi0Config(
            paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m"
        ).get_freeze_filter(),
        # Turn off EMA for LoRA finetuning.
        ema_decay=None,
    ),

Any help is welcome!! Thanks

pablovalle avatar May 21 '25 07:05 pablovalle

Hi, I'm new to OpenPi and ran into the same error while trying to fine-tune the pi0_base model on HSR dataset.

Since the error seemed to be caused by type checking of "opt_state", I tried modifying the type definition of the TrainState class in openpi/src/openpi/training/utils.py (starting at line 15) from: opt_state: optax.OptState to: opt_state: Any This change resolved the error, and I was able to start training successfully.

However, I'm a bit concerned that removing the type check might cause issues later. I’d appreciate it if someone more experienced with OpenPi could confirm whether this workaround is acceptable or if there's a better fix.

Hope this helps.

t-rakko avatar May 28 '25 13:05 t-rakko

This was addressed at HEAD.

uzhilinsky avatar Jun 01 '25 22:06 uzhilinsky