flashbax icon indicating copy to clipboard operation
flashbax copied to clipboard

How to load saved buffer properly?

Open 4ku opened this issue 1 year ago • 5 comments

I tried:

class ReplayBufferDataStore():
    def __init__(
        self,
        env: gym.Env,
        capacity: int,
        sample_batch_size: int = 32,
        priority_exponent: float = 0.8,
        device: str = "gpu",
        name: str = "replay_buffer",
        checkpoint_path: str = None,
    ):
        self.sample_batch_size = sample_batch_size
        self.priority_exponent = priority_exponent
        self.device = jax.devices(device)[0]

        self.buffer = fbx.make_prioritised_flat_buffer(
            max_length=capacity,
            min_length=sample_batch_size,
            sample_batch_size=sample_batch_size,
            add_sequences=True,
            add_batch_size=None,
            priority_exponent=priority_exponent,
            device=device,
        )

        # Preprocess the transition once to avoid redundant transformations
        single_transition = self._initialize_single_transition(env)
        self.state = self.buffer.init(single_transition)
        self.state = jax.device_put(self.state, device=self.device)

        self.vault = Vault(
            vault_name=name,
            experience_structure=self.state.experience,
            rel_dir=os.path.join(os.path.dirname(checkpoint_path), "vaults"),
        )
...

    def save(self):
        self.vault.write(self.state)

    def load(self, vault_path: str):
        vault_name = vault_path.split("/")[-2]
        vault_uid = vault_path.split("/")[-1]
        vault_path = os.path.dirname(os.path.dirname(vault_path))
        vault = Vault(
            vault_name=vault_name,
            experience_structure=self.state.experience,
            rel_dir=vault_path,
            vault_uid=vault_uid,
        )
        state = vault.read()

        loaded_experience = frozen_dict.freeze(state.experience)
        self.state = _insert(self.buffer, self.state, loaded_experience)

experience_structure of loaded state doesn't match with self.state. For example, if there are 500 transitions stored in the buffer via vault, the loaded state size will be 500, but I initialized the buffer with size 100_000.

4ku avatar Nov 26 '24 12:11 4ku

Thanks for the question! The proper way to load experience from a Vault into a buffer state is to use the "buffer.add" function. For example, this notebook uses buffer_add = jax.jit(buffer.add, donate_argnums=0) buffer_state = buffer_add(buffer_state, new_experience)

In your case, "new_experience" would be the experience loaded from the Vault. Please let me know if you have further questions!

lbeyers avatar Jan 07 '25 10:01 lbeyers

Hey @4ku & @lbeyers! Sorry for not taking a look at this sooner :)

As Louise mentioned, you can load the Vault state using an add function, but this is actually inefficient—the read Vault state is exactly compatible with a normal flashbax state. So we don't need to load the vault state, and then use add. We can directly use the loaded state with our usual fbx functions.

It seems that the problem above is actually the number of timesteps in the loaded state—is that correct, @4ku?

If so: indeed, Vault writes to disk the timesteps up to the current_index. e.g. If we haven't added any timesteps when we write to the vault, and then we load in that vault, the state will be of size (B, 0, E). i.e. zero timesteps.

Previously, I was thinking about having this functionality though—perhaps we know how many timesteps we want in the flashbax buffer state, even if the vault is smaller than that. I think it should be a pretty simple change—e.g. we could use the timesteps parameter: https://github.com/instadeepai/flashbax/blob/1352bfa06494f45174f5d3498e0795d0d31be77c/flashbax/vault/vault.py#L461-L463

Currently, if you ask for more timesteps than available (e.g. as in the original post, I ask for a buffer size of 100_000 but the vault only has 500 transitions), it breaks—oops 😂 But this could easily be fixed, and I think is a nice way to use the current API to achieve the desired functionality.

Haven't had a chance to give this a proper look but let me know if I have understood you correctly. Happy to draft something when I get a moment, probably sometime this week.

Thanks!

callumtilbury avatar Jan 07 '25 11:01 callumtilbury

@callumtilbury Yeah, you are right. The actual problem is a number of timesteps in the loaded state. So it will be great if I can do:

state = vault.read(timesteps=100_000)

And state will have size 100_000 with first 500 transitions from vault. Look forward to your fix! Thank you

4ku avatar Jan 09 '25 11:01 4ku

Also I think current_index shouldn't be 0 (or here should be extra parameter). I want to continue to add transitions to the end of loaded state.

4ku avatar Jan 09 '25 12:01 4ku

@callumtilbury Really look forward to your fixes)

4ku avatar Feb 05 '25 14:02 4ku