Store_mass_matrix in low rank adapt mode
The code
code = """
stan_model
"""
compiled = nutpie.compile_stan_model(code=code)
compiled = compiled.with_data(mu=3.)
trace = nutpie.sample(compiled, save_warmup=True, low_rank_modified_mass_matrix=True, store_mass_matrix=True)
raises ValueError: cannot reshape array of size 0 into shape (1800,)
from
--> 407 return _trace_to_arviz( 408 results, 409 self._settings.num_tune, ... --> 115 data[i, : len(chunk)] = values.reshape((len(chunk), *last_shape)) 116 stats_dict[name] = data[:, n_tune:] 117 stats_dict_tune[name] = data[:, :n_tune]
Thanks for reporting. Seems like the python side of the mass matrix storage is broken. The rust code stores the diagonal values and the eigenvalues (not the eigenvectors, which would be useful...) here. But when we convert the trace from the internal format based on arrow to arviz, things break.
I think the issue is that for the very first iterations, there are no mass_matrix_eigenvals to report since these only get computed at switch() times..? So the storage writes a null, which pyarrow does not like. We get a pyarrow ListArray of the form [null, null,...,null, [1,2,3],...[4,5,6]], which can't be converted .to_numpy() as the others can.
I think two changes could resolve this:
- Can we fill initial window with nulls that have a shape which matches the later values? Right now we get
[null,null, .. , [1,2]]and I think the ragged array is causing the numpy conversion to misbehave. (It just slices the null values off to get a regular array). - I think this bit isn't using the correct condition do decide who gets a special shape. For the
inverse_mass_matrix, it will correctly find out that the last shape = number of parameters. It looks like theinverse_mass_matrixis aFixedSizeListType. Butmass_matrix_eigenvalsis aListTypeso it concludes the last shape is an empty tuple. That cause problems reshaping downstream because the data array won't be expecting the extra dimension.
https://github.com/pymc-devs/nutpie/blob/8f15f8e2cc2b88d587db2fe4506c78d563bae05c/python/nutpie/sample.py#L95C1-L100C28
if hasattr(col_type, "list_size"):
last_shape = (col_type.list_size,)
dtype = col_type.field(0).type.to_pandas_dtype()
else:
dtype = col_type.to_pandas_dtype()
last_shape = ()