Possible mx.array bug?
I think the mx.array is not initializing arrays the way one would expect and does not raise an error when one references array elements that are out of bounds, which can be a trap when implementing complex logic, especially for guys like me, who have been trained in fortran and always have to be conscious about the indexing 😄 ...
Here's a simple example to highlight what I think is a bug:
import mlx.core as mx
import numpy as np
# Initializing the temperature field and copy on the GPU
T_mx = mx.ones([4, 4])
T_np = np.ones([4, 4])
print(T_mx)
print(T_np)
print(T_mx[0,0], T_mx[0,1], T_mx[0,2], T_mx[0,3], T_mx[0,4], T_mx[0,5], T_mx[0,6])
print(T_np[0,0], T_np[0,1], T_np[0,2], T_np[0,3], T_np[0,4], T_np[0,5], T_np[0,6])
and here's the output
array([[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]], dtype=float32)
[[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]
array(1, dtype=float32) array(1, dtype=float32) array(1, dtype=float32) array(1, dtype=float32) array(1, dtype=float32) array(1, dtype=float32) array(1, dtype=float32)
Traceback (most recent call last):
File "/Users/m2/PycharmProjects/pythonProject_MXL/bug.py", line 11, in <module>
print(T_np[0,0], T_np[0,1], T_np[0,2], T_np[0,3], T_np[0,4], T_np[0,5], T_np[0,6])
~~~~^^^^^
IndexError: index 4 is out of bounds for axis 1 with size 4
Process finished with exit code 1
MLX is happy putting 1 into the out of bounds elements, while Numpy properly raises error for those elements.
Thanks for flagging this! I'll get working on a fix right away
Actually, this is expected behavior. This is in general once place where we might end up being different than NumPy and go more like Jax. You can try the same example in Jax, I believe the default behavior is to use the last value for out of bounds indexing (but it doesn't throw).
One reason we go this route in general is that we don't raise exceptions from Metal kernels and we don't want to sanitize inputs before the kernel as that is costly / doesn't make sense with the execution model.
There may be some simple cases (like Python integer based indexing) that we can properly throw (or at least make it an option). But when you start to do more fancy indexing with other mx.array, we may have to follow Jax.
Another thing we should do @jagrit06 is document this a little more clearly so it's not a surprise.
I'm going to mark this as docs/enhancement since it's not actually a bug.
@awni I agree about raising exceptions from metal kernels
My thinking was we let the underlying gather/take metal kernel remain the same, but instead do a check in the __get_item__ call - that way most cases are covered before ending up at the metal kernel - what do you think ?
I can see the performance motive, but if by mistake you use an out of bounds index to set the value (i.e. you think you are setting the last value but your index actually overshoots) the value of the last element is not being set. So this works one way (reading) but not the other way around (writing). This can be nasty to track. I run into just this error, I thought was setting the last element, but I was overshooting the bounds by 1 in the time stepping, so I was actually setting a non-used ghost element. It took me quite a while to figure out. It would be nice if at least there was an option to enforce NumPy like behavior.