[draft] Split each tape into a different jaxpr during trace_quantum_function
Before submitting
TODO: changelog, tests, writeup of implementation
Please complete the following checklist when submitting a PR:
-
[ ] All new functions and code must be clearly commented and documented.
-
[ ] Ensure that code is properly formatted by running
make format. The latest version of black andclang-format-14are used in CI/CD to check formatting. -
[ ] All new features must include a unit test. Integration and frontend tests should be added to
frontend/test, Quantum dialect and MLIR tests should be added tomlir/test, and Runtime tests should be added toruntime/tests.
When all the above are checked, delete everything above the dashed line and fill in the pull request template.
Context: Split each tape into a different jaxpr during trace_quantum_function Returns an overall big jaxpr that calls these small "each_tape" jaxprs.
This is a preliminary design. There's still some problems, the most significant one being the tape order does not necessarily agree with the pre-transform qnode's return value order.
Also many pytests simply crash under the current design.
An alternative is to apply the transform before the tracing begins, circumventing the need to manually build jaxprs.
Description of the Change:
Benefits:
Possible Drawbacks:
Related GitHub Issues: closes #442 [sc-67125]
Hello. You may have forgotten to update the changelog!
Please edit doc/changelog.md on your branch with:
- A one-to-two sentence description of the change. You may include a small working example for new features.
- A link back to this PR.
- Your name (or GitHub username) in the contributors section.
Hello. You may have forgotten to update the changelog!
Please edit doc/changelog.md on your branch with:
- A one-to-two sentence description of the change. You may include a small working example for new features.
- A link back to this PR.
- Your name (or GitHub username) in the contributors section.
Romain's previous draft on this: #692
After a lot of experimenting I now think messing with the current tracing frame's jaxpr is a bad idea.
The next approach would be insert some sort of "tape cut" primitive where currently we reset.
Oh dear, that was intense...
Only very few frontend tests still fail. Outside of some tests needing to be rewritten (since we are doing an entirely different thing now for multi tapes), there is only one true bug remaining.
We can now accomplish the following: given a transformed multiple tape circuit, each tape is in its own jaxpr and function!
dev = qml.device("lightning.qubit", wires=2)
def my_quantum_transform(tape: qml.tape.QuantumTape) -> (Sequence[qml.tape.QuantumTape], Callable):
tape1 = tape
tape2 = qml.tape.QuantumTape([qml.RY(0.5, wires=0)], [qml.expval(qml.X(0))])
def post_processing_fn(results):
return results[0] + results[1]
return [tape1, tape2], post_processing_fn
dispatched_transform = transform(my_quantum_transform)
@qml.qnode(dev)
def circuit(x):
qml.adjoint(qml.RY)(x[0], wires=0)
qml.RX(x[1], wires=1)
return qml.expval(qml.X(0))
circuit = dispatched_transform(circuit)
print("core PL results: ", circuit([0.1, 0.2]))
circuit = qjit(circuit)
print("qjit results: ", circuit([0.1, 0.2]))
print(circuit.jaxpr, circuit.mlir)
>>>
core PL results: 0.37959212195737485
qjit results: 0.37959212195737485
{ lambda ; a:f64[] b:f64[]. let
c:f64[] = func[
call_jaxpr={ lambda ; d:f64[] e:f64[]. let
f:f64[] = func[
call_jaxpr={ lambda ; g:f64[] h:f64[]. let
qdevice[
rtd_kwargs={'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}
rtd_lib=/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so
rtd_name=LightningSimulator
]
i:AbstractQreg() = qalloc 2
j:AbstractQbit() = qextract i 0
k:AbstractQbit() = qinst[
adjoint=True
ctrl_len=0
op=RY
params_len=1
qubits_len=1
] j g
l:AbstractObs(num_qubits=None,primitive=None) = namedobs[
kind=PauliX
] k
m:f64[] = expval[shots=None] l
n:AbstractQreg() = qinsert i 0 k
o:AbstractQbit() = qextract i 1
p:AbstractQbit() = qinst[
adjoint=False
ctrl_len=0
op=RX
params_len=1
qubits_len=1
] o h
q:AbstractQreg() = qinsert n 1 p
qdealloc q
in (m,) }
fn=<function trace_quantum_function.<locals>._f at 0x7a30bf90a680>
] d e
r:f64[] = func[
call_jaxpr={ lambda ; s:f64[] t:f64[]. let
qdevice[
rtd_kwargs={'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}
rtd_lib=/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so
rtd_name=LightningSimulator
]
u:AbstractQreg() = qalloc 2
v:AbstractQbit() = qextract u 0
w:AbstractQbit() = qinst[
adjoint=False
ctrl_len=0
op=RY
params_len=1
qubits_len=1
] v 0.5
x:AbstractObs(num_qubits=None,primitive=None) = namedobs[
kind=PauliX
] w
y:f64[] = expval[shots=None] x
z:AbstractQreg() = qinsert u 0 w
qdealloc z
in (y,) }
fn=<function trace_quantum_function.<locals>._f at 0x7a30bf90a0e0>
] d e
ba:f64[] = func[
call_jaxpr={ lambda ; bb:f64[] bc:f64[]. let
bd:f64[] = add bb bc
in (bd,) }
fn=<function circuit at 0x7a30bfe149d0>
] f r
in (ba,) }
fn=<QNode: device='<lightning.qubit device (wires=2) at 0x7a3105207f10>', interface='auto', diff_method='best'>
] a b
in (c,) }
module @circuit {
func.func public @jit_circuit(%arg0: tensor<f64>, %arg1: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} {
%0 = call @circuit(%arg0, %arg1) : (tensor<f64>, tensor<f64>) -> tensor<f64>
return %0 : tensor<f64>
}
func.func private @circuit(%arg0: tensor<f64>, %arg1: tensor<f64>) -> tensor<f64> attributes {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage<internal>, qnode} {
%0 = call @_f(%arg0, %arg1) : (tensor<f64>, tensor<f64>) -> tensor<f64>
%1 = call @_f_0(%arg0, %arg1) : (tensor<f64>, tensor<f64>) -> tensor<f64>
%2 = call @circuit_1(%0, %1) : (tensor<f64>, tensor<f64>) -> tensor<f64>
return %2 : tensor<f64>
}
func.func private @_f(%arg0: tensor<f64>, %arg1: tensor<f64>) -> tensor<f64> attributes {llvm.linkage = #llvm.linkage<internal>} {
quantum.device["/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so", "LightningSimulator", "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
%0 = quantum.alloc( 2) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
%extracted = tensor.extract %arg0[] : tensor<f64>
%out_qubits = quantum.custom "RY"(%extracted) %1 {adjoint} : !quantum.bit
%2 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs
%3 = quantum.expval %2 : f64
%from_elements = tensor.from_elements %3 : tensor<f64>
%4 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit
%5 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
%extracted_0 = tensor.extract %arg1[] : tensor<f64>
%out_qubits_1 = quantum.custom "RX"(%extracted_0) %5 : !quantum.bit
%6 = quantum.insert %4[ 1], %out_qubits_1 : !quantum.reg, !quantum.bit
quantum.dealloc %6 : !quantum.reg
quantum.device_release
return %from_elements : tensor<f64>
}
func.func private @_f_0(%arg0: tensor<f64>, %arg1: tensor<f64>) -> tensor<f64> attributes {llvm.linkage = #llvm.linkage<internal>} {
%cst = arith.constant 5.000000e-01 : f64
quantum.device["/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so", "LightningSimulator", "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
%0 = quantum.alloc( 2) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
%out_qubits = quantum.custom "RY"(%cst) %1 : !quantum.bit
%2 = quantum.namedobs %out_qubits[ PauliX] : !quantum.obs
%3 = quantum.expval %2 : f64
%from_elements = tensor.from_elements %3 : tensor<f64>
%4 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit
quantum.dealloc %4 : !quantum.reg
quantum.device_release
return %from_elements : tensor<f64>
}
func.func private @circuit_1(%arg0: tensor<f64>, %arg1: tensor<f64>) -> tensor<f64> attributes {llvm.linkage = #llvm.linkage<internal>} {
%0 = stablehlo.add %arg0, %arg1 : tensor<f64>
return %0 : tensor<f64>
}
func.func @setup() {
quantum.init
return
}
func.func @teardown() {
quantum.finalize
return
}
}
Marking this as ready for review only to trigger CI. This is not really ready.
Marking this as draft as we switch lanes and try to split tapes during mlir instead of during tracing
closing as #1017 accomplishes the same goal