catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

Add EnumAttr to the Gradient Dialect for different differentiation methods.

Open erick-xanadu opened this issue 1 year ago • 2 comments

Issue description

  • Actual behavior: Right now, the gradient method is a StringAttr, but it would be better to have an EnumAttr.

erick-xanadu avatar Jun 06 '24 18:06 erick-xanadu

This refers to the gradient in Struct QuantumDevice?

mews6 avatar Jul 18 '24 05:07 mews6

No.

The gradient dialect is an MLIR dialect. Its operations are defined in the file catalyst/mlir/include/Gradient/IR/GradientOps.td . These definitions are in the tablegen language. Some operations, like ValueAndGradOp take a string attribute called method.

// Part of the definition of 
def ValueAndGradOp : Gradient_Op<"value_and_grad", [
        SameVariadicResultSize,
        DeclareOpInterfaceMethods<CallOpInterface>,
        DeclareOpInterfaceMethods<SymbolUserOpInterface>,
        GradientOpInterface
        ]> {
    let summary = "Compute the value and gradient of a function.";

    let arguments = (ins
        StrAttr:$method,  // <---- this corresponds to whether the derivative is obtained through finite differences or adjoint, etc.
        FlatSymbolRefAttr:$callee,
        Variadic<AnyType>:$operands,
        OptionalAttr<AnyIntElementsAttr>:$diffArgIndices,
        OptionalAttr<Builtin_FloatAttr>:$finiteDiffParam
    );

The ticket suggest changing this string to an enumeration, similar to how named observables are defined in the quantum dialect.


def NamedObservable : I32EnumAttr<"NamedObservable",
    "Known named observables",
    [
        I32EnumAttrCase<"Identity", 0>,  // These observables are an enum
        I32EnumAttrCase<"PauliX",   1>, // but still have a name.
        I32EnumAttrCase<"PauliY",   2>,
        I32EnumAttrCase<"PauliZ",   3>,
        I32EnumAttrCase<"Hadamard", 4>,
    ]> {
    let cppNamespace = "catalyst::quantum";
    let genSpecializedAttr = 0;
}

The QuantumDevice is a class in the runtime, which does have a Gradient method. I'll take a closer look, I forgot about this interface, but I think it is to compute the adjoint differentiation method at runtime.

erick-xanadu avatar Jul 18 '24 12:07 erick-xanadu