dpctl
dpctl copied to clipboard
Implement tensor.trace
Implement dpctl.tensor.trace per array API spec.
This could be an all top-level function:
- Form a view into diagonals
- Call
dpctl.tensor.sumon the last axis of the view
Possible implementation for offset=0:
def trace(ary):
assert isinstance(ary, dpt.usm_ndarray)
assert ary.ndim >= 2
res_shape = ary.shape[:-2] + (min(ary.shape[-2:]), )
res_strides = ary.strides[:-2] + (sum(ary.strides[-2:]),)
view = dpt.usm_ndarray(res_shape, dtype=ary.dtype, buffer=ary, strides=res_strides)
return dpt.sum(view, axis=-1, dtype=ary.dtype)
With sample usage:
In [4]: trace(dpt.ones((3,3,4,4)))
Out[4]:
usm_ndarray([[4., 4., 4.],
[4., 4., 4.],
[4., 4., 4.]], dtype=float32)