maxtext
maxtext copied to clipboard
Adding Mixtral-8x22b
- adding mixtral-8x22b config (to pyconfig as well)
- improving the llama and mistral conversion script - in-place weight writing to reduce total RAM usage - better progress tracking
Are we good to start review? If so, please mark it as ready, and assign it to @RissyRan @gobbleturk and @ZhiyuLi-goog. Thanks!
Not yet, will do, still finishing testing the new tests. Sounds good!
- should I remove the model run configs (
MaxText/configs/...), leaving model configs (MaxText/configs/models/...)? - assets need to be pushed to
gs://maxtext-externalfor new test.../8x22b/1_test_mixstral.shto work - the model needs a new tokenizer, I added it to
assets/ - there's a quality-of-life change to
MaxText/checkpointing.pyfor local storage, that maybe could use testing?
A decoding test is coherent: "Input [INST] I love to [/INST] -> Sure, I'm here to help answer", but the logits do not agree sometimes (?):
logit 00 --------------------------------------------------------------------------------
metric abs rel
---------- --------- -----------
logit_norm 11.0552 0.0158547
logit_abs 1.60011 181.683
prob_norm 0.274151 0.621473
prob_abs 0.168545 0.801035
kl_div 0.188969 0.188969
top_1 0 0
top_k -1 -1
where_top 2 2
logit 01 --------------------------------------------------------------------------------
metric abs rel
---------- ----------- ------------
logit_norm 12.1327 0.0195044
logit_abs 0.381988 625.316
prob_norm 0.0202508 0.0428087
prob_abs 0.0147027 0.453724
kl_div 0.00461103 0.00461103
top_1 -1 -1
top_k -1 -1
where_top 0 0
logit 02 --------------------------------------------------------------------------------
metric abs rel
---------- ---------- -----------
logit_norm 175.829 0.240218
logit_abs 48.1239 2550.73
prob_norm 0.302262 0.950779
prob_abs 0.212441 18.4295
kl_div 0.868579 0.868579
top_1 0 0
top_k 0 0
where_top 2 2
logit 03 --------------------------------------------------------------------------------
metric abs rel
---------- ----------- ----------
logit_norm 1575.29 1.11016
logit_abs 25.4272 269.331
prob_norm 0.502296 1.13861
prob_abs 0.414026 2022.58
kl_div 3.34534 3.34534
top_1 0 0
top_k 0 0
where_top 27 27
logit 04 --------------------------------------------------------------------------------
metric abs rel
---------- ----------- -----------
logit_norm 687.712 0.478749
logit_abs 15.3228 693.711
prob_norm 0.191784 1.06292
prob_abs 0.148175 170.036
kl_div 1.86368 1.86368
top_1 0 0
top_k 0 0
where_top 9357 9357
logit 05 --------------------------------------------------------------------------------
metric abs rel
---------- ---------- ------------
logit_norm 403.599 0.89769
logit_abs 8.13489 13229.5
prob_norm 0.247333 0.625215
prob_abs 0.215195 14.4916
kl_div 0.218769 0.218769
top_1 0 0
top_k -1 -1
where_top 1 1
logit 06 --------------------------------------------------------------------------------
metric abs rel
---------- ---------- ------------
logit_norm 671.76 0.806734
logit_abs 11.0339 11134.3
prob_norm 0.213499 0.418123
prob_abs 0.130965 318.033
kl_div 1.06856 1.06856
top_1 -1 -1
top_k -1 -1
where_top 0 0
Moving this PR to the maxtext branch instead of a fork: https://github.com/google/maxtext/pull/845