Skip to content

Commit

Permalink
Implement a frontend UI to invoke the mlir quantum peephole transform…
Browse files Browse the repository at this point in the history
…ation passes (#911)

**Context:** Implement a frontend UI to invoke the mlir quantum peephole
transformation passes.

Currently there is a `remove-chained-self-inverse` mlir pass in Catalyst
that removes two neighbouring Hadamard gates in a circuit, but there is
no way to actually invoke it from the frontend. This PR implements this
frontend as a decorator on a qnode. When the `catalyst.cancel_inverses`
decorator is added onto a qnode, the `remove-chained-self-inverse` mlir
pass will be run on that qnode.

Following the decision to split #883 into two PRs, this is the frontend
portion.

**Description of the Change:**
We implment the peephole optimization library with the mlir transform
dialect https://mlir.llvm.org/docs/Dialects/Transform/,
https://mlir.llvm.org/docs/Tutorials/transform/

Briefly, we wish to generate the following mlir during frontend tracing
and lowering

```
module @workflow {

  func.func private @f{ ... }
  func.func private @g{ ... }

  module attributes {transform.with_named_sequence} {
    transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) {
      %0 = transform.apply_registered_pass "remove-chained-self-inverse" to %arg0 {options = "func-name=f"} : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
      transform.yield 
    }
  }
}
```

The outer module is the payload (the module to run the transform on),
and the inner module is the transformer (the module that schedules what
passes to run). The schedule in this example is run the
`-remove-chained-self-inverse` pass on the payload module `@workflow`,
with pass option `func-name=f`.

The `transform.named_sequence` must be terminated by a
`transform.yield`.

A new pass, `ApplyTransformSequence.cpp`, will perform the following:
1. Remove the transformer module from the top-level payload module
(which is its parent), and save the transformer module operation in
memory
2. Perform the transform, through API provided by the mlir transform
dialect
3. Delete the transformer module

######

We generate this mlir through two new jax primitives. 

The first one, `transform_named_sequence_p`, will be lowered to
`transform.named_sequence` with a `transform.yield` inside and put it in
a parent transformer module marked with the unit attribute
`transform.with_named_sequence`. The parent transformer module is
inserted into the top-level payload module.

The second one, `apply_registered_pass_p`, will be lowered to a
`transform.apply_registered_pass` inside the `transform.named_sequence`.
If one `transform.apply_registered_pass` already exists in the sequence,
the new pass will be added to after the previous one.


The design is as follows:

0. QJIT start
1. Capture jaxpr (`QJIT.capture()`)
1.1. Before capture, we insert the `transform_named_sequence_p` to
**every** jaxpr produced (the purpose is so that transform dialect will
not fail)
1.2. During capture, if the API `@cancel_inverses` is called, inject a
corresponding `apply_registered_pass_p` into jaxpr

2. Generate mlir and compile (`QJIT.generate_ir()`, `QJIT.compile()`)
   2.1. The jaxpr primitives are lowered to mlir. 
2.2. Here the `Compiler` class will run the pipeline, with the new
`-apply-transform-sequence` pass


**Benefits:** The user can now run quantum mlir passes from the
frontend.


[sc-67519]
  • Loading branch information
paul0403 authored Aug 6, 2024
1 parent 7231b5b commit dbc6376
Show file tree
Hide file tree
Showing 22 changed files with 1,016 additions and 13 deletions.
57 changes: 57 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,63 @@
)
```

* A frontend decorator can be applied to a qnode to signal a compiler pass run.
[(#911)](https://github.com/PennyLaneAI/catalyst/pull/911)

A new module, `catalyst.passes` (file `frontend/catalyst/passes.py`), is added to
provide UI access for pass decorators. This PR adds the `cancel_inverses` decorator,
which runs the `-removed-chained-self-inverse` mlir pass that cancels two neighbouring
Hadamard gates.

```python
from catalyst.debug import print_compilation_stage
from catalyst.passes import cancel_inverses

dev = qml.device("lightning.qubit", wires=1)

@qjit(keep_intermediate=True)
def workflow():
@cancel_inverses
@qml.qnode(dev)
def f(x: float):
qml.RX(x, wires=0)
qml.Hadamard(wires=0)
qml.Hadamard(wires=0)
return qml.expval(qml.PauliZ(0))

@qml.qnode(dev)
def g(x: float):
qml.RX(x, wires=0)
qml.Hadamard(wires=0)
qml.Hadamard(wires=0)
return qml.expval(qml.PauliZ(0))

ff = f(1.0)
gg = g(1.0)

return ff, gg

>>> workflow()
(Array(0.54030231, dtype=float64), Array(0.54030231, dtype=float64))
>>> print_compilation_stage(workflow, "QuantumCompilationPass")
func.func private @f {
...
%out_qubits = quantum.custom "RX"(%extracted) %1 : !quantum.bit
%2 = quantum.namedobs %out_qubits[ PauliZ] : !quantum.obs
%3 = quantum.expval %2 : f64
...
}
func.func private @g {
...
%out_qubits = quantum.custom "RX"(%extracted) %1 : !quantum.bit
%out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit
%out_qubits_1 = quantum.custom "Hadamard"() %out_qubits_0 : !quantum.bit
%2 = quantum.namedobs %out_qubits_1[ PauliZ] : !quantum.obs
%3 = quantum.expval %2 : f64
...
}
```

<h3>Improvements</h3>

* Catalyst is now compatible with Enzyme `v0.0.130`
Expand Down
9 changes: 9 additions & 0 deletions doc/code/__init__.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,12 @@ Module: catalyst.third_party.cuda
:no-heading:
:no-inheritance-diagram:
:fullname:


Module: catalyst.passes
---------------------------------

.. automodapi:: catalyst.passes
:no-heading:
:no-inheritance-diagram:
:fullname:
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __getattr__(cls, name):
"mlir_quantum.dialects.gradient",
"mlir_quantum.dialects.catalyst",
"mlir_quantum.dialects.mitigation",
"mlir_quantum.dialects._transform_ops_gen",
"mlir_quantum.compiler_driver",
"pybind11",
"cudaq",
Expand Down
1 change: 1 addition & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt
QUANTUM_COMPILATION_PASS = (
"QuantumCompilationPass",
[
"apply-transform-sequence", # Run the transform sequence defined in the MLIR module
"annotate-function",
"lower-mitigation",
"lower-gradients",
Expand Down
164 changes: 164 additions & 0 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
from jaxlib.mlir.dialects.scf import ConditionOp, ForOp, IfOp, WhileOp, YieldOp
from jaxlib.mlir.dialects.stablehlo import ConstantOp as StableHLOConstantOp
from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp
from mlir_quantum.dialects._transform_ops_gen import (
ApplyRegisteredPassOp,
NamedSequenceOp,
)
from mlir_quantum.dialects._transform_ops_gen import YieldOp as TransformYieldOp
from mlir_quantum.dialects.catalyst import (
AssertionOp,
CallbackCallOp,
Expand Down Expand Up @@ -183,6 +188,18 @@ def _obs_lowering(aval):
return (ir.OpaqueType.get("quantum", "obs"),)


#
# Transform Module Type
#
class AbstractTransformMod(AbstractValue):
"""Abstract transform module type."""


def _transform_mod_lowering(aval):
assert isinstance(aval, AbstractTransformMod)
return (ir.OpaqueType.get("transform", 'op<"builtin.module">'),)


#
# registration
#
Expand All @@ -195,6 +212,9 @@ def _obs_lowering(aval):
core.raise_to_shaped_mappings[AbstractObs] = lambda aval, _: aval
mlir.ir_type_handlers[AbstractObs] = _obs_lowering

core.raise_to_shaped_mappings[AbstractTransformMod] = lambda aval, _: aval
mlir.ir_type_handlers[AbstractTransformMod] = _transform_mod_lowering


class Folding(Enum):
"""
Expand Down Expand Up @@ -263,6 +283,9 @@ class Folding(Enum):
value_and_grad_p.multiple_results = True
assert_p = core.Primitive("assert")
assert_p.multiple_results = True
apply_registered_pass_p = core.Primitive("apply_registered_pass")
transform_named_sequence_p = core.Primitive("transform_named_sequence")
transform_named_sequence_p.multiple_results = True


def _assert_jaxpr_without_constants(jaxpr: ClosedJaxpr):
Expand Down Expand Up @@ -386,6 +409,145 @@ def _print_lowering(jax_ctx: mlir.LoweringRuleContext, *args, string=None, memre
return PrintOp(val=val, const_val=None, print_descriptor=memref).results


#
# transform dialect lowering
#


def get_named_sequence_in_module(mod):
for op in mod.body.operations:
if op.operation.name == "transform.named_sequence":
return op.operation
return None


#
# transform_named_sequence
#
@transform_named_sequence_p.def_abstract_eval
def _transform_named_sequence_p_abstract_eval(*args):
return ()


@transform_named_sequence_p.def_impl
def _transform_named_sequence_p_def_impl(*args): # pragma: no cover
raise NotImplementedError()


def _transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, *args):
transform_mod_type = ir.OpaqueType.get("transform", 'op<"builtin.module">')
module = jax_ctx.module_context.module

# We wish to generate the transformer module, and place it in the top-level module
# The transformer module must be marked with the "transform.with_named_sequence" attribute
# The transformer module has a single block, and the block contains the
# "transform.named_sequence @__transform_main" operation

with ir.InsertionPoint(module.body):
transformer_module = ir.Operation.create("builtin.module", regions=1)
with_named_sequence_attr = ir.UnitAttr.get(jax_ctx.module_context.context)
transformer_module.operation.attributes["transform.with_named_sequence"] = (
with_named_sequence_attr
)
bb_transformer = ir.Block.create_at_start(transformer_module.bodyRegion)

functype = ir.FunctionType.get(inputs=[transform_mod_type], results=[])
functype_attr = ir.TypeAttr.get(functype)

# Insert the transform.named_sequence op into the transformer module
# Note that InsertionPoint(Block) inserts after the last operation but still inside the block.
with ir.InsertionPoint(bb_transformer):
named_sequence_op = NamedSequenceOp(
sym_name="__transform_main",
function_type=functype_attr,
)

# transform.named_sequence op is the "main function" of the transform dialect
# and thus needs an entry block (which also should be its only block)
# The argument of the block is the payload module
bb_named_sequence = ir.Block.create_at_start(
named_sequence_op.body, arg_types=[transform_mod_type]
)

# The transform.named_sequence needs a terminator called "transform.yield"
with ir.InsertionPoint(bb_named_sequence):
transform_yield_op = TransformYieldOp(operands_=[]) # pylint: disable=unused-variable

return named_sequence_op.results


#
# apply_registered_pass
#
@apply_registered_pass_p.def_abstract_eval
def _apply_registered_pass_abstract_eval(*args, pass_name, options=None):
return AbstractTransformMod()


@apply_registered_pass_p.def_impl
def _apply_registered_pass_def_impl(*args, pass_name, options=None): # pragma: no cover
raise NotImplementedError()


def _apply_registered_pass_lowering(
jax_ctx: mlir.LoweringRuleContext, *args, pass_name, options=None
):
transform_mod_type = ir.OpaqueType.get("transform", 'op<"builtin.module">')
module = jax_ctx.module_context.module
named_sequence_op = None
for op in reversed(module.body.operations):
# transformer module usually is at the end of the module, so look for it from the end
if op.operation.name == "builtin.module":
named_sequence_op = get_named_sequence_in_module(op)
break
assert (
named_sequence_op is not None
), """
transform.apply_registered_pass must be placed in a transform.named_sequence,
but none exist in the module.
"""

# If there already is a apply_registered_pass,
# insert after the last pass in the existing pass sequence.
# Note that ir.InsertionPoint(op) sets the insertion point to immediately BEFORE the op
named_sequence_op_block = named_sequence_op.regions[0].blocks[0]
first_op_in_block = named_sequence_op_block.operations[0].operation

assert first_op_in_block.name in (
"transform.apply_registered_pass",
"transform.yield",
), """
Unexpected operation in transform.named_sequence!
Only transform.apply_registered_pass and transform.yield are allowed.
"""

if first_op_in_block.name == "transform.apply_registered_pass":
_ = len(named_sequence_op_block.operations)
yield_op = named_sequence_op_block.operations[_ - 1].operation
current_last_pass = named_sequence_op_block.operations[_ - 2].operation
with ir.InsertionPoint(yield_op):
apply_registered_pass_op = ApplyRegisteredPassOp(
result=transform_mod_type,
target=current_last_pass.result,
pass_name=pass_name,
options=options,
)

# otherwise it's the first pass, i.e. only a yield op is in the block
# so insert right before the yield op
else:
ip = named_sequence_op.regions[0].blocks[0]
with ir.InsertionPoint(ip.operations[len(ip.operations) - 1]):
apply_registered_pass_op = ApplyRegisteredPassOp(
result=transform_mod_type,
target=ip.arguments[0],
pass_name=pass_name,
options=options,
)

return apply_registered_pass_op.results


#
# func
#
Expand Down Expand Up @@ -1987,6 +2149,8 @@ def _adjoint_lowering(
mlir.register_lowering(assert_p, _assert_lowering)
mlir.register_lowering(python_callback_p, _python_callback_lowering)
mlir.register_lowering(value_and_grad_p, _value_and_grad_lowering)
mlir.register_lowering(apply_registered_pass_p, _apply_registered_pass_lowering)
mlir.register_lowering(transform_named_sequence_p, _transform_named_sequence_lowering)


def _scalar_abstractify(t):
Expand Down
19 changes: 18 additions & 1 deletion frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from catalyst.debug.instruments import instrument
from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr
from catalyst.logging import debug_logger, debug_logger_init
from catalyst.passes import _inject_transform_named_sequence
from catalyst.qfunc import QFunc
from catalyst.tracing.contexts import EvaluationContext
from catalyst.tracing.type_signatures import (
Expand Down Expand Up @@ -584,8 +585,24 @@ def closure(*args, **kwargs):
(qml.QNode, "__call__", closure),
):
# TODO: improve PyTree handling

def fn_with_transform_named_sequence(*args, **kwargs):
"""
This function behaves exactly like the user function being jitted,
taking in the same arguments and producing the same results, except
it injects a transform_named_sequence jax primitive at the beginning
of the jaxpr when being traced.
Note that we do not overwrite self.original_function and self.user_function;
this fn_with_transform_named_sequence is ONLY used here to produce tracing
results with a transform_named_sequence primitive at the beginning of the
jaxpr. It is never executed or used anywhere, except being traced here.
"""
_inject_transform_named_sequence()
return self.user_function(*args, **kwargs)

jaxpr, out_type, treedef = trace_to_jaxpr(
self.user_function, static_argnums, abstracted_axes, full_sig, {}
fn_with_transform_named_sequence, static_argnums, abstracted_axes, full_sig, {}
)

return jaxpr, out_type, treedef, dynamic_sig
Expand Down
Loading

0 comments on commit dbc6376

Please sign in to comment.