Solve dispatch for "pure producers" kernels
Kernels like ConstantOp and InputOp have signatures () -> Tensor and currently always return Float64 tensor.
We need something to feed extra dtype information into the kernels so that they can produce the kind of tensor expected on the lower level.
Other kernels use dtype of the input arguments to figure out the output's dtype. But "pure producers" have no inputs and can not do that.
The troubled code is marked with as TODO like this:
// TODO: figure out which dtype to return here
Ok(AbstractTensor::Float64(z))
Other kernels use dtype of the input arguments to figure out the output's dtype. But "pure producers" have no inputs and can not do that.
Can't the return type in the signature be used? What am I missing?
A good example is Ones kernel.
modelled_kernel! {
PlacementOnes::ones, OnesOp,
[
(HostPlacement, (HostShape) -> Tensor => [hybrid] Self::logical_host_kernel),
and later
impl OnesOp {
#[allow(clippy::type_complexity)]
pub(crate) fn logical_host_kernel<S: Session>(
// ....
{
let result = plc.ones(sess, &shape);
Ok(AbstractTensor::Float64(result))
}
Now that we have TensorDType inside the Tensor type we should be able to do that. We might need a custom kernel for that though.