Add covariance_type option to GMM class
Description
Currently, the zuko.flows.mixture.GMM class only supports full covariance matrices. However, there are a number of use cases (especially high-dimensional) where a full covariance matrix is either not needed or infeasible to estimate. This issue proposes to add the option to choose between different covariance matrix types similar to sklearn.mixture.GaussianMixture
Here is an example of how different covariance types approximate a mixture of 3 Gaussians with varying covariance matrices.
Implementation
The current structure of the GMM zuko.flows.mixture.GMM class makes it very easy to add the above mentioned enhancements. I have implemented the changes in a fork of the repository and could open a pull request if this change is wanted. I have only tested the code for the unconditional case, but I do not see any way I could break it when adding context features.
Further improvements
When generating the above figure, I (again) realised how easily mode collapse happens for GMMs. The zuko.flows.mixture.GMM class could, therefore, also benefit from some sort of initialisation procedure, again, similar to sklearn.mixture.GaussianMixture. I fully understand if that goes beyond the scope of what Zuko wants to achieve. The benefit is that Zuko is very convenient to use and ties in so well with Pytorch code that having such a procedure here could be nice. However, it might add another dependency (e.g., sklearn) if you want to use existing implementations of initialisation algorithms. I have some basic implementation of this (using sklearn) lying around and would be happy to polish it up and make another commit if this is wanted.
Hello @dominik-strutz, thank you for the feature request!
I think supporting several covariance types would be a valuable improvement for the GMM class. Actually I wanted to add a diagonal plus low rank covariance option at some point.
For the initialization, this is also relevant although handling the conditional case might be tricky (but possible by editing the weights of the last layer of the hyper network). However, Zuko cannot rely on sklearn so it should be implemented from scratch.
We would accept a PR that implements this feature!