Cache the number of elements in the action space
You probably dont need to dispatch to numpy everytime you call split to calculate the number of elements in the space. This PR caches the sizes (in a less than nice way imo) as an example. Before and after pictures below
This looks reasonable, waiting for hardware to test end to end. Any other optimization ideas for split? It's the main bottleneck right now. From before your patch:
🐡 python tests/test_extensions.py 0.00000032: Flatten time 0.00000294: Concatenate time 0.00001958: Split time 0.00000056: Unflatten time
You could try to vectorize the generation of samps -> leaves a la what was done in evaluate? Though I'm unsure if that'll work if the sz's can vary.
I think it'd look like
leaves = stacked_sample.reshape(len(flat_space), batch, *next(flat_space.values()).shape)
I assume since sz is the same across all flat spaces, shape will be the same too.