mlx
mlx copied to clipboard
[BUG] broadcast of scalar array in last dimension fails after #1035
Describe the bug broadcast of scalar array in last dimension fails after #1035
To Reproduce
Include code snippet
>>> import mlx.core as mx
>>> a = mx.zeros([2, 3, 4, 5, 3])
>>> a[..., 0] = 1
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: [expand_dims] Invalid axes 4 for output array with 1 dimensions.
Expected behavior
This should be able to broadcast the scalar array into the last dimension.
Desktop (please complete the following information):
- Version: commit hash 490c0c4fdc4b5873b1b5f3807abdcca5691f5acf