Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions .dep-versions
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Always update the version check in catalyst.__init__ when changing the JAX version.
# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
jax=0.7.1
stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d
llvm=113f01aa82d055410f22a9d03b3468fa68600589
enzyme=v0.0.203
stablehlo=v1.13.7
llvm=8f264586d7521b0e305ca7bb78825aa3382ffef7
enzyme=v0.0.238

# Always remove custom PL/LQ versions before release.

Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-wheel-linux-arm64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ jobs:
run: |
cd $GITHUB_WORKSPACE/mlir/llvm-project
git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch
git apply $GITHUB_WORKSPACE/mlir/patches/llvm-python-bindinggen-annotations.patch

- name: Clone Stablehlo Submodule
if: steps.cache-stablehlo-source.outputs.cache-hit != 'true'
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-wheel-linux-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ jobs:
run: |
cd $GITHUB_WORKSPACE/mlir/llvm-project
git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch
git apply $GITHUB_WORKSPACE/mlir/patches/llvm-python-bindinggen-annotations.patch

- name: Clone Stablehlo Submodule
if: steps.cache-stablehlo-source.outputs.cache-hit != 'true'
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-wheel-macos-arm64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ jobs:
run: |
cd $GITHUB_WORKSPACE/mlir/llvm-project
git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch
git apply $GITHUB_WORKSPACE/mlir/patches/llvm-python-bindinggen-annotations.patch

- name: Clone Stablehlo Submodule
if: steps.cache-stablehlo-source.outputs.cache-hit != 'true'
Expand Down
52 changes: 27 additions & 25 deletions doc/dev/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,14 @@ In C++ it will look as follows:
// performs C+=A*B
// so we need to create a zero matrix of the desired type and shape first
tensor::EmptyOp zeromat =
rewriter.create<tensor::EmptyOp>(op.getLoc(), MatrixType, ValueRange{});
tensor::EmptyOp::create(rewriter, op.getLoc(), MatrixType, ValueRange{});

// The first argument to the `create` need to be a `Location`
// The first argument to the `create` method is the `OpBuilder` (rewriter)
// The second argument to the `create` need to be a `Location`
// which can usually just be a `getLoc()` from any operation you have handy
// The second argument needs to be (a list of) type(s) of the operation's output
// The third argument needs to be (a list of) input value(s) to the operation
linalg::MatmulOp matmul = rewriter.create<linalg::MatmulOp>(
// The third argument needs to be (a list of) type(s) of the operation's output
// The fourth argument needs to be (a list of) input value(s) to the operation
linalg::MatmulOp matmul = linalg::MatmulOp::create(rewriter,
op.getLoc(), TypeRange{MatrixType}, ValueRange{m1, m2}, ValueRange{zeromat});

// Some peculiarity for the matmul operation; no need to worry about it here
Expand Down Expand Up @@ -427,26 +428,27 @@ changes (also called the insertion point). Let's have a look at some of these el

- **Constructing new operations**:

New operations are created via the ``rewriter.create`` method. Here we want to generate a matrix
New operations are created via the ``OpTy::create`` method. Here we want to generate a matrix
multiplication instruction from the ``linalg`` dialect. C++ namespaces usually correspond to the
dialect name. The first thing the rewriter needs is always a `location object <https://mlir.llvm.org/docs/Diagnostics/#source-locations>`_,
dialect name. The first argument to ``create`` is always the ``OpBuilder`` (or ``PatternRewriter``),
followed by a `location object <https://mlir.llvm.org/docs/Diagnostics/#source-locations>`_,
which is used in debugging to refer back to the original source code line, for example.
Following this, we need to provide the right arguments to instantiate the operation. So-called
operation builders are automatically defined for this purpose, whose source can be referenced to
Following this, we need to provide the right arguments to instantiate the operation. The ``create``
methods are automatically defined for this purpose, whose source can be referenced to
consult which arguments are required. Looking into ``LinalgStructuredOps.h.inc`` for example
reveals the following options:

.. code-block:: cpp

static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ValueRange inputs, ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, TypeRange resultTensorTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes = {});
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, Attribute cast, ArrayRef<NamedAttribute> attributes = {});
static MatmulOp create(::mlir::OpBuilder &builder, ::mlir::Location location, ValueRange inputs, ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
static MatmulOp create(::mlir::OpBuilder &builder, ::mlir::Location location, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
static MatmulOp create(::mlir::OpBuilder &builder, ::mlir::Location location, TypeRange resultTensorTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes = {});
static MatmulOp create(::mlir::OpBuilder &builder, ::mlir::Location location, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, Attribute cast, ArrayRef<NamedAttribute> attributes = {});

We can always ignore the first two arguments, ``odsBuilder`` and ``odsState``, but the remaining
ones are the arguments we'll need to provide to the rewriter. We chose the simplest one which
only requires specifying a range of values for the operation ``inputs`` (two to be precise). We
can ignore ``outputs`` argument for now as it is a peculiarity of the ``linalg`` dialect.
The first argument is always the ``OpBuilder`` (or ``PatternRewriter``), and the second is the ``Location``.
The remaining arguments depend on the specific operation. We chose the version which
requires specifying result types, input values, and output values. We can ignore ``outputs``
argument for now as it is a peculiarity of the ``linalg`` dialect.
If necessary, the result types of an operation may be specified as can be seen in the second
version, but for ``matmul`` the result types can be automatically deduced.

Expand Down Expand Up @@ -618,7 +620,7 @@ For the rewriting part we'll want to introduce a few new elements, such as looki
// The function type should be identical to the type signature of the grad operation.
FunctionType fnType = rewriter.getFunctionType(op.getOperandTypes(), op.getResultTypes());

gradFn = rewriter.create<func::FuncOp>(op.getLoc(), fnName, fnType, visibility, nullptr, nullptr);
gradFn = func::FuncOp::create(rewriter, op.getLoc(), fnName, fnType, visibility, nullptr, nullptr);

// Now we just to populate the actual body of the function. First create an empty body.
Block *fnBody = gradFn.addEntryBlock();
Expand Down Expand Up @@ -678,30 +680,30 @@ In code:
ValueRange callArgs = gradFn.getArguments();

// We can reuse the same f(x, y, z) evaluation for all partial derivatives.
func::CallOp callOp = rewriter.create<func::CallOp>(loc, callee, callArgs);
func::CallOp callOp = func::CallOp::create(rewriter, loc, callee, callArgs);

// Loop through x, y, z to collect the partial derivatives.
std::vector<Value> gradient;
for (auto [idx, arg] : llvm::enumerate(callArgs)) {

FloatAttr hAttr = rewriter.getF64FloatAttr(0.1); // or another small fd parameter
Value hValue = rewriter.create<arith::ConstantOp>(loc, hAttr);
Value hValue = arith::ConstantOp::create(rewriter, loc, hAttr);

Value argPlusH = rewriter.create<arith::AddFOp>(loc, arg, hValue);
Value argPlusH = arith::AddFOp::create(rewriter, loc, arg, hValue);

// Make a copy of arguments to replace the argument with it's shifted value.
std::vector<Value> callArgsForward(callArgs.begin(), callArgs.end());
callArgsForward[idx] = argPlusH;
func::CallOp callOpForward =
rewriter.create<func::CallOp>(loc, callee, callArgsForward);
func::CallOp::create(rewriter, loc, callee, callArgsForward);

// Compute the finite difference.
Value difference = rewriter.create<arith::SubFOp>(loc, callOpForward.getResult(0), callOp.getResult(0));
Value partialDerivative = rewriter.create<arith::DivFOp>(loc, difference, hValue);
Value difference = arith::SubFOp::create(rewriter, loc, callOpForward.getResult(0), callOp.getResult(0));
Value partialDerivative = arith::DivFOp::create(rewriter, loc, difference, hValue);
gradient.push_back(partialDerivative);
}

rewriter.create<func::ReturnOp>(loc, gradient);
func::ReturnOp::create(rewriter, loc, gradient);
}

Alright, our function should now look something like this:
Expand Down
14 changes: 14 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@

<h3>Breaking changes 💔</h3>

* (Compiler integrators only) The versions of StableHLO/LLVM/Enzyme used by Catalyst have been updated.
[(#2415)](https://github.com/PennyLaneAI/catalyst/pull/2415)
[(#2416)](https://github.com/PennyLaneAI/catalyst/pull/2416)
[(#2444)](https://github.com/PennyLaneAI/catalyst/pull/2444)
[(#2445)](https://github.com/PennyLaneAI/catalyst/pull/2445)

- The StableHLO version has been updated to
[v1.13.7](https://github.com/openxla/stablehlo/tree/v1.13.7).
- The LLVM version has been updated to
[commit 8f26458](https://github.com/llvm/llvm-project/tree/8f264586d7521b0e305ca7bb78825aa3382ffef7).
- The Enzyme version has been updated to
[v0.0.238](https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.238).

* When an integer argnums is provided to `catalyst.vjp`, a singleton dimension is now squeezed
out. This brings the behaviour in line with that of `grad` and `jacobian`.
[(#2279)](https://github.com/PennyLaneAI/catalyst/pull/2279)
Expand Down Expand Up @@ -158,6 +171,7 @@ Lillian Frederiksen,
Sengthai Heng,
David Ittah,
Jeffrey Kam,
Mehrdad Malekmohammadi,
River McCubbin,
Mudit Pandey,
Andrija Paurevic,
Expand Down
5 changes: 2 additions & 3 deletions frontend/test/pytest/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,6 @@ def interpreted(x):
assert np.allclose(compiled(inp), interpreted(inp))


@pytest.mark.xfail(reason="Need PR 332.")
@pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)])
def test_preprocessing_outside_qnode(inp, backend):
"""Test the preprocessing outside qnode."""
Expand Down Expand Up @@ -1723,9 +1722,9 @@ def circuit(weights):
assert np.allclose(cat_res, jax_res)


@pytest.mark.xfail(reason="First need #332, then Vmap yields wrong results when differentiated")
def test_vmap_worflow_derivation(backend):
"""Check the gradient of a vmap workflow"""
pytest.xfail("Avoid segfault in CI: vmap differentiation not stable yet.")
n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3

Expand Down Expand Up @@ -1779,9 +1778,9 @@ def loss_fn(params, data, targets):
assert jnp.allclose(data_cat[1], data_jax[1])


@pytest.mark.xfail(reason="First need #332, then Vmap yields wrong results when differentiated")
def test_forloop_vmap_worflow_derivation(backend):
"""Test a forloop vmap."""
pytest.xfail("Avoid segfault in CI: vmap differentiation not stable yet.")
n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
weights = jnp.ones([n_wires, 3])
Expand Down
1 change: 1 addition & 0 deletions mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ set(STABLEHLO_LIBS
StablehloAssemblyFormat
StablehloBase
StablehloBroadcastUtils
StablehloBroadcastLowering
StablehloCAPI
StablehloLinalgTransforms
StablehloOps
Expand Down
2 changes: 1 addition & 1 deletion mlir/Enzyme
Submodule Enzyme updated 819 files
5 changes: 5 additions & 0 deletions mlir/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ llvm:
@if cd llvm-project; git apply --check $(MK_DIR)/patches/llvm-bufferization-segfault.patch; then \
git apply $(MK_DIR)/patches/llvm-bufferization-segfault.patch; \
fi
# Patch mlir python bindinggen annotations
# Remove patch after JAX is updated to v0.8.2 or later
@if cd llvm-project; git apply --check $(MK_DIR)/patches/llvm-python-bindinggen-annotations.patch; then \
git apply $(MK_DIR)/patches/llvm-python-bindinggen-annotations.patch; \
fi

cmake -G Ninja -S llvm-project/llvm -B $(LLVM_BUILD_DIR) \
-DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/Catalyst/Utils/EnsureFunctionDeclaration.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ OpT ensureFunctionDeclaration(PatternRewriter &rewriter, Operation *op, StringRe
rewriter.setInsertionPointToStart(mod.getBody());

// Create the specific function operation (LLVMFuncOp or FuncOp)
auto newFunc = rewriter.create<OpT>(op->getLoc(), fnSymbol, fnType);
auto newFunc = OpT::create(rewriter, op->getLoc(), fnSymbol, fnType);

// Handle visibility differences:
// func::FuncOp usually requires explicit private visibility for runtime decls.
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/Ion/IR/IonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def ParallelProtocolOp : Ion_Op<"parallelprotocol", [SingleBlockImplicitTerminat

let builders = [
OpBuilder<(ins
CArg<"mlir::ValueRange", "std::nullopt">:$in_qubits,
CArg<"mlir::ValueRange", "{}">:$in_qubits,
Comment thread
paul0403 marked this conversation as resolved.
CArg<"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::ValueRange)>",
"nullptr">)>
];
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/QEC/IR/QECOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,8 @@ def LayerOp : QEC_Op<"layer", [SingleBlockImplicitTerminator<"YieldOp">]> {
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins
CArg<"mlir::ValueRange", "std::nullopt">:$initArgs,
CArg<"mlir::ValueRange", "std::nullopt">:$outArgs,
CArg<"mlir::ValueRange", "{}">:$initArgs,
CArg<"mlir::ValueRange", "{}">:$outArgs,
CArg<"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::ValueRange, mlir::ValueRange)>",
"nullptr">)>
];
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/Quantum/Utils/QuantumSplitting.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class AugmentedCircuitGenerator {
template <typename IndexingOp> void cacheDynamicWire(IndexingOp op, mlir::OpBuilder &builder)
{
if (!op.getIdxAttr().has_value()) {
builder.create<ListPushOp>(op.getLoc(), oldToCloned.lookupOrDefault(op.getIdx()),
cache.wireVector);
ListPushOp::create(builder, op.getLoc(), oldToCloned.lookupOrDefault(op.getIdx()),
cache.wireVector);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def createDenormalIEEE : NativeCodeCall<
"::mlir::arith::DenormalModeAttr::get("
"$_builder.getContext(), ::mlir::arith::DenormalMode::ieee"
")">;
def createExactNone : NativeCodeCall<"::mlir::UnitAttr()">;


// Unary Lowering Patterns.
Expand Down Expand Up @@ -108,7 +109,7 @@ def : Pat<(StableHLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(Arith_MulIOp $l, $r, (createOverflowNone )),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(StableHLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(Arith_DivSIOp $l, $r),
(Arith_DivSIOp $l, $r, (createExactNone )),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(StableHLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(Arith_RemSIOp $l, $r),
Expand Down
Loading