mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Convert Musicgen to MLX

Open akashicMarga opened this issue 2 years ago • 10 comments

I would like to express my gratitude for your hard work on this project! I am interested in converting Musigen from the Meta team to MLX. To optimize the model for my limited RAM, I am considering using the smaller-model version (https://huggingface.co/facebook/musicgen-stereo-small).

Could you please suggest a good starting point for this conversion? From the Hugging Face implementation, I can see that it uses the T5 encoder for text encoding, which is already available in this repository, and Encodec from Meta team for audio encoding.

Thank you for your assistance!

akashicMarga avatar Dec 30 '23 11:12 akashicMarga

That would be awesome! For conversions usually the best place to start is a reference PyTorch implementation. This is a slightly bigger project since it involves multiple models (encodec, the music generator and the LM).

We already have a T5 example as you mentioned.

It might make sense to have Encodec as a standalone example since it's useful for lot's of downstream audio generation. For example, I was looking recently at converting a different TTS model which also uses it. Maybe that is a good place to start?

awni avatar Dec 31 '23 05:12 awni

Yes, I was thinking the same of Separating Encodec as a different module as it could be used individually and in many TTS systems like VALLE and VITs.

akashicMarga avatar Jan 01 '24 07:01 akashicMarga

Is anyone working on a port of Encodec? If not I might take a stab myself as I'm interested in getting some audio generation up and running!

awni avatar Jan 02 '24 21:01 awni

@awni I just started yesterday night and there are modules which are directly available in torch like LSTM and sequential layers which are used in ENcodec but not available in mlx directly. I started from encodec main repo as it was pretty simple. I will be getting time only over weekends as i have my org works too. Can't give a timeline TBH. And i really want audio generation up and running.

akashicMarga avatar Jan 03 '24 05:01 akashicMarga

Got it. Ok let me know what's missing in terms of layers etc that should be in mlx.nn and we can prioritize getting them in. For example there is a PR for RNNs/LSTMs out now that we can try to get merged sooner.

If you think it will take a while, you can always start a draft PR and we can collaborate on it!

awni avatar Jan 03 '24 05:01 awni

Yes i checked that PR today morning. it has most of the things. i will go through Encodec code and come back with more details. Maybe by tomorrow. @awni just a suggestion, can we keep discussions instead of issues as most of the issues reported here are only enhancements as mlx is still growing and wrt to performance i haven't got any issues till now.

akashicMarga avatar Jan 03 '24 05:01 akashicMarga

@awni

below modules will be required in mlx and some existing PRs and issues have already been addressed.

  1. Setting dilation, groups in convolution layer - https://github.com/ml-explore/mlx/issues/100
  2. Addition of sequential layers like lstm, rnn, gru - https://github.com/ml-explore/mlx/pull/268
  3. torch has a function for full layer normalisation which will be helpful here - https://pytorch.org/docs/stable/_modules/torch/nn/utils/weight_norm.html#weight_norm

Rest of the items seems easily portable. for point 1 i have tried using cnn without dilation and group params as the values set did not have major impact when i went through pytorch code. it's default only.

akashicMarga avatar Jan 04 '24 06:01 akashicMarga

I'm digging into the C++ for it https://github.com/pytorch/pytorch/blob/834c7a1d3ea07878ad87d127ee28606fc140b552/aten/src/ATen/native/WeightNorm.cpp#L50

I'm fine with C++ but not configured to build the MLX project.. questioning motivation on a Saturday night haha. Do we have a Discord? I'd like to speak with someone, maybe there is already an implementation or an obvious way to get it done. I'm brand new (few hours) to MLX. Reading into https://github.com/ml-explore/mlx/blob/main/mlx/primitives.cpp

signalprime avatar Feb 25 '24 05:02 signalprime

We have a discord link here: https://github.com/ml-explore/mlx/discussions/733

You shouldn't need to implement weight norm in C++. That can all be done in Python using existing ops.

awni avatar Feb 25 '24 06:02 awni

That's great, and thanks a lot for the input @awni. Part of me was thinking that too.

signalprime avatar Feb 25 '24 06:02 signalprime