functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Add jvpvjp test that use a decomposition and change dtypes

Open samdow opened this issue 3 years ago • 0 comments

Original conversation: https://github.com/pytorch/functorch/pull/818#discussion_r878187686

Right now, all the functions that do forward over reverse testing using a decompositions don't change dtypes (if we remove the lines that turn all the jacobian outputs into floats, all tests still pass)

Step 1: Add a test that changes dtype. One option for this that doesn't involve finding a new function is to use type promotion (i.e. two inputs, one of type float and one of type double)

Step 2: Update the lines that change the outputs to be all floats* to include a check that the original outputs are of the correct dtype. From @zou3519:

If the operation has inputs of dtype A and outputs of dtype B:

  • then the vjp is a function that takes inputs of dtype (A, B) and gives outputs of dtype B
  • so the jacobian_vjp should return jacobians of dtype A (for the primals) and dtype B (for the cotangents)
  • the jacobian_jvp should return jacobians of dtype B

* Lines changing all outputs to floats look like:

# For dtype changing operations, the jacobians have different dtype.
jacobian_jvp = tree_map(lambda x: x.to(torch.float), jacobian_jvp)
jacobian_vjp = tree_map(lambda x: x.to(torch.float), jacobian_vjp)

samdow avatar May 20 '22 15:05 samdow