moose icon indicating copy to clipboard operation
moose copied to clipboard

Solve dispatch for "pure producers" kernels

Open voronaam opened this issue 4 years ago • 2 comments

Kernels like ConstantOp and InputOp have signatures () -> Tensor and currently always return Float64 tensor.

We need something to feed extra dtype information into the kernels so that they can produce the kind of tensor expected on the lower level.

Other kernels use dtype of the input arguments to figure out the output's dtype. But "pure producers" have no inputs and can not do that.

The troubled code is marked with as TODO like this:

        // TODO: figure out which dtype to return here
        Ok(AbstractTensor::Float64(z))

voronaam avatar Nov 03 '21 16:11 voronaam

Other kernels use dtype of the input arguments to figure out the output's dtype. But "pure producers" have no inputs and can not do that.

Can't the return type in the signature be used? What am I missing?

mortendahl avatar Nov 08 '21 10:11 mortendahl

A good example is Ones kernel.

modelled_kernel! {
    PlacementOnes::ones, OnesOp,
    [
        (HostPlacement, (HostShape) -> Tensor => [hybrid] Self::logical_host_kernel),

and later

impl OnesOp {
    #[allow(clippy::type_complexity)]
    pub(crate) fn logical_host_kernel<S: Session>(
// ....
    {
        let result = plc.ones(sess, &shape);
        Ok(AbstractTensor::Float64(result))
    }

Now that we have TensorDType inside the Tensor type we should be able to do that. We might need a custom kernel for that though.

voronaam avatar Mar 09 '22 21:03 voronaam