catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

[draft] Split each tape into a different jaxpr during trace_quantum_function

Open paul0403 opened this issue 1 year ago • 6 comments

Before submitting

TODO: changelog, tests, writeup of implementation

Please complete the following checklist when submitting a PR:

  • [ ] All new functions and code must be clearly commented and documented.

  • [ ] Ensure that code is properly formatted by running make format. The latest version of black and clang-format-14 are used in CI/CD to check formatting.

  • [ ] All new features must include a unit test. Integration and frontend tests should be added to frontend/test, Quantum dialect and MLIR tests should be added to mlir/test, and Runtime tests should be added to runtime/tests.

When all the above are checked, delete everything above the dashed line and fill in the pull request template.


Context: Split each tape into a different jaxpr during trace_quantum_function Returns an overall big jaxpr that calls these small "each_tape" jaxprs.

This is a preliminary design. There's still some problems, the most significant one being the tape order does not necessarily agree with the pre-transform qnode's return value order.

Also many pytests simply crash under the current design.

An alternative is to apply the transform before the tracing begins, circumventing the need to manually build jaxprs.

Description of the Change:

Benefits:

Possible Drawbacks:

Related GitHub Issues: closes #442 [sc-67125]

paul0403 avatar Aug 05 '24 16:08 paul0403

Hello. You may have forgotten to update the changelog! Please edit doc/changelog.md on your branch with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

github-actions[bot] avatar Aug 05 '24 16:08 github-actions[bot]

Hello. You may have forgotten to update the changelog! Please edit doc/changelog.md on your branch with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

github-actions[bot] avatar Aug 05 '24 16:08 github-actions[bot]

Romain's previous draft on this: #692

paul0403 avatar Aug 06 '24 13:08 paul0403

After a lot of experimenting I now think messing with the current tracing frame's jaxpr is a bad idea.

The next approach would be insert some sort of "tape cut" primitive where currently we reset.

paul0403 avatar Aug 08 '24 21:08 paul0403

Oh dear, that was intense...

Only very few frontend tests still fail. Outside of some tests needing to be rewritten (since we are doing an entirely different thing now for multi tapes), there is only one true bug remaining.

We can now accomplish the following: given a transformed multiple tape circuit, each tape is in its own jaxpr and function!

dev = qml.device("lightning.qubit", wires=2)

def my_quantum_transform(tape: qml.tape.QuantumTape) -> (Sequence[qml.tape.QuantumTape], Callable):
    tape1 = tape
    tape2 = qml.tape.QuantumTape([qml.RY(0.5, wires=0)], [qml.expval(qml.X(0))])

    def post_processing_fn(results):
        return results[0] + results[1]

    return [tape1, tape2], post_processing_fn

dispatched_transform = transform(my_quantum_transform)

@qml.qnode(dev)
def circuit(x):
    qml.adjoint(qml.RY)(x[0], wires=0)
    qml.RX(x[1], wires=1)
    return qml.expval(qml.X(0))


circuit = dispatched_transform(circuit)
print("core PL results: ", circuit([0.1, 0.2]))

circuit = qjit(circuit)
print("qjit results: ", circuit([0.1, 0.2]))
print(circuit.jaxpr, circuit.mlir)

>>>
core PL results:  0.37959212195737485
qjit results:  0.37959212195737485

{ lambda ; a:f64[] b:f64[]. let
    c:f64[] = func[
      call_jaxpr={ lambda ; d:f64[] e:f64[]. let
          f:f64[] = func[
            call_jaxpr={ lambda ; g:f64[] h:f64[]. let
                qdevice[
                  rtd_kwargs={'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}
                  rtd_lib=/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so
                  rtd_name=LightningSimulator
                ] 
                i:AbstractQreg() = qalloc 2
                j:AbstractQbit() = qextract i 0
                k:AbstractQbit() = qinst[
                  adjoint=True
                  ctrl_len=0
                  op=RY
                  params_len=1
                  qubits_len=1
                ] j g
                l:AbstractObs(num_qubits=None,primitive=None) = namedobs[
                  kind=PauliX
                ] k
                m:f64[] = expval[shots=None] l
                n:AbstractQreg() = qinsert i 0 k
                o:AbstractQbit() = qextract i 1
                p:AbstractQbit() = qinst[
                  adjoint=False
                  ctrl_len=0
                  op=RX
                  params_len=1
                  qubits_len=1
                ] o h
                q:AbstractQreg() = qinsert n 1 p
                qdealloc q
              in (m,) }
            fn=<function trace_quantum_function.<locals>._f at 0x7a30bf90a680>
          ] d e
          r:f64[] = func[
            call_jaxpr={ lambda ; s:f64[] t:f64[]. let
                qdevice[
                  rtd_kwargs={'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}
                  rtd_lib=/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so
                  rtd_name=LightningSimulator
                ] 
                u:AbstractQreg() = qalloc 2
                v:AbstractQbit() = qextract u 0
                w:AbstractQbit() = qinst[
                  adjoint=False
                  ctrl_len=0
                  op=RY
                  params_len=1
                  qubits_len=1
                ] v 0.5
                x:AbstractObs(num_qubits=None,primitive=None) = namedobs[
                  kind=PauliX
                ] w
                y:f64[] = expval[shots=None] x
                z:AbstractQreg() = qinsert u 0 w
                qdealloc z
              in (y,) }
            fn=<function trace_quantum_function.<locals>._f at 0x7a30bf90a0e0>
          ] d e
          ba:f64[] = func[
            call_jaxpr={ lambda ; bb:f64[] bc:f64[]. let
                bd:f64[] = add bb bc
              in (bd,) }
            fn=<function circuit at 0x7a30bfe149d0>
          ] f r
        in (ba,) }
      fn=<QNode: device='<lightning.qubit device (wires=2) at 0x7a3105207f10>', interface='auto', diff_method='best'>
    ] a b
  in (c,) } 

module @circuit {
  func.func public @jit_circuit(%arg0: tensor<f64>, %arg1: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} {
    %0 = call @circuit(%arg0, %arg1) : (tensor<f64>, tensor<f64>) -> tensor<f64>
    return %0 : tensor<f64>
  }
  func.func private @circuit(%arg0: tensor<f64>, %arg1: tensor<f64>) -> tensor<f64> attributes {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage<internal>, qnode} {
    %0 = call @_f(%arg0, %arg1) : (tensor<f64>, tensor<f64>) -> tensor<f64>
    %1 = call @_f_0(%arg0, %arg1) : (tensor<f64>, tensor<f64>) -> tensor<f64>
    %2 = call @circuit_1(%0, %1) : (tensor<f64>, tensor<f64>) -> tensor<f64>
    return %2 : tensor<f64>
  }
  func.func private @_f(%arg0: tensor<f64>, %arg1: tensor<f64>) -> tensor<f64> attributes {llvm.linkage = #llvm.linkage<internal>} {
    quantum.device["/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so", "LightningSimulator", "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
    %0 = quantum.alloc( 2) : !quantum.reg
    %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
    %extracted = tensor.extract %arg0[] : tensor<f64>
    %out_qubits = quantum.custom "RY"(%extracted) %1 {adjoint} : !quantum.bit
    %2 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs
    %3 = quantum.expval %2 : f64
    %from_elements = tensor.from_elements %3 : tensor<f64>
    %4 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit
    %5 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
    %extracted_0 = tensor.extract %arg1[] : tensor<f64>
    %out_qubits_1 = quantum.custom "RX"(%extracted_0) %5 : !quantum.bit
    %6 = quantum.insert %4[ 1], %out_qubits_1 : !quantum.reg, !quantum.bit
    quantum.dealloc %6 : !quantum.reg
    quantum.device_release
    return %from_elements : tensor<f64>
  }
  func.func private @_f_0(%arg0: tensor<f64>, %arg1: tensor<f64>) -> tensor<f64> attributes {llvm.linkage = #llvm.linkage<internal>} {
    %cst = arith.constant 5.000000e-01 : f64
    quantum.device["/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so", "LightningSimulator", "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
    %0 = quantum.alloc( 2) : !quantum.reg
    %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
    %out_qubits = quantum.custom "RY"(%cst) %1 : !quantum.bit
    %2 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs
    %3 = quantum.expval %2 : f64
    %from_elements = tensor.from_elements %3 : tensor<f64>
    %4 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit
    quantum.dealloc %4 : !quantum.reg
    quantum.device_release
    return %from_elements : tensor<f64>
  }
  func.func private @circuit_1(%arg0: tensor<f64>, %arg1: tensor<f64>) -> tensor<f64> attributes {llvm.linkage = #llvm.linkage<internal>} {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f64>
    return %0 : tensor<f64>
  }
  func.func @setup() {
    quantum.init
    return
  }
  func.func @teardown() {
    quantum.finalize
    return
  }
}

paul0403 avatar Aug 09 '24 16:08 paul0403

Marking this as ready for review only to trigger CI. This is not really ready.

paul0403 avatar Aug 09 '24 16:08 paul0403

Marking this as draft as we switch lanes and try to split tapes during mlir instead of during tracing

paul0403 avatar Aug 13 '24 14:08 paul0403

closing as #1017 accomplishes the same goal

paul0403 avatar Sep 10 '24 20:09 paul0403