Add jvpvjp test that use a decomposition and change dtypes
Original conversation: https://github.com/pytorch/functorch/pull/818#discussion_r878187686
Right now, all the functions that do forward over reverse testing using a decompositions don't change dtypes (if we remove the lines that turn all the jacobian outputs into floats, all tests still pass)
Step 1: Add a test that changes dtype. One option for this that doesn't involve finding a new function is to use type promotion (i.e. two inputs, one of type float and one of type double)
Step 2: Update the lines that change the outputs to be all floats* to include a check that the original outputs are of the correct dtype. From @zou3519:
If the operation has inputs of dtype A and outputs of dtype B:
- then the vjp is a function that takes inputs of dtype (A, B) and gives outputs of dtype B
- so the jacobian_vjp should return jacobians of dtype A (for the primals) and dtype B (for the cotangents)
- the jacobian_jvp should return jacobians of dtype B
* Lines changing all outputs to floats look like:
# For dtype changing operations, the jacobians have different dtype.
jacobian_jvp = tree_map(lambda x: x.to(torch.float), jacobian_jvp)
jacobian_vjp = tree_map(lambda x: x.to(torch.float), jacobian_vjp)