dpctl icon indicating copy to clipboard operation
dpctl copied to clipboard

Implement tensor.trace

Open oleksandr-pavlyk opened this issue 2 years ago • 0 comments

Implement dpctl.tensor.trace per array API spec.

This could be an all top-level function:

  1. Form a view into diagonals
  2. Call dpctl.tensor.sum on the last axis of the view

Possible implementation for offset=0:

def trace(ary):
     assert isinstance(ary, dpt.usm_ndarray)
     assert ary.ndim >= 2
     res_shape = ary.shape[:-2] + (min(ary.shape[-2:]), )
     res_strides = ary.strides[:-2] + (sum(ary.strides[-2:]),)
     view = dpt.usm_ndarray(res_shape, dtype=ary.dtype, buffer=ary, strides=res_strides)
     return dpt.sum(view, axis=-1, dtype=ary.dtype)

With sample usage:

In [4]: trace(dpt.ones((3,3,4,4)))
Out[4]:
usm_ndarray([[4., 4., 4.],
             [4., 4., 4.],
             [4., 4., 4.]], dtype=float32)

oleksandr-pavlyk avatar Aug 21 '23 09:08 oleksandr-pavlyk