diff --git a/.dep-versions b/.dep-versions index 215a27612b..8e0934c8ef 100644 --- a/.dep-versions +++ b/.dep-versions @@ -1,10 +1,8 @@ # Always update the version check in catalyst.__init__ when changing the JAX version. -# To update JAX version alongside compatible dependency tags, run the following script: -# python3 .github/workflows/set_dep_versions.py {JAX_version} jax=0.7.1 -stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d -llvm=113f01aa82d055410f22a9d03b3468fa68600589 -enzyme=v0.0.203 +stablehlo=v1.13.7 +llvm=8f264586d7521b0e305ca7bb78825aa3382ffef7 +enzyme=v0.0.238 # Always remove custom PL/LQ versions before release. diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index 91f6066a7a..f054a17cdb 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -113,6 +113,7 @@ jobs: run: | cd $GITHUB_WORKSPACE/mlir/llvm-project git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch + git apply $GITHUB_WORKSPACE/mlir/patches/llvm-python-bindinggen-annotations.patch - name: Clone Stablehlo Submodule if: steps.cache-stablehlo-source.outputs.cache-hit != 'true' diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index 0f8d5f9776..8a6975d967 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -132,6 +132,7 @@ jobs: run: | cd $GITHUB_WORKSPACE/mlir/llvm-project git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch + git apply $GITHUB_WORKSPACE/mlir/patches/llvm-python-bindinggen-annotations.patch - name: Clone Stablehlo Submodule if: steps.cache-stablehlo-source.outputs.cache-hit != 'true' diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index 2b309b08f8..7462a9e9d9 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -118,6 +118,7 @@ jobs: run: | cd $GITHUB_WORKSPACE/mlir/llvm-project git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch + git apply $GITHUB_WORKSPACE/mlir/patches/llvm-python-bindinggen-annotations.patch - name: Clone Stablehlo Submodule if: steps.cache-stablehlo-source.outputs.cache-hit != 'true' diff --git a/doc/dev/transforms.rst b/doc/dev/transforms.rst index d30b244d87..5cf3e19c78 100644 --- a/doc/dev/transforms.rst +++ b/doc/dev/transforms.rst @@ -390,13 +390,14 @@ In C++ it will look as follows: // performs C+=A*B // so we need to create a zero matrix of the desired type and shape first tensor::EmptyOp zeromat = - rewriter.create(op.getLoc(), MatrixType, ValueRange{}); + tensor::EmptyOp::create(rewriter, op.getLoc(), MatrixType, ValueRange{}); - // The first argument to the `create` need to be a `Location` + // The first argument to the `create` method is the `OpBuilder` (rewriter) + // The second argument to the `create` need to be a `Location` // which can usually just be a `getLoc()` from any operation you have handy - // The second argument needs to be (a list of) type(s) of the operation's output - // The third argument needs to be (a list of) input value(s) to the operation - linalg::MatmulOp matmul = rewriter.create( + // The third argument needs to be (a list of) type(s) of the operation's output + // The fourth argument needs to be (a list of) input value(s) to the operation + linalg::MatmulOp matmul = linalg::MatmulOp::create(rewriter, op.getLoc(), TypeRange{MatrixType}, ValueRange{m1, m2}, ValueRange{zeromat}); // Some peculiarity for the matmul operation; no need to worry about it here @@ -427,26 +428,27 @@ changes (also called the insertion point). Let's have a look at some of these el - **Constructing new operations**: - New operations are created via the ``rewriter.create`` method. Here we want to generate a matrix + New operations are created via the ``OpTy::create`` method. Here we want to generate a matrix multiplication instruction from the ``linalg`` dialect. C++ namespaces usually correspond to the - dialect name. The first thing the rewriter needs is always a `location object `_, + dialect name. The first argument to ``create`` is always the ``OpBuilder`` (or ``PatternRewriter``), + followed by a `location object `_, which is used in debugging to refer back to the original source code line, for example. - Following this, we need to provide the right arguments to instantiate the operation. So-called - operation builders are automatically defined for this purpose, whose source can be referenced to + Following this, we need to provide the right arguments to instantiate the operation. The ``create`` + methods are automatically defined for this purpose, whose source can be referenced to consult which arguments are required. Looking into ``LinalgStructuredOps.h.inc`` for example reveals the following options: .. code-block:: cpp - static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ValueRange inputs, ValueRange outputs, ArrayRef attributes = {}); - static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef attributes = {}); - static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, TypeRange resultTensorTypes, ValueRange operands, ArrayRef attributes = {}); - static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, Attribute cast, ArrayRef attributes = {}); + static MatmulOp create(::mlir::OpBuilder &builder, ::mlir::Location location, ValueRange inputs, ValueRange outputs, ArrayRef attributes = {}); + static MatmulOp create(::mlir::OpBuilder &builder, ::mlir::Location location, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef attributes = {}); + static MatmulOp create(::mlir::OpBuilder &builder, ::mlir::Location location, TypeRange resultTensorTypes, ValueRange operands, ArrayRef attributes = {}); + static MatmulOp create(::mlir::OpBuilder &builder, ::mlir::Location location, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, Attribute cast, ArrayRef attributes = {}); - We can always ignore the first two arguments, ``odsBuilder`` and ``odsState``, but the remaining - ones are the arguments we'll need to provide to the rewriter. We chose the simplest one which - only requires specifying a range of values for the operation ``inputs`` (two to be precise). We - can ignore ``outputs`` argument for now as it is a peculiarity of the ``linalg`` dialect. + The first argument is always the ``OpBuilder`` (or ``PatternRewriter``), and the second is the ``Location``. + The remaining arguments depend on the specific operation. We chose the version which + requires specifying result types, input values, and output values. We can ignore ``outputs`` + argument for now as it is a peculiarity of the ``linalg`` dialect. If necessary, the result types of an operation may be specified as can be seen in the second version, but for ``matmul`` the result types can be automatically deduced. @@ -618,7 +620,7 @@ For the rewriting part we'll want to introduce a few new elements, such as looki // The function type should be identical to the type signature of the grad operation. FunctionType fnType = rewriter.getFunctionType(op.getOperandTypes(), op.getResultTypes()); - gradFn = rewriter.create(op.getLoc(), fnName, fnType, visibility, nullptr, nullptr); + gradFn = func::FuncOp::create(rewriter, op.getLoc(), fnName, fnType, visibility, nullptr, nullptr); // Now we just to populate the actual body of the function. First create an empty body. Block *fnBody = gradFn.addEntryBlock(); @@ -678,30 +680,30 @@ In code: ValueRange callArgs = gradFn.getArguments(); // We can reuse the same f(x, y, z) evaluation for all partial derivatives. - func::CallOp callOp = rewriter.create(loc, callee, callArgs); + func::CallOp callOp = func::CallOp::create(rewriter, loc, callee, callArgs); // Loop through x, y, z to collect the partial derivatives. std::vector gradient; for (auto [idx, arg] : llvm::enumerate(callArgs)) { FloatAttr hAttr = rewriter.getF64FloatAttr(0.1); // or another small fd parameter - Value hValue = rewriter.create(loc, hAttr); + Value hValue = arith::ConstantOp::create(rewriter, loc, hAttr); - Value argPlusH = rewriter.create(loc, arg, hValue); + Value argPlusH = arith::AddFOp::create(rewriter, loc, arg, hValue); // Make a copy of arguments to replace the argument with it's shifted value. std::vector callArgsForward(callArgs.begin(), callArgs.end()); callArgsForward[idx] = argPlusH; func::CallOp callOpForward = - rewriter.create(loc, callee, callArgsForward); + func::CallOp::create(rewriter, loc, callee, callArgsForward); // Compute the finite difference. - Value difference = rewriter.create(loc, callOpForward.getResult(0), callOp.getResult(0)); - Value partialDerivative = rewriter.create(loc, difference, hValue); + Value difference = arith::SubFOp::create(rewriter, loc, callOpForward.getResult(0), callOp.getResult(0)); + Value partialDerivative = arith::DivFOp::create(rewriter, loc, difference, hValue); gradient.push_back(partialDerivative); } - rewriter.create(loc, gradient); + func::ReturnOp::create(rewriter, loc, gradient); } Alright, our function should now look something like this: diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index b91b6208dc..bfe16e8d91 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -30,6 +30,19 @@

Breaking changes 💔

+* (Compiler integrators only) The versions of StableHLO/LLVM/Enzyme used by Catalyst have been updated. + [(#2415)](https://github.com/PennyLaneAI/catalyst/pull/2415) + [(#2416)](https://github.com/PennyLaneAI/catalyst/pull/2416) + [(#2444)](https://github.com/PennyLaneAI/catalyst/pull/2444) + [(#2445)](https://github.com/PennyLaneAI/catalyst/pull/2445) + + - The StableHLO version has been updated to + [v1.13.7](https://github.com/openxla/stablehlo/tree/v1.13.7). + - The LLVM version has been updated to + [commit 8f26458](https://github.com/llvm/llvm-project/tree/8f264586d7521b0e305ca7bb78825aa3382ffef7). + - The Enzyme version has been updated to + [v0.0.238](https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.238). + * When an integer argnums is provided to `catalyst.vjp`, a singleton dimension is now squeezed out. This brings the behaviour in line with that of `grad` and `jacobian`. [(#2279)](https://github.com/PennyLaneAI/catalyst/pull/2279) @@ -158,6 +171,7 @@ Lillian Frederiksen, Sengthai Heng, David Ittah, Jeffrey Kam, +Mehrdad Malekmohammadi, River McCubbin, Mudit Pandey, Andrija Paurevic, diff --git a/frontend/test/pytest/test_gradient.py b/frontend/test/pytest/test_gradient.py index 988811fb85..b053a1ad18 100644 --- a/frontend/test/pytest/test_gradient.py +++ b/frontend/test/pytest/test_gradient.py @@ -1643,7 +1643,6 @@ def interpreted(x): assert np.allclose(compiled(inp), interpreted(inp)) -@pytest.mark.xfail(reason="Need PR 332.") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_preprocessing_outside_qnode(inp, backend): """Test the preprocessing outside qnode.""" @@ -1723,9 +1722,9 @@ def circuit(weights): assert np.allclose(cat_res, jax_res) -@pytest.mark.xfail(reason="First need #332, then Vmap yields wrong results when differentiated") def test_vmap_worflow_derivation(backend): """Check the gradient of a vmap workflow""" + pytest.xfail("Avoid segfault in CI: vmap differentiation not stable yet.") n_wires = 5 data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3 @@ -1779,9 +1778,9 @@ def loss_fn(params, data, targets): assert jnp.allclose(data_cat[1], data_jax[1]) -@pytest.mark.xfail(reason="First need #332, then Vmap yields wrong results when differentiated") def test_forloop_vmap_worflow_derivation(backend): """Test a forloop vmap.""" + pytest.xfail("Avoid segfault in CI: vmap differentiation not stable yet.") n_wires = 5 data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3 weights = jnp.ones([n_wires, 3]) diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 26b88efab1..7ca4eef6ea 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -52,6 +52,7 @@ set(STABLEHLO_LIBS StablehloAssemblyFormat StablehloBase StablehloBroadcastUtils + StablehloBroadcastLowering StablehloCAPI StablehloLinalgTransforms StablehloOps diff --git a/mlir/Enzyme b/mlir/Enzyme index 476c8e3193..33a5cff906 160000 --- a/mlir/Enzyme +++ b/mlir/Enzyme @@ -1 +1 @@ -Subproject commit 476c8e3193a8577ba24ff845ae2294109225f83a +Subproject commit 33a5cff90674ed82cef8cb78650d4744711ff8c6 diff --git a/mlir/Makefile b/mlir/Makefile index 73cd9d240e..a67709ded7 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -66,6 +66,11 @@ llvm: @if cd llvm-project; git apply --check $(MK_DIR)/patches/llvm-bufferization-segfault.patch; then \ git apply $(MK_DIR)/patches/llvm-bufferization-segfault.patch; \ fi + # Patch mlir python bindinggen annotations + # Remove patch after JAX is updated to v0.8.2 or later + @if cd llvm-project; git apply --check $(MK_DIR)/patches/llvm-python-bindinggen-annotations.patch; then \ + git apply $(MK_DIR)/patches/llvm-python-bindinggen-annotations.patch; \ + fi cmake -G Ninja -S llvm-project/llvm -B $(LLVM_BUILD_DIR) \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ diff --git a/mlir/include/Catalyst/Utils/EnsureFunctionDeclaration.h b/mlir/include/Catalyst/Utils/EnsureFunctionDeclaration.h index 2703da0756..356ff21bb2 100644 --- a/mlir/include/Catalyst/Utils/EnsureFunctionDeclaration.h +++ b/mlir/include/Catalyst/Utils/EnsureFunctionDeclaration.h @@ -41,7 +41,7 @@ OpT ensureFunctionDeclaration(PatternRewriter &rewriter, Operation *op, StringRe rewriter.setInsertionPointToStart(mod.getBody()); // Create the specific function operation (LLVMFuncOp or FuncOp) - auto newFunc = rewriter.create(op->getLoc(), fnSymbol, fnType); + auto newFunc = OpT::create(rewriter, op->getLoc(), fnSymbol, fnType); // Handle visibility differences: // func::FuncOp usually requires explicit private visibility for runtime decls. diff --git a/mlir/include/Ion/IR/IonOps.td b/mlir/include/Ion/IR/IonOps.td index 56045fa19c..ab53bb099b 100644 --- a/mlir/include/Ion/IR/IonOps.td +++ b/mlir/include/Ion/IR/IonOps.td @@ -213,7 +213,7 @@ def ParallelProtocolOp : Ion_Op<"parallelprotocol", [SingleBlockImplicitTerminat let builders = [ OpBuilder<(ins - CArg<"mlir::ValueRange", "std::nullopt">:$in_qubits, + CArg<"mlir::ValueRange", "{}">:$in_qubits, CArg<"llvm::function_ref", "nullptr">)> ]; diff --git a/mlir/include/QEC/IR/QECOps.td b/mlir/include/QEC/IR/QECOps.td index fa4b2c119b..3a58986516 100644 --- a/mlir/include/QEC/IR/QECOps.td +++ b/mlir/include/QEC/IR/QECOps.td @@ -671,8 +671,8 @@ def LayerOp : QEC_Op<"layer", [SingleBlockImplicitTerminator<"YieldOp">]> { let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins - CArg<"mlir::ValueRange", "std::nullopt">:$initArgs, - CArg<"mlir::ValueRange", "std::nullopt">:$outArgs, + CArg<"mlir::ValueRange", "{}">:$initArgs, + CArg<"mlir::ValueRange", "{}">:$outArgs, CArg<"llvm::function_ref", "nullptr">)> ]; diff --git a/mlir/include/Quantum/Utils/QuantumSplitting.h b/mlir/include/Quantum/Utils/QuantumSplitting.h index a22470d846..55365333fb 100644 --- a/mlir/include/Quantum/Utils/QuantumSplitting.h +++ b/mlir/include/Quantum/Utils/QuantumSplitting.h @@ -66,8 +66,8 @@ class AugmentedCircuitGenerator { template void cacheDynamicWire(IndexingOp op, mlir::OpBuilder &builder) { if (!op.getIdxAttr().has_value()) { - builder.create(op.getLoc(), oldToCloned.lookupOrDefault(op.getIdx()), - cache.wireVector); + ListPushOp::create(builder, op.getLoc(), oldToCloned.lookupOrDefault(op.getIdx()), + cache.wireVector); } } diff --git a/mlir/include/hlo-extensions/Transforms/stablehlo_legalize_to_standard_patterns.td b/mlir/include/hlo-extensions/Transforms/stablehlo_legalize_to_standard_patterns.td index ffeddb9e93..84f21a2189 100644 --- a/mlir/include/hlo-extensions/Transforms/stablehlo_legalize_to_standard_patterns.td +++ b/mlir/include/hlo-extensions/Transforms/stablehlo_legalize_to_standard_patterns.td @@ -71,6 +71,7 @@ def createDenormalIEEE : NativeCodeCall< "::mlir::arith::DenormalModeAttr::get(" "$_builder.getContext(), ::mlir::arith::DenormalMode::ieee" ")">; +def createExactNone : NativeCodeCall<"::mlir::UnitAttr()">; // Unary Lowering Patterns. @@ -108,7 +109,7 @@ def : Pat<(StableHLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_MulIOp $l, $r, (createOverflowNone )), [(IsSameSizeConstraint $l, $r)]>; def : Pat<(StableHLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), - (Arith_DivSIOp $l, $r), + (Arith_DivSIOp $l, $r, (createExactNone )), [(IsSameSizeConstraint $l, $r)]>; def : Pat<(StableHLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_RemSIOp $l, $r), diff --git a/mlir/lib/Catalyst/Transforms/ArrayListToMemRefPass.cpp b/mlir/lib/Catalyst/Transforms/ArrayListToMemRefPass.cpp index 226d9da47f..89809f5ceb 100644 --- a/mlir/lib/Catalyst/Transforms/ArrayListToMemRefPass.cpp +++ b/mlir/lib/Catalyst/Transforms/ArrayListToMemRefPass.cpp @@ -38,7 +38,7 @@ struct ArrayListBuilder { return failure(); } - auto unpacked = b.create(loc, resultTypes, list); + auto unpacked = UnrealizedConversionCastOp::create(b, loc, resultTypes, list); return ArrayListBuilder{.dataField = unpacked.getResult(0), .sizeField = unpacked.getResult(1), .capacityField = unpacked.getResult(2), @@ -62,7 +62,7 @@ struct ArrayListBuilder { ctx, /*inputs=*/ {dataField.getType(), sizeField.getType(), capacityField.getType(), elementType}, /*outputs=*/{}); - auto pushFn = b.create(loc, funcName, pushFnType); + auto pushFn = func::FuncOp::create(b, loc, funcName, pushFnType); pushFn.setPrivate(); Block *entryBlock = pushFn.addEntryBlock(); @@ -72,31 +72,32 @@ struct ArrayListBuilder { BlockArgument capacityField = pushFn.getArgument(2); BlockArgument value = pushFn.getArgument(3); - Value sizeVal = b.create(loc, sizeField); - Value capacityVal = b.create(loc, capacityField); + Value sizeVal = memref::LoadOp::create(b, loc, sizeField); + Value capacityVal = memref::LoadOp::create(b, loc, capacityField); Value predicate = - b.create(loc, arith::CmpIPredicate::eq, sizeVal, capacityVal); - b.create(loc, predicate, [&](OpBuilder &thenBuilder, Location loc) { - Value two = thenBuilder.create(loc, 2); - Value newCapacity = thenBuilder.create(loc, capacityVal, two); - Value oldElements = thenBuilder.create(loc, elementsField); - Value newElements = thenBuilder.create( - loc, cast(oldElements.getType()), oldElements, newCapacity); - thenBuilder.create(loc, newElements, elementsField); - thenBuilder.create(loc, newCapacity, capacityField); - thenBuilder.create(loc); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, sizeVal, capacityVal); + scf::IfOp::create(b, loc, predicate, [&](OpBuilder &thenBuilder, Location loc) { + Value two = arith::ConstantIndexOp::create(thenBuilder, loc, 2); + Value newCapacity = arith::MulIOp::create(thenBuilder, loc, capacityVal, two); + Value oldElements = memref::LoadOp::create(thenBuilder, loc, elementsField); + Value newElements = + memref::ReallocOp::create(thenBuilder, loc, cast(oldElements.getType()), + oldElements, newCapacity); + memref::StoreOp::create(thenBuilder, loc, newElements, elementsField); + memref::StoreOp::create(thenBuilder, loc, newCapacity, capacityField); + scf::YieldOp::create(thenBuilder, loc); }); - Value elementsVal = b.create(loc, elementsField); - b.create(loc, value, elementsVal, - /*indices=*/sizeVal); + Value elementsVal = memref::LoadOp::create(b, loc, elementsField); + memref::StoreOp::create(b, loc, value, elementsVal, + /*indices=*/sizeVal); - Value one = b.create(loc, 1); - Value newSize = b.create(loc, sizeVal, one); + Value one = arith::ConstantIndexOp::create(b, loc, 1); + Value newSize = arith::AddIOp::create(b, loc, sizeVal, one); - b.create(loc, newSize, sizeField); - b.create(loc); + memref::StoreOp::create(b, loc, newSize, sizeField); + func::ReturnOp::create(b, loc); return SymbolRefAttr::get(ctx, funcName); } @@ -118,7 +119,7 @@ struct ArrayListBuilder { FunctionType::get(ctx, /*inputs=*/ {dataField.getType(), sizeField.getType(), capacityField.getType()}, /*outputs=*/elementType); - auto popFn = builder.create(loc, funcName, popFnType); + auto popFn = func::FuncOp::create(builder, loc, funcName, popFnType); popFn.setPrivate(); Block *entryBlock = popFn.addEntryBlock(); @@ -128,28 +129,28 @@ struct ArrayListBuilder { BlockArgument elementsField = arguments[0]; BlockArgument sizeField = arguments[1]; - Value elementsVal = builder.create(loc, elementsField); - Value sizeVal = builder.create(loc, sizeField); - Value one = builder.create(loc, 1); - Value newSize = builder.create(loc, sizeVal, one); - Value poppedVal = builder.create(loc, elementsVal, newSize); + Value elementsVal = memref::LoadOp::create(builder, loc, elementsField); + Value sizeVal = memref::LoadOp::create(builder, loc, sizeField); + Value one = arith::ConstantIndexOp::create(builder, loc, 1); + Value newSize = arith::SubIOp::create(builder, loc, sizeVal, one); + Value poppedVal = memref::LoadOp::create(builder, loc, elementsVal, newSize); - builder.create(loc, newSize, sizeField); - builder.create(loc, poppedVal); + memref::StoreOp::create(builder, loc, newSize, sizeField); + func::ReturnOp::create(builder, loc, poppedVal); return SymbolRefAttr::get(ctx, funcName); } void emitPush(Location loc, Value value, OpBuilder &b, FlatSymbolRefAttr pushFn) const { - b.create(loc, pushFn, /*results=*/TypeRange{}, - /*operands=*/ValueRange{dataField, sizeField, capacityField, value}); + func::CallOp::create(b, loc, pushFn, /*results=*/TypeRange{}, + /*operands=*/ValueRange{dataField, sizeField, capacityField, value}); } Value emitPop(Location loc, OpBuilder &builder, FlatSymbolRefAttr popFn) const { - auto callOp = builder.create( - loc, popFn, /*results=*/elementType, - /*operands=*/ValueRange{dataField, sizeField, capacityField}); + auto callOp = + func::CallOp::create(builder, loc, popFn, /*results=*/elementType, + /*operands=*/ValueRange{dataField, sizeField, capacityField}); return callOp.getResult(0); } }; @@ -165,20 +166,20 @@ struct LowerListInit : public OpConversionPattern { op.emitError() << "Failed to convert type " << op.getType(); return failure(); } - Value capacity = rewriter.create(op.getLoc(), 32); - Value initialSize = rewriter.create(op.getLoc(), 0); + Value capacity = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 32); + Value initialSize = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); auto dataType = cast(resultTypes[0]); auto sizeType = cast(resultTypes[1]); auto capacityType = cast(resultTypes[2]); - Value buffer = rewriter.create(op.getLoc(), - cast(dataType.getElementType()), - /*dynamicSize=*/capacity); - Value bufferField = rewriter.create(op.getLoc(), dataType); - Value sizeField = rewriter.create(op.getLoc(), sizeType); - Value capacityField = rewriter.create(op.getLoc(), capacityType); - rewriter.create(op.getLoc(), buffer, bufferField); - rewriter.create(op.getLoc(), initialSize, sizeField); - rewriter.create(op.getLoc(), capacity, capacityField); + Value buffer = memref::AllocOp::create(rewriter, op.getLoc(), + cast(dataType.getElementType()), + /*dynamicSize=*/capacity); + Value bufferField = memref::AllocOp::create(rewriter, op.getLoc(), dataType); + Value sizeField = memref::AllocOp::create(rewriter, op.getLoc(), sizeType); + Value capacityField = memref::AllocOp::create(rewriter, op.getLoc(), capacityType); + memref::StoreOp::create(rewriter, op.getLoc(), buffer, bufferField); + memref::StoreOp::create(rewriter, op.getLoc(), initialSize, sizeField); + memref::StoreOp::create(rewriter, op.getLoc(), capacity, capacityField); rewriter.replaceOpWithNewOp( op, op.getType(), ValueRange{bufferField, sizeField, capacityField}); return success(); @@ -198,11 +199,11 @@ struct LowerListDealloc : public OpConversionPattern { return failure(); } - Value data = rewriter.create(op.getLoc(), arraylistBuilder->dataField); - rewriter.create(op.getLoc(), data); - rewriter.create(op.getLoc(), arraylistBuilder->dataField); - rewriter.create(op.getLoc(), arraylistBuilder->sizeField); - rewriter.create(op.getLoc(), arraylistBuilder->capacityField); + Value data = memref::LoadOp::create(rewriter, op.getLoc(), arraylistBuilder->dataField); + memref::DeallocOp::create(rewriter, op.getLoc(), data); + memref::DeallocOp::create(rewriter, op.getLoc(), arraylistBuilder->dataField); + memref::DeallocOp::create(rewriter, op.getLoc(), arraylistBuilder->sizeField); + memref::DeallocOp::create(rewriter, op.getLoc(), arraylistBuilder->capacityField); rewriter.eraseOp(op); return success(); } @@ -266,14 +267,14 @@ struct LowerListLoadData : public OpConversionPattern { // Ensure the result memref has the correct underlying size (which may be different than the // list's underlying memref due to the geometric reallocation). Value data = - rewriter.create(op.getLoc(), arraylistBuilder.value().dataField); + memref::LoadOp::create(rewriter, op.getLoc(), arraylistBuilder.value().dataField); auto memrefType = cast(data.getType()); Value size = - rewriter.create(op.getLoc(), arraylistBuilder.value().sizeField); + memref::LoadOp::create(rewriter, op.getLoc(), arraylistBuilder.value().sizeField); SmallVector offsets{rewriter.getIndexAttr(0)}, sizes{size}, strides{rewriter.getIndexAttr(1)}; - Value dataView = rewriter.create(op.getLoc(), memrefType, data, offsets, - sizes, strides); + Value dataView = memref::SubViewOp::create(rewriter, op.getLoc(), memrefType, data, offsets, + sizes, strides); rewriter.replaceOp(op, dataView); return success(); } diff --git a/mlir/lib/Catalyst/Transforms/BufferDeallocation.cpp b/mlir/lib/Catalyst/Transforms/BufferDeallocation.cpp index 3100e602bf..96ec6d4710 100644 --- a/mlir/lib/Catalyst/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferDeallocation.cpp @@ -618,7 +618,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase { } else { // Build a "default" DeallocOp for unknown allocation sources. - builder.create(alloc.getLoc(), alloc); + memref::DeallocOp::create(builder, alloc.getLoc(), alloc); } return success(); } @@ -641,7 +641,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase { "are not supported"); } // Build a "default" CloneOp for unknown allocation sources. - return builder.create(alloc.getLoc(), alloc).getResult(); + return bufferization::CloneOp::create(builder, alloc.getLoc(), alloc).getResult(); } /// The dominator info to find the appropriate start operation to move the diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index d3ff275b6a..3703daa7ed 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -147,9 +147,9 @@ struct CustomCallOpInterface MemRefType::get(bufferedOperandMemrefType.getShape(), bufferedOperandMemrefType.getElementType()); auto allocOp = - rewriter.create(op->getLoc(), copiedOperandMemrefType); + memref::AllocOp::create(rewriter, op->getLoc(), copiedOperandMemrefType); auto copyOp = - rewriter.create(op->getLoc(), *opBuffer, allocOp.getResult()); + memref::CopyOp::create(rewriter, op->getLoc(), *opBuffer, allocOp.getResult()); bufferArgs.push_back(copyOp.getTarget()); } else { @@ -171,7 +171,7 @@ struct CustomCallOpInterface MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); auto newBuffer = - rewriter.create(op->getLoc(), memrefType, *tensorAlloc); + bufferization::ToBufferOp::create(rewriter, op->getLoc(), memrefType, *tensorAlloc); bufferArgs.push_back(newBuffer); } @@ -180,8 +180,8 @@ struct CustomCallOpInterface IntegerAttr numArgumentsAttr = rewriter.getI32IntegerAttr(numArguments); // Create an updated custom call operation - rewriter.create(op->getLoc(), TypeRange{}, bufferArgs, - customCallOp.getCallTargetName(), numArgumentsAttr); + CustomCallOp::create(rewriter, op->getLoc(), TypeRange{}, bufferArgs, + customCallOp.getCallTargetName(), numArgumentsAttr); size_t startIndex = bufferArgs.size() - customCallOp.getNumResults(); SmallVector bufferResults(bufferArgs.begin() + startIndex, bufferArgs.end()); bufferization::replaceOpWithBufferizedValues(rewriter, op, bufferResults); @@ -318,15 +318,15 @@ struct CallbackCallOpInterface auto shape = tensorTy.getShape(); auto elementTy = tensorTy.getElementType(); auto memrefType = MemRefType::get(shape, elementTy); - auto toBufferOp = rewriter.create(loc, memrefType, tensor); + auto toBufferOp = bufferization::ToBufferOp::create(rewriter, loc, memrefType, tensor); auto memref = toBufferOp.getResult(); outmemrefs.push_back(memref); newInputs.push_back(memref); } SmallVector emptyRets; - rewriter.create(loc, emptyRets, callOp.getCallee(), newInputs, - /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); + CallbackCallOp::create(rewriter, loc, emptyRets, callOp.getCallee(), newInputs, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); bufferization::replaceOpWithBufferizedValues(rewriter, op, outmemrefs); return success(); } diff --git a/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp b/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp index 37f0283aa3..9c62f023ec 100644 --- a/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp +++ b/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp @@ -224,7 +224,7 @@ LogicalResult AddExceptionHandlingTransform::matchAndRewrite(LLVM::CallOp callOp if (successBlock->hasNoSuccessors()) { PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToEnd(failBlock); - rewriter.create(invokeOp->getLoc()); + LLVM::UnreachableOp::create(rewriter, invokeOp->getLoc()); } else { auto successor = successBlock->getSuccessor(0); @@ -520,7 +520,7 @@ LogicalResult LivenessAnalysisDropRef::matchAndRewrite(LLVM::CallOp sink, Type llvmInt64Type = IntegerType::get(sink->getContext(), 64); auto one = rewriter.getIntegerAttr(llvmInt64Type, 1); - Value c1 = rewriter.create(sink->getLoc(), llvmInt64Type, one); + Value c1 = LLVM::ConstantOp::create(rewriter, sink->getLoc(), llvmInt64Type, one); // We just need to await for the tokens. // The tokens is the aggregate of all values. @@ -531,7 +531,7 @@ LogicalResult LivenessAnalysisDropRef::matchAndRewrite(LLVM::CallOp sink, for (auto awaitMe : tokens) { auto contains = valuesToDrop.find(awaitMe) != valuesToDrop.end(); if (contains) - rewriter.create(sink.getLoc(), awaitFnDecl, awaitMe); + LLVM::CallOp::create(rewriter, sink.getLoc(), awaitFnDecl, awaitMe); } // We will drop all values that were alive. Tokens and values. @@ -541,7 +541,7 @@ LogicalResult LivenessAnalysisDropRef::matchAndRewrite(LLVM::CallOp sink, // llvm.call @__catalyst__host__rt__unrecoverable_error() { catalyst.sink } for (auto dropMe : valuesToDrop) { SmallVector params = {dropMe, c1}; - rewriter.create(sink.getLoc(), dropRefFnDecl, params); + LLVM::CallOp::create(rewriter, sink.getLoc(), dropRefFnDecl, params); } // It is important that we do not cleanup the source, as other sinks @@ -582,7 +582,7 @@ LogicalResult BranchToUnreachableTransform::matchAndRewrite(LLVM::BrOp candidate if (!hasAttr) return failure(); - auto unreachable = rewriter.create(candidate.getLoc()); + auto unreachable = LLVM::UnreachableOp::create(rewriter, candidate.getLoc()); rewriter.replaceOp(candidate, unreachable); return success(); } @@ -642,7 +642,7 @@ void replaceCallsWithCallToTarget(SmallVector &oldCallOps, LLVM::L PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPoint(oldCallOp); auto newCallOp = - rewriter.create(oldCallOp.getLoc(), target, oldCallOp.getOperands()); + LLVM::CallOp::create(rewriter, oldCallOp.getLoc(), target, oldCallOp.getOperands()); rewriter.replaceOp(oldCallOp, newCallOp); newCalls.push_back(newCallOp); } @@ -762,7 +762,7 @@ void replaceTerminatorWithUnconditionalJumpToSuccessBlock(SmallVector a PatternRewriter::InsertionGuard insertGuard(rewriter); auto terminator = abort->getTerminator(); rewriter.setInsertionPoint(terminator); - auto brOp = rewriter.create(terminator->getLoc(), success); + auto brOp = LLVM::BrOp::create(rewriter, terminator->getLoc(), success); // Make sure we clean it up later. AsyncUtils::annotateBrToUnreachable(brOp, rewriter); rewriter.replaceOp(terminator, brOp); @@ -779,7 +779,7 @@ std::tuple getBlocks(LLVM::CallOp callOp, PatternRewr rewriter.setInsertionPoint(callOp); Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto zeroOp = rewriter.create(callOp.getLoc(), ptrTy); + auto zeroOp = LLVM::ZeroOp::create(rewriter, callOp.getLoc(), ptrTy); Block *unwindBlock = rewriter.createBlock(successBlock); rewriter.setInsertionPointToEnd(unwindBlock); @@ -787,7 +787,7 @@ std::tuple getBlocks(LLVM::CallOp callOp, PatternRewr std::vector operands = {zeroOp.getResult()}; auto i32Ty = IntegerType::get(rewriter.getContext(), 32); auto structTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), {ptrTy, i32Ty}); - rewriter.create(callOp.getLoc(), structTy, isCleanUp, operands); + LLVM::LandingpadOp::create(rewriter, callOp.getLoc(), structTy, isCleanUp, operands); return std::tuple(blockContainingCall, successBlock, unwindBlock); } @@ -806,9 +806,9 @@ LLVM::InvokeOp transformCallToInvoke(LLVM::CallOp callOp, Block *successBlock, B { auto calleeAttr = callOp.getCalleeAttr(); SmallVector unwindArgs; - auto invokeOp = rewriter.create(callOp.getLoc(), callOp.getResultTypes(), - calleeAttr, callOp.getOperands(), successBlock, - ValueRange(), failBlock, unwindArgs); + auto invokeOp = LLVM::InvokeOp::create(rewriter, callOp.getLoc(), callOp.getResultTypes(), + calleeAttr, callOp.getOperands(), successBlock, + ValueRange(), failBlock, unwindArgs); rewriter.replaceOp(callOp, invokeOp); return invokeOp; } @@ -856,7 +856,7 @@ void insertCallToMlirAsyncRuntimeErrorFunction(Value value, LLVM::LLVMFuncOp fnD PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToEnd(failBlock); SmallVector operands = {value}; - rewriter.create(fnDecl.getLoc(), fnDecl, operands); + LLVM::CallOp::create(rewriter, fnDecl.getLoc(), fnDecl, operands); } void insertErrorCalls(std::vector tokens, std::vector values, Block *failBlock, @@ -895,7 +895,7 @@ void insertBranchFromFailToSuccessor(Block *fail, Block *success, PatternRewrite auto landingPad = fail->begin(); auto loc = landingPad->getLoc(); - rewriter.create(loc, success); + LLVM::BrOp::create(rewriter, loc, success); } } // namespace diff --git a/mlir/lib/Catalyst/Transforms/DetensorizeFunctionBoundaryPass.cpp b/mlir/lib/Catalyst/Transforms/DetensorizeFunctionBoundaryPass.cpp index 252d31e43a..973772528a 100644 --- a/mlir/lib/Catalyst/Transforms/DetensorizeFunctionBoundaryPass.cpp +++ b/mlir/lib/Catalyst/Transforms/DetensorizeFunctionBoundaryPass.cpp @@ -89,7 +89,7 @@ struct DetensorizeCallSitePattern : public OpRewritePattern { // Create the new function, passing the collected signature auto newFuncType = FunctionType::get(getContext(), newArgTypes, newResultTypes); newFuncOp = - rewriter.create(funcOp.getLoc(), newFuncName, newFuncType, newAttrs); + func::FuncOp::create(rewriter, funcOp.getLoc(), newFuncName, newFuncType, newAttrs); // Map FuncOp body and return operation Block *newEntryBlock = newFuncOp.addEntryBlock(); @@ -134,8 +134,8 @@ struct DetensorizeCallSitePattern : public OpRewritePattern { if (isScalarTensor(oldArg.getType())) { // Insert a FromElementsOp if the old argument is a scalar tensor - auto fromElementsOp = rewriter.create( - newArg.getLoc(), oldArg.getType(), newArg); + auto fromElementsOp = tensor::FromElementsOp::create(rewriter, newArg.getLoc(), + oldArg.getType(), newArg); mapper.map(oldArg, fromElementsOp.getResult()); } else { @@ -157,15 +157,15 @@ struct DetensorizeCallSitePattern : public OpRewritePattern { Value newOperand = mapper.lookup(operand); if (isScalarTensor(newOperand.getType())) { // Insert ExtractOp if the operand is a scalar tensor - auto extractOp = rewriter.create(oldReturnOp.getLoc(), - newOperand, ValueRange{}); + auto extractOp = tensor::ExtractOp::create(rewriter, oldReturnOp.getLoc(), + newOperand, ValueRange{}); newReturnOperands.push_back(extractOp.getResult()); } else { newReturnOperands.push_back(newOperand); } } - rewriter.create(oldReturnOp.getLoc(), newReturnOperands); + func::ReturnOp::create(rewriter, oldReturnOp.getLoc(), newReturnOperands); } void replaceCallOp(PatternRewriter &rewriter, func::CallOp &callOp, @@ -178,7 +178,7 @@ struct DetensorizeCallSitePattern : public OpRewritePattern { // function if (isScalarTensor(operand.getType())) { auto extractOp = - rewriter.create(callOp.getLoc(), operand, ValueRange{}); + tensor::ExtractOp::create(rewriter, callOp.getLoc(), operand, ValueRange{}); newOperands.push_back(extractOp.getResult()); } else { @@ -186,7 +186,7 @@ struct DetensorizeCallSitePattern : public OpRewritePattern { } } - auto newCallOp = rewriter.create(callOp.getLoc(), newFuncOp, newOperands); + auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), newFuncOp, newOperands); SmallVector newResults; for (size_t i = 0; i < callOp.getNumResults(); ++i) { @@ -195,8 +195,8 @@ struct DetensorizeCallSitePattern : public OpRewritePattern { if (isScalarTensor(oldResult.getType())) { // Insert a FromElementsOp if the old result is a scalar tensor to bridge the // detensorized function - auto fromElementsOp = rewriter.create( - callOp.getLoc(), oldResult.getType(), newResult); + auto fromElementsOp = tensor::FromElementsOp::create( + rewriter, callOp.getLoc(), oldResult.getType(), newResult); newResults.push_back(fromElementsOp.getResult()); } else { diff --git a/mlir/lib/Catalyst/Transforms/DetensorizeSCFPass.cpp b/mlir/lib/Catalyst/Transforms/DetensorizeSCFPass.cpp index a5a1beb90b..ffebc35d0e 100644 --- a/mlir/lib/Catalyst/Transforms/DetensorizeSCFPass.cpp +++ b/mlir/lib/Catalyst/Transforms/DetensorizeSCFPass.cpp @@ -58,8 +58,8 @@ struct DetensorizeForOp : public OpRewritePattern { if (isScalarTensor(opOperand.get())) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(forOp); - Value value = rewriter.create(forOp->getLoc(), opOperand.get(), - ValueRange{}); + Value value = tensor::ExtractOp::create(rewriter, forOp->getLoc(), opOperand.get(), + ValueRange{}); newIterOperands.push_back(value); newIterOperandsIndices.push_back(it.index()); continue; @@ -71,8 +71,8 @@ struct DetensorizeForOp : public OpRewritePattern { OpBuilder::InsertionGuard forOpInsertionGuard(rewriter); rewriter.setInsertionPoint(forOp); scf::ForOp newForOp = - rewriter.create(forOp.getLoc(), forOp.getLowerBound(), - forOp.getUpperBound(), forOp.getStep(), newIterOperands); + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), + forOp.getUpperBound(), forOp.getStep(), newIterOperands); newForOp->setAttrs(forOp->getAttrs()); // 3. Copy body @@ -87,8 +87,8 @@ struct DetensorizeForOp : public OpRewritePattern { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(&newBlock, newBlock.begin()); for (std::size_t index : newIterOperandsIndices) { - Value value = rewriter.create( - newForOp->getLoc(), + Value value = tensor::FromElementsOp::create( + rewriter, newForOp->getLoc(), RankedTensorType::get({}, newForOp.getRegionIterArg(index).getType()), newForOp.getRegionIterArg(index)); rewriter.replaceUsesWithIf( @@ -104,10 +104,11 @@ struct DetensorizeForOp : public OpRewritePattern { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(clonedYieldOp); for (std::size_t index : newIterOperandsIndices) { - newYieldOperands[index] = rewriter.create( - clonedYieldOp->getLoc(), clonedYieldOp->getOperand(index), ValueRange{}); + newYieldOperands[index] = + tensor::ExtractOp::create(rewriter, clonedYieldOp->getLoc(), + clonedYieldOp->getOperand(index), ValueRange{}); } - rewriter.create(newForOp.getLoc(), newYieldOperands); + scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands); rewriter.eraseOp(clonedYieldOp); } @@ -117,8 +118,8 @@ struct DetensorizeForOp : public OpRewritePattern { rewriter.setInsertionPointAfter(forOp); for (std::size_t index : newIterOperandsIndices) { Value for_result = newForOp->getResult(index); - Value value = rewriter.create( - newForOp->getLoc(), RankedTensorType::get({}, for_result.getType()), + Value value = tensor::FromElementsOp::create( + rewriter, newForOp->getLoc(), RankedTensorType::get({}, for_result.getType()), for_result); rewriter.replaceUsesWithIf(forOp->getResult(index), value, [&](OpOperand &op) { return !isa(op.getOwner()); @@ -151,8 +152,8 @@ struct DetensorizeIfOp : public OpRewritePattern { Value operand = it.value(); // Detensorize operand: extract tensor element before yielding if (isScalarTensor(operand)) { - Value value = rewriter.create(yield_op->getLoc(), operand, - ValueRange{}); + Value value = tensor::ExtractOp::create(rewriter, yield_op->getLoc(), operand, + ValueRange{}); yield_op.setOperand(it.index(), value); } } @@ -172,7 +173,7 @@ struct DetensorizeIfOp : public OpRewritePattern { OpBuilder::InsertionGuard ifOpInsertionGuard(rewriter); rewriter.setInsertionPoint(ifOp); auto newIfOp = - rewriter.create(ifOp.getLoc(), newResultTypes, ifOp.getCondition()); + scf::IfOp::create(rewriter, ifOp.getLoc(), newResultTypes, ifOp.getCondition()); newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); @@ -184,8 +185,9 @@ struct DetensorizeIfOp : public OpRewritePattern { auto oldResult = std::get<0>(results); auto newResult = std::get<1>(results); if (isScalarTensor(oldResult)) { - Value value = rewriter.create( - ifOp->getLoc(), RankedTensorType::get({}, newResult.getType()), newResult); + Value value = tensor::FromElementsOp::create( + rewriter, ifOp->getLoc(), RankedTensorType::get({}, newResult.getType()), + newResult); rewriter.replaceAllUsesWith(oldResult, value); } } @@ -218,7 +220,7 @@ struct DetensorizeWhileOp : public OpRewritePattern { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(whileOp); Value value = - rewriter.create(whileOp->getLoc(), opOperand, ValueRange{}); + tensor::ExtractOp::create(rewriter, whileOp->getLoc(), opOperand, ValueRange{}); newIterOperands.push_back(value); newIterOperandsIndices.push_back(it.index()); continue; @@ -242,8 +244,8 @@ struct DetensorizeWhileOp : public OpRewritePattern { OpBuilder::InsertionGuard newWhileOpInsertionGuard(rewriter); rewriter.setInsertionPoint(whileOp); scf::WhileOp newWhileOp = - rewriter.create(whileOp.getLoc(), newResultTypes, newIterOperands, - /*beforeBody*/ nullptr, /*afterBody*/ nullptr); + scf::WhileOp::create(rewriter, whileOp.getLoc(), newResultTypes, newIterOperands, + /*beforeBody*/ nullptr, /*afterBody*/ nullptr); // 3. Copy body Block &newBeforeBlock = *newWhileOp.getBeforeBody(); @@ -257,8 +259,8 @@ struct DetensorizeWhileOp : public OpRewritePattern { { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(&newBeforeBlock, newBeforeBlock.begin()); - Value value = rewriter.create( - newWhileOp->getLoc(), + Value value = tensor::FromElementsOp::create( + rewriter, newWhileOp->getLoc(), RankedTensorType::get({}, newBeforeBlock.getArgument(index).getType()), newBeforeBlock.getArgument(index)); rewriter.replaceUsesWithIf( @@ -271,8 +273,8 @@ struct DetensorizeWhileOp : public OpRewritePattern { { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(&newAfterBlock, newAfterBlock.begin()); - Value value = rewriter.create( - newWhileOp->getLoc(), + Value value = tensor::FromElementsOp::create( + rewriter, newWhileOp->getLoc(), RankedTensorType::get({}, newAfterBlock.getArgument(index).getType()), newAfterBlock.getArgument(index)); rewriter.replaceUsesWithIf( @@ -291,8 +293,8 @@ struct DetensorizeWhileOp : public OpRewritePattern { for (const auto &it : llvm::enumerate(condOpArgs)) { Value condOpArg = it.value(); if (isScalarTensor(condOpArg)) { - Value value = rewriter.create(condOp->getLoc(), condOpArg, - ValueRange{}); + Value value = tensor::ExtractOp::create(rewriter, condOp->getLoc(), condOpArg, + ValueRange{}); newCondOpArgs.push_back(value); } else { @@ -310,10 +312,11 @@ struct DetensorizeWhileOp : public OpRewritePattern { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(clonedYieldOp); for (std::size_t index : newIterOperandsIndices) { - newYieldOperands[index] = rewriter.create( - clonedYieldOp->getLoc(), clonedYieldOp->getOperand(index), ValueRange{}); + newYieldOperands[index] = + tensor::ExtractOp::create(rewriter, clonedYieldOp->getLoc(), + clonedYieldOp->getOperand(index), ValueRange{}); } - rewriter.create(newWhileOp.getLoc(), newYieldOperands); + scf::YieldOp::create(rewriter, newWhileOp.getLoc(), newYieldOperands); rewriter.eraseOp(clonedYieldOp); } @@ -323,8 +326,8 @@ struct DetensorizeWhileOp : public OpRewritePattern { rewriter.setInsertionPointAfter(whileOp); for (std::size_t index : newResultIndices) { Value for_result = newWhileOp->getResult(index); - Value value = rewriter.create( - newWhileOp->getLoc(), RankedTensorType::get({}, for_result.getType()), + Value value = tensor::FromElementsOp::create( + rewriter, newWhileOp->getLoc(), RankedTensorType::get({}, for_result.getType()), for_result); rewriter.replaceUsesWithIf(whileOp->getResult(index), value, [&](OpOperand &op) { return !isa(op.getOwner()); diff --git a/mlir/lib/Catalyst/Transforms/QnodeToAsyncPatterns.cpp b/mlir/lib/Catalyst/Transforms/QnodeToAsyncPatterns.cpp index b2621a3816..69e4155ae5 100644 --- a/mlir/lib/Catalyst/Transforms/QnodeToAsyncPatterns.cpp +++ b/mlir/lib/Catalyst/Transforms/QnodeToAsyncPatterns.cpp @@ -56,11 +56,11 @@ struct CallOpToAsyncOPRewritePattern : public mlir::OpRewritePattern(op.getLoc(), executeOp.getResults().front()); + async::AwaitOp::create(rewriter, op.getLoc(), executeOp.getResults().front()); for (auto refCountedValue : executeOp.getResults()) { - rewriter.create(op.getLoc(), refCountedValue, - rewriter.getI64IntegerAttr(1)); + async::RuntimeDropRefOp::create(rewriter, op.getLoc(), refCountedValue, + rewriter.getI64IntegerAttr(1)); } } @@ -149,7 +149,7 @@ struct CallOpToAsyncOPRewritePattern : public mlir::OpRewritePattern(op.getLoc(), newVal); + auto awaitOp = async::AwaitOp::create(rewriter, op.getLoc(), newVal); auto awaitVal = awaitOp.getResults(); rewriter.replaceUsesWithIf(oldVal, awaitVal, [&](OpOperand &use) { // TODO: @@ -192,14 +192,14 @@ struct CallOpToAsyncOPRewritePattern : public mlir::OpRewritePatternsetAttr("transformed", rewriter.getUnitAttr()); }); IRMapping map; - auto executeOp = - rewriter.create(op.getLoc(), retTy, dependencies, operands, noopExec); + auto executeOp = async::ExecuteOp::create(rewriter, op.getLoc(), retTy, dependencies, + operands, noopExec); { PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPoint(executeOp.getBody(), executeOp.getBody()->end()); Operation *cloneOp = op->clone(map); rewriter.insert(cloneOp); - rewriter.create(op.getLoc(), cloneOp->getResults()); + async::YieldOp::create(rewriter, op.getLoc(), cloneOp->getResults()); } insertDropRefOps(op, executeOp, rewriter); diff --git a/mlir/lib/Catalyst/Transforms/RegisterInactiveCallbackPass.cpp b/mlir/lib/Catalyst/Transforms/RegisterInactiveCallbackPass.cpp index 4439b0c1ef..9909c0e3ed 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterInactiveCallbackPass.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterInactiveCallbackPass.cpp @@ -46,16 +46,17 @@ struct RegisterInactiveCallbackPass auto isConstant = false; auto linkage = LLVM::Linkage::External; auto key = catalyst::gradient::enzyme_inactivefn_key; - auto glb = builder.create(loc, arrTy, isConstant, linkage, key, nullptr); + auto glb = LLVM::GlobalOp::create(builder, loc, arrTy, isConstant, linkage, key, nullptr); // Create a block and push it to the global Block *block = new Block(); glb.getInitializerRegion().push_back(block); builder.setInsertionPointToStart(block); - auto undef = builder.create(glb.getLoc(), arrTy); + auto undef = LLVM::UndefOp::create(builder, glb.getLoc(), arrTy); auto fnSym = SymbolRefAttr::get(context, inactive_callbackFnName); - auto fnPtr = builder.create(glb.getLoc(), ptrTy, fnSym); - auto filledInArray = builder.create(glb.getLoc(), undef, fnPtr, 0); - builder.create(glb.getLoc(), filledInArray); + auto fnPtr = LLVM::AddressOfOp::create(builder, glb.getLoc(), ptrTy, fnSym); + auto filledInArray = LLVM::InsertValueOp::create(builder, glb.getLoc(), undef, fnPtr, + SmallVector{0}); + LLVM::ReturnOp::create(builder, glb.getLoc(), filledInArray); } }; diff --git a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp index db6d9ad071..783a8a1e04 100644 --- a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp +++ b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp @@ -45,13 +45,12 @@ Value getGlobalString(Location loc, OpBuilder &rewriter, StringRef key, StringRe if (!glb) { OpBuilder::InsertionGuard guard(rewriter); // to reset the insertion point rewriter.setInsertionPointToStart(mod.getBody()); - glb = rewriter.create(loc, type, true, LLVM::Linkage::Internal, key, - rewriter.getStringAttr(value)); + glb = LLVM::GlobalOp::create(rewriter, loc, type, true, LLVM::Linkage::Internal, key, + rewriter.getStringAttr(value)); } - return rewriter.create(loc, LLVM::LLVMPointerType::get(rewriter.getContext()), - type, rewriter.create(loc, glb), - ArrayRef{0, 0}, - LLVM::GEPNoWrapFlags::inbounds); + return LLVM::GEPOp::create(rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), + type, LLVM::AddressOfOp::create(rewriter, loc, glb), + ArrayRef{0, 0}, LLVM::GEPNoWrapFlags::inbounds); } enum NumericType : int8_t { @@ -150,22 +149,22 @@ Value EncodeOpaqueMemRef(Location loc, PatternRewriter &rewriter, MemRefType mem // Create values for filling encoded memref struct. Value dtype = - rewriter.create(loc, rewriter.getI8IntegerAttr(elementDtype.value())); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI8IntegerAttr(elementDtype.value())); Value rank = - rewriter.create(loc, rewriter.getI64IntegerAttr(memrefType.getRank())); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(memrefType.getRank())); // Construct encoded memref value. - Value memref = rewriter.create(loc, type); + Value memref = LLVM::UndefOp::create(rewriter, loc, type); // Rank - memref = rewriter.create(loc, memref, rank, 0); + memref = LLVM::InsertValueOp::create(rewriter, loc, memref, rank, SmallVector{0}); // Memref Value memrefPtr = getStaticAlloca(loc, rewriter, llvmMemrefType, 1); - rewriter.create(loc, memrefLlvm, memrefPtr); - memref = rewriter.create(loc, memref, memrefPtr, 1); + LLVM::StoreOp::create(rewriter, loc, memrefLlvm, memrefPtr); + memref = LLVM::InsertValueOp::create(rewriter, loc, memref, memrefPtr, SmallVector{1}); // Dtype - memref = rewriter.create(loc, memref, dtype, 2); + memref = LLVM::InsertValueOp::create(rewriter, loc, memref, dtype, SmallVector{2}); return memref; } @@ -196,7 +195,7 @@ struct PrintOpPattern : public OpConversionPattern { StringRef stringValue = op.getConstVal().value(); std::string symbolName = std::to_string(std::hash()(stringValue.str())); Value global = getGlobalString(loc, rewriter, symbolName, stringValue, mod); - rewriter.create(loc, fnDecl, global); + LLVM::CallOp::create(rewriter, loc, fnDecl, global); rewriter.eraseOp(op); } else { @@ -226,10 +225,10 @@ struct PrintOpPattern : public OpConversionPattern { EncodeOpaqueMemRef(loc, rewriter, memrefType, llvmMemrefType, llvmMemref); Value structPtr = getStaticAlloca(loc, rewriter, structType, 1); - rewriter.create(loc, structValue, structPtr); + LLVM::StoreOp::create(rewriter, loc, structValue, structPtr); - Value printDescriptor = rewriter.create(loc, rewriter.getI1Type(), - op.getPrintDescriptor()); + Value printDescriptor = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), + op.getPrintDescriptor()); SmallVector callArgs{structPtr, printDescriptor}; rewriter.replaceOpWithNewOp(op, fnDecl, callArgs); } @@ -266,7 +265,7 @@ struct AssertionOpPattern : public OpConversionPattern { Value globalString = getGlobalString(loc, rewriter, symbolName, errorMessage, mod); SmallVector callArgs{assertionDescriptor, globalString}; - rewriter.create(loc, assertFunc, callArgs); + LLVM::CallOp::create(rewriter, loc, assertFunc, callArgs); rewriter.eraseOp(op); @@ -296,25 +295,25 @@ Value EncodeDataMemRef(Location loc, PatternRewriter &rewriter, MemRefType memre // Create values for filling encoded memref struct. Value dtype = - rewriter.create(loc, rewriter.getI8IntegerAttr(elementDtype.value())); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI8IntegerAttr(elementDtype.value())); Value rank = - rewriter.create(loc, rewriter.getI64IntegerAttr(memrefType.getRank())); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(memrefType.getRank())); // Construct encoded memref value. - Value memref = rewriter.create(loc, type); + Value memref = LLVM::UndefOp::create(rewriter, loc, type); // Rank - memref = rewriter.create(loc, memref, rank, 0); + memref = LLVM::InsertValueOp::create(rewriter, loc, memref, rank, SmallVector{0}); // Memref data MemRefDescriptor desc = MemRefDescriptor(memrefLlvm); - Value c0 = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - Value data = rewriter.create(loc, ptr, memrefType.getElementType(), - desc.alignedPtr(rewriter, loc), c0, - LLVM::GEPNoWrapFlags::inbounds); - memref = rewriter.create(loc, memref, data, 1); + Value c0 = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value data = + LLVM::GEPOp::create(rewriter, loc, ptr, memrefType.getElementType(), + desc.alignedPtr(rewriter, loc), c0, LLVM::GEPNoWrapFlags::inbounds); + memref = LLVM::InsertValueOp::create(rewriter, loc, memref, data, SmallVector{1}); // Dtype - memref = rewriter.create(loc, memref, dtype, 2); + memref = LLVM::InsertValueOp::create(rewriter, loc, memref, dtype, SmallVector{2}); return memref; } @@ -363,7 +362,7 @@ struct CustomCallOpPattern : public OpConversionPattern { auto encodedArg = EncodeDataMemRef(loc, rewriter, memref_type, llvmMemrefType, std::get<1>(tuple)); LLVM::AllocaOp alloca = getStaticAlloca(loc, rewriter, encodedArg.getType(), 1); - rewriter.create(loc, encodedArg, alloca); + LLVM::StoreOp::create(rewriter, loc, encodedArg, alloca); encodedArgs.push_back(alloca); } @@ -372,9 +371,9 @@ struct CustomCallOpPattern : public OpConversionPattern { Type typeArgs = LLVM::LLVMArrayType::get(ptr, len); // Prepare an array for encoded arguments. - Value arrArgs = rewriter.create(loc, typeArgs); + Value arrArgs = LLVM::UndefOp::create(rewriter, loc, typeArgs); auto insertValueArgs = [&](Value value, int64_t offset) { - arrArgs = rewriter.create(loc, arrArgs, value, offset); + arrArgs = LLVM::InsertValueOp::create(rewriter, loc, arrArgs, value, offset); }; // Store pointer to encoded arguments into the allocated storage. for (const auto &pair : llvm::enumerate(encodedArgs)) { @@ -385,7 +384,7 @@ struct CustomCallOpPattern : public OpConversionPattern { LLVM::AllocaOp alloca = getStaticAlloca(loc, rewriter, arrArgs.getType(), 1); // Store constructed arguments pointers array into the alloca. - rewriter.create(loc, arrArgs, alloca); + LLVM::StoreOp::create(rewriter, loc, arrArgs, alloca); // Alloca that encodes the custom call arguments. auto encodedArguments = alloca.getResult(); @@ -400,7 +399,7 @@ struct CustomCallOpPattern : public OpConversionPattern { auto encodedRes = EncodeDataMemRef(loc, rewriter, memref_type, llvmMemrefType, std::get<1>(tuple)); LLVM::AllocaOp alloca = getStaticAlloca(loc, rewriter, encodedRes.getType(), 1); - rewriter.create(loc, encodedRes, alloca); + LLVM::StoreOp::create(rewriter, loc, encodedRes, alloca); encodedRess.push_back(alloca); } @@ -409,9 +408,9 @@ struct CustomCallOpPattern : public OpConversionPattern { Type typeRes = LLVM::LLVMArrayType::get(ptr, lenRes); // Prepare an array for encoding results. - Value arrRes = rewriter.create(loc, typeRes); + Value arrRes = LLVM::UndefOp::create(rewriter, loc, typeRes); auto insertValueRes = [&](Value value, int64_t offset) { - arrRes = rewriter.create(loc, arrRes, value, offset); + arrRes = LLVM::InsertValueOp::create(rewriter, loc, arrRes, value, offset); }; // Store encoded results into the allocated storage. @@ -424,13 +423,13 @@ struct CustomCallOpPattern : public OpConversionPattern { LLVM::AllocaOp allocaRes = getStaticAlloca(loc, rewriter, arrRes.getType(), 1); // Store constructed results pointers array on the stack. - rewriter.create(loc, arrRes, allocaRes); + LLVM::StoreOp::create(rewriter, loc, arrRes, allocaRes); // Alloca that encodes the custom call returns. auto encodedResults = allocaRes.getResult(); // Call op SmallVector callArgs{encodedArguments, encodedResults}; - rewriter.create(loc, customCallFnOp, callArgs); + LLVM::CallOp::create(rewriter, loc, customCallFnOp, callArgs); rewriter.eraseOp(op); return success(); } @@ -455,11 +454,11 @@ struct DefineCallbackOpPattern : public OpConversionPattern { auto ctx = rewriter.getContext(); auto loc = op.getLoc(); auto idAttr = op.getIdAttr(); - auto constantId = rewriter.create(loc, idAttr); + auto constantId = LLVM::ConstantOp::create(rewriter, loc, idAttr); auto argcAttr = op.getArgcAttr(); - auto constantArgc = rewriter.create(loc, argcAttr); + auto constantArgc = LLVM::ConstantOp::create(rewriter, loc, argcAttr); auto rescAttr = op.getRescAttr(); - auto constantResc = rewriter.create(loc, rescAttr); + auto constantResc = LLVM::ConstantOp::create(rewriter, loc, rescAttr); SmallVector callArgs = {constantId, constantArgc, constantResc}; @@ -481,13 +480,13 @@ struct DefineCallbackOpPattern : public OpConversionPattern { for (auto arg : op.getArguments()) { Type structTy = typeConverter->convertType(arg.getType()); auto structVal = - rewriter.create(loc, structTy, arg).getResult(0); + UnrealizedConversionCastOp::create(rewriter, loc, structTy, arg).getResult(0); Value ptr = getStaticAlloca(loc, rewriter, structTy, 1); - rewriter.create(loc, structVal, ptr); + LLVM::StoreOp::create(rewriter, loc, structVal, ptr); callArgs.push_back(ptr); } - rewriter.create(loc, customCallFnOp, callArgs); - rewriter.create(loc, TypeRange{}, ValueRange{}); + LLVM::CallOp::create(rewriter, loc, customCallFnOp, callArgs); + func::ReturnOp::create(rewriter, loc, TypeRange{}, ValueRange{}); return success(); } }; @@ -506,8 +505,8 @@ struct ReplaceCallbackOpWithFuncOp : public OpConversionPattern { ModuleOp mod = op->getParentOfType(); rewriter.setInsertionPointToStart(mod.getBody()); - auto func = - rewriter.create(op.getLoc(), op.getSymName(), op.getFunctionType()); + auto func = mlir::func::FuncOp::create(rewriter, op.getLoc(), op.getSymName(), + op.getFunctionType()); func.setPrivate(); auto noinline = rewriter.getStringAttr("noinline"); rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); @@ -535,11 +534,11 @@ struct CallbackCallOpPattern : public OpConversionPattern { // allocate a memref descriptor on the stack Value ptr = getStaticAlloca(loc, rewriter, structVal.getType(), 1); // store the memref descriptor on the pointer - rewriter.create(loc, structVal, ptr); + LLVM::StoreOp::create(rewriter, loc, structVal, ptr); // add the ptr to the arguments callArgs.push_back(ptr); } - rewriter.create(loc, adaptor.getCallee(), TypeRange{}, callArgs); + func::CallOp::create(rewriter, loc, adaptor.getCallee(), TypeRange{}, callArgs); rewriter.eraseOp(op); return success(); } diff --git a/mlir/lib/Catalyst/Utils/StaticAllocas.cpp b/mlir/lib/Catalyst/Utils/StaticAllocas.cpp index eab0956637..f17bf02727 100644 --- a/mlir/lib/Catalyst/Utils/StaticAllocas.cpp +++ b/mlir/lib/Catalyst/Utils/StaticAllocas.cpp @@ -29,9 +29,9 @@ LLVM::AllocaOp getStaticAlloca(Location &loc, RewriterBase &rewriter, Type ty, i PatternRewriter::InsertionGuard insertGuard(rewriter); // Move the value at the beginning rewriter.setInsertionPointAfter(&entryBlock->front()); - auto valueOp = rewriter.create(loc, rewriter.getI64IntegerAttr(value)); - return rewriter.create(loc, LLVM::LLVMPointerType::get(rewriter.getContext()), - ty, valueOp); + auto valueOp = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(value)); + return LLVM::AllocaOp::create(rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), + ty, valueOp); } mlir::memref::AllocaOp getStaticMemrefAlloca(Location &loc, RewriterBase &rewriter, @@ -45,7 +45,7 @@ mlir::memref::AllocaOp getStaticMemrefAlloca(Location &loc, RewriterBase &rewrit if (insertionBlock != entryBlock) { rewriter.setInsertionPoint(&entryBlock->front()); } - return rewriter.create(loc, paramCountType); + return memref::AllocaOp::create(rewriter, loc, paramCountType); } } // namespace catalyst diff --git a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp index 4397aafab2..298ed740c8 100644 --- a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp @@ -84,13 +84,13 @@ Value generateAllocation(OpBuilder &builder, Location loc, Value reference) if (!memrefType.hasStaticShape()) { for (int64_t dim = 0; dim < memrefType.getRank(); dim++) { if (memrefType.isDynamicDim(dim)) { - Value dimIndex = builder.create(loc, dim); - dynamicDims.push_back(builder.create(loc, reference, dimIndex)); + Value dimIndex = index::ConstantOp::create(builder, loc, dim); + dynamicDims.push_back(memref::DimOp::create(builder, loc, reference, dimIndex)); } } } - return builder.create(loc, memrefType, dynamicDims); + return memref::AllocOp::create(builder, loc, memrefType, dynamicDims); } // Helper function to generate a list of memref allocations. @@ -131,10 +131,11 @@ getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index, auto tensorType = dyn_cast(funcOp.getArgumentTypes()[index]); assert(tensorType && "expected TensorType"); - BaseMemRefType memrefType = options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), nullptr, options); + auto tensorLikeType = dyn_cast(tensorType); + bufferization::BufferLikeType memrefType = options.functionArgTypeConverterFn( + tensorLikeType, *options.defaultMemorySpaceFn(tensorType), nullptr, options); - return cast(memrefType); + return memrefType; } static ReturnOp getAssumedUniqueReturnOp(FunctionOpInterface funcOp) @@ -195,7 +196,7 @@ struct AdjointOpInterface for (const auto &[i, resType] : llvm::enumerate(resTypes)) { if (isa(resType)) { MemRefType memrefType = cast(resType); - Value memrefValue = rewriter.create(loc, memrefType, gradSize); + Value memrefValue = memref::AllocOp::create(rewriter, loc, memrefType, gradSize); memrefValues.push_back(memrefValue); } else { @@ -220,8 +221,8 @@ struct AdjointOpInterface } auto newAdjointOp = - rewriter.create(loc, nonTensorResultTypes, adjointOp.getCalleeAttr(), - adjointOp.getGradSize(), bufferArgs, memrefValues); + AdjointOp::create(rewriter, loc, nonTensorResultTypes, adjointOp.getCalleeAttr(), + adjointOp.getGradSize(), bufferArgs, memrefValues); SmallVector bufferdNewValues; size_t nonTensorResultCounter = 0; size_t tensorResultCounter = 0; @@ -347,9 +348,9 @@ struct BackpropOpInterface // 4. Create bufferized backprop op DenseIntElementsAttr diffArgIndicesAttr = backpropOp.getDiffArgIndices().value_or(nullptr); - auto bufferizedBackpropOp = rewriter.create( - loc, TypeRange{}, scalarReturnTypes, backpropOp.getCalleeAttr(), bufferArgs, argShadows, - calleeResults, bufferCotangents, diffArgIndicesAttr, + auto bufferizedBackpropOp = BackpropOp::create( + rewriter, loc, TypeRange{}, scalarReturnTypes, backpropOp.getCalleeAttr(), bufferArgs, + argShadows, calleeResults, bufferCotangents, diffArgIndicesAttr, backpropOp.getKeepValueResultsAttr()); // Fill in the null placeholders. for (const auto &[idx, scalarResult] : @@ -479,7 +480,7 @@ struct ForwardOpInterface options.unknownTypeConverterFn(cast(returnVal.getType()), *options.defaultMemorySpaceFn(tensorType), options); Value toBufferOp = - rewriter.create(loc, resultType, returnVal); + bufferization::ToBufferOp::create(rewriter, loc, resultType, returnVal); returnValues.push_back(toBufferOp); } @@ -588,7 +589,7 @@ struct ReverseOpInterface options.unknownTypeConverterFn(cast(returnVal.getType()), *options.defaultMemorySpaceFn(tensorType), options); Value toBufferOp = - rewriter.create(loc, resultType, returnVal); + bufferization::ToBufferOp::create(rewriter, loc, resultType, returnVal); returnValues.push_back(toBufferOp); } diff --git a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp index 6cc86a98a1..c943821bd3 100644 --- a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp @@ -71,13 +71,14 @@ LogicalResult wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeCon // used. Otherwise, LLVM may optimize it away with a poison value. // Note: both the volatile load and store are necessary. MLIR respects the store // but not the load, while LLVM respects the volatile load but not the store. - Value replacedMemref = rewriter.create( - loc, structType, wrappedMemref, /*alignment*/ 0, /*isVolatile=*/volatileArgs); + Value replacedMemref = + LLVM::LoadOp::create(rewriter, loc, structType, wrappedMemref, /*alignment*/ 0, + /*isVolatile=*/volatileArgs); if (volatileArgs) { - rewriter.create(loc, replacedMemref, wrappedMemref); + LLVM::StoreOp::create(rewriter, loc, replacedMemref, wrappedMemref); } replacedMemref = - rewriter.create(loc, argType, replacedMemref) + UnrealizedConversionCastOp::create(rewriter, loc, argType, replacedMemref) .getResult(0); memrefArg.replaceAllUsesWith(replacedMemref); if (failed(func.eraseArgument(memrefArg.getArgNumber()))) { @@ -106,9 +107,9 @@ void wrapMemRefArgsCallsites(func::FuncOp func, const TypeConverter *typeConvert Value space = getStaticAlloca(loc, rewriter, convertedType, 1); Value convertedValue = - rewriter.create(loc, convertedType, memref) + UnrealizedConversionCastOp::create(rewriter, loc, convertedType, memref) .getResult(0); - rewriter.create(loc, convertedValue, space); + LLVM::StoreOp::create(rewriter, loc, convertedValue, space); return space; }; for (Value oldOperand : callOp.getOperands()) { @@ -119,7 +120,7 @@ void wrapMemRefArgsCallsites(func::FuncOp func, const TypeConverter *typeConvert for (Type resultType : callOp.getResultTypes()) { if (auto memrefType = dyn_cast(resultType)) { assert(memrefType.hasStaticShape()); - Value memref = rewriter.create(loc, memrefType); + Value memref = memref::AllocOp::create(rewriter, loc, memrefType); outputs.push_back(memref); memref = wrapMemref(memref); @@ -127,7 +128,7 @@ void wrapMemRefArgsCallsites(func::FuncOp func, const TypeConverter *typeConvert } } - rewriter.create(callOp.getLoc(), func, operands); + func::CallOp::create(rewriter, callOp.getLoc(), func, operands); rewriter.replaceOp(callOp, outputs); } } @@ -150,24 +151,25 @@ LLVM::GlobalOp insertEnzymeCustomGradient(OpBuilder &builder, ModuleOp moduleOp, auto ptrType = LLVM::LLVMPointerType::get(context); auto resultType = LLVM::LLVMArrayType::get(ptrType, 3); - customGradient = builder.create(loc, resultType, - /*isConstant=*/false, LLVM::Linkage::External, - key, /*address space=*/nullptr); + customGradient = LLVM::GlobalOp::create(builder, loc, resultType, + /*isConstant=*/false, LLVM::Linkage::External, key, + /*address space=*/nullptr); builder.createBlock(&customGradient.getInitializerRegion()); - Value origFnPtr = builder.create(loc, originalFunc.getFunctionType(), - originalFunc.getName()); - Value augFnPtr = builder.create(loc, augmentedPrimal.getFunctionType(), - augmentedPrimal.getName()); + Value origFnPtr = func::ConstantOp::create(builder, loc, originalFunc.getFunctionType(), + originalFunc.getName()); + Value augFnPtr = func::ConstantOp::create(builder, loc, augmentedPrimal.getFunctionType(), + augmentedPrimal.getName()); Value gradFnPtr = - builder.create(loc, gradient.getFunctionType(), gradient.getName()); + func::ConstantOp::create(builder, loc, gradient.getFunctionType(), gradient.getName()); SmallVector fnPtrs{origFnPtr, augFnPtr, gradFnPtr}; - Value result = builder.create(loc, resultType); + Value result = LLVM::UndefOp::create(builder, loc, resultType); for (const auto &[idx, fnPtr] : llvm::enumerate(fnPtrs)) { - Value casted = builder.create(loc, ptrType, fnPtr).getResult(0); - result = builder.create(loc, result, casted, idx); + Value casted = + UnrealizedConversionCastOp::create(builder, loc, ptrType, fnPtr).getResult(0); + result = LLVM::InsertValueOp::create(builder, loc, result, casted, idx); } - builder.create(loc, result); + LLVM::ReturnOp::create(builder, loc, result); return customGradient; } @@ -237,33 +239,33 @@ struct AdjointOpPattern : public ConvertOpToLLVMPattern { rewriter, op, gradFnName, gradFnSignature); // Run the forward pass and cache the circuit. - Value c_true = rewriter.create( - loc, rewriter.getIntegerAttr(IntegerType::get(ctx, 1), 1)); - Value c_false = rewriter.create( - loc, rewriter.getIntegerAttr(IntegerType::get(ctx, 1), 0)); - rewriter.create(loc, cacheFnDecl, c_true); - Value qreg = rewriter.create(loc, callee, op.getArgs()).getResult(0); + Value c_true = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(IntegerType::get(ctx, 1), 1)); + Value c_false = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(IntegerType::get(ctx, 1), 0)); + LLVM::CallOp::create(rewriter, loc, cacheFnDecl, c_true); + Value qreg = func::CallOp::create(rewriter, loc, callee, op.getArgs()).getResult(0); if (!isa(qreg.getType())) return callee.emitOpError("qfunc must return quantum register"); - rewriter.create(loc, cacheFnDecl, c_false); + LLVM::CallOp::create(rewriter, loc, cacheFnDecl, c_false); // We follow the C ABI convention of passing result memrefs as struct pointers in the // arguments to the C function, although in this case as a variadic argument list to allow // for a varying number of results in a single signature. - Value c1 = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value numResults = rewriter.create( - loc, rewriter.getI64IntegerAttr(op.getDataIn().size())); + Value c1 = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value numResults = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(op.getDataIn().size())); SmallVector args = {numResults}; for (Value memref : adaptor.getDataIn()) { - Value newArg = rewriter.create( - loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vectorType, c1); - rewriter.create(loc, memref, newArg); + Value newArg = LLVM::AllocaOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vectorType, c1); + LLVM::StoreOp::create(rewriter, loc, memref, newArg); args.push_back(newArg); } - rewriter.create(loc, gradFnDecl, args); - rewriter.create(loc, qreg); - rewriter.create(loc); + LLVM::CallOp::create(rewriter, loc, gradFnDecl, args); + catalyst::quantum::DeallocOp::create(rewriter, loc, qreg); + catalyst::quantum::DeviceReleaseOp::create(rewriter, loc); rewriter.eraseOp(op); @@ -363,7 +365,7 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { // The first argument to Enzyme is a function pointer of the function to be differentiated Value calleePtr = - rewriter.create(loc, callee.getFunctionType(), callee.getName()); + func::ConstantOp::create(rewriter, loc, callee.getFunctionType(), callee.getName()); calleePtr = castToConvertedType(calleePtr, rewriter, loc); SmallVector callArgs = {calleePtr}; @@ -372,8 +374,8 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { insertGlobalSymbol(rewriter, moduleOp, enzyme_dupnoneed_key, std::nullopt); ValueRange argShadows = adaptor.getDiffArgShadows(); - Value enzymeConst = rewriter.create(loc, LLVM::LLVMPointerType::get(ctx), - enzyme_const_key); + Value enzymeConst = LLVM::AddressOfOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(ctx), enzyme_const_key); ValueRange convArgs = adaptor.getArgs(); // Add the arguments and the argument shadows of memrefs for (auto [index, arg] : llvm::enumerate(op.getArgs())) { @@ -417,7 +419,7 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { // The results of backprop are in argShadows, except scalar derivatives which are in the // results of the enzyme call. - auto enzymeCall = rewriter.create(loc, backpropFnDecl, callArgs); + auto enzymeCall = LLVM::CallOp::create(rewriter, loc, backpropFnDecl, callArgs); SmallVector scalarResults; unpackScalarResults(enzymeCall, scalarResults, rewriter, loc); rewriter.replaceOp(op, scalarResults); @@ -427,8 +429,8 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { private: Value castToConvertedType(Value value, OpBuilder &builder, Location loc) const { - auto casted = builder.create( - loc, getTypeConverter()->convertType(value.getType()), value); + auto casted = UnrealizedConversionCastOp::create( + builder, loc, getTypeConverter()->convertType(value.getType()), value); return casted.getResult(0); } @@ -439,9 +441,9 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { { auto llvmPtrType = LLVM::LLVMPointerType::get(builder.getContext()); auto memRefType = cast(memRefArg.getType()); - Value enzymeConst = builder.create(loc, llvmPtrType, enzyme_const_key); + Value enzymeConst = LLVM::AddressOfOp::create(builder, loc, llvmPtrType, enzyme_const_key); Value enzymeDupNoNeed = - builder.create(loc, llvmPtrType, enzyme_dupnoneed_key); + LLVM::AddressOfOp::create(builder, loc, llvmPtrType, enzyme_dupnoneed_key); Value argStruct = castToConvertedType(memRefArg, builder, loc); MemRefDescriptor desc(argStruct); @@ -462,9 +464,9 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { if (options.zeroOut) { Value bufferSizeBytes = computeMemRefSizeInBytes(memRefType, shadowDesc, builder, loc); - Value zero = builder.create(loc, builder.getI8Type(), 0); - builder.create(loc, shadowPtr, zero, bufferSizeBytes, - /*isVolatile=*/false); + Value zero = LLVM::ConstantOp::create(builder, loc, builder.getI8Type(), 0); + LLVM::MemsetOp::create(builder, loc, shadowPtr, zero, bufferSizeBytes, + /*isVolatile=*/false); } callArgs.push_back(shadowPtr); } @@ -513,7 +515,7 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { if (auto structType = dyn_cast(result.getType())) { size_t numResults = structType.getBody().size(); for (size_t i = 0; i < numResults; i++) { - results.push_back(builder.create(loc, result, i)); + results.push_back(LLVM::ExtractValueOp::create(builder, loc, result, i)); } } } @@ -528,17 +530,17 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { Value bufferSize; Type indexType = getTypeConverter()->getIndexType(); if (type.getRank() == 0) { - bufferSize = builder.create(loc, indexType, builder.getIndexAttr(1)); + bufferSize = LLVM::ConstantOp::create(builder, loc, indexType, builder.getIndexAttr(1)); } else { - bufferSize = builder.create(loc, descriptor.size(builder, loc, 0), - descriptor.stride(builder, loc, 0)); + bufferSize = LLVM::MulOp::create(builder, loc, descriptor.size(builder, loc, 0), + descriptor.stride(builder, loc, 0)); bufferSize = - builder.create(loc, descriptor.offset(builder, loc), bufferSize); + LLVM::AddOp::create(builder, loc, descriptor.offset(builder, loc), bufferSize); } - Value elementByteSize = builder.create( - loc, indexType, builder.getIndexAttr(type.getElementTypeBitWidth() / 8)); - Value bufferSizeBytes = builder.create(loc, elementByteSize, bufferSize); + Value elementByteSize = LLVM::ConstantOp::create( + builder, loc, indexType, builder.getIndexAttr(type.getElementTypeBitWidth() / 8)); + Value bufferSizeBytes = LLVM::MulOp::create(builder, loc, elementByteSize, bufferSize); return bufferSizeBytes; } @@ -575,8 +577,8 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { auto tapeType = LLVM::LLVMPointerType::get(ctx); SmallVector argTypes; convertCustomGradArgumentTypes(qnode.getArgumentTypes(), argTypes); - augmentedForward = builder.create( - qnode.getLoc(), augmentedName, FunctionType::get(ctx, argTypes, {tapeType})); + augmentedForward = func::FuncOp::create(builder, qnode.getLoc(), augmentedName, + FunctionType::get(ctx, argTypes, {tapeType})); augmentedForward.setPrivate(); Location loc = qnode.getLoc(); @@ -589,9 +591,9 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { arguments.push_back(augmentedForward.getArgument(i * 2)); } - builder.create(loc, qnode, arguments); - Value tape = builder.create(loc, tapeType); - builder.create(loc, tape); + func::CallOp::create(builder, loc, qnode, arguments); + Value tape = LLVM::ZeroOp::create(builder, loc, tapeType); + func::ReturnOp::create(builder, loc, tape); return augmentedForward; } @@ -626,7 +628,7 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPoint(qnode); auto funcType = FunctionType::get(ctx, argTypes, {}); - customQGrad = builder.create(qnode.getLoc(), customQGradName, funcType); + customQGrad = func::FuncOp::create(builder, qnode.getLoc(), customQGradName, funcType); customQGrad.setPrivate(); Block *block = customQGrad.addEntryBlock(); builder.setInsertionPointToStart(block); @@ -643,8 +645,8 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { auto unwrapMemRef = [&](Value wrapped, Type unwrappedType) { auto structType = getTypeConverter()->convertType(unwrappedType); - Value unwrapped = builder.create(loc, structType, wrapped); - unwrapped = builder.create(loc, unwrappedType, unwrapped) + Value unwrapped = LLVM::LoadOp::create(builder, loc, structType, wrapped); + unwrapped = UnrealizedConversionCastOp::create(builder, loc, unwrappedType, unwrapped) .getResult(0); return unwrapped; }; @@ -666,10 +668,10 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { Value gateParamShadow = unwrappedShadows.back(); assert(cast(gateParamShadow.getType()).getRank() == 1 && "Expected gate parameter list to be a rank-1 memref"); - Value pcount = builder.create(loc, gateParamShadow, 0); + Value pcount = memref::DimOp::create(builder, loc, gateParamShadow, 0); primalInputs.push_back(pcount); - auto qgrad = builder.create(loc, qgradFn, primalInputs); + auto qgrad = func::CallOp::create(builder, loc, qgradFn, primalInputs); for (unsigned i = 0; i < qnodeType.getNumResults(); i++) { // The QNode has n inputs and m outputs (in destination-passing style). @@ -706,7 +708,7 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { catalyst::einsumLinalgGeneric(builder, loc, resultDims, qgradDims, {0}, resultShadow, qgrad.getResult(i), gateParamShadow); } - builder.create(loc); + func::ReturnOp::create(builder, loc); return customQGrad; } @@ -733,28 +735,30 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { auto ptrType = LLVM::LLVMPointerType::get(context); auto resultType = LLVM::LLVMArrayType::get(ptrType, 4); - builder.create(loc, LLVM::LLVMArrayType::get(builder.getI8Type(), 3), true, - LLVM::Linkage::Linkonce, "dealloc_indices", - builder.getStringAttr(StringRef("-1", 3))); - allocationLike = - builder.create(loc, resultType, - /*isConstant=*/false, LLVM::Linkage::External, - enzyme_allocation_key, /*address space=*/nullptr); + LLVM::GlobalOp::create(builder, loc, LLVM::LLVMArrayType::get(builder.getI8Type(), 3), true, + LLVM::Linkage::Linkonce, "dealloc_indices", + builder.getStringAttr(StringRef("-1", 3))); + allocationLike = LLVM::GlobalOp::create(builder, loc, resultType, + /*isConstant=*/false, LLVM::Linkage::External, + enzyme_allocation_key, /*address space=*/nullptr); builder.createBlock(&allocationLike.getInitializerRegion()); - Value allocFn = builder.create(loc, ptrType, allocFuncName); + Value allocFn = LLVM::AddressOfOp::create(builder, loc, ptrType, allocFuncName); Value sizeArgIndex = - builder.create(loc, indexType, builder.getIndexAttr(0)); - Value sizeArgIndexPtr = builder.create(loc, ptrType, sizeArgIndex); + LLVM::ConstantOp::create(builder, loc, indexType, builder.getIndexAttr(0)); + Value sizeArgIndexPtr = LLVM::IntToPtrOp::create(builder, loc, ptrType, sizeArgIndex); Value deallocIndicesPtr = - builder.create(loc, ptrType, "dealloc_indices"); - Value freeFn = builder.create(loc, ptrType, freeFuncName); - - Value result = builder.create(loc, resultType); - result = builder.create(loc, result, allocFn, 0); - result = builder.create(loc, result, sizeArgIndexPtr, 1); - result = builder.create(loc, result, deallocIndicesPtr, 2); - result = builder.create(loc, result, freeFn, 3); - builder.create(loc, result); + LLVM::AddressOfOp::create(builder, loc, ptrType, "dealloc_indices"); + Value freeFn = LLVM::AddressOfOp::create(builder, loc, ptrType, freeFuncName); + + Value result = LLVM::UndefOp::create(builder, loc, resultType); + result = + LLVM::InsertValueOp::create(builder, loc, result, allocFn, SmallVector{0}); + result = LLVM::InsertValueOp::create(builder, loc, result, sizeArgIndexPtr, + SmallVector{1}); + result = LLVM::InsertValueOp::create(builder, loc, result, deallocIndicesPtr, + SmallVector{2}); + result = LLVM::InsertValueOp::create(builder, loc, result, freeFn, SmallVector{3}); + LLVM::ReturnOp::create(builder, loc, result); return allocationLike; } @@ -774,13 +778,13 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { auto shortTy = IntegerType::get(context, 8); if (!glb) { if (!value) { - rewriter.create(op.getLoc(), shortTy, - /*isConstant=*/true, LLVM::Linkage::Linkonce, key, - IntegerAttr::get(shortTy, 0)); + LLVM::GlobalOp::create(rewriter, op.getLoc(), shortTy, + /*isConstant=*/true, LLVM::Linkage::Linkonce, key, + IntegerAttr::get(shortTy, 0)); } else { - rewriter.create( - op.getLoc(), LLVM::LLVMArrayType::get(shortTy, value->size()), true, + LLVM::GlobalOp::create( + rewriter, op.getLoc(), LLVM::LLVMArrayType::get(shortTy, value->size()), true, LLVM::Linkage::Linkonce, key, rewriter.getStringAttr(*value)); } } @@ -802,9 +806,8 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { if (glb) { return glb; } - glb = rewriter.create(op.getLoc(), LLVM::LLVMArrayType::get(ptrType, 2), - /*isConstant=*/false, LLVM::Linkage::External, key, - nullptr); + glb = LLVM::GlobalOp::create(rewriter, op.getLoc(), LLVM::LLVMArrayType::get(ptrType, 2), + /*isConstant=*/false, LLVM::Linkage::External, key, nullptr); // Create the block and push it back in the global auto *contextGlb = glb.getContext(); @@ -817,19 +820,19 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { // Get original global name auto originalNameRefAttr = SymbolRefAttr::get(contextGlb, originalName); auto originalGlobal = - rewriter.create(glb.getLoc(), llvmPtr, originalNameRefAttr); + LLVM::AddressOfOp::create(rewriter, glb.getLoc(), llvmPtr, originalNameRefAttr); // Get global name auto nameRefAttr = SymbolRefAttr::get(contextGlb, name); - auto enzymeGlobal = rewriter.create(glb.getLoc(), llvmPtr, nameRefAttr); + auto enzymeGlobal = LLVM::AddressOfOp::create(rewriter, glb.getLoc(), llvmPtr, nameRefAttr); auto undefArray = - rewriter.create(glb.getLoc(), LLVM::LLVMArrayType::get(ptrType, 2)); - Value llvmInsert0 = - rewriter.create(glb.getLoc(), undefArray, originalGlobal, 0); - Value llvmInsert1 = - rewriter.create(glb.getLoc(), llvmInsert0, enzymeGlobal, 1); - rewriter.create(glb.getLoc(), llvmInsert1); + LLVM::UndefOp::create(rewriter, glb.getLoc(), LLVM::LLVMArrayType::get(ptrType, 2)); + Value llvmInsert0 = LLVM::InsertValueOp::create(rewriter, glb.getLoc(), undefArray, + originalGlobal, SmallVector{0}); + Value llvmInsert1 = LLVM::InsertValueOp::create(rewriter, glb.getLoc(), llvmInsert0, + enzymeGlobal, SmallVector{1}); + LLVM::ReturnOp::create(rewriter, glb.getLoc(), llvmInsert1); return glb; } }; @@ -866,7 +869,7 @@ struct ForwardOpPattern : public ConvertOpToLLVMPattern { auto oldFuncTy = op.getFunctionType(); auto funcTy = FunctionType::get(ctx, oldFuncTy.getInputs(), {retTy}); - auto func = rewriter.create(op.getLoc(), op.getSymName(), funcTy); + auto func = func::FuncOp::create(rewriter, op.getLoc(), op.getSymName(), funcTy); func.setPrivate(); auto noinline = rewriter.getStringAttr("noinline"); @@ -941,7 +944,7 @@ struct ReverseOpPattern : public ConvertOpToLLVMPattern { PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(mod.getBody()); - auto func = rewriter.create(op.getLoc(), op.getSymName(), newFuncTy); + auto func = func::FuncOp::create(rewriter, op.getLoc(), op.getSymName(), newFuncTy); func.setPrivate(); Block *entry = func.addEntryBlock(); @@ -958,14 +961,14 @@ struct ReverseOpPattern : public ConvertOpToLLVMPattern { for (size_t i = 0; i < tapeCount; i++) { SmallVector pos = {0, static_cast(i)}; Value tapeStructIth = - rewriter.create(loc, wrappedStructValAgg, pos); + LLVM::ExtractValueOp::create(rewriter, loc, wrappedStructValAgg, pos); tapestructs.push_back(tapeStructIth); } SmallVector tapememrefs; for (auto [_struct, memref] : llvm::zip(tapestructs, tapeElements)) { auto memrefTy = memref.getType(); - auto castOp = rewriter.create(loc, memrefTy, _struct); + auto castOp = UnrealizedConversionCastOp::create(rewriter, loc, memrefTy, _struct); tapememrefs.push_back(castOp.getResult(0)); } @@ -1032,7 +1035,7 @@ struct ReturnOpPattern : public ConvertOpToLLVMPattern { { auto loc = op.getLoc(); if (op.getEmpty()) { - auto returnOp = rewriter.create(loc, ValueRange{}); + auto returnOp = LLVM::ReturnOp::create(rewriter, loc, ValueRange{}); rewriter.replaceOp(op, returnOp); return success(); } @@ -1042,8 +1045,8 @@ struct ReturnOpPattern : public ConvertOpToLLVMPattern { if (tape.empty()) { // Just return an empty pointer auto ptrType = LLVM::LLVMPointerType::get(ctx); - Value nullPtr = rewriter.create(loc, ptrType); - auto returnOp = rewriter.create(loc, nullPtr); + Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, ptrType); + auto returnOp = LLVM::ReturnOp::create(rewriter, loc, nullPtr); rewriter.replaceOp(op, returnOp); return success(); } @@ -1056,13 +1059,13 @@ struct ReturnOpPattern : public ConvertOpToLLVMPattern { // This gives me { { memref, ..., memrefn } } auto wrappedFlatTapeStructTy = LLVM::LLVMStructType::getLiteral(ctx, {tapeTy}); - Value result = rewriter.create(loc, wrappedFlatTapeStructTy); + Value result = LLVM::UndefOp::create(rewriter, loc, wrappedFlatTapeStructTy); for (auto [idx, elem] : llvm::enumerate(tapeStructVals)) { SmallVector pos = {0, static_cast(idx)}; - result = rewriter.create(loc, result, elem, pos); + result = LLVM::InsertValueOp::create(rewriter, loc, result, elem, pos); } - auto returnOp = rewriter.create(loc, result); + auto returnOp = LLVM::ReturnOp::create(rewriter, loc, result); rewriter.replaceOp(op, returnOp); return success(); } diff --git a/mlir/lib/Gradient/Transforms/GradMethods/Adjoint.cpp b/mlir/lib/Gradient/Transforms/GradMethods/Adjoint.cpp index 7ab95d8705..307b982dde 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/Adjoint.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/Adjoint.cpp @@ -105,7 +105,7 @@ func::FuncOp AdjointLowering::discardAndReturnReg(PatternRewriter &rewriter, Loc PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointAfter(callee); unallocFn = - rewriter.create(loc, fnName, fnType, visibility, nullptr, nullptr); + func::FuncOp::create(rewriter, loc, fnName, fnType, visibility, nullptr, nullptr); // Clone the body. IRMapping mapper; @@ -153,14 +153,14 @@ func::FuncOp AdjointLowering::genQGradFunction(PatternRewriter &rewriter, Locati PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointAfter(callee); - qGradFn = rewriter.create(loc, fnName, fnType, visibility, nullptr, nullptr); + qGradFn = func::FuncOp::create(rewriter, loc, fnName, fnType, visibility, nullptr, nullptr); rewriter.setInsertionPointToStart(qGradFn.addEntryBlock()); - AdjointOp qGradOp = rewriter.create( - loc, computeQGradTypes(callee), SymbolRefAttr::get(unallocFn), + AdjointOp qGradOp = AdjointOp::create( + rewriter, loc, computeQGradTypes(callee), SymbolRefAttr::get(unallocFn), qGradFn.getArguments().back(), qGradFn.getArguments().drop_back(), ValueRange{}); - rewriter.create(loc, qGradOp.getResults()); + func::ReturnOp::create(rewriter, loc, qGradOp.getResults()); } return qGradFn; diff --git a/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp b/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp index 5c24252a1e..99ea3ec2c1 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp @@ -58,7 +58,7 @@ func::FuncOp genParamCountFunction(PatternRewriter &rewriter, Location loc, func // their gate parameters instead. rewriter.setInsertionPointAfter(callee); paramCountFn = - rewriter.create(loc, fnName, fnType, visibility, nullptr, nullptr); + func::FuncOp::create(rewriter, loc, fnName, fnType, visibility, nullptr, nullptr); rewriter.cloneRegionBefore(callee.getBody(), paramCountFn.getBody(), paramCountFn.end()); PatternRewriter::InsertionGuard insertGuard(rewriter); @@ -68,8 +68,8 @@ func::FuncOp genParamCountFunction(PatternRewriter &rewriter, Location loc, func // for updated parameter counts from arbitrary regions/ops. MemRefType paramCountType = MemRefType::get({}, rewriter.getIndexType()); Value paramCountBuffer = getStaticMemrefAlloca(loc, rewriter, paramCountType); - Value cZero = rewriter.create(loc, 0); - rewriter.create(loc, cZero, paramCountBuffer); + Value cZero = index::ConstantOp::create(rewriter, loc, 0); + memref::StoreOp::create(rewriter, loc, cZero, paramCountBuffer); paramCountFn.walk([&](Operation *op) { // For each quantum gate add the number of parameters to the counter. @@ -79,10 +79,10 @@ func::FuncOp genParamCountFunction(PatternRewriter &rewriter, Location loc, func ValueRange diffParams = gate.getDiffParams(); if (!diffParams.empty()) { - Value currCount = rewriter.create(loc, paramCountBuffer); - Value numParams = rewriter.create(loc, diffParams.size()); - Value newCount = rewriter.create(loc, currCount, numParams); - rewriter.create(loc, newCount, paramCountBuffer); + Value currCount = memref::LoadOp::create(rewriter, loc, paramCountBuffer); + Value numParams = index::ConstantOp::create(rewriter, loc, diffParams.size()); + Value newCount = index::AddOp::create(rewriter, loc, currCount, numParams); + memref::StoreOp::create(rewriter, loc, newCount, paramCountBuffer); } rewriter.replaceOp(gate, gate.getQubitOperands()); @@ -110,7 +110,7 @@ func::FuncOp genParamCountFunction(PatternRewriter &rewriter, Location loc, func PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPoint(op); - Value paramCount = rewriter.create(loc, paramCountBuffer); + Value paramCount = memref::LoadOp::create(rewriter, loc, paramCountBuffer); op->setOperands(paramCount); } }); @@ -139,7 +139,7 @@ func::FuncOp genSplitPreprocessed(PatternRewriter &rewriter, Location loc, func: // First copy the original function as is, then we can replace all quantum ops by collecting // their gate parameters in a memory buffer instead. This buffer is passed into a modified // qnodeQuantum. - splitFn = rewriter.create(loc, fnName, fnType, visibility, nullptr, nullptr); + splitFn = func::FuncOp::create(rewriter, loc, fnName, fnType, visibility, nullptr, nullptr); rewriter.cloneRegionBefore(qnode.getBody(), splitFn.getBody(), splitFn.end()); Block &argMapBlock = splitFn.getFunctionBody().front(); SmallVector qnodeQuantumArgs{argMapBlock.getArguments()}; @@ -147,16 +147,17 @@ func::FuncOp genSplitPreprocessed(PatternRewriter &rewriter, Location loc, func: Value paramCount = argMapBlock.addArgument(rewriter.getIndexType(), loc); PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(&splitFn.getBody().front()); - Value paramsBuffer = rewriter.create(loc, paramsBufferType, paramCount); - Value paramsTensor = rewriter.create( - loc, memref::getTensorTypeFromMemRefType(paramsBuffer.getType()), paramsBuffer, true); + Value paramsBuffer = memref::AllocOp::create(rewriter, loc, paramsBufferType, paramCount); + Value paramsTensor = bufferization::ToTensorOp::create( + rewriter, loc, memref::getTensorTypeFromMemRefType(paramsBuffer.getType()), + paramsBuffer, true); qnodeQuantumArgs.push_back(paramsTensor); MemRefType paramsProcessedType = MemRefType::get({}, rewriter.getIndexType()); Value paramsProcessed = getStaticMemrefAlloca(loc, rewriter, paramsProcessedType); - Value cZero = rewriter.create(loc, 0); - rewriter.create(loc, cZero, paramsProcessed); - Value cOne = rewriter.create(loc, 1); + Value cZero = index::ConstantOp::create(rewriter, loc, 0); + memref::StoreOp::create(rewriter, loc, cZero, paramsProcessed); + Value cOne = index::ConstantOp::create(rewriter, loc, 1); splitFn.walk([&](Operation *op) { // Insert gate parameters into the params buffer. @@ -166,12 +167,12 @@ func::FuncOp genSplitPreprocessed(PatternRewriter &rewriter, Location loc, func: ValueRange diffParams = gate.getDiffParams(); if (!diffParams.empty()) { - Value paramIdx = rewriter.create(loc, paramsProcessed); + Value paramIdx = memref::LoadOp::create(rewriter, loc, paramsProcessed); for (auto param : diffParams) { - rewriter.create(loc, param, paramsBuffer, paramIdx); - paramIdx = rewriter.create(loc, paramIdx, cOne); + memref::StoreOp::create(rewriter, loc, param, paramsBuffer, paramIdx); + paramIdx = index::AddOp::create(rewriter, loc, paramIdx, cOne); } - rewriter.create(loc, paramIdx, paramsProcessed); + memref::StoreOp::create(rewriter, loc, paramIdx, paramsProcessed); } rewriter.replaceOp(op, gate.getQubitOperands()); @@ -199,7 +200,7 @@ func::FuncOp genSplitPreprocessed(PatternRewriter &rewriter, Location loc, func: PatternRewriter::InsertionGuard insertionGuard(rewriter); rewriter.setInsertionPoint(returnOp); auto modifiedCall = - rewriter.create(loc, qnodeQuantum, qnodeQuantumArgs); + func::CallOp::create(rewriter, loc, qnodeQuantum, qnodeQuantumArgs); returnOp.getOperandsMutable().assign(modifiedCall.getResults()); } diff --git a/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp b/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp index 853ecd1e82..aa86af7fa6 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp @@ -56,7 +56,7 @@ LogicalResult FiniteDiffLowering::matchAndRewrite(GradOp op, PatternRewriter &re PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointAfter(callee); - gradFn = rewriter.create(loc, fnName, fnType, visibility, nullptr, nullptr); + gradFn = func::FuncOp::create(rewriter, loc, fnName, fnType, visibility, nullptr, nullptr); rewriter.setInsertionPointToStart(gradFn.addEntryBlock()); computeFiniteDiff(rewriter, loc, gradFn, callee, diffArgIndices, hValue); @@ -75,7 +75,7 @@ void FiniteDiffLowering::computeFiniteDiff(PatternRewriter &rewriter, Location l std::vector gradients; gradients.reserve(gradFn.getNumResults()); - func::CallOp callOp = rewriter.create(loc, callee, callArgs); + func::CallOp callOp = func::CallOp::create(rewriter, loc, callee, callArgs); for (size_t diffResIdx = 0; diffResIdx < callee.getNumResults(); ++diffResIdx) { for (size_t diffArgIdxIdx = 0; diffArgIdxIdx < diffArgIndices.size(); ++diffArgIdxIdx) { size_t diffArgIdx = diffArgIndices[diffArgIdxIdx]; @@ -111,45 +111,45 @@ void FiniteDiffLowering::computeFiniteDiff(PatternRewriter &rewriter, Location l std::vector dynamicDimSizes; for (int64_t j = 0; j < resultRank; j++) { if (resultShape[j] == ShapedType::kDynamic) { - dynamicDimSizes.push_back(rewriter.create(loc, callRes, j)); + dynamicDimSizes.push_back(tensor::DimOp::create(rewriter, loc, callRes, j)); } } for (int64_t i = 0; i < operandRank; i++) { if (operandShape[i] == ShapedType::kDynamic) { - dynamicDimSizes.push_back(rewriter.create(loc, diffArg, i)); + dynamicDimSizes.push_back(tensor::DimOp::create(rewriter, loc, diffArg, i)); } } TypedAttr shiftForResult = rewriter.getFloatAttr(baseResultTy, hValue); - Value hForResult = rewriter.create(loc, shiftForResult); + Value hForResult = arith::ConstantOp::create(rewriter, loc, shiftForResult); if (isGradientTensor && cast(gradientTy).hasStaticShape()) { - hForResult = rewriter.create(loc, hForResult, gradientTy); + hForResult = tensor::SplatOp::create(rewriter, loc, hForResult, gradientTy); } else if (isGradientTensor) { - Value outTensor = rewriter.create(loc, gradientShape, baseResultTy, - dynamicDimSizes); + Value outTensor = tensor::EmptyOp::create(rewriter, loc, gradientShape, + baseResultTy, dynamicDimSizes); hForResult = - rewriter.create(loc, hForResult, outTensor).getResult(0); + linalg::FillOp::create(rewriter, loc, hForResult, outTensor).getResult(0); } TypedAttr shiftForOperand = isOperandScalarTensor ? (TypedAttr)DenseFPElementsAttr::get(cast(operandTy), hValue) : (TypedAttr)rewriter.getFloatAttr(baseOperandTy, hValue); - Value hForOperand = rewriter.create(loc, shiftForOperand); + Value hForOperand = arith::ConstantOp::create(rewriter, loc, shiftForOperand); Value gradient; if (!isOperandTensor || isOperandScalarTensor) { - Value diffArgShifted = rewriter.create(loc, diffArg, hForOperand); + Value diffArgShifted = arith::AddFOp::create(rewriter, loc, diffArg, hForOperand); std::vector callArgsForward(callArgs.begin(), callArgs.end()); callArgsForward[diffArgIdx] = diffArgShifted; func::CallOp callOpForward = - rewriter.create(loc, callee, callArgsForward); + func::CallOp::create(rewriter, loc, callee, callArgsForward); Value callResForward = callOpForward.getResult(diffResIdx); - gradient = rewriter.create(loc, callResForward, callRes); + gradient = arith::SubFOp::create(rewriter, loc, callResForward, callRes); } else { auto bodyBuilder = [&](OpBuilder &rewriter, Location loc, @@ -165,49 +165,51 @@ void FiniteDiffLowering::computeFiniteDiff(PatternRewriter &rewriter, Location l auto memrefTy = bufferization::getMemRefTypeWithStaticIdentityLayout( cast(tensorTy)); auto toBufferOp = - rewriter.create(loc, memrefTy, diffArg); + bufferization::ToBufferOp::create(rewriter, loc, memrefTy, diffArg); - auto cloneOp = rewriter.create(loc, toBufferOp); + auto cloneOp = bufferization::CloneOp::create(rewriter, loc, toBufferOp); - auto toTensorOp = rewriter.create( - loc, memref::getTensorTypeFromMemRefType(cloneOp.getOutput().getType()), - cloneOp, true); + auto toTensorOp = bufferization::ToTensorOp::create( + rewriter, loc, + memref::getTensorTypeFromMemRefType(cloneOp.getOutput().getType()), cloneOp, + true); auto diffArgCopy = toTensorOp.getResult(); - Value diffArgElem = rewriter.create( - loc, diffArgCopy, tensorIndices.take_back(operandRank)); + Value diffArgElem = tensor::ExtractOp::create( + rewriter, loc, diffArgCopy, tensorIndices.take_back(operandRank)); Value diffArgElemShifted = - rewriter.create(loc, diffArgElem, hForOperand); - Value diffArgShifted = rewriter.create( - loc, diffArgElemShifted, diffArgCopy, tensorIndices.take_back(operandRank)); + arith::AddFOp::create(rewriter, loc, diffArgElem, hForOperand); + Value diffArgShifted = + tensor::InsertOp::create(rewriter, loc, diffArgElemShifted, diffArgCopy, + tensorIndices.take_back(operandRank)); std::vector callArgsForward(callArgs.begin(), callArgs.end()); callArgsForward[diffArgIdx] = diffArgShifted; func::CallOp callOpForward = - rewriter.create(loc, callee, callArgsForward); + func::CallOp::create(rewriter, loc, callee, callArgsForward); Value callResForward = callOpForward.getResult(diffResIdx); - Value result = rewriter.create(loc, callResForward, callRes); + Value result = arith::SubFOp::create(rewriter, loc, callResForward, callRes); if (isResultTensor) { - result = rewriter.create( - loc, result, tensorIndices.take_front(resultRank)); + result = tensor::ExtractOp::create(rewriter, loc, result, + tensorIndices.take_front(resultRank)); } - rewriter.create(loc, result); + tensor::YieldOp::create(rewriter, loc, result); }; - gradient = rewriter.create(loc, gradientTy, dynamicDimSizes, - bodyBuilder); + gradient = tensor::GenerateOp::create(rewriter, loc, gradientTy, dynamicDimSizes, + bodyBuilder); } - gradient = rewriter.create(loc, gradient, hForResult); + gradient = arith::DivFOp::create(rewriter, loc, gradient, hForResult); gradients.push_back(gradient); } } - rewriter.create(loc, gradients); + func::ReturnOp::create(rewriter, loc, gradients); } } // namespace gradient diff --git a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp index 75ad0a61a0..0a685a137d 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp @@ -69,7 +69,7 @@ void iterateOverEntries(RankedTensorType resultType, OpBuilder &builder, Locatio SmallVector indices; for (int64_t dim = 0; dim < resultType.getRank(); dim++) { indices.push_back( - builder.create(loc, flatIdx / strides[dim] % shape[dim])); + index::ConstantOp::create(builder, loc, flatIdx / strides[dim] % shape[dim])); } processWithIndices(indices); @@ -85,21 +85,21 @@ void initializeCotangents(TypeRange primalResultTypes, unsigned activeResult, Va ? cast(activeResultType).getElementType() : activeResultType); - Value zero = builder.create( - loc, elementType, APFloat(elementType.getFloatSemantics(), 0)); - Value one = builder.create(loc, elementType, - APFloat(elementType.getFloatSemantics(), 1)); + Value zero = arith::ConstantFloatOp::create(builder, loc, elementType, + APFloat(elementType.getFloatSemantics(), 0)); + Value one = arith::ConstantFloatOp::create(builder, loc, elementType, + APFloat(elementType.getFloatSemantics(), 1)); Value zeroTensor; if (auto activeResultTensor = dyn_cast(activeResultType)) { - zeroTensor = builder.create(loc, activeResultTensor, - /*dynamicSizes=*/ValueRange{}); + zeroTensor = tensor::EmptyOp::create(builder, loc, activeResultTensor, + /*dynamicSizes=*/ValueRange{}); } else { - zeroTensor = builder.create(loc, ArrayRef(), activeResultType); + zeroTensor = tensor::EmptyOp::create(builder, loc, ArrayRef(), activeResultType); } - zeroTensor = builder.create(loc, zero, zeroTensor).getResult(0); - Value cotangent = builder.create(loc, one, zeroTensor, indices); + zeroTensor = linalg::FillOp::create(builder, loc, zero, zeroTensor).getResult(0); + Value cotangent = tensor::InsertOp::create(builder, loc, one, zeroTensor, indices); // Initialize cotangents for all of the primal outputs for (const auto &[resultIdx, primalResultType] : llvm::enumerate(primalResultTypes)) { @@ -112,14 +112,14 @@ void initializeCotangents(TypeRange primalResultTypes, unsigned activeResult, Va // an explicit empty + fill vs a constant tensor. Value zeroTensor; if (isa(primalResultType)) { - zeroTensor = builder.create(loc, primalResultType, ValueRange{}); + zeroTensor = tensor::EmptyOp::create(builder, loc, primalResultType, ValueRange{}); } else { zeroTensor = - builder.create(loc, ArrayRef(), primalResultType); + tensor::EmptyOp::create(builder, loc, ArrayRef(), primalResultType); } cotangents.push_back( - builder.create(loc, zero, zeroTensor).getResult(0)); + linalg::FillOp::create(builder, loc, zero, zeroTensor).getResult(0)); } } } @@ -185,7 +185,7 @@ static FailureOr cloneCallee(PatternRewriter &rewriter, Operation rewriter.setInsertionPoint(callSite); Value paramCount = - rewriter.create(loc, paramCountFn, argOperands).getResult(0); + func::CallOp::create(rewriter, loc, paramCountFn, argOperands).getResult(0); backpropArgs.push_back(paramCount); // If the callee is a QNode, we want to backprop through the split preprocessed // version. @@ -199,10 +199,9 @@ static FailureOr cloneCallee(PatternRewriter &rewriter, Operation if (callOp.getCallee() == qnode.getName()) { PatternRewriter::InsertionGuard insertionGuard(rewriter); rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front()); - Value paramCount = - rewriter - .create(loc, paramCountFn, callOp.getArgOperands()) - .getResult(0); + Value paramCount = func::CallOp::create(rewriter, loc, paramCountFn, + callOp.getArgOperands()) + .getResult(0); callOp.setCallee(qnodeSplit.getName()); callOp.getOperandsMutable().append(paramCount); } @@ -298,7 +297,7 @@ static func::FuncOp genQNodeQuantumOnly(PatternRewriter &rewriter, Location loc, return modifiedCallee; } - modifiedCallee = rewriter.create(loc, fnName, fnType); + modifiedCallee = func::FuncOp::create(rewriter, loc, fnName, fnType); modifiedCallee.setPrivate(); rewriter.cloneRegionBefore(qnode.getBody(), modifiedCallee.getBody(), modifiedCallee.end()); Block &entryBlock = modifiedCallee.getFunctionBody().front(); @@ -309,16 +308,16 @@ static func::FuncOp genQNodeQuantumOnly(PatternRewriter &rewriter, Location loc, MemRefType paramsProcessedType = MemRefType::get({}, rewriter.getIndexType()); Value paramCounter = getStaticMemrefAlloca(loc, rewriter, paramsProcessedType); - Value cZero = rewriter.create(loc, 0); - rewriter.create(loc, cZero, paramCounter); - Value cOne = rewriter.create(loc, 1); + Value cZero = index::ConstantOp::create(rewriter, loc, 0); + memref::StoreOp::create(rewriter, loc, cZero, paramCounter); + Value cOne = index::ConstantOp::create(rewriter, loc, 1); auto loadThenIncrementCounter = [&](OpBuilder &builder, Value counter, Value paramTensor) -> Value { - Value index = builder.create(loc, counter); - Value nextIndex = builder.create(loc, index, cOne); - builder.create(loc, nextIndex, counter); - return builder.create(loc, paramTensor, index); + Value index = memref::LoadOp::create(builder, loc, counter); + Value nextIndex = index::AddOp::create(builder, loc, index, cOne); + memref::StoreOp::create(builder, loc, nextIndex, counter); + return tensor::ExtractOp::create(builder, loc, paramTensor, index); }; modifiedCallee.walk([&](quantum::DifferentiableGate gateOp) { @@ -378,7 +377,7 @@ static func::FuncOp genFullGradFunction(PatternRewriter &rewriter, Location loc, PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointAfter(callee); - fullGradFn = rewriter.create(loc, fnName, fnType); + fullGradFn = func::FuncOp::create(rewriter, loc, fnName, fnType); fullGradFn.setPrivate(); Block *entryBlock = fullGradFn.addEntryBlock(); rewriter.setInsertionPointToStart(entryBlock); @@ -392,12 +391,12 @@ static func::FuncOp genFullGradFunction(PatternRewriter &rewriter, Location loc, for (unsigned argIdx = 0; argIdx < diffArgIndices.size(); argIdx++) { Type jacobianType = resultTypes[argIdx + cotangentIdx * diffArgIndices.size()]; if (auto tensorType = dyn_cast(jacobianType)) { - jacobians.push_back(rewriter.create( - loc, tensorType.getShape(), tensorType.getElementType())); + jacobians.push_back(tensor::EmptyOp::create( + rewriter, loc, tensorType.getShape(), tensorType.getElementType())); } else { - jacobians.push_back(rewriter.create( - loc, cast(jacobianType), APFloat(0.0))); + jacobians.push_back(arith::ConstantFloatOp::create( + rewriter, loc, cast(jacobianType), APFloat(0.0))); } } @@ -412,8 +411,8 @@ static func::FuncOp genFullGradFunction(PatternRewriter &rewriter, Location loc, initializeCotangents(callee.getResultTypes(), cotangentIdx, indices, rewriter, loc, cotangents); - auto backpropOp = rewriter.create( - loc, valTypes, computeBackpropTypes(callee, diffArgIndices), + auto backpropOp = gradient::BackpropOp::create( + rewriter, loc, valTypes, computeBackpropTypes(callee, diffArgIndices), SymbolRefAttr::get(callee), entryBlock->getArguments(), /*arg_shadows=*/ValueRange{}, /*primal results=*/ValueRange{}, cotangents, diffArgIndicesAttr, @@ -458,9 +457,9 @@ static func::FuncOp genFullGradFunction(PatternRewriter &rewriter, Location loc, SmallVector strides{jacobianRank, rewriter.getIndexAttr(1)}; - jacobians[backpropIdx] = rewriter.create( - loc, jacobianSlice, jacobians[backpropIdx], offsets, sizes, - strides); + jacobians[backpropIdx] = tensor::InsertSliceOp::create( + rewriter, loc, jacobianSlice, jacobians[backpropIdx], + offsets, sizes, strides); } else { jacobians[backpropIdx] = jacobianSlice; @@ -468,8 +467,8 @@ static func::FuncOp genFullGradFunction(PatternRewriter &rewriter, Location loc, } else { assert(isa(jacobianSlice.getType())); - jacobians[backpropIdx] = rewriter.create( - loc, jacobianSlice, jacobians[backpropIdx], indices); + jacobians[backpropIdx] = tensor::InsertOp::create( + rewriter, loc, jacobianSlice, jacobians[backpropIdx], indices); } backpropGradResults[backpropIdx + cotangentIdx * diffArgIndices.size()] = @@ -483,8 +482,8 @@ static func::FuncOp genFullGradFunction(PatternRewriter &rewriter, Location loc, initializeCotangents(callee.getResultTypes(), cotangentIdx, ValueRange(), rewriter, loc, cotangents); - auto backpropOp = rewriter.create( - loc, valTypes, computeBackpropTypes(callee, diffArgIndices), + auto backpropOp = gradient::BackpropOp::create( + rewriter, loc, valTypes, computeBackpropTypes(callee, diffArgIndices), SymbolRefAttr::get(callee), entryBlock->getArguments(), /*arg_shadows=*/ValueRange{}, /*primal results=*/ValueRange{}, cotangents, diffArgIndicesAttr, keepValueResults); @@ -504,8 +503,8 @@ static func::FuncOp genFullGradFunction(PatternRewriter &rewriter, Location loc, // user requests a scalar, give it to them. if (isa(jacobianSlice.getType()) && isa(resultTypes[resultIdx])) { - backpropGradResults[resultIdx] = rewriter.create( - loc, jacobianSlice, ValueRange{}); + backpropGradResults[resultIdx] = tensor::ExtractOp::create( + rewriter, loc, jacobianSlice, ValueRange{}); } } } @@ -519,7 +518,7 @@ static func::FuncOp genFullGradFunction(PatternRewriter &rewriter, Location loc, backpropResults.insert(backpropResults.end(), backpropGradResults.begin(), backpropGradResults.end()); - rewriter.create(loc, backpropResults); + func::ReturnOp::create(rewriter, loc, backpropResults); } // if (!fullGradFn) return fullGradFn; } diff --git a/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp b/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp index ad0e26cdf0..e4f4cf244b 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp @@ -92,12 +92,12 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew assert(grad_result_types.size() == func_diff_operand_indices.size() * funcResultTypes.size() && "GradOp does't seem to return a tuple of Jacobians"); - auto fCallOp = rewriter.create(loc, calleeOp, calleeOperands); + auto fCallOp = func::CallOp::create(rewriter, loc, calleeOp, calleeOperands); - auto gradOp = rewriter.create(loc, grad_result_types, op.getMethod(), op.getCallee(), - calleeOperands, op.getDiffArgIndicesAttr(), - op.getFiniteDiffParamAttr(), /*arg_attrs=*/nullptr, - /*res_attrs=*/nullptr); + auto gradOp = GradOp::create(rewriter, loc, grad_result_types, op.getMethod(), op.getCallee(), + calleeOperands, op.getDiffArgIndicesAttr(), + op.getFiniteDiffParamAttr(), /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr); std::vector einsumResults; for (size_t nout = 0; nout < funcResultTypes.size(); nout++) { @@ -158,8 +158,9 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew } else { assert(acc.value().getType() == res.getType()); - auto addOp = rewriter.create( - loc, res.getType(), ValueRange{acc.value(), res}, ValueRange{acc.value()}); + auto addOp = + linalg::AddOp::create(rewriter, loc, res.getType(), + ValueRange{acc.value(), res}, ValueRange{acc.value()}); acc = addOp.getResultTensors()[0]; } } @@ -210,12 +211,12 @@ LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rew assert(grad_result_types.size() == func_diff_operand_indices.size() * funcResultTypes.size() && "GradOp does't seem to return a tuple of Jacobians"); - auto fCallOp = rewriter.create(loc, calleeOp, calleeOperands); + auto fCallOp = func::CallOp::create(rewriter, loc, calleeOp, calleeOperands); - auto gradOp = rewriter.create(loc, grad_result_types, op.getMethod(), op.getCallee(), - calleeOperands, op.getDiffArgIndicesAttr(), - op.getFiniteDiffParamAttr(), /*arg_attrs=*/nullptr, - /*res_attrs=*/nullptr); + auto gradOp = GradOp::create(rewriter, loc, grad_result_types, op.getMethod(), op.getCallee(), + calleeOperands, op.getDiffArgIndicesAttr(), + op.getFiniteDiffParamAttr(), /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr); std::vector einsumResults; for (size_t nparam = 0; nparam < func_diff_operand_indices.size(); nparam++) { @@ -272,8 +273,9 @@ LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rew else { assert(acc.value().getType() == res.getType()); - auto addOp = rewriter.create( - loc, res.getType(), ValueRange{acc.value(), res}, ValueRange{acc.value()}); + auto addOp = + linalg::AddOp::create(rewriter, loc, res.getType(), + ValueRange{acc.value(), res}, ValueRange{acc.value()}); acc = addOp.getResultTensors()[0]; } } diff --git a/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp b/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp index 7227a3d35a..8f19717351 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp @@ -31,25 +31,25 @@ static Value genSelectiveShift(PatternRewriter &rewriter, Location loc, Value pa const std::vector> &selectors) { if (selectors.empty()) { - return rewriter.create(loc, shift, param); + return arith::AddFOp::create(rewriter, loc, shift, param); } // Make sure all active iteration variables match the selectors. - Value shiftCondition = rewriter.create(loc, 1, true); + Value shiftCondition = arith::ConstantIntOp::create(rewriter, loc, 1, true); for (auto &[iteration, selector] : selectors) { Value iterationMatch = - rewriter.create(loc, arith::CmpIPredicate::eq, iteration, selector); - shiftCondition = rewriter.create(loc, shiftCondition, iterationMatch); + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, iteration, selector); + shiftCondition = arith::AndIOp::create(rewriter, loc, shiftCondition, iterationMatch); } - scf::IfOp ifOp = rewriter.create( - loc, shiftCondition, + scf::IfOp ifOp = scf::IfOp::create( + rewriter, loc, shiftCondition, [&](OpBuilder &builder, Location loc) { // then - Value shiftedParam = builder.create(loc, shift, param); - builder.create(loc, shiftedParam); + Value shiftedParam = arith::AddFOp::create(builder, loc, shift, param); + scf::YieldOp::create(builder, loc, shiftedParam); }, [&](OpBuilder &builder, Location loc) { // else - builder.create(loc, param); + scf::YieldOp::create(builder, loc, param); }); return ifOp.getResult(0); @@ -79,7 +79,7 @@ func::FuncOp ParameterShiftLowering::genShiftFunction(PatternRewriter &rewriter, PatternRewriter::InsertionGuard insertGuard(rewriter); shiftedFn = - rewriter.create(loc, fnName, fnType, visibility, nullptr, nullptr); + func::FuncOp::create(rewriter, loc, fnName, fnType, visibility, nullptr, nullptr); // First copy the entire function as is, then we can add the shifts. // Make sure to add the shiftVector/selectorVector parameters to the new function. @@ -98,9 +98,9 @@ func::FuncOp ParameterShiftLowering::genShiftFunction(PatternRewriter &rewriter, PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(forOp.getBody()); - Value idx = rewriter.create( - loc, rewriter.getIndexAttr(selectors.size())); - Value selector = rewriter.create(loc, selectorVector, idx); + Value idx = arith::ConstantOp::create(rewriter, loc, + rewriter.getIndexAttr(selectors.size())); + Value selector = tensor::ExtractOp::create(rewriter, loc, selectorVector, idx); Value iteration = forOp.getInductionVar(); selectors.push_back({iteration, selector}); } @@ -117,8 +117,8 @@ func::FuncOp ParameterShiftLowering::genShiftFunction(PatternRewriter &rewriter, shiftedParams.reserve(params.size()); for (size_t i = 0; i < params.size(); i++) { - Value idx = rewriter.create(loc, shiftsProcessed++); - Value shift = rewriter.create(loc, shiftVector, idx); + Value idx = index::ConstantOp::create(rewriter, loc, shiftsProcessed++); + Value shift = tensor::ExtractOp::create(rewriter, loc, shiftVector, idx); Value shiftedParam = genSelectiveShift(rewriter, loc, params[i], shift, selectors); shiftedParams.push_back(shiftedParam); diff --git a/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp index d5cb0c117e..6abeb701fe 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp @@ -44,8 +44,8 @@ static void updateSelectorVector(PatternRewriter &rewriter, Location loc, for (auto &[forOp, idx] : selectorsToStore) { rewriter.setInsertionPointToStart(forOp.getBody()); Value iteration = forOp.getInductionVar(); - Value idxVal = rewriter.create(loc, idx); - rewriter.create(loc, iteration, selectorBuffer, idxVal); + Value idxVal = index::ConstantOp::create(rewriter, loc, idx); + memref::StoreOp::create(rewriter, loc, iteration, selectorBuffer, idxVal); } selectorsToStore.clear(); @@ -59,8 +59,9 @@ static std::vector computePartialDerivative(PatternRewriter &rewriter, Lo { constexpr double shift = llvm::numbers::pi / 2; ShapedType shiftVectorType = RankedTensorType::get({numShifts}, rewriter.getF64Type()); - Value selectorVector = rewriter.create( - loc, memref::getTensorTypeFromMemRefType(selectorBuffer.getType()), selectorBuffer, true); + Value selectorVector = bufferization::ToTensorOp::create( + rewriter, loc, memref::getTensorTypeFromMemRefType(selectorBuffer.getType()), + selectorBuffer, true); // Define the shift vectors (pos/neg) as sparse tensor constants. DenseElementsAttr nonZeroIndices = rewriter.getI64TensorAttr(currentShift); @@ -69,32 +70,32 @@ static std::vector computePartialDerivative(PatternRewriter &rewriter, Lo DenseFPElementsAttr::get(RankedTensorType::get(1, rewriter.getF64Type()), shift); TypedAttr shiftVectorAttrPos = SparseElementsAttr::get(shiftVectorType, nonZeroIndices, nonZeroValuesPos); - Value shiftVectorPos = rewriter.create(loc, shiftVectorAttrPos); + Value shiftVectorPos = arith::ConstantOp::create(rewriter, loc, shiftVectorAttrPos); DenseElementsAttr nonZeroValuesNeg = DenseFPElementsAttr::get(RankedTensorType::get(1, rewriter.getF64Type()), -shift); TypedAttr shiftVectorAttrNeg = SparseElementsAttr::get(shiftVectorType, nonZeroIndices, nonZeroValuesNeg); - Value shiftVectorNeg = rewriter.create(loc, shiftVectorAttrNeg); + Value shiftVectorNeg = arith::ConstantOp::create(rewriter, loc, shiftVectorAttrNeg); // Compute the partial derivate for this parameter via the simplified // parameter-shift rule: df/dx = [f(x + pi/2) - f(x - pi/2)] / 2. callArgs.push_back(shiftVectorPos); callArgs.push_back(selectorVector); - ValueRange evalPos = rewriter.create(loc, shiftedFn, callArgs).getResults(); + ValueRange evalPos = func::CallOp::create(rewriter, loc, shiftedFn, callArgs).getResults(); callArgs[callArgs.size() - 2] = shiftVectorNeg; - ValueRange evalNeg = rewriter.create(loc, shiftedFn, callArgs).getResults(); + ValueRange evalNeg = func::CallOp::create(rewriter, loc, shiftedFn, callArgs).getResults(); std::vector derivatives; derivatives.reserve(evalPos.size()); for (size_t i = 0; i < evalPos.size(); i++) { - Value diff = rewriter.create(loc, evalPos[i], evalNeg[i]); - Value divisor = rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + Value diff = arith::SubFOp::create(rewriter, loc, evalPos[i], evalNeg[i]); + Value divisor = arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(2.0)); if (auto tensorType = dyn_cast(evalPos[i].getType())) - divisor = rewriter.create(loc, divisor, tensorType); - derivatives.push_back(rewriter.create(loc, diff, divisor)); + divisor = tensor::SplatOp::create(rewriter, loc, divisor, tensorType); + derivatives.push_back(arith::DivFOp::create(rewriter, loc, diff, divisor)); } return derivatives; @@ -105,7 +106,7 @@ static void storePartialDerivative(PatternRewriter &rewriter, Location loc, ValueRange gradientBuffers, Value gradientsProcessed, ValueRange derivatives) { - Value gradIdx = rewriter.create(loc, gradientsProcessed); + Value gradIdx = memref::LoadOp::create(rewriter, loc, gradientsProcessed); for (size_t i = 0; i < gradientBuffers.size(); i++) { Value gradientBuffer = gradientBuffers[i]; @@ -139,8 +140,8 @@ static void storePartialDerivative(PatternRewriter &rewriter, Location loc, std::vector dynSizes; for (int64_t dim = 1; dim < rank; dim++) { if (sizes[dim] == ShapedType::kDynamic) { - Value idx = rewriter.create(loc, dim); - Value dimSize = rewriter.create(loc, gradientBuffer, idx); + Value idx = index::ConstantOp::create(rewriter, loc, dim); + Value dimSize = tensor::DimOp::create(rewriter, loc, gradientBuffer, idx); dynSizes.push_back(dimSize); } } @@ -156,25 +157,25 @@ static void storePartialDerivative(PatternRewriter &rewriter, Location loc, gradientBufferType.getShape().drop_front(), gradientBufferType, offsets, sizes, strides); Value gradientSubview = - rewriter.create(loc, resultType, gradientBuffer, dynOffsets, - dynSizes, dynStrides, offsets, sizes, strides); + memref::SubViewOp::create(rewriter, loc, resultType, gradientBuffer, dynOffsets, + dynSizes, dynStrides, offsets, sizes, strides); - auto materializeOp = rewriter.create( - loc, derivative, gradientSubview); + auto materializeOp = bufferization::MaterializeInDestinationOp::create( + rewriter, loc, derivative, gradientSubview); materializeOp.setWritable(true); } else if (isDerivativeScalarTensor) { - Value extracted = rewriter.create(loc, derivative); - rewriter.create(loc, extracted, gradientBuffer, gradIdx); + Value extracted = tensor::ExtractOp::create(rewriter, loc, derivative); + memref::StoreOp::create(rewriter, loc, extracted, gradientBuffer, gradIdx); } else { - rewriter.create(loc, derivative, gradientBuffer, gradIdx); + memref::StoreOp::create(rewriter, loc, derivative, gradientBuffer, gradIdx); } } - Value cOne = rewriter.create(loc, 1); - Value newGradIdx = rewriter.create(loc, gradIdx, cOne); - rewriter.create(loc, newGradIdx, gradientsProcessed); + Value cOne = index::ConstantOp::create(rewriter, loc, 1); + Value newGradIdx = index::AddOp::create(rewriter, loc, gradIdx, cOne); + memref::StoreOp::create(rewriter, loc, newGradIdx, gradientsProcessed); } func::FuncOp ParameterShiftLowering::genQGradFunction(PatternRewriter &rewriter, Location loc, @@ -199,7 +200,7 @@ func::FuncOp ParameterShiftLowering::genQGradFunction(PatternRewriter &rewriter, PatternRewriter::InsertionGuard insertGuard(rewriter); gradientFn = - rewriter.create(loc, fnName, fnType, visibility, nullptr, nullptr); + func::FuncOp::create(rewriter, loc, fnName, fnType, visibility, nullptr, nullptr); // First copy the entire function as is, then we can modify it to compute the gradient. rewriter.cloneRegionBefore(callee.getBody(), gradientFn.getBody(), gradientFn.end()); @@ -221,15 +222,15 @@ func::FuncOp ParameterShiftLowering::genQGradFunction(PatternRewriter &rewriter, PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(&gradientFn.getBody().front()); - cZero = rewriter.create(loc, 0); - cOne = rewriter.create(loc, 1); + cZero = index::ConstantOp::create(rewriter, loc, 0); + cOne = index::ConstantOp::create(rewriter, loc, 1); // Use stack allocation for selector vector as it's not expected to be too big. selectorBuffer = getStaticMemrefAlloca(loc, rewriter, selectorBufferType); auto gradientsProcessedTy = MemRefType::get({}, rewriter.getIndexType()); gradientsProcessed = getStaticMemrefAlloca(loc, rewriter, gradientsProcessedTy); - rewriter.create(loc, cZero, gradientsProcessed); + memref::StoreOp::create(rewriter, loc, cZero, gradientsProcessed); for (Type gradType : gradResTypes) { TensorType gradTensorType = cast(gradType); @@ -238,7 +239,7 @@ func::FuncOp ParameterShiftLowering::genQGradFunction(PatternRewriter &rewriter, // TODO: add support for dynamic result dimensions gradientBuffers.push_back( - rewriter.create(loc, gradBufferType, gradientSize)); + memref::AllocOp::create(rewriter, loc, gradBufferType, gradientSize)); } } @@ -286,8 +287,9 @@ func::FuncOp ParameterShiftLowering::genQGradFunction(PatternRewriter &rewriter, std::vector gradientTensors; gradientTensors.reserve(gradResTypes.size()); for (Value gradientBuffer : gradientBuffers) { - gradientTensors.push_back(rewriter.create( - loc, memref::getTensorTypeFromMemRefType(gradientBuffer.getType()), + gradientTensors.push_back(bufferization::ToTensorOp::create( + rewriter, loc, + memref::getTensorTypeFromMemRefType(gradientBuffer.getType()), gradientBuffer, true)); } op->setOperands(gradientTensors); diff --git a/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp b/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp index d34dd89974..7c8f2d2e85 100644 --- a/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp @@ -134,7 +134,7 @@ struct PostprocessForwardOp : public OpRewritePattern { for (Value operand : returnOp.getOperands()) { if (isa(operand.getType()) && idx < resc) { BlockArgument output = op.getArgument(idx * 2 + argc * 2); - rewriter.create(returnOp.getLoc(), operand, output); + memref::CopyOp::create(rewriter, returnOp.getLoc(), operand, output); idx++; } else { @@ -226,7 +226,7 @@ struct PostprocessReverseOp : public OpRewritePattern { for (Value operand : returnOp.getOperands()) { if (isa(operand.getType()) && idx < forwardArgc) { BlockArgument output = op.getArgument(2 * idx + 1); - rewriter.create(returnOp.getLoc(), operand, output); + memref::CopyOp::create(rewriter, returnOp.getLoc(), operand, output); idx++; } } diff --git a/mlir/lib/Gradient/Transforms/PreprocessingPatterns.cpp b/mlir/lib/Gradient/Transforms/PreprocessingPatterns.cpp index dd5c7251f7..13ca2ba17e 100644 --- a/mlir/lib/Gradient/Transforms/PreprocessingPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/PreprocessingPatterns.cpp @@ -52,11 +52,11 @@ struct PreprocessForwardOp : public OpRewritePattern { auto implResTy = implOp.getResultTypes(); Location loc = op.getLoc(); - auto callOp = rewriter.create(loc, impl, implResTy, inputs); + auto callOp = func::CallOp::create(rewriter, loc, impl, implResTy, inputs); SmallVector outputs(callOp.getResults()); auto F = rewriter.getIntegerAttr(rewriter.getI1Type(), 0); - rewriter.create(loc, outputs, F); + catalyst::gradient::ReturnOp::create(rewriter, loc, outputs, F); return success(); } @@ -96,11 +96,11 @@ struct PreprocessReverseOp : public OpRewritePattern { auto implResTy = implOp.getResultTypes(); Location loc = op.getLoc(); - auto callOp = rewriter.create(loc, impl, implResTy, tapeInputs); + auto callOp = func::CallOp::create(rewriter, loc, impl, implResTy, tapeInputs); SmallVector outputs(callOp.getResults()); auto T = rewriter.getIntegerAttr(rewriter.getI1Type(), 1); - rewriter.create(loc, outputs, T); + catalyst::gradient::ReturnOp::create(rewriter, loc, outputs, T); return success(); } diff --git a/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp b/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp index 08c082a316..a6b6f73731 100644 --- a/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp +++ b/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp @@ -76,7 +76,7 @@ LogicalResult catalyst::convertToDestinationPassingStyle(func::FuncOp callee, Op BlockArgument output = callee.getArgument(idx + dpsOutputIdx); // We need a linalg.copy instead of a memref.copy here because it provides better // type information at the LLVM level for Enzyme. - builder.create(returnOp.getLoc(), operand, output); + linalg::CopyOp::create(builder, returnOp.getLoc(), operand, output); idx++; } else { diff --git a/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp b/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp index 4d1164c140..b1cb7a0315 100644 --- a/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp +++ b/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp @@ -31,8 +31,8 @@ Value buildBufferLinalgGeneric(OpBuilder &builder, Location loc, ValueRange oper ArrayRef iteratorTypes, function_ref buildBody) { - builder.create(loc, operands, output, indexingMaps, iteratorTypes, - buildBody); + linalg::GenericOp::create(builder, loc, operands, output, indexingMaps, iteratorTypes, + buildBody); return output; } @@ -43,14 +43,14 @@ Value buildTensorLinalgGeneric(OpBuilder &builder, Location loc, ValueRange oper { // Initialize the result tensor FloatType elementType = cast(resultType.getElementType()); - Value zero = builder.create( - loc, elementType, APFloat::getZero(elementType.getFloatSemantics())); + Value zero = arith::ConstantFloatOp::create(builder, loc, elementType, + APFloat::getZero(elementType.getFloatSemantics())); Value result = - builder.create(loc, resultType.getShape(), resultType.getElementType()); - result = builder.create(loc, zero, result).getResult(0); + tensor::EmptyOp::create(builder, loc, resultType.getShape(), resultType.getElementType()); + result = linalg::FillOp::create(builder, loc, zero, result).getResult(0); - auto genericOp = builder.create(loc, resultType, operands, result, - indexingMaps, iteratorTypes, buildBody); + auto genericOp = linalg::GenericOp::create(builder, loc, resultType, operands, result, + indexingMaps, iteratorTypes, buildBody); return genericOp.getResult(0); } @@ -114,9 +114,10 @@ Value einsumLinalgGeneric(OpBuilder &ob, Location loc, ArrayRef axisCod maps); inferIteratorTypes(axisDims, axisCodesResult, iteratorTypes); auto bodyBuilder = [](OpBuilder &builder, Location loc, ValueRange args) { - builder.create( - loc, Value(builder.create( - loc, args[2], builder.create(loc, args[0], args[1])))); + linalg::YieldOp::create( + builder, loc, + Value(arith::AddFOp::create(builder, loc, args[2], + arith::MulFOp::create(builder, loc, args[0], args[1])))); }; if (useBufferSemantics) { diff --git a/mlir/lib/Ion/IR/IonOps.cpp b/mlir/lib/Ion/IR/IonOps.cpp index 8438e4b40c..ecb45760aa 100644 --- a/mlir/lib/Ion/IR/IonOps.cpp +++ b/mlir/lib/Ion/IR/IonOps.cpp @@ -58,7 +58,7 @@ void ParallelProtocolOp::build(OpBuilder &builder, OperationState &result, Value bodyBuilder(builder, loc, bodyBlock->getArguments()); builder.setInsertionPointToEnd(bodyBlock); - builder.create(loc, bodyBlock->getArguments()); + ion::YieldOp::create(builder, loc, bodyBlock->getArguments()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Ion/Transforms/ConversionPatterns.cpp b/mlir/lib/Ion/Transforms/ConversionPatterns.cpp index 23d8643259..668406b97f 100644 --- a/mlir/lib/Ion/Transforms/ConversionPatterns.cpp +++ b/mlir/lib/Ion/Transforms/ConversionPatterns.cpp @@ -54,32 +54,35 @@ Value createBeamStruct(Location loc, OpBuilder &rewriter, MLIRContext *ctx, Beam auto polarization = beamAttr.getPolarization().asArrayRef(); auto wavevector = beamAttr.getWavevector().asArrayRef(); - Value beamStruct = rewriter.create(loc, beamStructType); - beamStruct = rewriter.create( - loc, beamStruct, rewriter.create(loc, transitionIndex), 0); - beamStruct = rewriter.create( - loc, beamStruct, rewriter.create(loc, rabi), 1); - beamStruct = rewriter.create( - loc, beamStruct, rewriter.create(loc, detuning), 2); + Value beamStruct = LLVM::UndefOp::create(rewriter, loc, beamStructType); + beamStruct = LLVM::InsertValueOp::create( + rewriter, loc, beamStruct, LLVM::ConstantOp::create(rewriter, loc, transitionIndex), + SmallVector{0}); + beamStruct = LLVM::InsertValueOp::create(rewriter, loc, beamStruct, + LLVM::ConstantOp::create(rewriter, loc, rabi), + SmallVector{1}); + beamStruct = LLVM::InsertValueOp::create(rewriter, loc, beamStruct, + LLVM::ConstantOp::create(rewriter, loc, detuning), + SmallVector{2}); for (size_t i = 0; i < polarization.size(); i++) { - Value polarizaitonConst = rewriter.create( - loc, rewriter.getI64Type(), + Value polarizaitonConst = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI64Type(), rewriter.getIntegerAttr(rewriter.getI64Type(), polarization[i])); - beamStruct = rewriter.create( - loc, beamStruct, polarizaitonConst, ArrayRef({3, static_cast(i)})); + beamStruct = LLVM::InsertValueOp::create(rewriter, loc, beamStruct, polarizaitonConst, + ArrayRef({3, static_cast(i)})); } for (size_t i = 0; i < wavevector.size(); i++) { - Value waveConst = rewriter.create( - loc, rewriter.getI64Type(), - rewriter.getIntegerAttr(rewriter.getI64Type(), wavevector[i])); - beamStruct = rewriter.create( - loc, beamStruct, waveConst, ArrayRef({4, static_cast(i)})); + Value waveConst = + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), + rewriter.getIntegerAttr(rewriter.getI64Type(), wavevector[i])); + beamStruct = LLVM::InsertValueOp::create(rewriter, loc, beamStruct, waveConst, + ArrayRef({4, static_cast(i)})); } Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - Value c1 = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value beamStructPtr = rewriter.create(loc, /*resultType=*/ptrType, - /*elementType=*/beamStructType, c1); - rewriter.create(loc, beamStruct, beamStructPtr); + Value c1 = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value beamStructPtr = LLVM::AllocaOp::create(rewriter, loc, /*resultType=*/ptrType, + /*elementType=*/beamStructType, c1); + LLVM::StoreOp::create(rewriter, loc, beamStruct, beamStructPtr); return beamStructPtr; } @@ -240,24 +243,24 @@ struct ParallelProtocolOpPattern : public OpConversionPatternconvertType(PulseType::get(ctx)), parallelPulses.size()); - Value pulseArray = rewriter.create(loc, pulseArrayType); + Value pulseArray = LLVM::UndefOp::create(rewriter, loc, pulseArrayType); for (size_t i = 0; i < parallelPulses.size(); i++) { - auto convertedPulse = rewriter - .create( - loc, LLVM::LLVMPointerType::get(ctx), parallelPulses[i]) - .getResult(0); - pulseArray = rewriter.create(loc, pulseArray, convertedPulse, i); + auto convertedPulse = + UnrealizedConversionCastOp::create(rewriter, loc, LLVM::LLVMPointerType::get(ctx), + parallelPulses[i]) + .getResult(0); + pulseArray = LLVM::InsertValueOp::create(rewriter, loc, pulseArray, convertedPulse, i); } Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - Value c1 = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value pulseArrayPtr = rewriter.create(loc, /*resultType=*/ptrType, - /*elementType=*/pulseArrayType, c1); - rewriter.create(loc, pulseArray, pulseArrayPtr); + Value c1 = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value pulseArrayPtr = LLVM::AllocaOp::create(rewriter, loc, /*resultType=*/ptrType, + /*elementType=*/pulseArrayType, c1); + LLVM::StoreOp::create(rewriter, loc, pulseArray, pulseArrayPtr); - Value pulseArraySize = rewriter.create( - loc, rewriter.getI64IntegerAttr(parallelPulses.size())); + Value pulseArraySize = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(parallelPulses.size())); SmallVector operands; operands.push_back(pulseArrayPtr); operands.push_back(pulseArraySize); @@ -268,7 +271,7 @@ struct ParallelProtocolOpPattern : public OpConversionPattern( rewriter, op, protocolFuncName, protocolFuncType); - rewriter.create(loc, protocolFnDecl, operands); + LLVM::CallOp::create(rewriter, loc, protocolFnDecl, operands); SmallVector values; values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); @@ -289,7 +292,7 @@ struct PulseOpPattern : public OpConversionPattern { const TypeConverter *conv = getTypeConverter(); auto time = op.getTime(); - auto phase = rewriter.create(loc, op.getPhase()); + auto phase = LLVM::ConstantOp::create(rewriter, loc, op.getPhase()); Type qubitTy = conv->convertType(catalyst::ion::QubitType::get(ctx)); auto inQubit = adaptor.getInQubit(); auto beamAttr = op.getBeam(); diff --git a/mlir/lib/Ion/Transforms/GatesToPulsesPatterns.cpp b/mlir/lib/Ion/Transforms/GatesToPulsesPatterns.cpp index 7838da1314..8f8e0084d2 100644 --- a/mlir/lib/Ion/Transforms/GatesToPulsesPatterns.cpp +++ b/mlir/lib/Ion/Transforms/GatesToPulsesPatterns.cpp @@ -67,7 +67,7 @@ std::optional> convertQuantumBitsToIonQubits(mlir::PatternRew } auto ionQubitType = ion::QubitType::get(ctx); Value ionQubit = - rewriter.create(loc, ionQubitType, qubit).getResult(0); + UnrealizedConversionCastOp::create(rewriter, loc, ionQubitType, qubit).getResult(0); ionQubits.push_back(ionQubit); } return ionQubits; @@ -97,7 +97,7 @@ std::optional> convertIonQubitsToQuantumBits(mlir::PatternRew } auto qubitType = quantum::QubitType::get(ctx); Value qubit = - rewriter.create(loc, qubitType, ionQubit).getResult(0); + UnrealizedConversionCastOp::create(rewriter, loc, qubitType, ionQubit).getResult(0); qubits.push_back(qubit); } return qubits; @@ -200,32 +200,31 @@ mlir::Value CreateNormalizedAngle(mlir::PatternRewriter &rewriter, mlir::Locatio constexpr double FOUR_PI = 4.0 * PI; auto four_pi_attr = rewriter.getF64FloatAttr(FOUR_PI); - auto four_pi_const = rewriter.create(loc, angle.getType(), four_pi_attr); + auto four_pi_const = arith::ConstantOp::create(rewriter, loc, angle.getType(), four_pi_attr); // Find angle fmod 4pi. mlir::Value remainder = - rewriter.create(loc, angle.getType(), angle, four_pi_const); + arith::RemFOp::create(rewriter, loc, angle.getType(), angle, four_pi_const); // Find if the remainder is less than 0. auto zero_attr = rewriter.getZeroAttr(angle.getType()); - auto zero_const = rewriter.create(loc, angle.getType(), zero_attr); - auto less_than_zero = rewriter.create(loc, arith::CmpFPredicate::OLT, remainder, - zero_const); // Signed less than + auto zero_const = arith::ConstantOp::create(rewriter, loc, angle.getType(), zero_attr); + auto less_than_zero = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OLT, remainder, + zero_const); // Signed less than // Create a conditional add (if remainder < 0, add 4*PI) - auto normalized_angle = - rewriter - .create( - loc, less_than_zero, - [&](OpBuilder &builder, Location loc) { // then - mlir::Value add_op = rewriter.create( - loc, angle.getType(), remainder, four_pi_const); // or AddIOp for integers - builder.create(loc, add_op); - }, - [&](OpBuilder &builder, Location loc) { // else - builder.create(loc, remainder); - }) - .getResult(0); + auto normalized_angle = scf::IfOp::create( + rewriter, loc, less_than_zero, + [&](OpBuilder &builder, Location loc) { // then + mlir::Value add_op = arith::AddFOp::create( + rewriter, loc, angle.getType(), remainder, + four_pi_const); // or AddIOp for integers + scf::YieldOp::create(builder, loc, add_op); + }, + [&](OpBuilder &builder, Location loc) { // else + scf::YieldOp::create(builder, loc, remainder); + }) + .getResult(0); return normalized_angle; } @@ -253,15 +252,15 @@ mlir::Value computePulseDuration(mlir::PatternRewriter &rewriter, mlir::Location { auto normalizedAngle = CreateNormalizedAngle(rewriter, loc, angle); TypedAttr rabiAttr = rewriter.getF64FloatAttr(rabi); - mlir::Value rabiValue = rewriter.create(loc, rabiAttr).getResult(); + mlir::Value rabiValue = arith::ConstantOp::create(rewriter, loc, rabiAttr).getResult(); TypedAttr detuningTimesTwoAttr = rewriter.getF64FloatAttr(detuning * 2); mlir::Value detuningTimesTwoValue = - rewriter.create(loc, detuningTimesTwoAttr).getResult(); + arith::ConstantOp::create(rewriter, loc, detuningTimesTwoAttr).getResult(); mlir::Value detuningTimesTwoTimesAngle = - rewriter.create(loc, normalizedAngle, detuningTimesTwoValue); - mlir::Value rabiSquared = rewriter.create(loc, rabiValue, rabiValue); + arith::MulFOp::create(rewriter, loc, normalizedAngle, detuningTimesTwoValue); + mlir::Value rabiSquared = arith::MulFOp::create(rewriter, loc, rabiValue, rabiValue); mlir::Value duration = - rewriter.create(loc, detuningTimesTwoTimesAngle, rabiSquared); + arith::DivFOp::create(rewriter, loc, detuningTimesTwoTimesAngle, rabiSquared); return duration; } @@ -302,15 +301,16 @@ mlir::LogicalResult oneQubitGateToPulse(CustomOp op, mlir::PatternRewriter &rewr return failure(); } - auto ppOp = rewriter.create( - loc, ionQubits.value(), [&](OpBuilder &builder, Location loc, ValueRange qubits) { + auto ppOp = ion::ParallelProtocolOp::create( + rewriter, loc, ionQubits.value(), + [&](OpBuilder &builder, Location loc, ValueRange qubits) { mlir::FloatAttr phase1Attr = builder.getF64FloatAttr(phase1); mlir::FloatAttr phase2Attr = builder.getF64FloatAttr(phase2); auto qubit = qubits.front(); - builder.create(loc, PulseType::get(ctx), time, qubit, beam0toEAttr, - phase1Attr); - builder.create(loc, PulseType::get(ctx), time, qubit, beam1toEAttr, - phase2Attr); + ion::PulseOp::create(builder, loc, PulseType::get(ctx), time, qubit, beam0toEAttr, + phase1Attr); + ion::PulseOp::create(builder, loc, PulseType::get(ctx), time, qubit, beam1toEAttr, + phase2Attr); }); // Convert ion.qubit back to quantum.bit @@ -406,8 +406,9 @@ mlir::LogicalResult MSGateToPulse(CustomOp op, mlir::PatternRewriter &rewriter, return failure(); } - auto ppOp = rewriter.create( - loc, ionQubits.value(), [&](OpBuilder &builder, Location loc, ValueRange qubits) { + auto ppOp = ion::ParallelProtocolOp::create( + rewriter, loc, ionQubits.value(), + [&](OpBuilder &builder, Location loc, ValueRange qubits) { mlir::FloatAttr phase0Attr = builder.getF64FloatAttr(0.0); auto qubit0 = qubits.front(); auto qubit1 = qubits.back(); @@ -435,8 +436,8 @@ mlir::LogicalResult MSGateToPulse(CustomOp op, mlir::PatternRewriter &rewriter, rewriter.getF64FloatAttr(beam.detuning), rewriter.getDenseI64ArrayAttr(beam.polarization), rewriter.getDenseI64ArrayAttr(beam.wavevector)); - builder.create(loc, PulseType::get(ctx), time, qubit0, beam1Attr, - phase0Attr); + ion::PulseOp::create(builder, loc, PulseType::get(ctx), time, qubit0, beam1Attr, + phase0Attr); // Pulse2( // transition=Transition(level1=1,level2=e), @@ -458,8 +459,8 @@ mlir::LogicalResult MSGateToPulse(CustomOp op, mlir::PatternRewriter &rewriter, rewriter.getF64FloatAttr(beam.detuning + phonon0ComX.energy), rewriter.getDenseI64ArrayAttr(beam.polarization), rewriter.getDenseI64ArrayAttr(flipSign(beam.wavevector))); - builder.create(loc, PulseType::get(ctx), time, qubit0, beam2Attr, - phase0Attr); + ion::PulseOp::create(builder, loc, PulseType::get(ctx), time, qubit0, beam2Attr, + phase0Attr); // Pulse3( // transition=Transition(level1=1,level2=e), @@ -480,8 +481,8 @@ mlir::LogicalResult MSGateToPulse(CustomOp op, mlir::PatternRewriter &rewriter, rewriter.getF64FloatAttr(beam.detuning - phonon0ComX.energy), rewriter.getDenseI64ArrayAttr(beam.polarization), rewriter.getDenseI64ArrayAttr(flipSign(beam.wavevector))); - builder.create(loc, PulseType::get(ctx), time, qubit0, beam3Attr, - phase0Attr); + ion::PulseOp::create(builder, loc, PulseType::get(ctx), time, qubit0, beam3Attr, + phase0Attr); // Pulse4( // transition=Transition(level1=0,level2=e), @@ -500,8 +501,8 @@ mlir::LogicalResult MSGateToPulse(CustomOp op, mlir::PatternRewriter &rewriter, rewriter.getF64FloatAttr(beam.detuning), rewriter.getDenseI64ArrayAttr(beam.polarization), rewriter.getDenseI64ArrayAttr(beam.wavevector)); - builder.create(loc, PulseType::get(ctx), time, qubit1, beam4Attr, - phase0Attr); + ion::PulseOp::create(builder, loc, PulseType::get(ctx), time, qubit1, beam4Attr, + phase0Attr); // Pulse5( // transition=Transition(level1=1,level2=e), @@ -523,8 +524,8 @@ mlir::LogicalResult MSGateToPulse(CustomOp op, mlir::PatternRewriter &rewriter, rewriter.getF64FloatAttr(beam.detuning + phonon1ComX.energy), rewriter.getDenseI64ArrayAttr(beam.polarization), rewriter.getDenseI64ArrayAttr(flipSign(beam.wavevector))); - builder.create(loc, PulseType::get(ctx), time, qubit1, beam5Attr, - phase0Attr); + ion::PulseOp::create(builder, loc, PulseType::get(ctx), time, qubit1, beam5Attr, + phase0Attr); // Pulse6( // transition=Transition(level1=1,level2=e), @@ -546,8 +547,8 @@ mlir::LogicalResult MSGateToPulse(CustomOp op, mlir::PatternRewriter &rewriter, rewriter.getF64FloatAttr(beam.detuning - phonon1ComX.energy), rewriter.getDenseI64ArrayAttr(beam.polarization), rewriter.getDenseI64ArrayAttr(flipSign(beam.wavevector))); - builder.create(loc, PulseType::get(ctx), time, qubit1, beam6Attr, - phase0Attr); + ion::PulseOp::create(builder, loc, PulseType::get(ctx), time, qubit1, beam6Attr, + phase0Attr); }); // Convert ion.qubit back to quantum.bit diff --git a/mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp b/mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp index 69784cf2ed..9706091062 100644 --- a/mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp +++ b/mlir/lib/Ion/Transforms/IonToRTIOPatterns.cpp @@ -42,7 +42,7 @@ Value awaitEvents(ArrayRef events, PatternRewriter &rewriter) return events.front(); } auto eventType = rtio::EventType::get(rewriter.getContext()); - return rewriter.create(rewriter.getUnknownLoc(), eventType, events); + return rtio::RTIOSyncOp::create(rewriter, rewriter.getUnknownLoc(), eventType, events); } //===----------------------------------------------------------------------===// @@ -100,7 +100,7 @@ struct ParallelProtocolToRTIOPattern : public OpConversionPattern(loc, eventType, qubit).getResult(0); + return UnrealizedConversionCastOp::create(rewriter, loc, eventType, qubit).getResult(0); }); Value inputSyncEvent = awaitEvents(llvm::to_vector(events), rewriter); @@ -167,8 +167,8 @@ struct ParallelProtocolToRTIOPattern : public OpConversionPattern results; for (Value result : op.getResults()) { // unrealized conversion cast sync event to result type - auto event = - rewriter.create(loc, result.getType(), outputSyncEvent); + auto event = UnrealizedConversionCastOp::create(rewriter, loc, result.getType(), + outputSyncEvent); results.push_back(event.getResult(0)); } @@ -235,8 +235,9 @@ struct PulseToRTIOPattern : public OpConversionPattern { int64_t transitionIndex = beamAttr.getTransitionIndex().getInt(); double frequency = calculateFrequency(transitionIndex, detuning, ionInfo); Value freqValue = - rewriter.create(loc, rewriter.getF64FloatAttr(frequency)); - Value phaseValue = rewriter.create(loc, rewriter.getF64FloatAttr(phase)); + arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(frequency)); + Value phaseValue = + arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(phase)); // Convert the qubit to a channel ArrayAttr qualifiers = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(transitionIndex)}); @@ -257,12 +258,12 @@ struct PulseToRTIOPattern : public OpConversionPattern { } Value channel = - rewriter.create(loc, channelType, memrefLoadValue); + rtio::RTIOQubitToChannelOp::create(rewriter, loc, channelType, memrefLoadValue); // Create rtio.pulse auto eventType = rtio::EventType::get(ctx); - Value event = rewriter.create(loc, eventType, channel, duration, - freqValue, phaseValue, nullptr); + Value event = rtio::RTIOPulseOp::create(rewriter, loc, eventType, channel, duration, + freqValue, phaseValue, nullptr); rewriter.replaceOp(op, event); return success(); @@ -363,7 +364,7 @@ struct ResolveChannelMappingPattern : public OpRewritePattern(loc, resolvedChannelType); + Value channel = rtio::RTIOChannelOp::create(rewriter, loc, resolvedChannelType); rewriter.replaceOp(op, channel); @@ -431,7 +432,7 @@ struct PropagateEventsPattern : public OpRewritePattern(op.getLoc(), eventType); + Value emptyEvent = rtio::RTIOEmptyOp::create(rewriter, op.getLoc(), eventType); rewriter.replaceOp(op, emptyEvent); return success(); } diff --git a/mlir/lib/Ion/Transforms/gates_to_pulses.cpp b/mlir/lib/Ion/Transforms/gates_to_pulses.cpp index 9be39a6423..977c1dce92 100644 --- a/mlir/lib/Ion/Transforms/gates_to_pulses.cpp +++ b/mlir/lib/Ion/Transforms/gates_to_pulses.cpp @@ -106,10 +106,10 @@ struct GatesToPulsesPass : impl::GatesToPulsesPassBase { } builder.setInsertionPointToStart(&(op->getRegion(0).front())); - builder.create( - op->getLoc(), IonType::get(ctx), builder.getStringAttr(ion.name), - builder.getF64FloatAttr(ion.mass), builder.getF64FloatAttr(ion.charge), - ion.position, builder.getArrayAttr(levels), builder.getArrayAttr(transitions)); + ion::IonOp::create(builder, op->getLoc(), IonType::get(ctx), + builder.getStringAttr(ion.name), builder.getF64FloatAttr(ion.mass), + builder.getF64FloatAttr(ion.charge), ion.position, + builder.getArrayAttr(levels), builder.getArrayAttr(transitions)); SmallVector phonons; for (const Phonon &phonon : dataManager.getPhononParams()) { @@ -117,7 +117,7 @@ struct GatesToPulsesPass : impl::GatesToPulsesPassBase { } // TODO: For now, we only print one phonon to be consistent with TriCal examples, // but we should print all of them eventually - builder.create(op->getLoc(), builder.getArrayAttr(phonons[0])); + ion::ModesOp::create(builder, op->getLoc(), builder.getArrayAttr(phonons[0])); } RewritePatternSet ionPatterns(&getContext()); diff --git a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp index 9ace474ced..c0a8d9f000 100644 --- a/mlir/lib/Ion/Transforms/ion-to-rtio.cpp +++ b/mlir/lib/Ion/Transforms/ion-to-rtio.cpp @@ -145,7 +145,7 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { return nullptr; Type inputType = inputs.front().getType(); if (inputType != resultType) { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); } return inputs[0]; @@ -157,7 +157,7 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { return nullptr; Type inputType = inputs.front().getType(); if (inputType != resultType) { - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); } return inputs[0]; @@ -304,8 +304,8 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { // Get the global memref in the function builder.setInsertionPointAfter(allocOp); - Value qubitMap = builder.create(allocOp.getLoc(), memrefType, - globalOp.getSymName()); + Value qubitMap = memref::GetGlobalOp::create(builder, allocOp.getLoc(), memrefType, + globalOp.getSymName()); qregToMemrefMap[allocOp.getResult()] = qubitMap; }); @@ -320,16 +320,16 @@ struct IonToRTIOPass : public impl::IonToRTIOPassBase { Value memrefLoadValue = nullptr; if (Value idx = extractOp.getIdx()) { // idx is an operand (i64), need to cast to index - Value indexValue = builder.create( - extractOp.getLoc(), builder.getIndexType(), idx); - memrefLoadValue = builder.create( - extractOp.getLoc(), memref, ValueRange{indexValue}); + Value indexValue = arith::IndexCastOp::create( + builder, extractOp.getLoc(), builder.getIndexType(), idx); + memrefLoadValue = memref::LoadOp::create( + builder, extractOp.getLoc(), memref, ValueRange{indexValue}); } else if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) { - Value indexValue = builder.create( - extractOp.getLoc(), idxAttr.getInt()); - memrefLoadValue = builder.create( - extractOp.getLoc(), memref, ValueRange{indexValue}); + Value indexValue = arith::ConstantIndexOp::create( + builder, extractOp.getLoc(), idxAttr.getInt()); + memrefLoadValue = memref::LoadOp::create( + builder, extractOp.getLoc(), memref, ValueRange{indexValue}); } if (memrefLoadValue) { qextractToMemrefMap[extractOp.getResult()] = memrefLoadValue; diff --git a/mlir/lib/MBQC/Transforms/ConversionPatterns.cpp b/mlir/lib/MBQC/Transforms/ConversionPatterns.cpp index 95123acf22..91c1f91782 100644 --- a/mlir/lib/MBQC/Transforms/ConversionPatterns.cpp +++ b/mlir/lib/MBQC/Transforms/ConversionPatterns.cpp @@ -61,18 +61,19 @@ struct MeasureInBasisOpPattern : public OpConversionPattern { // Extract the integer value for the plane attribute from its enum const auto planeValueInt = static_cast(op.getPlane()); Value planeValue = - rewriter.create(loc, rewriter.getI32IntegerAttr(planeValueInt)); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(planeValueInt)); // Create the postselect value. If not given, it defaults to NO_POSTSELECT - LLVM::ConstantOp postselect = rewriter.create( - loc, op.getPostselect() ? op.getPostselectAttr() - : rewriter.getI32IntegerAttr(NO_POSTSELECT)); + LLVM::ConstantOp postselect = LLVM::ConstantOp::create( + rewriter, loc, + op.getPostselect() ? op.getPostselectAttr() + : rewriter.getI32IntegerAttr(NO_POSTSELECT)); // Add values as arguments of the CallOp SmallVector args = {adaptor.getInQubit(), planeValue, op.getAngle(), postselect}; - Value resultPtr = rewriter.create(loc, fnDecl, args).getResult(); - Value mres = rewriter.create(loc, IntegerType::get(ctx, 1), resultPtr); + Value resultPtr = LLVM::CallOp::create(rewriter, loc, fnDecl, args).getResult(); + Value mres = LLVM::LoadOp::create(rewriter, loc, IntegerType::get(ctx, 1), resultPtr); rewriter.replaceOp(op, {mres, adaptor.getInQubit()}); return success(); diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 7d8d06bc56..2463b6d087 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -63,7 +63,7 @@ func::FuncOp createZneFunc(func::FuncOp funcOp, PatternRewriter &rewriter) /*outputs=*/funcOp.getResultTypes()); std::string fnFoldedName = funcOp.getName().str() + ".zne"; rewriter.setInsertionPointToStart(funcOp->getParentOfType().getBody()); - auto fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); + auto fnFoldedOp = func::FuncOp::create(rewriter, loc, fnFoldedName, fnFoldedType); rewriter.cloneRegionBefore(funcOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); @@ -156,71 +156,68 @@ LogicalResult ZneLowering::matchAndRewrite(mitigation::ZneOp op, PatternRewriter RankedTensorType resultType = cast(op.getResultTypes().front()); // Loop over the num fold to create a folded circuit per factor - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value size = rewriter.create(loc, sizeInt); + Value c0 = index::ConstantOp::create(rewriter, loc, 0); + Value c1 = index::ConstantOp::create(rewriter, loc, 1); + Value size = index::ConstantOp::create(rewriter, loc, sizeInt); // Initialize the results as empty tensor Value results = - rewriter.create(loc, resultType.getShape(), resultType.getElementType()); + tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()); Value resultValues = - rewriter - .create( - loc, c0, size, c1, /*iterArgsInit=*/results, - [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - std::vector newArgs(op.getArgs().begin(), op.getArgs().end()); - SmallVector index = {i}; - Value numFold = builder.create(loc, numFolds, index); - Value numFoldCasted = - builder.create(loc, builder.getIndexType(), numFold); - newArgs.push_back(numFoldCasted); - func::CallOp callOp = builder.create(loc, fnFoldedOp, newArgs); - - int64_t numResults = callOp.getNumResults(); - - // Measurements - ValueRange resultValuesMulti = callOp.getResults(); - SmallVector vectorResultsMulti; - // Create a tensor - for (Value resultValue : resultValuesMulti) { - Value resultExtracted; - if (isa(resultValue.getType())) { - resultExtracted = builder.create(loc, resultValue); - } - else { - resultExtracted = resultValue; - } - vectorResultsMulti.push_back(resultExtracted); + scf::ForOp::create( + rewriter, loc, c0, size, c1, /*iterArgsInit=*/results, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + std::vector newArgs(op.getArgs().begin(), op.getArgs().end()); + SmallVector index = {i}; + Value numFold = tensor::ExtractOp::create(builder, loc, numFolds, index); + Value numFoldCasted = + index::CastSOp::create(builder, loc, builder.getIndexType(), numFold); + newArgs.push_back(numFoldCasted); + func::CallOp callOp = func::CallOp::create(builder, loc, fnFoldedOp, newArgs); + + int64_t numResults = callOp.getNumResults(); + + // Measurements + ValueRange resultValuesMulti = callOp.getResults(); + SmallVector vectorResultsMulti; + // Create a tensor + for (Value resultValue : resultValuesMulti) { + Value resultExtracted; + if (isa(resultValue.getType())) { + resultExtracted = tensor::ExtractOp::create(builder, loc, resultValue); } - SmallVector resShape = {numResults}; - Type type = RankedTensorType::get(resShape, vectorResultsMulti[0].getType()); - auto tensorResults = - builder.create(loc, type, vectorResultsMulti); - Value sizeResultsValue = rewriter.create(loc, numResults); - Value resultValuesFor = - rewriter - .create( - loc, c0, sizeResultsValue, c1, - /*iterArgsInit=*/iterArgs.front(), - [&](OpBuilder &builder, Location loc, Value j, - ValueRange iterArgsIn) { - Value resultExtracted = - builder.create(loc, tensorResults, j); - SmallVector indices; - if (numResults == 1) { - indices = {i}; - } - else { - indices = {i, j}; - } - Value resultInserted = builder.create( - loc, resultExtracted, iterArgsIn.front(), indices); - - builder.create(loc, resultInserted); - }) - .getResult(0); - builder.create(loc, resultValuesFor); - }) + else { + resultExtracted = resultValue; + } + vectorResultsMulti.push_back(resultExtracted); + } + SmallVector resShape = {numResults}; + Type type = RankedTensorType::get(resShape, vectorResultsMulti[0].getType()); + auto tensorResults = + tensor::FromElementsOp::create(builder, loc, type, vectorResultsMulti); + Value sizeResultsValue = index::ConstantOp::create(rewriter, loc, numResults); + Value resultValuesFor = + scf::ForOp::create( + rewriter, loc, c0, sizeResultsValue, c1, + /*iterArgsInit=*/iterArgs.front(), + [&](OpBuilder &builder, Location loc, Value j, ValueRange iterArgsIn) { + Value resultExtracted = + tensor::ExtractOp::create(builder, loc, tensorResults, j); + SmallVector indices; + if (numResults == 1) { + indices = {i}; + } + else { + indices = {i, j}; + } + Value resultInserted = tensor::InsertOp::create( + builder, loc, resultExtracted, iterArgsIn.front(), indices); + + scf::YieldOp::create(builder, loc, resultInserted); + }) + .getResult(0); + scf::YieldOp::create(builder, loc, resultValuesFor); + }) .getResult(0); // Replace the original results rewriter.replaceOp(op, resultValues); @@ -242,10 +239,10 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st rewriter.setInsertionPointToStart(fnFoldedOp.addEntryBlock()); // Loop control variables - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); + Value c0 = index::ConstantOp::create(rewriter, loc, 0); + Value c1 = index::ConstantOp::create(rewriter, loc, 1); TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); - Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); + Value numberQubitsValue = arith::ConstantOp::create(rewriter, loc, numberQubitsAttr); // TODO: in the frontend, calculation of shots will happen outside of the qnode, // before qml.device(..., shots = ) is called, @@ -255,58 +252,57 @@ FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::st Operation *shotsLocal = shots->clone(); rewriter.insert(shotsLocal); - rewriter.create(loc, shotsLocal->getResult(0), lib, name, kwargs); + quantum::DeviceInitOp::create(rewriter, loc, shotsLocal->getResult(0), lib, name, kwargs); - Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); + Value allocQreg = + func::CallOp::create(rewriter, loc, fnAllocOp, numberQubitsValue).getResult(0); int64_t sizeArgs = fnFoldedOp.getArguments().size(); Value size = fnFoldedOp.getArgument(sizeArgs - 1); // Add scf for loop to create the folding Value loopedQreg = - rewriter - .create( - loc, c0, size, c1, /*iterArgsInit=*/allocQreg, - [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - Value qreg = iterArgs.front(); - std::vector argsAndQreg(fnFoldedOp.getArguments().begin(), - fnFoldedOp.getArguments().end()); - argsAndQreg.pop_back(); - argsAndQreg.push_back(qreg); - - // Call the function without measurements - Value fnWithoutMeasurementsQreg = - builder.create(loc, fnWithoutMeasurementsOp, argsAndQreg) - .getResult(0); - - // Call the function without measurements in an adjoint region - auto adjointOp = builder.create(loc, qregType, - fnWithoutMeasurementsQreg); - Region *adjointRegion = &adjointOp.getRegion(); - Block *adjointBlock = builder.createBlock(adjointRegion, {}, qregType, loc); - - std::vector argsAndQregAdjoint(fnFoldedOp.getArguments().begin(), - fnFoldedOp.getArguments().end()); - argsAndQregAdjoint.pop_back(); - argsAndQregAdjoint.push_back(adjointBlock->getArgument(0)); - Value fnWithoutMeasurementsAdjointQreg = - builder - .create(loc, fnWithoutMeasurementsOp, argsAndQregAdjoint) - .getResult(0); - builder.create(loc, fnWithoutMeasurementsAdjointQreg); - builder.setInsertionPointAfter(adjointOp); - builder.create(loc, adjointOp.getResult()); - }) + scf::ForOp::create( + rewriter, loc, c0, size, c1, /*iterArgsInit=*/allocQreg, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + Value qreg = iterArgs.front(); + std::vector argsAndQreg(fnFoldedOp.getArguments().begin(), + fnFoldedOp.getArguments().end()); + argsAndQreg.pop_back(); + argsAndQreg.push_back(qreg); + + // Call the function without measurements + Value fnWithoutMeasurementsQreg = + func::CallOp::create(builder, loc, fnWithoutMeasurementsOp, argsAndQreg) + .getResult(0); + + // Call the function without measurements in an adjoint region + auto adjointOp = + quantum::AdjointOp::create(builder, loc, qregType, fnWithoutMeasurementsQreg); + Region *adjointRegion = &adjointOp.getRegion(); + Block *adjointBlock = builder.createBlock(adjointRegion, {}, qregType, loc); + + std::vector argsAndQregAdjoint(fnFoldedOp.getArguments().begin(), + fnFoldedOp.getArguments().end()); + argsAndQregAdjoint.pop_back(); + argsAndQregAdjoint.push_back(adjointBlock->getArgument(0)); + Value fnWithoutMeasurementsAdjointQreg = + func::CallOp::create(builder, loc, fnWithoutMeasurementsOp, argsAndQregAdjoint) + .getResult(0); + quantum::YieldOp::create(builder, loc, fnWithoutMeasurementsAdjointQreg); + builder.setInsertionPointAfter(adjointOp); + scf::YieldOp::create(builder, loc, adjointOp.getResult()); + }) .getResult(0); std::vector argsAndRegMeasurement(fnFoldedOp.getArguments().begin(), fnFoldedOp.getArguments().end()); argsAndRegMeasurement.pop_back(); argsAndRegMeasurement.push_back(loopedQreg); ValueRange funcFolded = - rewriter.create(loc, fnWithMeasurementsOp, argsAndRegMeasurement) + func::CallOp::create(rewriter, loc, fnWithMeasurementsOp, argsAndRegMeasurement) .getResults(); // Remove device - rewriter.create(loc); - rewriter.create(loc, funcFolded); + quantum::DeviceReleaseOp::create(rewriter, loc); + func::ReturnOp::create(rewriter, loc, funcFolded); return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); } // In *.cpp module only, to keep extraneous headers out of *.hpp @@ -335,25 +331,23 @@ FlatSymbolRefAttr allLocalFolding(PatternRewriter &rewriter, std::string fnFolde // Insert a for loop immediately before each quantum::QuantumGate const auto forVal = - rewriter - .create( - loc, c0, size, c1, /*iterArgsInit=*/opQubitArgs, - [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - // Create adjoint and original operations - quantum::QuantumGate origOp = - dyn_cast(builder.clone(*op)); - origOp.setQubitOperands(iterArgs); - auto origOpVal = origOp->getResults(); - - quantum::QuantumGate adjointOp = - dyn_cast(builder.clone(*origOp)); - adjointOp.setQubitOperands(origOpVal); - adjointOp.setAdjointFlag(!adjointOp.getAdjointFlag()); - auto adjointOpVal = adjointOp->getResults(); - - // Yield the qubits. - builder.create(loc, adjointOpVal); - }) + scf::ForOp::create(rewriter, loc, c0, size, c1, /*iterArgsInit=*/opQubitArgs, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + // Create adjoint and original operations + quantum::QuantumGate origOp = + dyn_cast(builder.clone(*op)); + origOp.setQubitOperands(iterArgs); + auto origOpVal = origOp->getResults(); + + quantum::QuantumGate adjointOp = + dyn_cast(builder.clone(*origOp)); + adjointOp.setQubitOperands(origOpVal); + adjointOp.setAdjointFlag(!adjointOp.getAdjointFlag()); + auto adjointOpVal = adjointOp->getResults(); + + // Yield the qubits. + scf::YieldOp::create(builder, loc, adjointOpVal); + }) .getResults(); op.setQubitOperands(forVal); @@ -405,7 +399,7 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew typesFolded, /*outputs=*/fnOp.getResultTypes()); - func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); + func::FuncOp fnFoldedOp = func::FuncOp::create(rewriter, loc, fnFoldedName, fnFoldedType); fnFoldedOp.setPrivate(); if (foldingAlgorithm == Folding(1)) { // Quantum Alloc function @@ -436,8 +430,8 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew Block *fnFoldedOpBlock = &fnFoldedOp.getBody().front(); rewriter.setInsertionPointToStart(fnFoldedOpBlock); // Loop control variables - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); + Value c0 = index::ConstantOp::create(rewriter, loc, 0); + Value c1 = index::ConstantOp::create(rewriter, loc, 1); fnFoldedOpBlock->addArgument(fnFoldedOp.getArgumentTypes().back(), loc); @@ -465,14 +459,14 @@ FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewr FunctionType fnAllocType = FunctionType::get(ctx, /*inputs=*/ i64Type, /*outputs=*/qregType); - func::FuncOp fnAlloc = rewriter.create(loc, fnAllocName, fnAllocType); + func::FuncOp fnAlloc = func::FuncOp::create(rewriter, loc, fnAllocName, fnAllocType); fnAlloc.setPrivate(); Block *allocBloc = fnAlloc.addEntryBlock(); rewriter.setInsertionPointToStart(allocBloc); Value nQubits = allocBloc->getArgument(0); IntegerAttr intAttr{}; - auto qreg = rewriter.create(loc, qregType, nQubits, intAttr); - rewriter.create(loc, qreg.getResult()); + auto qreg = quantum::AllocOp::create(rewriter, loc, qregType, nQubits, intAttr); + func::ReturnOp::create(rewriter, loc, qreg.getResult()); return SymbolRefAttr::get(ctx, fnAllocName); } FlatSymbolRefAttr ZneLowering::getOrInsertFnWithoutMeasurements(Location loc, @@ -497,7 +491,7 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFnWithoutMeasurements(Location loc, typesWithoutMeasurements, /*outputs=*/qregType); func::FuncOp fnWithoutMeasurementsOp = - rewriter.create(loc, fnWithoutMeasurementsName, fnWithoutMeasurementsType); + func::FuncOp::create(rewriter, loc, fnWithoutMeasurementsName, fnWithoutMeasurementsType); fnWithoutMeasurementsOp.setPrivate(); rewriter.cloneRegionBefore(fnOp.getBody(), fnWithoutMeasurementsOp.getBody(), fnWithoutMeasurementsOp.end()); @@ -551,7 +545,7 @@ ZneLowering::getOrInsertFnWithMeasurements(Location loc, PatternRewriter &rewrit typesWithQreg, /*outputs=*/fnOp.getResultTypes()); func::FuncOp fnWithMeasurementsOp = - rewriter.create(loc, fnWithMeasurementsName, fnWithMeasurementsType); + func::FuncOp::create(rewriter, loc, fnWithMeasurementsName, fnWithMeasurementsType); fnWithMeasurementsOp.setPrivate(); rewriter.cloneRegionBefore(fnOp.getBody(), fnWithMeasurementsOp.getBody(), fnWithMeasurementsOp.end()); diff --git a/mlir/lib/PauliFrame/Transforms/CliffordTToPauliFramePatterns.cpp b/mlir/lib/PauliFrame/Transforms/CliffordTToPauliFramePatterns.cpp index 4aaebf532c..5d137c4eb1 100644 --- a/mlir/lib/PauliFrame/Transforms/CliffordTToPauliFramePatterns.cpp +++ b/mlir/lib/PauliFrame/Transforms/CliffordTToPauliFramePatterns.cpp @@ -69,11 +69,11 @@ enum class GateEnum { I, X, Y, Z, H, S, T, CNOT, Unknown }; GateEnum hashGate(CustomOp op) { return llvm::StringSwitch(op.getGateName()) - .Cases("Identity", "I", GateEnum::I) - .Cases("PauliX", "X", GateEnum::X) - .Cases("PauliY", "Y", GateEnum::Y) - .Cases("PauliZ", "Z", GateEnum::Z) - .Cases("H", "Hadamard", GateEnum::H) + .Cases({"Identity", "I"}, GateEnum::I) + .Cases({"PauliX", "X"}, GateEnum::X) + .Cases({"PauliY", "Y"}, GateEnum::Y) + .Cases({"PauliZ", "Z"}, GateEnum::Z) + .Cases({"H", "Hadamard"}, GateEnum::H) .Case("S", GateEnum::S) .Case("T", GateEnum::T) .Case("CNOT", GateEnum::CNOT) @@ -84,26 +84,26 @@ GateEnum hashGate(CustomOp op) // Applies the gates in the order X -> Z and returns the output qubit of the Z gate. OpResult insertPauliOpsAfterFlush(PatternRewriter &rewriter, Location loc, FlushOp flushOp) { - auto pauliXIfOp = rewriter.create( - loc, flushOp.getXParity(), + auto pauliXIfOp = scf::IfOp::create( + rewriter, loc, flushOp.getXParity(), [&](OpBuilder &builder, Location loc) { // then - auto pauliX = rewriter.create(loc, "X", flushOp.getOutQubit()); - builder.create(loc, pauliX.getOutQubits()); + auto pauliX = CustomOp::create(rewriter, loc, "X", flushOp.getOutQubit()); + scf::YieldOp::create(builder, loc, pauliX.getOutQubits()); }, [&](OpBuilder &builder, Location loc) { // else - builder.create(loc, flushOp.getOutQubit()); + scf::YieldOp::create(builder, loc, flushOp.getOutQubit()); }); auto pauliXOutQubit = pauliXIfOp->getResult(0); - auto pauliZIfOp = rewriter.create( - loc, flushOp.getZParity(), + auto pauliZIfOp = scf::IfOp::create( + rewriter, loc, flushOp.getZParity(), [&](OpBuilder &builder, Location loc) { // then - auto pauliZ = rewriter.create(loc, "Z", pauliXOutQubit); - builder.create(loc, pauliZ.getOutQubits()); + auto pauliZ = CustomOp::create(rewriter, loc, "Z", pauliXOutQubit); + scf::YieldOp::create(builder, loc, pauliZ.getOutQubits()); }, [&](OpBuilder &builder, Location loc) { // else - builder.create(loc, pauliXOutQubit); + scf::YieldOp::create(builder, loc, pauliXOutQubit); }); return pauliZIfOp->getResult(0); @@ -137,8 +137,8 @@ LogicalResult convertPauliGate(CustomOp op, PatternRewriter &rewriter, bool x_pa auto inQubits = op.getInQubits(); UpdateOp updateOp = - rewriter.create(loc, outQubitTypes, rewriter.getBoolAttr(x_parity), - rewriter.getBoolAttr(z_parity), inQubits); + UpdateOp::create(rewriter, loc, outQubitTypes, rewriter.getBoolAttr(x_parity), + rewriter.getBoolAttr(z_parity), inQubits); rewriter.replaceOp(op, updateOp.getOutQubits()); return success(); @@ -172,7 +172,7 @@ LogicalResult convertCliffordGate(CustomOp op, PatternRewriter &rewriter, Cliffo auto inQubits = op.getInQubits(); UpdateWithCliffordOp updateOp = - rewriter.create(loc, outQubitTypes, gate, inQubits); + UpdateWithCliffordOp::create(rewriter, loc, outQubitTypes, gate, inQubits); rewriter.modifyOpInPlace(op, [&] { op->setOperands(updateOp->getResults()); }); return success(); @@ -219,8 +219,8 @@ LogicalResult convertNonCliffordGate(CustomOp op, PatternRewriter &rewriter) auto outQubitType = outQubitTypes[0]; auto inQubits = op.getInQubits(); - FlushOp flushOp = rewriter.create(loc, rewriter.getI1Type(), rewriter.getI1Type(), - outQubitType, inQubits[0]); + FlushOp flushOp = FlushOp::create(rewriter, loc, rewriter.getI1Type(), rewriter.getI1Type(), + outQubitType, inQubits[0]); auto pauliZOutQubit = insertPauliOpsAfterFlush(rewriter, loc, flushOp); @@ -297,7 +297,7 @@ struct InitPauliRecordQbitPattern : public OpRewritePattern { LLVM_DEBUG(llvm::dbgs() << "Initializing Pauli record of qubit: " << qubit << "\n"); rewriter.setInsertionPointAfter(op); - InitOp initOp = rewriter.create(loc, qubit.getType(), qubit); + InitOp initOp = InitOp::create(rewriter, loc, qubit.getType(), qubit); qubit.replaceAllUsesExcept(initOp.getOutQubits()[0], initOp); return success(); @@ -330,7 +330,7 @@ struct InitPauliRecordQregPattern : public OpRewritePattern { << "\n"); rewriter.setInsertionPointAfter(op); - InitQregOp initQregOp = rewriter.create(loc, qreg.getType(), qreg); + InitQregOp initQregOp = InitQregOp::create(rewriter, loc, qreg.getType(), qreg); qreg.replaceAllUsesExcept(initQregOp.getOutQreg(), initQregOp); return success(); @@ -365,8 +365,8 @@ struct CorrectMeasurementPattern : public OpRewritePattern { << outQubit << "\n"); rewriter.setInsertionPointAfter(op); - CorrectMeasurementOp correctMeasOp = rewriter.create( - loc, mres.getType(), outQubit.getType(), mres, outQubit); + CorrectMeasurementOp correctMeasOp = CorrectMeasurementOp::create( + rewriter, loc, mres.getType(), outQubit.getType(), mres, outQubit); mres.replaceAllUsesExcept(correctMeasOp.getOutMres(), correctMeasOp); outQubit.replaceAllUsesExcept(correctMeasOp.getOutQubit(), correctMeasOp); @@ -452,8 +452,8 @@ struct FlushBeforeMeasurementProcessPattern : public OpRewritePattern flushOp = getFlushOpAppliedToQubit(qubit); if (!flushOp) { - auto flushOp = rewriter.create( - loc, rewriter.getI1Type(), rewriter.getI1Type(), qubit.getType(), qubit); + auto flushOp = FlushOp::create(rewriter, loc, rewriter.getI1Type(), + rewriter.getI1Type(), qubit.getType(), qubit); auto pauliZOutQubit = insertPauliOpsAfterFlush(rewriter, loc, flushOp); rewriter.modifyOpInPlace(obsOp, [&] { obsOp->setOperand(idx, pauliZOutQubit); }); } diff --git a/mlir/lib/QEC/Transforms/CommutePPR.cpp b/mlir/lib/QEC/Transforms/CommutePPR.cpp index 9bd20375fb..0ca896917c 100644 --- a/mlir/lib/QEC/Transforms/CommutePPR.cpp +++ b/mlir/lib/QEC/Transforms/CommutePPR.cpp @@ -126,8 +126,8 @@ void moveCliffordPastNonClifford(const PauliStringWrapper &lhsPauli, // Create the new PPR auto nonCliffordOp = - rewriter.create(rhs->getLoc(), newOutQubitsTypesList, pauliProduct, - rhs.getRotationKindAttr(), newRHSOperands); + PPRotationOp::create(rewriter, rhs->getLoc(), newOutQubitsTypesList, pauliProduct, + rhs.getRotationKindAttr(), newRHSOperands); rewriter.moveOpBefore(nonCliffordOp, rhs); // Update the use of value in newRHSOperands diff --git a/mlir/lib/QEC/Transforms/DecomposeArbitraryPPR.cpp b/mlir/lib/QEC/Transforms/DecomposeArbitraryPPR.cpp index 72c71d3247..f27529714e 100644 --- a/mlir/lib/QEC/Transforms/DecomposeArbitraryPPR.cpp +++ b/mlir/lib/QEC/Transforms/DecomposeArbitraryPPR.cpp @@ -51,9 +51,9 @@ LogicalResult convertArbitraryPPRToArbitraryZ(PPRotationArbitraryOp &op, Pattern auto loc = op.getLoc(); /// |+⟩── - auto allocateQubit = rewriter.create(loc); + auto allocateQubit = AllocQubitOp::create(rewriter, loc); auto plus = LogicalInitKind::plus; - auto plusQubit = rewriter.create(loc, plus, allocateQubit.getOutQubit()); + auto plusQubit = PrepareStateOp::create(rewriter, loc, plus, allocateQubit.getOutQubit()); // ┌───┐── // | P |── @@ -64,7 +64,7 @@ LogicalResult convertArbitraryPPRToArbitraryZ(PPRotationArbitraryOp &op, Pattern PZ.emplace_back("Z"); SmallVector inQubits = op.getInQubits(); inQubits.emplace_back(plusQubit.getOutQubits().front()); - auto ppmPZ = rewriter.create(loc, PZ, inQubits); + auto ppmPZ = PPMeasurementOp::create(rewriter, loc, PZ, inQubits); // ════╗ // ┌───╩───┐ @@ -73,19 +73,19 @@ LogicalResult convertArbitraryPPRToArbitraryZ(PPRotationArbitraryOp &op, Pattern SmallVector X = {"X"}; const uint16_t PI2 = 2; // For rotation of P(PI/2) auto inQubit = ppmPZ.getOutQubits().back(); - auto pprX = rewriter.create(loc, X, PI2, inQubit, ppmPZ.getMres()); + auto pprX = PPRotationOp::create(rewriter, loc, X, PI2, inQubit, ppmPZ.getMres()); // ┌───────┐ // | Z(phi)|── // └───────┘ SmallVector Z = {"Z"}; auto phi = op.getArbitraryAngle(); - auto pprZ = rewriter.create(loc, Z, phi, pprX.getOutQubits()); + auto pprZ = PPRotationArbitraryOp::create(rewriter, loc, Z, phi, pprX.getOutQubits()); // ┌─╩─┐ // | X |── // └───┘ - auto ppmX = rewriter.create(loc, X, pprZ.getOutQubits()); + auto ppmX = PPMeasurementOp::create(rewriter, loc, X, pprZ.getOutQubits()); // ┌───────┐── // | P(π/2)|── @@ -94,12 +94,12 @@ LogicalResult convertArbitraryPPRToArbitraryZ(PPRotationArbitraryOp &op, Pattern SmallVector outPZQubits = ppmPZ.getOutQubits(); outPZQubits.pop_back(); auto P = op.getPauliProduct(); - auto pprP = rewriter.create(loc, P, PI2, outPZQubits, ppmX.getMres()); + auto pprP = PPRotationOp::create(rewriter, loc, P, PI2, outPZQubits, ppmX.getMres()); rewriter.replaceOp(op, pprP.getOutQubits()); // Deallocate the axillary qubits |+⟩ - rewriter.create(loc, ppmX.getOutQubits().back()); + DeallocQubitOp::create(rewriter, loc, ppmX.getOutQubits().back()); return success(); } diff --git a/mlir/lib/QEC/Transforms/DecomposeCliffordPPR.cpp b/mlir/lib/QEC/Transforms/DecomposeCliffordPPR.cpp index a21cd4c49f..99dcab06e5 100644 --- a/mlir/lib/QEC/Transforms/DecomposeCliffordPPR.cpp +++ b/mlir/lib/QEC/Transforms/DecomposeCliffordPPR.cpp @@ -75,15 +75,15 @@ PPRotationOp decompose_pi_over_four_flattening(bool avoidPauliYMeasure, PPRotati } auto ppmPZ = - rewriter.create(loc, pauliP, rotationSign, m1InQubits, measResult); + PPMeasurementOp::create(rewriter, loc, pauliP, rotationSign, m1InQubits, measResult); SmallVector pauliX = {"X"}; auto ppmX = - rewriter.create(loc, pauliX, ppmPZ.getOutQubits().back(), measResult); + PPMeasurementOp::create(rewriter, loc, pauliX, ppmPZ.getOutQubits().back(), measResult); // FIXME: Check global phase on this decomposition - auto cond = rewriter.create(loc, ppmPZ.getMres(), ppmX.getMres()); + auto cond = arith::XOrIOp::create(rewriter, loc, ppmPZ.getMres(), ppmX.getMres()); SmallVector outPZQubits = ppmPZ.getOutQubits(); outPZQubits.pop_back(); @@ -91,10 +91,10 @@ PPRotationOp decompose_pi_over_four_flattening(bool avoidPauliYMeasure, PPRotati const uint16_t PI_DENOMINATOR = 2; // For rotation of P(PI/2) auto pprPI2 = - rewriter.create(loc, pauliP, PI_DENOMINATOR, outPZQubits, cond.getResult()); + PPRotationOp::create(rewriter, loc, pauliP, PI_DENOMINATOR, outPZQubits, cond.getResult()); // Deallocate the axillary qubit - rewriter.create(loc, ppmX.getOutQubits().back()); + DeallocQubitOp::create(rewriter, loc, ppmX.getOutQubits().back()); rewriter.replaceOp(op, pprPI2.getOutQubits()); return pprPI2; diff --git a/mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp b/mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp index 573a14cc8b..3d42344e88 100644 --- a/mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp +++ b/mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp @@ -76,7 +76,7 @@ void decomposePauliCorrectedPiOverEight(bool avoidPauliYMeasure, PPRotationOp op { auto loc = op.getLoc(); // We always initialize the magic state here, not the conjugate. - auto magic = rewriter.create(loc, LogicalInitKind::magic); + auto magic = FabricateOp::create(rewriter, loc, LogicalInitKind::magic); SmallVector pauliP = extractPauliString(op); // [P] SmallVector inQubits = op.getInQubits(); // [input qubits] @@ -91,61 +91,61 @@ void decomposePauliCorrectedPiOverEight(bool avoidPauliYMeasure, PPRotationOp op if (rotationKind < 0) { rotationSign = -1; } - auto ppmPZ = rewriter.create(loc, extendedPauliP, rotationSign, inQubits); + auto ppmPZ = PPMeasurementOp::create(rewriter, loc, extendedPauliP, rotationSign, inQubits); auto ppmPZRes = ppmPZ.getMres(); if (avoidPauliYMeasure) { auto YBuilder = [&](OpBuilder &builder, Location loc) { // Initialize |Y⟩ state - auto yQubit = rewriter.create(loc, LogicalInitKind::plus_i); + auto yQubit = FabricateOp::create(rewriter, loc, LogicalInitKind::plus_i); // PPM (Z⊗Y) on qubits |m⟩ and |Y⟩ SmallVector axillaryQubits = {ppmPZ.getOutQubits().back(), yQubit.getOutQubits().back()}; SmallVector pauliZZ = {"Z", "Z"}; // [Z, Z] - auto ppmZZ = rewriter.create(loc, pauliZZ, axillaryQubits); + auto ppmZZ = PPMeasurementOp::create(rewriter, loc, pauliZZ, axillaryQubits); SmallVector pauliXX = {"X", "X"}; // [X, X] - auto ppmXX = rewriter.create(loc, pauliXX, ppmZZ.getOutQubits()); + auto ppmXX = PPMeasurementOp::create(rewriter, loc, pauliXX, ppmZZ.getOutQubits()); SmallVector outPZQubits = ppmPZ.getOutQubits(); // [input qubits, |m⟩] outPZQubits.pop_back(); // [input qubits] auto pprPI2 = - rewriter.create(loc, pauliP, 2, outPZQubits, ppmXX.getMres()); + PPRotationOp::create(rewriter, loc, pauliP, 2, outPZQubits, ppmXX.getMres()); for (auto q : axillaryQubits) - rewriter.create(loc, q); - rewriter.create(loc, pprPI2.getOutQubits()); + DeallocQubitOp::create(rewriter, loc, q); + scf::YieldOp::create(rewriter, loc, pprPI2.getOutQubits()); }; auto XBuilder = [&](OpBuilder &builder, Location loc) { // PPM (X) on qubit |m⟩ SmallVector pauliX = {"X"}; - auto ppmX = rewriter.create(loc, pauliX, ppmPZ.getOutQubits().back()); + auto ppmX = PPMeasurementOp::create(rewriter, loc, pauliX, ppmPZ.getOutQubits().back()); SmallVector outPZQubits = ppmPZ.getOutQubits(); // [input qubits, |m⟩] outPZQubits.pop_back(); // [input qubits] auto pprPI2 = - rewriter.create(loc, pauliP, 2, outPZQubits, ppmX.getMres()); - rewriter.create(loc, - ppmX.getOutQubits().back()); // Deallocate |m⟩ qubit - rewriter.create(loc, pprPI2.getOutQubits()); + PPRotationOp::create(rewriter, loc, pauliP, 2, outPZQubits, ppmX.getMres()); + DeallocQubitOp::create(rewriter, loc, + ppmX.getOutQubits().back()); // Deallocate |m⟩ qubit + scf::YieldOp::create(rewriter, loc, pprPI2.getOutQubits()); }; scf::IfOp ifOp; if (rotationKind > 0) { - ifOp = rewriter.create(loc, ppmPZRes, YBuilder, XBuilder); + ifOp = scf::IfOp::create(rewriter, loc, ppmPZRes, YBuilder, XBuilder); } else { - ifOp = rewriter.create(loc, ppmPZRes, XBuilder, YBuilder); + ifOp = scf::IfOp::create(rewriter, loc, ppmPZRes, XBuilder, YBuilder); } rewriter.replaceOp(op, ifOp); } else { SmallVector pauliX = {"X"}; SmallVector pauliY = {"Y"}; - auto ppmXY = rewriter.create(loc, ppmPZRes, pauliY, pauliX, - ppmPZ.getOutQubits().back()); + auto ppmXY = SelectPPMeasurementOp::create(rewriter, loc, ppmPZRes, pauliY, pauliX, + ppmPZ.getOutQubits().back()); // PPR P(π/2) on input qubits if PPM (X or Y) yields -1 SmallVector outPZQubits = ppmPZ.getOutQubits(); // [input qubits, |m⟩] outPZQubits.pop_back(); // [input qubits] - auto pprPI2 = rewriter.create(loc, pauliP, 2, outPZQubits, ppmXY.getMres()); - rewriter.create(loc, ppmXY.getOutQubits().back()); + auto pprPI2 = PPRotationOp::create(rewriter, loc, pauliP, 2, outPZQubits, ppmXY.getMres()); + DeallocQubitOp::create(rewriter, loc, ppmXY.getOutQubits().back()); rewriter.replaceOp(op, pprPI2.getOutQubits()); } } @@ -202,7 +202,7 @@ void decomposeAutoCorrectedPiOverEight(bool avoidPauliYMeasure, PPRotationOp op, // Initialize |0⟩ (zero) or Fabricate |Y⟩ (plus_i) auto axillaryQubit = initializeZeroOrPlusI(avoidPauliYMeasure, loc, rewriter); - auto magic = rewriter.create(loc, getMagicState(op)); + auto magic = FabricateOp::create(rewriter, loc, getMagicState(op)); auto [pauliForAxillaryQubit, rotationSign] = determinePauliAndRotationSignOfMeasurement(avoidPauliYMeasure); @@ -215,34 +215,34 @@ void decomposeAutoCorrectedPiOverEight(bool avoidPauliYMeasure, PPRotationOp op, SmallVector extPauliP = pauliP; extPauliP.emplace_back("Z"); // extend Z for the axillary qubit inQubits.emplace_back(magic.getOutQubits()[0]); // [input qubits, |m⟩] - auto ppmPZ = rewriter.create(loc, extPauliP, inQubits); // [input qubits, |m⟩] + auto ppmPZ = PPMeasurementOp::create(rewriter, loc, extPauliP, inQubits); // [input qubits, |m⟩] // PPM (Z⊗Y/0) on qubits |m⟩ and |Y⟩or|0⟩ SmallVector axillaryQubits = {ppmPZ.getOutQubits().back(), axillaryQubit}; SmallVector pauliZY = {"Z", pauliForAxillaryQubit}; // [Z, Y/Z] - auto ppmZY = rewriter.create(loc, pauliZY, rotationSign, axillaryQubits, - nullptr); // [|m⟩, |Y⟩/|0⟩] + auto ppmZY = PPMeasurementOp::create(rewriter, loc, pauliZY, rotationSign, axillaryQubits, + nullptr); // [|m⟩, |Y⟩/|0⟩] // PPM (X) on qubit |m⟩ SmallVector pauliX = {"X"}; - auto ppmX = rewriter.create(loc, pauliX, ppmZY.getOutQubits().front()); // |m⟩ + auto ppmX = PPMeasurementOp::create(rewriter, loc, pauliX, ppmZY.getOutQubits().front()); // |m⟩ // PPM (X/Z) based on the result of PPM (P⊗Z) on qubit |0⟩ SmallVector pauliZ = {"Z"}; - auto ppmXZ = rewriter.create(loc, ppmPZ.getMres(), pauliX, pauliZ, - ppmZY.getOutQubits().back()); // |0⟩ + auto ppmXZ = SelectPPMeasurementOp::create(rewriter, loc, ppmPZ.getMres(), pauliX, pauliZ, + ppmZY.getOutQubits().back()); // |0⟩ // XOR of the results of PPM (P⊗Z) and PPM (X) - auto condOp = rewriter.create(loc, ppmZY.getMres(), ppmX.getMres()); + auto condOp = arith::XOrIOp::create(rewriter, loc, ppmZY.getMres(), ppmX.getMres()); // PPR P(π/2) based on the result of XOR on input qubits SmallVector outPZQubits = ppmPZ.getOutQubits(); outPZQubits.pop_back(); - auto pprPI2 = rewriter.create(loc, pauliP, 2, outPZQubits, condOp.getResult()); + auto pprPI2 = PPRotationOp::create(rewriter, loc, pauliP, 2, outPZQubits, condOp.getResult()); // Deallocate the axillary qubits - rewriter.create(loc, ppmXZ.getOutQubits().back()); // |0⟩ - rewriter.create(loc, ppmX.getOutQubits().back()); // |m⟩ + DeallocQubitOp::create(rewriter, loc, ppmXZ.getOutQubits().back()); // |0⟩ + DeallocQubitOp::create(rewriter, loc, ppmX.getOutQubits().back()); // |m⟩ rewriter.replaceOp(op, pprPI2.getOutQubits()); } @@ -277,7 +277,7 @@ void decomposeInjectMagicStatePiOverEight(PPRotationOp op, PatternRewriter &rewr auto loc = op.getLoc(); // Fabricate the magic state |m⟩ - auto magic = rewriter.create(loc, getMagicState(op)); + auto magic = FabricateOp::create(rewriter, loc, getMagicState(op)); SmallVector pauliP = extractPauliString(op); // [P = n qubit] SmallVector inQubits = op.getInQubits(); // [input qubits] @@ -286,25 +286,25 @@ void decomposeInjectMagicStatePiOverEight(PPRotationOp op, PatternRewriter &rewr SmallVector extendedPauliP = pauliP; extendedPauliP.emplace_back("Z"); // extend Z for the axillary qubit -> [P, Z] inQubits.emplace_back(magic.getOutQubits()[0]); // [input qubits, |m⟩] - auto ppmPZ = rewriter.create(loc, extendedPauliP, inQubits); + auto ppmPZ = PPMeasurementOp::create(rewriter, loc, extendedPauliP, inQubits); // PPR P(π/4) on input qubits if PPM (P⊗Z) yields -1 SmallVector outPZQubits = ppmPZ.getOutQubits(); // [input qubits, |m⟩] outPZQubits.pop_back(); // [input qubits] const uint16_t PI_DENOMINATOR = 4; // For rotation of P(PI/4) auto pprPI4 = - rewriter.create(loc, pauliP, PI_DENOMINATOR, outPZQubits, ppmPZ.getMres()); + PPRotationOp::create(rewriter, loc, pauliP, PI_DENOMINATOR, outPZQubits, ppmPZ.getMres()); // PPM (X) on |m⟩ SmallVector pauliX = {"X"}; - auto ppmX = rewriter.create(loc, pauliX, ppmPZ.getOutQubits().back()); + auto ppmX = PPMeasurementOp::create(rewriter, loc, pauliX, ppmPZ.getOutQubits().back()); // PPR P(π/2) on input qubits if PPM (X) yields -1 auto pprPI2 = - rewriter.create(loc, pauliP, 2, pprPI4.getOutQubits(), ppmX.getMres()); + PPRotationOp::create(rewriter, loc, pauliP, 2, pprPI4.getOutQubits(), ppmX.getMres()); // Deallocate the axillary qubit - rewriter.create(loc, ppmX.getOutQubits().back()); + DeallocQubitOp::create(rewriter, loc, ppmX.getOutQubits().back()); rewriter.replaceOp(op, pprPI2.getOutQubits()); } diff --git a/mlir/lib/QEC/Transforms/LowerQECInitOps.cpp b/mlir/lib/QEC/Transforms/LowerQECInitOps.cpp index 124cc10df0..45937eda40 100644 --- a/mlir/lib/QEC/Transforms/LowerQECInitOps.cpp +++ b/mlir/lib/QEC/Transforms/LowerQECInitOps.cpp @@ -30,15 +30,15 @@ Value createGate(Location loc, PatternRewriter &rewriter, Value inQubit, StringR bool adjoint = false) { auto outQubitType = inQubit.getType(); - auto gateOp = rewriter.create(loc, - /*out_qubits=*/TypeRange{outQubitType}, - /*out_ctrl_qubits=*/TypeRange{}, - /*params=*/ValueRange{}, - /*in_qubits=*/ValueRange{inQubit}, - /*gate_name=*/gateName, - /*adjoint=*/adjoint, - /*in_ctrl_qubits=*/ValueRange{}, - /*in_ctrl_values=*/ValueRange{}); + auto gateOp = CustomOp::create(rewriter, loc, + /*out_qubits=*/TypeRange{outQubitType}, + /*out_ctrl_qubits=*/TypeRange{}, + /*params=*/ValueRange{}, + /*in_qubits=*/ValueRange{inQubit}, + /*gate_name=*/gateName, + /*adjoint=*/adjoint, + /*in_ctrl_qubits=*/ValueRange{}, + /*in_ctrl_values=*/ValueRange{}); return gateOp.getOutQubits().front(); } @@ -95,7 +95,7 @@ template struct LowerQECInitOpPattern : public OpRewritePatter } else { // FabricateOp: allocate a new qubit - auto allocOp = rewriter.create(loc); + auto allocOp = AllocQubitOp::create(rewriter, loc); qubit = allocOp.getResult(); } diff --git a/mlir/lib/QEC/Transforms/MergePPRIntoPPM.cpp b/mlir/lib/QEC/Transforms/MergePPRIntoPPM.cpp index 397235ada7..05eb527bb1 100644 --- a/mlir/lib/QEC/Transforms/MergePPRIntoPPM.cpp +++ b/mlir/lib/QEC/Transforms/MergePPRIntoPPM.cpp @@ -134,9 +134,8 @@ void moveCliffordPastPPM(const PauliStringWrapper &lhsPauli, const PauliStringWr Type mresType = rhs.getMres().getType(); - auto newPPM = - rewriter.create(rhs->getLoc(), mresType, newOutQubitTypes, pauliProduct, - rhs.getRotationSign(), newRHSOperands); + auto newPPM = PPMeasurementOp::create(rewriter, rhs->getLoc(), mresType, newOutQubitTypes, + pauliProduct, rhs.getRotationSign(), newRHSOperands); rewriter.moveOpBefore(newPPM, rhs); // Update the use of value in newRHSOperands diff --git a/mlir/lib/QEC/Transforms/PPRDecomposeUtils.cpp b/mlir/lib/QEC/Transforms/PPRDecomposeUtils.cpp index 23e1d8815e..b45b5c3f4f 100644 --- a/mlir/lib/QEC/Transforms/PPRDecomposeUtils.cpp +++ b/mlir/lib/QEC/Transforms/PPRDecomposeUtils.cpp @@ -34,12 +34,12 @@ mlir::OpResult initializeZeroOrPlusI(bool avoidPauliYMeasure, mlir::Location loc { if (avoidPauliYMeasure) { // Fabricate |Y⟩ - auto plusIOp = rewriter.create(loc, LogicalInitKind::plus_i); + auto plusIOp = FabricateOp::create(rewriter, loc, LogicalInitKind::plus_i); return plusIOp.getOutQubits().back(); } // Initialize |0⟩ - auto allocatedQubit = rewriter.create(loc); + auto allocatedQubit = quantum::AllocQubitOp::create(rewriter, loc); return allocatedQubit.getOutQubit(); } diff --git a/mlir/lib/QEC/Transforms/PPRToMBQC.cpp b/mlir/lib/QEC/Transforms/PPRToMBQC.cpp index f33f0ef190..6db9ddef91 100644 --- a/mlir/lib/QEC/Transforms/PPRToMBQC.cpp +++ b/mlir/lib/QEC/Transforms/PPRToMBQC.cpp @@ -28,8 +28,8 @@ namespace { CustomOp buildCNOTGate(Value control, Value target, ConversionPatternRewriter &rewriter) { - return rewriter.create( - control.getLoc(), + return quantum::CustomOp::create( + rewriter, control.getLoc(), /*out_qubits=*/mlir::TypeRange({control.getType(), target.getType()}), /*out_ctrl_qubits=*/mlir::TypeRange({}), /*params=*/mlir::ValueRange(), @@ -47,44 +47,44 @@ CustomOp buildSingleQubitGate(Value qubit, StringRef gateName, ArrayRef SmallVector paramValues; auto f64Ty = rewriter.getF64Type(); for (double p : params) { - auto cst = rewriter.create(qubit.getLoc(), f64Ty, - rewriter.getF64FloatAttr(p)); + auto cst = mlir::arith::ConstantOp::create(rewriter, qubit.getLoc(), f64Ty, + rewriter.getF64FloatAttr(p)); paramValues.push_back(cst.getResult()); } - return rewriter.create(qubit.getLoc(), - /*out_qubits=*/mlir::TypeRange({qubit.getType()}), - /*out_ctrl_qubits=*/mlir::TypeRange({}), - /*params=*/mlir::ValueRange(paramValues), - /*in_qubits=*/mlir::ValueRange({qubit}), - /*gate_name=*/gateName, - /*adjoint=*/false, - /*in_ctrl_qubits=*/mlir::ValueRange({}), - /*in_ctrl_values=*/mlir::ValueRange()); + return quantum::CustomOp::create(rewriter, qubit.getLoc(), + /*out_qubits=*/mlir::TypeRange({qubit.getType()}), + /*out_ctrl_qubits=*/mlir::TypeRange({}), + /*params=*/mlir::ValueRange(paramValues), + /*in_qubits=*/mlir::ValueRange({qubit}), + /*gate_name=*/gateName, + /*adjoint=*/false, + /*in_ctrl_qubits=*/mlir::ValueRange({}), + /*in_ctrl_values=*/mlir::ValueRange()); } // Version for dynamic (runtime) parameters CustomOp buildSingleQubitGateWithDynamicParams(Value qubit, StringRef gateName, ValueRange params, ConversionPatternRewriter &rewriter) { - return rewriter.create(qubit.getLoc(), - /*out_qubits=*/mlir::TypeRange({qubit.getType()}), - /*out_ctrl_qubits=*/mlir::TypeRange({}), - /*params=*/params, - /*in_qubits=*/mlir::ValueRange({qubit}), - /*gate_name=*/gateName, - /*adjoint=*/false, - /*in_ctrl_qubits=*/mlir::ValueRange({}), - /*in_ctrl_values=*/mlir::ValueRange()); + return quantum::CustomOp::create(rewriter, qubit.getLoc(), + /*out_qubits=*/mlir::TypeRange({qubit.getType()}), + /*out_ctrl_qubits=*/mlir::TypeRange({}), + /*params=*/params, + /*in_qubits=*/mlir::ValueRange({qubit}), + /*gate_name=*/gateName, + /*adjoint=*/false, + /*in_ctrl_qubits=*/mlir::ValueRange({}), + /*in_ctrl_values=*/mlir::ValueRange()); } MeasureOp buildMeasurementOp(Value qubit, ConversionPatternRewriter &rewriter) { - return rewriter.create(qubit.getLoc(), - /*mres=*/rewriter.getI1Type(), - /*out_qubits=*/qubit.getType(), - /*in_qubits=*/qubit, - /*postselect=*/nullptr); + return quantum::MeasureOp::create(rewriter, qubit.getLoc(), + /*mres=*/rewriter.getI1Type(), + /*out_qubits=*/qubit.getType(), + /*in_qubits=*/qubit, + /*postselect=*/nullptr); } // Applies per-qubit conjugations that map the provided Pauli string to the Z @@ -173,9 +173,9 @@ void constructKernelOperation(SmallVector &qubits, Value &measResult, QEC auto loc = pprArbitraryOp.getLoc(); // Create constant 2.0 and multiply to get RZ angle - auto two = rewriter.create(loc, rewriter.getF64Type(), - rewriter.getF64FloatAttr(2.0)); - auto rzAngle = rewriter.create(loc, angle, two.getResult()); + auto two = mlir::arith::ConstantOp::create(rewriter, loc, rewriter.getF64Type(), + rewriter.getF64FloatAttr(2.0)); + auto rzAngle = mlir::arith::MulFOp::create(rewriter, loc, angle, two.getResult()); // Create RZ gate with dynamic angle qubits[0] = diff --git a/mlir/lib/QEC/Transforms/PartitionLayers.cpp b/mlir/lib/QEC/Transforms/PartitionLayers.cpp index 838f0314eb..afa37270d2 100644 --- a/mlir/lib/QEC/Transforms/PartitionLayers.cpp +++ b/mlir/lib/QEC/Transforms/PartitionLayers.cpp @@ -62,8 +62,8 @@ void constructLayer(QECLayer &layer, IRRewriter &writer) OpBuilder::InsertionGuard guard(writer); writer.setInsertionPointAfter(layer.getOps().back()); - auto layerOp = writer.create( - loc, inOperands, outResults, + auto layerOp = qec::LayerOp::create( + writer, loc, inOperands, outResults, [&](OpBuilder &builder, Location loc, ValueRange operands, ValueRange results) { // Map input operands to the layer's block arguments (block arguments are entries) IRMapping mapper; @@ -86,7 +86,7 @@ void constructLayer(QECLayer &layer, IRRewriter &writer) } } } - builder.create(loc, newResults); + qec::YieldOp::create(builder, loc, newResults); }); // Replace all uses of the original SSA results with the new layer results diff --git a/mlir/lib/QEC/Transforms/ToPPR.cpp b/mlir/lib/QEC/Transforms/ToPPR.cpp index 42977e90e7..9c670cf196 100644 --- a/mlir/lib/QEC/Transforms/ToPPR.cpp +++ b/mlir/lib/QEC/Transforms/ToPPR.cpp @@ -83,14 +83,14 @@ void applyGlobalPhase(Location loc, Value phaseValue, ConversionPatternRewriter // ::mlir::TypeRange out_ctrl_qubits, ::mlir::Value params, /*optional*/bool adjoint, // ::mlir::ValueRange in_ctrl_qubits, ::mlir::ValueRange in_ctrl_values); - rewriter.create(loc, /*out_ctrl_qubits=*/TypeRange{}, /*params=*/phaseValue, - /*adjoint=*/false, /*in_ctrl_qubits*/ ValueRange{}, - /*in_ctrl_values*/ ValueRange{}); + GlobalPhaseOp::create(rewriter, loc, /*out_ctrl_qubits=*/TypeRange{}, /*params=*/phaseValue, + /*adjoint=*/false, /*in_ctrl_qubits*/ ValueRange{}, + /*in_ctrl_values*/ ValueRange{}); } void applyGlobalPhase(Location loc, const double phase, ConversionPatternRewriter &rewriter) { - Value paramValue = rewriter.create(loc, rewriter.getF64FloatAttr(phase)); + Value paramValue = arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(phase)); applyGlobalPhase(loc, paramValue, rewriter); } @@ -111,8 +111,8 @@ void applySingleQubitConversion(CustomOp op, const ArrayRef &gat applyAdjointIfNeeded(gateConversion, op); ArrayAttr pauliProduct = rewriter.getStrArrayAttr(gateConversion.pauliOperators); - pprOp = rewriter.create(loc, types, pauliProduct, gateConversion.rotationKind, - inQubits); + pprOp = PPRotationOp::create(rewriter, loc, types, pauliProduct, + gateConversion.rotationKind, inQubits); inQubits = pprOp.getOutQubits(); types = pprOp.getOutQubits().getType(); } @@ -145,24 +145,24 @@ LogicalResult controlledConversion(CustomOp op, StringRef P1, StringRef P2, auto inQubitsValues = op.getInQubits(); auto outQubitsTypesList = op.getOutQubits().getType(); - auto G0 = rewriter.create(loc, outQubitsTypesList, pauliProduct, g0.rotationKind, - inQubitsValues); + auto G0 = PPRotationOp::create(rewriter, loc, outQubitsTypesList, pauliProduct, g0.rotationKind, + inQubitsValues); // G1 = (P1 ⊗ 1)−π/4 pauliProduct = rewriter.getStrArrayAttr(g1.pauliOperators); SmallVector inQubitsValues1{G0.getOutQubits()[0]}; SmallVector outQubitsTypesList1{G0.getOutQubits()[0].getType()}; - auto G1 = rewriter.create(loc, outQubitsTypesList1, pauliProduct, g1.rotationKind, - inQubitsValues1); + auto G1 = PPRotationOp::create(rewriter, loc, outQubitsTypesList1, pauliProduct, + g1.rotationKind, inQubitsValues1); // G2 = (1 ⊗ P2)−π/4 pauliProduct = rewriter.getStrArrayAttr(g2.pauliOperators); SmallVector inQubitsValues2{G0.getOutQubits()[1]}; SmallVector inQubitsTypesList2{G0.getOutQubits()[1].getType()}; - auto G2 = rewriter.create(loc, inQubitsTypesList2, pauliProduct, g1.rotationKind, - inQubitsValues2); + auto G2 = PPRotationOp::create(rewriter, loc, inQubitsTypesList2, pauliProduct, g1.rotationKind, + inQubitsValues2); rewriter.replaceOp(op, {G1.getOutQubits()[0], G2.getOutQubits()[0]}); return success(); @@ -257,8 +257,8 @@ LogicalResult convertMeasureOpToPPM(MeasureOp op, StringRef axis, Type mresType = op.getMres().getType(); SmallVector outQubitTypes({qubitType}); - auto ppmOp = rewriter.create(loc, mresType, outQubitTypes, pauliProduct, - nullptr, inQubits); + auto ppmOp = PPMeasurementOp::create(rewriter, loc, mresType, outQubitTypes, pauliProduct, + nullptr, inQubits); rewriter.replaceOp(op, ppmOp); return success(); @@ -300,8 +300,8 @@ LogicalResult convertPauliRotGate(PauliRotOp op, ConversionPatternRewriter &rewr if (op.getAdjoint()) { rotationKind = -rotationKind; } - auto pprOp = rewriter.create(loc, outQubitTypes, pauliProduct, - rotationKind, inQubits); + auto pprOp = PPRotationOp::create(rewriter, loc, outQubitTypes, pauliProduct, + rotationKind, inQubits); rewriter.replaceOp(op, pprOp.getOutQubits()); return success(); } @@ -312,15 +312,15 @@ LogicalResult convertPauliRotGate(PauliRotOp op, ConversionPatternRewriter &rewr Value constResult; if (op.getAdjoint()) { constResult = - rewriter.create(loc, rewriter.getF64FloatAttr(-2.0)).getResult(); + arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(-2.0)).getResult(); } else { constResult = - rewriter.create(loc, rewriter.getF64FloatAttr(2.0)).getResult(); + arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(2.0)).getResult(); } - auto result = rewriter.create(loc, angleValue, constResult).getResult(); + auto result = arith::DivFOp::create(rewriter, loc, angleValue, constResult).getResult(); auto pprArbitraryOp = - rewriter.create(loc, outQubitTypes, pauliProduct, result, inQubits); + PPRotationArbitraryOp::create(rewriter, loc, outQubitTypes, pauliProduct, result, inQubits); rewriter.replaceOp(op, pprArbitraryOp.getOutQubits()); diff --git a/mlir/lib/QEC/Transforms/UnrollConditionalPPRPPM.cpp b/mlir/lib/QEC/Transforms/UnrollConditionalPPRPPM.cpp index 369103a311..3dbe011628 100644 --- a/mlir/lib/QEC/Transforms/UnrollConditionalPPRPPM.cpp +++ b/mlir/lib/QEC/Transforms/UnrollConditionalPPRPPM.cpp @@ -54,25 +54,25 @@ struct LowerSelectPPM : public OpRewritePattern { resultTypes.push_back(qubit.getType()); } - auto ifOp = rewriter.create(loc, resultTypes, selectSwitch, true); + auto ifOp = scf::IfOp::create(rewriter, loc, resultTypes, selectSwitch, true); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - auto ppm0 = rewriter.create(loc, pauliProduct0, inQubits); + auto ppm0 = PPMeasurementOp::create(rewriter, loc, pauliProduct0, inQubits); SmallVector yieldValues; yieldValues.push_back(ppm0.getMres()); yieldValues.append(ppm0.getOutQubits().begin(), ppm0.getOutQubits().end()); - rewriter.create(loc, yieldValues); + scf::YieldOp::create(rewriter, loc, yieldValues); } { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - auto ppm1 = rewriter.create(loc, pauliProduct1, inQubits); + auto ppm1 = PPMeasurementOp::create(rewriter, loc, pauliProduct1, inQubits); SmallVector yieldValues; yieldValues.push_back(ppm1.getMres()); yieldValues.append(ppm1.getOutQubits().begin(), ppm1.getOutQubits().end()); - rewriter.create(loc, yieldValues); + scf::YieldOp::create(rewriter, loc, yieldValues); } rewriter.replaceOp(op, ifOp.getResults()); @@ -113,20 +113,20 @@ struct LowerCondPPR : public OpRewritePattern { resultTypes.push_back(qubit.getType()); } - auto ifOp = rewriter.create(loc, resultTypes, condition, true); + auto ifOp = scf::IfOp::create(rewriter, loc, resultTypes, condition, true); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - auto ppr = rewriter.create(loc, resultTypes, pauliProduct, rotationKind, - inQubits); - rewriter.create(loc, ppr.getOutQubits()); + auto ppr = PPRotationOp::create(rewriter, loc, resultTypes, pauliProduct, rotationKind, + inQubits); + scf::YieldOp::create(rewriter, loc, ppr.getOutQubits()); } { // Unchanged else block OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - rewriter.create(loc, inQubits); + scf::YieldOp::create(rewriter, loc, inQubits); } rewriter.replaceOp(op, ifOp.getResults()); diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index f4de39731c..7759472204 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -57,7 +57,7 @@ LogicalResult CustomOp::canonicalize(CustomOp op, mlir::PatternRewriter &rewrite auto params = op.getParams(); SmallVector paramsNeg; for (auto param : params) { - auto paramNeg = rewriter.create(op.getLoc(), param); + auto paramNeg = mlir::arith::NegFOp::create(rewriter, op.getLoc(), param); paramsNeg.push_back(paramNeg); } @@ -75,7 +75,7 @@ LogicalResult CustomOp::canonicalize(CustomOp op, mlir::PatternRewriter &rewrite LogicalResult MultiRZOp::canonicalize(MultiRZOp op, mlir::PatternRewriter &rewriter) { if (op.getAdjoint()) { - auto paramNeg = rewriter.create(op.getLoc(), op.getTheta()); + auto paramNeg = mlir::arith::NegFOp::create(rewriter, op.getLoc(), op.getTheta()); rewriter.replaceOpWithNewOp( op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), paramNeg, @@ -89,7 +89,7 @@ LogicalResult MultiRZOp::canonicalize(MultiRZOp op, mlir::PatternRewriter &rewri LogicalResult PCPhaseOp::canonicalize(PCPhaseOp op, mlir::PatternRewriter &rewriter) { if (op.getAdjoint()) { - auto paramNeg = rewriter.create(op.getLoc(), op.getTheta()); + auto paramNeg = mlir::arith::NegFOp::create(rewriter, op.getLoc(), op.getTheta()); rewriter.replaceOpWithNewOp( op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), paramNeg, diff --git a/mlir/lib/Quantum/Transforms/AdjointPatterns.cpp b/mlir/lib/Quantum/Transforms/AdjointPatterns.cpp index ecfc7bd74b..254334365a 100644 --- a/mlir/lib/Quantum/Transforms/AdjointPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/AdjointPatterns.cpp @@ -98,8 +98,8 @@ class AdjointGenerator { } else if (auto insertOp = dyn_cast(op)) { Value dynamicWire = getDynamicWire(insertOp, builder); - auto extractOp = builder.create( - insertOp.getLoc(), insertOp.getQubit().getType(), + auto extractOp = quantum::ExtractOp::create( + builder, insertOp.getLoc(), insertOp.getQubit().getType(), remappedValues.lookup(insertOp.getOutQreg()), dynamicWire, insertOp.getIdxAttrAttr()); remappedValues.map(insertOp.getQubit(), extractOp.getResult()); @@ -108,8 +108,8 @@ class AdjointGenerator { } else if (auto extractOp = dyn_cast(op)) { Value dynamicWire = getDynamicWire(extractOp, builder); - auto insertOp = builder.create( - extractOp.getLoc(), extractOp.getQreg().getType(), + auto insertOp = quantum::InsertOp::create( + builder, extractOp.getLoc(), extractOp.getQreg().getType(), remappedValues.lookup(extractOp.getQreg()), dynamicWire, extractOp.getIdxAttrAttr(), remappedValues.lookup(extractOp.getQubit())); remappedValues.map(extractOp.getQreg(), insertOp.getResult()); @@ -139,7 +139,7 @@ class AdjointGenerator { { Value dynamicWire; if (!op.getIdxAttr().has_value()) { - dynamicWire = builder.create(op.getLoc(), cache.wireVector); + dynamicWire = ListPopOp::create(builder, op.getLoc(), cache.wireVector); } return dynamicWire; } @@ -179,7 +179,7 @@ class AdjointGenerator { verifyTypeIsCacheable(paramType, operation); if (paramType.isF64()) { cachedParams[numParams - 1 - idx] = - builder.create(parametrizedGate.getLoc(), cache.paramVector); + ListPopOp::create(builder, parametrizedGate.getLoc(), cache.paramVector); idx++; continue; } @@ -190,17 +190,17 @@ class AdjointGenerator { Type elementType = aTensorType.getElementType(); // Constants auto loc = parametrizedGate.getLoc(); - Value c0 = builder.create(loc, 0); - Value c1 = builder.create(loc, 1); + Value c0 = index::ConstantOp::create(builder, loc, 0); + Value c1 = index::ConstantOp::create(builder, loc, 1); // TODO: Generalize to all possible dimensions bool isDim0Static = ShapedType::kDynamic != shape[0]; bool isDim1Static = ShapedType::kDynamic != shape[1]; Value dim0Length = isDim0Static - ? (Value)builder.create(loc, shape[0]) - : (Value)builder.create(loc, param, c0); + ? (Value)index::ConstantOp::create(builder, loc, shape[0]) + : (Value)tensor::DimOp::create(builder, loc, param, c0); Value dim1Length = isDim1Static - ? (Value)builder.create(loc, shape[1]) - : (Value)builder.create(loc, param, c1); + ? (Value)index::ConstantOp::create(builder, loc, shape[1]) + : (Value)tensor::DimOp::create(builder, loc, param, c1); // Renaming for legibility // Note: Since this is a square matrix, upperBound for both loops is the @@ -211,13 +211,13 @@ class AdjointGenerator { Value lowerBoundDim1 = c0; Value upperBoundDim1 = dim1Length; Value stepDim1 = c1; - Value beginningTensor = builder.create(loc, shape, elementType); + Value beginningTensor = tensor::EmptyOp::create(builder, loc, shape, elementType); // This time, we are in reverse, so we need to start // with N-1 since MLIR does not allow for loops with negative step sizes. SmallVector initialValues = {beginningTensor}; - scf::ForOp iForLoop = builder.create( - loc, lowerBoundDim0, upperBoundDim0, stepDim0, initialValues); + scf::ForOp iForLoop = scf::ForOp::create(builder, loc, lowerBoundDim0, + upperBoundDim0, stepDim0, initialValues); { OpBuilder::InsertionGuard afterIForLoop(builder); builder.setInsertionPointToStart(iForLoop.getBody()); @@ -225,13 +225,14 @@ class AdjointGenerator { Value currIthTensor = iIterArgs.front(); Value i = iForLoop.getInductionVar(); - Value iPlusOne = builder.create(loc, i, c1); - Value nMinusIMinusOne = builder.create(loc, dim0Length, iPlusOne); + Value iPlusOne = index::AddOp::create(builder, loc, i, c1); + Value nMinusIMinusOne = + index::SubOp::create(builder, loc, dim0Length, iPlusOne); // Just for legibility Value iTensorIndex = nMinusIMinusOne; - scf::ForOp jForLoop = builder.create( - loc, lowerBoundDim1, upperBoundDim1, stepDim1, currIthTensor); + scf::ForOp jForLoop = scf::ForOp::create( + builder, loc, lowerBoundDim1, upperBoundDim1, stepDim1, currIthTensor); { OpBuilder::InsertionGuard afterJForLoop(builder); builder.setInsertionPointToStart(jForLoop.getBody()); @@ -240,27 +241,27 @@ class AdjointGenerator { "jForLoop has more induction variables than necessary."); Value currIthJthTensor = jIterArgs.front(); - Value imag = builder.create(loc, cache.paramVector); - Value real = builder.create(loc, cache.paramVector); + Value imag = ListPopOp::create(builder, loc, cache.paramVector); + Value real = ListPopOp::create(builder, loc, cache.paramVector); Value element = - builder.create(loc, elementType, real, imag); + complex::CreateOp::create(builder, loc, elementType, real, imag); // TODO: Generalize to types which are not complex Value j = jForLoop.getInductionVar(); - Value jPlusOne = builder.create(loc, j, c1); + Value jPlusOne = index::AddOp::create(builder, loc, j, c1); Value nMinusJMinusOne = - builder.create(loc, dim1Length, jPlusOne); + index::SubOp::create(builder, loc, dim1Length, jPlusOne); // Just for legibility Value jTensorIndex = nMinusJMinusOne; SmallVector indices = {iTensorIndex, jTensorIndex}; - Value updatedIthJthTensor = builder.create( - loc, element, currIthJthTensor, indices); - builder.create(loc, updatedIthJthTensor); + Value updatedIthJthTensor = tensor::InsertOp::create( + builder, loc, element, currIthJthTensor, indices); + scf::YieldOp::create(builder, loc, updatedIthJthTensor); } Value ithTensor = jForLoop.getResult(0); - builder.create(loc, ithTensor); + scf::YieldOp::create(builder, loc, ithTensor); } Value recreatedTensor = iForLoop.getResult(0); @@ -333,7 +334,7 @@ class AdjointGenerator { /*outputs=*/originalTerminator->getOperandTypes()); std::string adjointName = funcOp.getName().str() + ".adjoint"; Location loc = funcOp.getLoc(); - func::FuncOp adjointFnOp = builder.create(loc, adjointName, adjointFnType); + func::FuncOp adjointFnOp = func::FuncOp::create(builder, loc, adjointName, adjointFnType); adjointFnOp.setPrivate(); // Create the block of the adjoint function @@ -346,7 +347,7 @@ class AdjointGenerator { Value lastArg = adjointFnOp.getArgument(argumentsSize - 1); assert(isa(lastArg.getType()) && "The last argument of the function must be the quantum register."); - quantum::AdjointOp adjointOp = builder.create(loc, qregType, lastArg); + quantum::AdjointOp adjointOp = quantum::AdjointOp::create(builder, loc, qregType, lastArg); Region *adjointRegion = &adjointOp.getRegion(); Region &originalRegion = funcOp.getRegion(); @@ -367,13 +368,13 @@ class AdjointGenerator { ValueRange res = terminator->getOperands(); TypeRange resTypes = terminator->getResultTypes(); builder.setInsertionPointAfter(terminator); - builder.create(loc, resTypes, res); + quantum::YieldOp::create(builder, loc, resTypes, res); // Return the adjoint operation in the adjoint function IRRewriter rewriter(builder); rewriter.eraseOp(terminator); builder.setInsertionPointAfter(adjointOp); - builder.create(loc, adjointOp.getResult()); + func::ReturnOp::create(builder, loc, adjointOp.getResult()); // Leave the adjoint func op to go back at the saved insertion builder.restoreInsertionPoint(insertionSaved); @@ -384,7 +385,7 @@ class AdjointGenerator { args.pop_back(); args.push_back(reversedResult); // Call the adjoint func op - auto adjointCallOp = builder.create(loc, adjointFnOp, args); + auto adjointCallOp = func::CallOp::create(builder, loc, adjointFnOp, args); ValueRange initQreg = callOp.getArgOperands(); // Map the initial quantum register with the adjoint result remappedValues.map(getQuantumReg(initQreg).value(), adjointCallOp.getResult(0)); @@ -402,21 +403,22 @@ class AdjointGenerator { Value tape = cache.controlFlowTapes.at(forOp); // Popping the start, stop, and step implies that these are backwards relative to // the order they were pushed. - Value step = builder.create(forOp.getLoc(), tape); - Value stop = builder.create(forOp.getLoc(), tape); - Value start = builder.create(forOp.getLoc(), tape); + Value step = ListPopOp::create(builder, forOp.getLoc(), tape); + Value stop = ListPopOp::create(builder, forOp.getLoc(), tape); + Value start = ListPopOp::create(builder, forOp.getLoc(), tape); Value reversedResult = remappedValues.lookup(getQuantumReg(forOp.getResults()).value()); - auto replacedFor = builder.create( - forOp.getLoc(), start, stop, step, /*iterArgsInit=*/reversedResult, + auto replacedFor = scf::ForOp::create( + builder, forOp.getLoc(), start, stop, step, /*iterArgsInit=*/reversedResult, [&](OpBuilder &bodyBuilder, Location loc, Value iv, ValueRange iterArgs) { OpBuilder::InsertionGuard insertionGuard(builder); builder.restoreInsertionPoint(bodyBuilder.saveInsertionPoint()); remappedValues.map(yieldedQureg.value(), iterArgs[0]); generateImpl(forOp.getBodyRegion(), builder); - builder.create( - loc, remappedValues.lookup(getQuantumReg(forOp.getRegionIterArgs()).value())); + scf::YieldOp::create( + builder, loc, + remappedValues.lookup(getQuantumReg(forOp.getRegionIterArgs()).value())); }); remappedValues.map(getQuantumReg(forOp.getInitArgs()).value(), replacedFor.getResult(0)); } @@ -430,8 +432,8 @@ class AdjointGenerator { } Value tape = cache.controlFlowTapes.at(ifOp); - Value condition = builder.create(ifOp.getLoc(), tape); - condition = builder.create(ifOp.getLoc(), builder.getI1Type(), condition); + Value condition = ListPopOp::create(builder, ifOp.getLoc(), tape); + condition = index::CastSOp::create(builder, ifOp.getLoc(), builder.getI1Type(), condition); Value reversedResult = remappedValues.lookup(getQuantumReg(ifOp.getResults()).value()); // The quantum register is captured from outside rather than passed in through a @@ -455,13 +457,13 @@ class AdjointGenerator { getQuantumReg(oldRegion.front().getTerminator()->getOperands()); remappedValues.map(yieldedQureg.value(), reversedResult); generateImpl(oldRegion, builder); - builder.create( - loc, remappedValues.lookup(findOldestQuregInRegion(oldRegion))); + scf::YieldOp::create(builder, loc, + remappedValues.lookup(findOldestQuregInRegion(oldRegion))); }; }; - auto reversedIf = builder.create(ifOp.getLoc(), condition, - getRegionBuilder(ifOp.getThenRegion()), - getRegionBuilder(ifOp.getElseRegion())); + auto reversedIf = scf::IfOp::create(builder, ifOp.getLoc(), condition, + getRegionBuilder(ifOp.getThenRegion()), + getRegionBuilder(ifOp.getElseRegion())); Value startingThenQureg = findOldestQuregInRegion(ifOp.getThenRegion()); Value startingElseQureg = findOldestQuregInRegion(ifOp.getElseRegion()); assert(startingThenQureg == startingElseQureg && @@ -479,13 +481,14 @@ class AdjointGenerator { } Value tape = cache.controlFlowTapes.at(whileOp); - Value numIterations = builder.create(whileOp.getLoc(), tape); - Value c0 = builder.create(whileOp.getLoc(), 0); - Value c1 = builder.create(whileOp.getLoc(), 1); + Value numIterations = ListPopOp::create(builder, whileOp.getLoc(), tape); + Value c0 = index::ConstantOp::create(builder, whileOp.getLoc(), 0); + Value c1 = index::ConstantOp::create(builder, whileOp.getLoc(), 1); Value iterArgInit = remappedValues.lookup(getQuantumReg(whileOp.getResults()).value()); - auto replacedWhile = builder.create( - whileOp.getLoc(), /*start=*/c0, /*stop=*/numIterations, /*step=*/c1, iterArgInit, + auto replacedWhile = scf::ForOp::create( + builder, whileOp.getLoc(), /*start=*/c0, /*stop=*/numIterations, /*step=*/c1, + iterArgInit, /*bodyBuilder=*/ [&](OpBuilder &bodyBuilder, Location loc, Value iv, ValueRange iterArgs) { OpBuilder::InsertionGuard insertionGuard(builder); @@ -493,9 +496,10 @@ class AdjointGenerator { remappedValues.map(yieldedQureg.value(), iterArgs[0]); generateImpl(whileOp.getAfter(), builder); - builder.create( - loc, remappedValues.lookup( - getQuantumReg(whileOp.getAfter().front().getArguments()).value())); + scf::YieldOp::create( + builder, loc, + remappedValues.lookup( + getQuantumReg(whileOp.getAfter().front().getArguments()).value())); }); remappedValues.map(getQuantumReg(whileOp.getInits()).value(), replacedWhile.getResult(0)); } diff --git a/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp index 9498f306e8..6155a89a8c 100644 --- a/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp @@ -65,8 +65,8 @@ struct QubitUnitaryOpInterface Location loc = op->getLoc(); auto tensorType = cast(qubitUnitaryOp.getMatrix().getType()); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - auto toBufferOp = - rewriter.create(loc, memrefType, qubitUnitaryOp.getMatrix()); + auto toBufferOp = bufferization::ToBufferOp::create(rewriter, loc, memrefType, + qubitUnitaryOp.getMatrix()); auto memref = toBufferOp.getResult(); bufferization::replaceOpWithNewBufferizedOp( rewriter, op, qubitUnitaryOp.getOutQubits().getTypes(), @@ -110,10 +110,10 @@ struct HermitianOpInterface auto tensorType = cast(hermitianOp.getMatrix().getType()); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); auto toBufferOp = - rewriter.create(loc, memrefType, hermitianOp.getMatrix()); + bufferization::ToBufferOp::create(rewriter, loc, memrefType, hermitianOp.getMatrix()); auto memref = toBufferOp.getResult(); - auto newHermitianOp = rewriter.create(loc, hermitianOp.getType(), memref, - hermitianOp.getQubits()); + auto newHermitianOp = HermitianOp::create(rewriter, loc, hermitianOp.getType(), memref, + hermitianOp.getQubits()); bufferization::replaceOpWithBufferizedValues(rewriter, op, newHermitianOp.getObs()); return success(); @@ -153,10 +153,10 @@ struct HamiltonianOpInterface auto tensorType = cast(hamiltonianOp.getCoeffs().getType()); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); auto toBufferOp = - rewriter.create(loc, memrefType, hamiltonianOp.getCoeffs()); + bufferization::ToBufferOp::create(rewriter, loc, memrefType, hamiltonianOp.getCoeffs()); auto memref = toBufferOp.getResult(); - auto newHamiltonianOp = rewriter.create(loc, hamiltonianOp.getType(), memref, - hamiltonianOp.getTerms()); + auto newHamiltonianOp = HamiltonianOp::create(rewriter, loc, hamiltonianOp.getType(), + memref, hamiltonianOp.getTerms()); bufferization::replaceOpWithBufferizedValues(rewriter, op, newHamiltonianOp.getObs()); return success(); @@ -201,13 +201,13 @@ struct SampleOpInterface SmallVector allocSizes; for (Value dynShapeDimension : sampleOp.getDynamicShape()) { auto indexCastOp = - rewriter.create(loc, rewriter.getIndexType(), dynShapeDimension); + index::CastSOp::create(rewriter, loc, rewriter.getIndexType(), dynShapeDimension); allocSizes.push_back(indexCastOp); } - Value allocVal = rewriter.create(loc, resultType, allocSizes); - auto allocedSampleOp = rewriter.create( - loc, TypeRange{}, ValueRange{sampleOp.getObs(), allocVal}, op->getAttrs()); + Value allocVal = memref::AllocOp::create(rewriter, loc, resultType, allocSizes); + auto allocedSampleOp = SampleOp::create( + rewriter, loc, TypeRange{}, ValueRange{sampleOp.getObs(), allocVal}, op->getAttrs()); allocedSampleOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1})); bufferization::replaceOpWithBufferizedValues(rewriter, op, allocVal); return success(); @@ -256,19 +256,19 @@ struct CountsOpInterface Value allocVal; if (shape[0] == ShapedType::kDynamic) { - auto indexCastOp = rewriter.create(loc, rewriter.getIndexType(), - countsOp.getDynamicShape()); + auto indexCastOp = index::CastSOp::create(rewriter, loc, rewriter.getIndexType(), + countsOp.getDynamicShape()); allocVal = - rewriter.create(loc, resultType, ValueRange{indexCastOp}); + memref::AllocOp::create(rewriter, loc, resultType, ValueRange{indexCastOp}); } else { - allocVal = rewriter.create(loc, resultType); + allocVal = memref::AllocOp::create(rewriter, loc, resultType); } buffers.push_back(allocVal); } - rewriter.create(loc, nullptr, nullptr, countsOp.getObs(), nullptr, buffers[0], - buffers[1]); + CountsOp::create(rewriter, loc, nullptr, nullptr, countsOp.getObs(), nullptr, buffers[0], + buffers[1]); bufferization::replaceOpWithBufferizedValues(rewriter, op, buffers); return success(); @@ -313,16 +313,16 @@ struct ProbsOpInterface Value buffer; auto shape = cast(tensorType).getShape(); if (shape[0] == ShapedType::kDynamic) { - auto indexCastOp = rewriter.create(loc, rewriter.getIndexType(), - probsOp.getDynamicShape()); - buffer = rewriter.create(loc, resultType, ValueRange{indexCastOp}); + auto indexCastOp = index::CastSOp::create(rewriter, loc, rewriter.getIndexType(), + probsOp.getDynamicShape()); + buffer = memref::AllocOp::create(rewriter, loc, resultType, ValueRange{indexCastOp}); } else { - buffer = rewriter.create(loc, resultType); + buffer = memref::AllocOp::create(rewriter, loc, resultType); } auto allocedProbsOp = - rewriter.create(loc, TypeRange{}, ValueRange{probsOp.getObs(), buffer}); + ProbsOp::create(rewriter, loc, TypeRange{}, ValueRange{probsOp.getObs(), buffer}); allocedProbsOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1})); bufferization::replaceOpWithBufferizedValues(rewriter, op, buffer); return success(); @@ -367,16 +367,16 @@ struct StateOpInterface Value buffer; auto shape = cast(tensorType).getShape(); if (shape[0] == ShapedType::kDynamic) { - auto indexCastOp = rewriter.create(loc, rewriter.getIndexType(), - stateOp.getDynamicShape()); - buffer = rewriter.create(loc, resultType, ValueRange{indexCastOp}); + auto indexCastOp = index::CastSOp::create(rewriter, loc, rewriter.getIndexType(), + stateOp.getDynamicShape()); + buffer = memref::AllocOp::create(rewriter, loc, resultType, ValueRange{indexCastOp}); } else { - buffer = rewriter.create(loc, resultType); + buffer = memref::AllocOp::create(rewriter, loc, resultType); } auto allocedStateOp = - rewriter.create(loc, TypeRange{}, ValueRange{stateOp.getObs(), buffer}); + StateOp::create(rewriter, loc, TypeRange{}, ValueRange{stateOp.getObs(), buffer}); allocedStateOp->setAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 1})); bufferization::replaceOpWithBufferizedValues(rewriter, op, buffer); return success(); @@ -417,10 +417,10 @@ struct SetStateOpInterface MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); auto toBufferOp = - rewriter.create(loc, memrefType, setStateOp.getInState()); + bufferization::ToBufferOp::create(rewriter, loc, memrefType, setStateOp.getInState()); auto memref = toBufferOp.getResult(); - auto newSetStateOp = rewriter.create(loc, setStateOp.getOutQubits().getTypes(), - memref, setStateOp.getInQubits()); + auto newSetStateOp = SetStateOp::create(rewriter, loc, setStateOp.getOutQubits().getTypes(), + memref, setStateOp.getInQubits()); bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits()); return success(); } @@ -459,11 +459,12 @@ struct SetBasisStateOpInterface auto tensorType = cast(setBasisStateOp.getBasisState().getType()); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - auto toBufferOp = rewriter.create( - loc, memrefType, setBasisStateOp.getBasisState()); + auto toBufferOp = bufferization::ToBufferOp::create(rewriter, loc, memrefType, + setBasisStateOp.getBasisState()); auto memref = toBufferOp.getResult(); - auto newSetStateOp = rewriter.create( - loc, setBasisStateOp.getOutQubits().getTypes(), memref, setBasisStateOp.getInQubits()); + auto newSetStateOp = + SetBasisStateOp::create(rewriter, loc, setBasisStateOp.getOutQubits().getTypes(), + memref, setBasisStateOp.getInQubits()); bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits()); return success(); } diff --git a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp index b297445cad..4b4708462d 100644 --- a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp @@ -41,13 +41,12 @@ Value getGlobalString(Location loc, OpBuilder &rewriter, StringRef key, StringRe if (!glb) { OpBuilder::InsertionGuard guard(rewriter); // to reset the insertion point rewriter.setInsertionPointToStart(mod.getBody()); - glb = rewriter.create(loc, type, true, LLVM::Linkage::Internal, key, - rewriter.getStringAttr(value)); + glb = LLVM::GlobalOp::create(rewriter, loc, type, true, LLVM::Linkage::Internal, key, + rewriter.getStringAttr(value)); } - return rewriter.create(loc, LLVM::LLVMPointerType::get(rewriter.getContext()), - type, rewriter.create(loc, glb), - ArrayRef{0, 0}, - LLVM::GEPNoWrapFlags::inbounds); + return LLVM::GEPOp::create(rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), + type, LLVM::AddressOfOp::create(rewriter, loc, glb), + ArrayRef{0, 0}, LLVM::GEPNoWrapFlags::inbounds); } /** @@ -73,27 +72,27 @@ Value getModifiersPtr(Location loc, RewriterBase &rewriter, const TypeConverter auto ptrType = LLVM::LLVMPointerType::get(ctx); - Value nullPtr = rewriter.create(loc, ptrType); + Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, ptrType); if (!adjoint && controlledQubits.empty() && controlledValues.empty()) { return nullPtr; } - auto adjointVal = rewriter.create(loc, rewriter.getBoolAttr(adjoint)); + auto adjointVal = LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(adjoint)); auto structType = LLVM::LLVMStructType::getLiteral(ctx, {boolType, sizeType, ptrType, ptrType}); auto modifiersPtr = catalyst::getStaticAlloca(loc, rewriter, structType, 1).getResult(); - auto adjointPtr = rewriter.create(loc, ptrType, structType, modifiersPtr, - llvm::ArrayRef{0, 0}, - LLVM::GEPNoWrapFlags::inbounds); - auto numControlledPtr = rewriter.create(loc, ptrType, structType, modifiersPtr, - llvm::ArrayRef{0, 1}, - LLVM::GEPNoWrapFlags::inbounds); - auto controlledWiresPtr = rewriter.create(loc, ptrType, structType, modifiersPtr, - llvm::ArrayRef{0, 2}, - LLVM::GEPNoWrapFlags::inbounds); - auto controlledValuesPtr = rewriter.create(loc, ptrType, structType, modifiersPtr, - llvm::ArrayRef{0, 3}, - LLVM::GEPNoWrapFlags::inbounds); + auto adjointPtr = + LLVM::GEPOp::create(rewriter, loc, ptrType, structType, modifiersPtr, + llvm::ArrayRef{0, 0}, LLVM::GEPNoWrapFlags::inbounds); + auto numControlledPtr = + LLVM::GEPOp::create(rewriter, loc, ptrType, structType, modifiersPtr, + llvm::ArrayRef{0, 1}, LLVM::GEPNoWrapFlags::inbounds); + auto controlledWiresPtr = + LLVM::GEPOp::create(rewriter, loc, ptrType, structType, modifiersPtr, + llvm::ArrayRef{0, 2}, LLVM::GEPNoWrapFlags::inbounds); + auto controlledValuesPtr = + LLVM::GEPOp::create(rewriter, loc, ptrType, structType, modifiersPtr, + llvm::ArrayRef{0, 3}, LLVM::GEPNoWrapFlags::inbounds); Value ctrlPtr = nullPtr; Value valuePtr = nullPtr; @@ -104,28 +103,28 @@ Value getModifiersPtr(Location loc, RewriterBase &rewriter, const TypeConverter catalyst::getStaticAlloca(loc, rewriter, boolType, controlledQubits.size()).getResult(); for (int i = 0; static_cast(i) < controlledQubits.size(); i++) { { - auto itemPtr = rewriter.create(loc, ptrType, ptrType, ctrlPtr, - llvm::ArrayRef{i}, - LLVM::GEPNoWrapFlags::inbounds); + auto itemPtr = LLVM::GEPOp::create(rewriter, loc, ptrType, ptrType, ctrlPtr, + llvm::ArrayRef{i}, + LLVM::GEPNoWrapFlags::inbounds); auto qubit = controlledQubits[i]; - rewriter.create(loc, qubit, itemPtr); + LLVM::StoreOp::create(rewriter, loc, qubit, itemPtr); } { - auto itemPtr = rewriter.create(loc, ptrType, boolType, valuePtr, - llvm::ArrayRef{i}, - LLVM::GEPNoWrapFlags::inbounds); + auto itemPtr = LLVM::GEPOp::create(rewriter, loc, ptrType, boolType, valuePtr, + llvm::ArrayRef{i}, + LLVM::GEPNoWrapFlags::inbounds); auto value = controlledValues[i]; - rewriter.create(loc, value, itemPtr); + LLVM::StoreOp::create(rewriter, loc, value, itemPtr); } } } - rewriter.create(loc, adjointVal, adjointPtr); - auto ctrlQubits = - rewriter.create(loc, rewriter.getI64IntegerAttr(controlledQubits.size())); - rewriter.create(loc, ctrlQubits, numControlledPtr); - rewriter.create(loc, ctrlPtr, controlledWiresPtr); - rewriter.create(loc, valuePtr, controlledValuesPtr); + LLVM::StoreOp::create(rewriter, loc, adjointVal, adjointPtr); + auto ctrlQubits = LLVM::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(controlledQubits.size())); + LLVM::StoreOp::create(rewriter, loc, ctrlQubits, numControlledPtr); + LLVM::StoreOp::create(rewriter, loc, ctrlPtr, controlledWiresPtr); + LLVM::StoreOp::create(rewriter, loc, valuePtr, controlledValuesPtr); return modifiersPtr; } @@ -155,15 +154,15 @@ template struct RTBasedPattern : public OpConversionPattern { IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint(); OpBuilder::InsertionGuard guard(rewriter); // to reset the insertion point rewriter.setInsertionPointToStart(mod.getBody()); - LLVM::GlobalOp seed_glb = rewriter.create( - loc, IntegerType::get(ctx, 32), true, LLVM::Linkage::Internal, "seed", + LLVM::GlobalOp seed_glb = LLVM::GlobalOp::create( + rewriter, loc, IntegerType::get(ctx, 32), true, LLVM::Linkage::Internal, "seed", cast(op->getAttr("seed"))); rewriter.restoreInsertionPoint(ip); - seed_val = rewriter.create(loc, seed_glb); + seed_val = LLVM::AddressOfOp::create(rewriter, loc, seed_glb); } else { // Set seed argument to nullptr for unseeded runs - seed_val = rewriter.create(loc, intPtrType); + seed_val = LLVM::ZeroOp::create(rewriter, loc, intPtrType); } LLVM::LLVMFuncOp fnDecl = catalyst::ensureFunctionDeclaration( rewriter, op, qirName, qirSignature); @@ -221,18 +220,18 @@ struct DeviceInitOpPattern : public OpConversionPattern { Value shots = op.getShots(); if (!shots) { - auto zeroShots = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + auto zeroShots = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); operands.push_back(zeroShots); } else { operands.push_back(shots); } - Value autoQubitManagement = rewriter.create(loc, rewriter.getI1Type(), - op.getAutoQubitManagement()); + Value autoQubitManagement = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), + op.getAutoQubitManagement()); operands.push_back(autoQubitManagement); - rewriter.create(loc, fnDecl, operands); + LLVM::CallOp::create(rewriter, loc, fnDecl, operands); rewriter.eraseOp(op); @@ -305,7 +304,7 @@ struct AllocOpPattern : public OpConversionPattern { Value nQubits = adaptor.getNqubits(); if (!nQubits) { - nQubits = rewriter.create(loc, op.getNqubitsAttrAttr()); + nQubits = LLVM::ConstantOp::create(rewriter, loc, op.getNqubitsAttrAttr()); } rewriter.replaceOpWithNewOp(op, fnDecl, nQubits); @@ -399,11 +398,11 @@ struct ExtractOpPattern : public OpConversionPattern { Value index = adaptor.getIdx(); if (!index) { - index = rewriter.create(loc, op.getIdxAttrAttr()); + index = LLVM::ConstantOp::create(rewriter, loc, op.getIdxAttrAttr()); } SmallVector operands = {adaptor.getQreg(), index}; - Value elemPtr = rewriter.create(loc, fnDecl, operands).getResult(); + Value elemPtr = LLVM::CallOp::create(rewriter, loc, fnDecl, operands).getResult(); rewriter.replaceOpWithNewOp(op, conv->convertType(QubitType::get(ctx)), elemPtr); @@ -444,11 +443,11 @@ struct InsertOpArrayBackedPattern : public OpConversionPattern { Value index = adaptor.getIdx(); if (!index) { - index = rewriter.create(loc, op.getIdxAttrAttr()); + index = LLVM::ConstantOp::create(rewriter, loc, op.getIdxAttrAttr()); } SmallVector operands = {adaptor.getInQreg(), index, adaptor.getQubit()}; - rewriter.create(loc, fnDecl, operands); + LLVM::CallOp::create(rewriter, loc, fnDecl, operands); SmallVector values = {adaptor.getInQreg()}; rewriter.replaceOp(op, values); @@ -505,8 +504,8 @@ struct CustomOpPattern : public OpConversionPattern { // get the number of qbuits and place the input qubits at the end of the arguments. int64_t numQubits = op.getOutQubits().size(); args.insert(args.end(), modifiersPtr); - args.insert(args.end(), rewriter.create( - loc, rewriter.getI64IntegerAttr(numQubits))); + args.insert(args.end(), LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(numQubits))); args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); } else { @@ -514,7 +513,7 @@ struct CustomOpPattern : public OpConversionPattern { args.insert(args.end(), modifiersPtr); } - rewriter.create(loc, fnDecl, args); + LLVM::CallOp::create(rewriter, loc, fnDecl, args); SmallVector values; values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); values.insert(values.end(), adaptor.getInCtrlQubits().begin(), @@ -547,7 +546,7 @@ struct GlobalPhaseOpPattern : public OpConversionPattern { args.insert(args.end(), adaptor.getParams()); args.insert(args.end(), modifiersPtr); - rewriter.create(loc, fnDecl, args); + LLVM::CallOp::create(rewriter, loc, fnDecl, args); rewriter.eraseOp(op); return success(); @@ -580,9 +579,9 @@ struct MultiRZOpPattern : public OpConversionPattern { args.insert(args.end(), adaptor.getTheta()); args.insert(args.end(), modifiersPtr); args.insert(args.end(), - rewriter.create(loc, rewriter.getI64IntegerAttr(numQubits))); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(numQubits))); args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); - rewriter.create(loc, fnDecl, args); + LLVM::CallOp::create(rewriter, loc, fnDecl, args); SmallVector values; values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); @@ -622,9 +621,9 @@ struct PCPhaseOpPattern : public OpConversionPattern { args.insert(args.end(), adaptor.getDim()); args.insert(args.end(), modifiersPtr); args.insert(args.end(), - rewriter.create(loc, rewriter.getI64IntegerAttr(numQubits))); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(numQubits))); args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); - rewriter.create(loc, fnDecl, args); + LLVM::CallOp::create(rewriter, loc, fnDecl, args); SmallVector values; values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); @@ -667,13 +666,13 @@ struct QubitUnitaryOpPattern : public OpConversionPattern { int64_t numQubits = adaptor.getInQubits().size(); SmallVector args = adaptor.getOperands(); args.insert(args.begin() + 1, - rewriter.create(loc, rewriter.getI64IntegerAttr(numQubits))); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(numQubits))); args.insert(args.begin() + 1, modifiersPtr); // Replace the memref argument (LLVM struct) with a pointer to memref. args[0] = catalyst::getStaticAlloca(loc, rewriter, matrixType, 1); - rewriter.create(loc, adaptor.getMatrix(), args[0]); + LLVM::StoreOp::create(rewriter, loc, adaptor.getMatrix(), args[0]); - rewriter.create(loc, fnDecl, args); + LLVM::CallOp::create(rewriter, loc, fnDecl, args); SmallVector values; values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); @@ -736,7 +735,7 @@ struct NamedObsOpPattern : public OpConversionPattern { auto obsTypeInt = static_cast(op.getType()); Value obsType = - rewriter.create(loc, rewriter.getI64IntegerAttr(obsTypeInt)); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(obsTypeInt)); SmallVector args = {obsType, adaptor.getQubit()}; rewriter.replaceOpWithNewOp(op, fnDecl, args); @@ -773,10 +772,10 @@ struct HermitianOpPattern : public OpConversionPattern { int64_t numQubits = op.getQubits().size(); SmallVector args = adaptor.getOperands(); args.insert(args.begin() + 1, - rewriter.create(loc, rewriter.getI64IntegerAttr(numQubits))); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(numQubits))); // Replace the memref argument (LLVM struct) with a pointer to memref. args[0] = catalyst::getStaticAlloca(loc, rewriter, matrixType, 1); - rewriter.create(loc, adaptor.getMatrix(), args[0]); + LLVM::StoreOp::create(rewriter, loc, adaptor.getMatrix(), args[0]); rewriter.replaceOpWithNewOp(op, fnDecl, args); @@ -805,7 +804,7 @@ struct TensorOpPattern : public OpConversionPattern { int64_t numTerms = op.getTerms().size(); SmallVector args = adaptor.getOperands(); args.insert(args.begin(), - rewriter.create(loc, rewriter.getI64IntegerAttr(numTerms))); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(numTerms))); rewriter.replaceOpWithNewOp(op, fnDecl, args); @@ -840,10 +839,10 @@ struct HamiltonianOpPattern : public OpConversionPattern { int64_t numTerms = op.getTerms().size(); SmallVector args = adaptor.getOperands(); args.insert(args.begin() + 1, - rewriter.create(loc, rewriter.getI64IntegerAttr(numTerms))); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(numTerms))); // Replace the memref argument (LLVM struct) with a pointer to memref. args[0] = catalyst::getStaticAlloca(loc, rewriter, vectorType, 1); - rewriter.create(loc, adaptor.getCoeffs(), args[0]); + LLVM::StoreOp::create(rewriter, loc, adaptor.getCoeffs(), args[0]); rewriter.replaceOpWithNewOp(op, fnDecl, args); @@ -878,15 +877,16 @@ struct MeasureOpPattern : public OpConversionPattern { rewriter, op, qirName, qirSignature); // Create the postselect value. If not given, it defaults to NO_POSTSELECT - LLVM::ConstantOp postselect = rewriter.create( - loc, op.getPostselect() ? op.getPostselectAttr() - : rewriter.getI32IntegerAttr(NO_POSTSELECT)); + LLVM::ConstantOp postselect = LLVM::ConstantOp::create( + rewriter, loc, + op.getPostselect() ? op.getPostselectAttr() + : rewriter.getI32IntegerAttr(NO_POSTSELECT)); // Add qubit and postselect values as arguments of the CallOp SmallVector args = {adaptor.getInQubit(), postselect}; - Value resultPtr = rewriter.create(loc, fnDecl, args).getResult(); - Value mres = rewriter.create(loc, IntegerType::get(ctx, 1), resultPtr); + Value resultPtr = LLVM::CallOp::create(rewriter, loc, fnDecl, args).getResult(); + Value mres = LLVM::LoadOp::create(rewriter, loc, IntegerType::get(ctx, 1), resultPtr); rewriter.replaceOp(op, {mres, adaptor.getInQubit()}); return success(); @@ -922,23 +922,23 @@ template class SampleBasedPattern : public OpConversionPattern { ValueRange qubits = adaptor.getObs().getDefiningOp()->getOperands(); Value numQubits = - rewriter.create(loc, rewriter.getI64IntegerAttr(qubits.size())); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(qubits.size())); SmallVector args = {structPtr, numQubits}; args.insert(args.end(), qubits.begin(), qubits.end()); if constexpr (std::is_same_v) { - rewriter.create(loc, adaptor.getInData(), structPtr); + LLVM::StoreOp::create(rewriter, loc, adaptor.getInData(), structPtr); } else if constexpr (std::is_same_v) { - auto aStruct = rewriter.create(loc, structType); - auto bStruct = - rewriter.create(loc, aStruct, adaptor.getInEigvals(), 0); - auto cStruct = - rewriter.create(loc, bStruct, adaptor.getInCounts(), 1); - rewriter.create(loc, cStruct, structPtr); + auto aStruct = LLVM::UndefOp::create(rewriter, loc, structType); + auto bStruct = LLVM::InsertValueOp::create( + rewriter, loc, aStruct, adaptor.getInEigvals(), SmallVector{0}); + auto cStruct = LLVM::InsertValueOp::create( + rewriter, loc, bStruct, adaptor.getInCounts(), SmallVector{1}); + LLVM::StoreOp::create(rewriter, loc, cStruct, structPtr); } - rewriter.create(loc, fnDecl, args); + LLVM::CallOp::create(rewriter, loc, fnDecl, args); return structPtr; }; @@ -1056,7 +1056,7 @@ template struct StateBasedPattern : public OpConversionPattern { // We need to handle the C ABI convention of passing the result memref // as a struct pointer in the first argument to the C function. Value structPtr = catalyst::getStaticAlloca(loc, rewriter, vectorType, 1); - rewriter.create(loc, adaptor.getStateIn(), structPtr); + LLVM::StoreOp::create(rewriter, loc, adaptor.getStateIn(), structPtr); // For now obtain the qubit values from an unrealized cast created by the // ComputationalBasisOp lowering. Improve this once the runtime interface changes to @@ -1067,18 +1067,19 @@ template struct StateBasedPattern : public OpConversionPattern { SmallVector args = {structPtr}; if constexpr (std::is_same_v) { Value numQubits = - rewriter.create(loc, rewriter.getI64IntegerAttr(qubits.size())); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(qubits.size())); args.push_back(numQubits); args.insert(args.end(), qubits.begin(), qubits.end()); } else { // __catalyst__qis__State does not support individual qubit measurements yet, so it must // be invoked without specific specific qubits (i.e. measure the whole register). - Value numQubits = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value numQubits = + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); args.push_back(numQubits); } - rewriter.create(loc, fnDecl, args); + LLVM::CallOp::create(rewriter, loc, fnDecl, args); rewriter.eraseOp(op); return success(); @@ -1111,15 +1112,15 @@ struct SetStateOpPattern : public OpConversionPattern { auto allocaPtr = allocaOp.getResult(); auto size = adaptor.getInQubits().size(); - Value c = rewriter.create(loc, rewriter.getI64IntegerAttr(size)); + Value c = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(size)); args.push_back(allocaPtr); args.push_back(c); args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); - rewriter.create(loc, adaptor.getInState(), allocaPtr); + LLVM::StoreOp::create(rewriter, loc, adaptor.getInState(), allocaPtr); SmallVector values; values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); - rewriter.create(loc, func, args); + LLVM::CallOp::create(rewriter, loc, func, args); rewriter.replaceOp(op, values); return success(); } @@ -1152,15 +1153,15 @@ struct SetBasisStateOpPattern : public OpConversionPattern { auto allocaPtr = allocaOp.getResult(); auto size = adaptor.getInQubits().size(); - Value c = rewriter.create(loc, rewriter.getI64IntegerAttr(size)); + Value c = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(size)); args.push_back(allocaPtr); args.push_back(c); args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); - rewriter.create(loc, adaptor.getBasisState(), allocaPtr); + LLVM::StoreOp::create(rewriter, loc, adaptor.getBasisState(), allocaPtr); SmallVector values; values.insert(values.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); - rewriter.create(loc, func, args); + LLVM::CallOp::create(rewriter, loc, func, args); rewriter.replaceOp(op, values); return success(); } @@ -1226,16 +1227,16 @@ template struct PPRotationBasedPattern : public OpConversionPattern // rotation_kind can be ±1, ±2, ±4, ±8 int16_t rotationKind = static_cast(op.getRotationKind()); double theta = 2 * (llvm::numbers::pi / static_cast(rotationKind)); - thetaValue = rewriter.create(loc, rewriter.getF64FloatAttr(theta)); + thetaValue = LLVM::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(theta)); } else if constexpr (std::is_same_v) { if (op.getCondition()) { return op.emitOpError("PPRotationArbitraryOp with condition is not supported."); } // multiply by 2 to get the rotation angle - thetaValue = rewriter.create( - loc, adaptor.getArbitraryAngle(), - rewriter.create(loc, rewriter.getF64FloatAttr(2.0))); + thetaValue = LLVM::FMulOp::create( + rewriter, loc, adaptor.getArbitraryAngle(), + LLVM::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(2.0))); } else if constexpr (std::is_same_v) { // Use the arbitrary angle directly @@ -1247,12 +1248,12 @@ template struct PPRotationBasedPattern : public OpConversionPattern SmallVector args; args.push_back(pauliWordPtr); args.push_back(thetaValue); - args.push_back(rewriter.create(loc, ptrType)); + args.push_back(LLVM::ZeroOp::create(rewriter, loc, ptrType)); args.push_back( - rewriter.create(loc, rewriter.getI64IntegerAttr(numQubits))); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(numQubits))); args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); - rewriter.create(loc, fnDecl, args); + LLVM::CallOp::create(rewriter, loc, fnDecl, args); // Replace the op with the input qubits SmallVector values; @@ -1293,20 +1294,20 @@ struct PPMeasurementOpPattern : public OpConversionPattern { SmallVector args; args.push_back(pauliWordPtr); args.push_back( - rewriter.create(loc, rewriter.getI64IntegerAttr(numQubits))); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(numQubits))); args.insert(args.end(), adaptor.getInQubits().begin(), adaptor.getInQubits().end()); // Call the function and get the result pointer - Value resultPtr = rewriter.create(loc, fnDecl, args).getResult(); + Value resultPtr = LLVM::CallOp::create(rewriter, loc, fnDecl, args).getResult(); // Load the measurement result (i1) from the result pointer - Value mres = rewriter.create(loc, IntegerType::get(ctx, 1), resultPtr); + Value mres = LLVM::LoadOp::create(rewriter, loc, IntegerType::get(ctx, 1), resultPtr); // if the uint16_t rotation_sign is -1, we need to negate the measurement result if (static_cast(op.getRotationSign()) == -1) { - Value one = rewriter.create(loc, rewriter.getI1Type(), - rewriter.getBoolAttr(true)); - mres = rewriter.create(loc, mres, one); + Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), + rewriter.getBoolAttr(true)); + mres = LLVM::XOrOp::create(rewriter, loc, mres, one); } // Replace the op with the measurement result and the input qubits diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp index 6715808aae..b641fc228c 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp @@ -184,8 +184,8 @@ class BaseSignatureAnalyzer { for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) { const QubitIndex &index = signature.inWireIndices[i]; updatedQreg = - rewriter.create(loc, updatedQreg.getType(), updatedQreg, - index.getValue(), index.getAttr(), qubit); + quantum::InsertOp::create(rewriter, loc, updatedQreg.getType(), updatedQreg, + index.getValue(), index.getAttr(), qubit); } operands[operandIdx++] = updatedQreg; @@ -253,9 +253,9 @@ class BaseSignatureAnalyzer { for (const auto &indices : {signature.outQubitIndices, signature.outCtrlQubitIndices}) { for (const auto &index : indices) { - auto extractOp = rewriter.create( - callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), - index.getAttr()); + auto extractOp = quantum::ExtractOp::create( + rewriter, callOp.getLoc(), rewriter.getType(), qreg, + index.getValue(), index.getAttr()); newResults.emplace_back(extractOp.getResult()); } } @@ -267,7 +267,7 @@ class BaseSignatureAnalyzer { Value fromTensorOrAsIs(ValueRange values, Type type, PatternRewriter &rewriter, Location loc) { if (isa(type)) { - return rewriter.create(loc, type, values); + return tensor::FromElementsOp::create(rewriter, loc, type, values); } return values.front(); } @@ -332,13 +332,13 @@ class BaseSignatureAnalyzer { } else if (index.isAttr()) { auto attr = index.getAttr(); - auto constantValue = rewriter.create(loc, attr.getType(), attr); + auto constantValue = arith::ConstantOp::create(rewriter, loc, attr.getType(), attr); values.emplace_back(constantValue); } } if (isa(type)) { - return rewriter.create(loc, type, values); + return tensor::FromElementsOp::create(rewriter, loc, type, values); } assert(values.size() == 1 && "number of values should be 1 for non-tensor type"); diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index 3d5ccc8775..0b4e2d18e0 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -75,8 +75,8 @@ struct DLCustomOpPattern : public OpRewritePattern { auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); auto callOp = - rewriter.create(op.getLoc(), decompFunc.getFunctionType().getResults(), - decompFunc.getSymName(), callOperands); + func::CallOp::create(rewriter, op.getLoc(), decompFunc.getFunctionType().getResults(), + decompFunc.getSymName(), callOperands); // Replace the op with the call op and adjust the insert ops for the qreg mode if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) { @@ -149,8 +149,8 @@ struct DLMultiRZOpPattern : public OpRewritePattern { auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); auto callOp = - rewriter.create(op.getLoc(), decompFunc.getFunctionType().getResults(), - decompFunc.getSymName(), callOperands); + func::CallOp::create(rewriter, op.getLoc(), decompFunc.getFunctionType().getResults(), + decompFunc.getSymName(), callOperands); // Replace the op with the call op and adjust the insert ops for the qreg mode if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) { diff --git a/mlir/lib/Quantum/Transforms/DisentangleCNOT.cpp b/mlir/lib/Quantum/Transforms/DisentangleCNOT.cpp index f0b9faa0f3..d3bba91c3b 100644 --- a/mlir/lib/Quantum/Transforms/DisentangleCNOT.cpp +++ b/mlir/lib/Quantum/Transforms/DisentangleCNOT.cpp @@ -82,8 +82,8 @@ void disentangleCNOTs(FunctionOpInterface &func, bool verbose) else { builder.setInsertionPoint(op); quantum::CustomOp xgate = - builder.create(loc, /*gate_name=*/"PauliX", - /*in_qubits=*/mlir::ValueRange({targetIn})); + quantum::CustomOp::create(builder, loc, /*gate_name=*/"PauliX", + /*in_qubits=*/mlir::ValueRange({targetIn})); builder.replaceAllUsesWith(targetOut, xgate->getResult(0)); builder.eraseOp(op); return; @@ -111,8 +111,8 @@ void disentangleCNOTs(FunctionOpInterface &func, bool verbose) else { builder.setInsertionPoint(op); quantum::CustomOp zgate = - builder.create(loc, /*gate_name=*/"PauliZ", - /*in_qubits=*/mlir::ValueRange({controlIn})); + quantum::CustomOp::create(builder, loc, /*gate_name=*/"PauliZ", + /*in_qubits=*/mlir::ValueRange({controlIn})); builder.replaceAllUsesWith(controlOut, zgate->getResult(0)); builder.eraseOp(op); return; diff --git a/mlir/lib/Quantum/Transforms/DisentangleSWAP.cpp b/mlir/lib/Quantum/Transforms/DisentangleSWAP.cpp index 220e66076b..b018d62148 100644 --- a/mlir/lib/Quantum/Transforms/DisentangleSWAP.cpp +++ b/mlir/lib/Quantum/Transforms/DisentangleSWAP.cpp @@ -52,15 +52,15 @@ struct DisentangleSWAPPass : public impl::DisentangleSWAPPassBase(loc, - /*out_qubits=*/mlir::TypeRange({outQubit.getType()}), - /*out_ctrl_qubits=*/mlir::TypeRange(), - /*params=*/mlir::ValueRange(), - /*in_qubits=*/mlir::ValueRange({inQubit}), - /*gate_name=*/gateName, - /*adjoint=*/false, - /*in_ctrl_qubits=*/mlir::ValueRange(), - /*in_ctrl_values=*/mlir::ValueRange()); + quantum::CustomOp::create(builder, loc, + /*out_qubits=*/mlir::TypeRange({outQubit.getType()}), + /*out_ctrl_qubits=*/mlir::TypeRange(), + /*params=*/mlir::ValueRange(), + /*in_qubits=*/mlir::ValueRange({inQubit}), + /*gate_name=*/gateName, + /*adjoint=*/false, + /*in_ctrl_qubits=*/mlir::ValueRange(), + /*in_ctrl_values=*/mlir::ValueRange()); return newGate; } @@ -74,15 +74,15 @@ struct DisentangleSWAPPass : public impl::DisentangleSWAPPassBase(loc, - /*out_qubits=*/mlir::TypeRange({inQubit.getType()}), - /*out_ctrl_qubits=*/mlir::TypeRange(), - /*params=*/mlir::ValueRange(), - /*in_qubits=*/mlir::ValueRange({inQubit}), - /*gate_name=*/gateName, - /*adjoint=*/false, - /*in_ctrl_qubits=*/mlir::ValueRange(), - /*in_ctrl_values=*/mlir::ValueRange()); + quantum::CustomOp::create(builder, loc, + /*out_qubits=*/mlir::TypeRange({inQubit.getType()}), + /*out_ctrl_qubits=*/mlir::TypeRange(), + /*params=*/mlir::ValueRange(), + /*in_qubits=*/mlir::ValueRange({inQubit}), + /*gate_name=*/gateName, + /*adjoint=*/false, + /*in_ctrl_qubits=*/mlir::ValueRange(), + /*in_ctrl_values=*/mlir::ValueRange()); return newGate; } @@ -97,8 +97,8 @@ struct DisentangleSWAPPass : public impl::DisentangleSWAPPassBase( - loc, + quantum::CustomOp newGate = quantum::CustomOp::create( + builder, loc, /*out_qubits=*/mlir::TypeRange({controlOut.getType(), targetOut.getType()}), /*out_ctrl_qubits=*/mlir::TypeRange({}), /*params=*/mlir::ValueRange(), @@ -120,8 +120,8 @@ struct DisentangleSWAPPass : public impl::DisentangleSWAPPassBase( - loc, + quantum::CustomOp newGate = quantum::CustomOp::create( + builder, loc, /*out_qubits=*/mlir::TypeRange({controlIn.getType(), targetIn.getType()}), /*out_ctrl_qubits=*/mlir::TypeRange({}), /*params=*/mlir::ValueRange(), diff --git a/mlir/lib/Quantum/Transforms/GridsynthPatterns.cpp b/mlir/lib/Quantum/Transforms/GridsynthPatterns.cpp index ea8eb22530..88353fe84d 100644 --- a/mlir/lib/Quantum/Transforms/GridsynthPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/GridsynthPatterns.cpp @@ -63,7 +63,7 @@ Value createGateChain(PatternRewriter &rewriter, Location loc, Value qbitIn, rewriter.getNamedAttr("resultSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0}))); auto newOp = - rewriter.create(loc, qbitType, ValueRange{currentQbit}, newAttrs.getAttrs()); + CustomOp::create(rewriter, loc, qbitType, ValueRange{currentQbit}, newAttrs.getAttrs()); currentQbit = newOp.getResult(0); } @@ -110,7 +110,7 @@ void populateCliffordTSwitchCases(PatternRewriter &rewriter, Location loc, // Pass the qubit through the chain Value qbitOut = createGateChain(rewriter, loc, qbitIn, config.first, config.second); - rewriter.create(loc, qbitOut); + scf::YieldOp::create(rewriter, loc, qbitOut); } // Populate Default Case @@ -119,7 +119,7 @@ void populateCliffordTSwitchCases(PatternRewriter &rewriter, Location loc, rewriter.setInsertionPointToStart(&defaultRegion.front()); static StringRef gatesDefault[] = {"Identity"}; Value qbitDefault = createGateChain(rewriter, loc, qbitIn, gatesDefault, /*isAdjoint=*/false); - rewriter.create(loc, qbitDefault); + scf::YieldOp::create(rewriter, loc, qbitDefault); } /** @@ -141,8 +141,8 @@ void populatePPRBasisSwitchCases(PatternRewriter &rewriter, Location loc, // We need to cast back to uint16_t for the C++ builder signature uint16_t finalRotationArg = static_cast(signedRotation); - auto pprOp = builder.create(loc, pauliWord, finalRotationArg, - ValueRange{currentQbit}, nullptr); + auto pprOp = catalyst::qec::PPRotationOp::create(builder, loc, pauliWord, finalRotationArg, + ValueRange{currentQbit}, nullptr); return pprOp->getResult(0); }; @@ -187,14 +187,14 @@ void populatePPRBasisSwitchCases(PatternRewriter &rewriter, Location loc, : createPPROp(rewriter, config.pauli, config.n, config.isAdjoint, qbitIn); - rewriter.create(loc, qbitOut); + scf::YieldOp::create(rewriter, loc, qbitOut); } // Default Case Region &defaultRegion = switchOp.getDefaultRegion(); defaultRegion.push_back(new Block()); rewriter.setInsertionPointToStart(&defaultRegion.front()); - rewriter.create(loc, qbitIn); + scf::YieldOp::create(rewriter, loc, qbitIn); } struct DecompositionExternalFuncs { @@ -235,12 +235,12 @@ Value buildDecompositionLoop(PatternRewriter &rewriter, Location loc, Value qbit Value gatesMemref, Value numGates, double epsilon, bool pprBasis) { auto qbitType = QubitType::get(rewriter.getContext()); - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); + Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); // Create the scf.for loop over gate sequence indices // The loop carries the Qubit as an argument - auto forOp = rewriter.create(loc, c0, numGates, c1, ValueRange{qbitIn}); + auto forOp = scf::ForOp::create(rewriter, loc, c0, numGates, c1, ValueRange{qbitIn}); // Add attribute to the for op to indicate the estimated iterations of the loop auto estimatedRanges = static_cast(std::ceil(10 * std::log2(1 / epsilon))); @@ -253,7 +253,7 @@ Value buildDecompositionLoop(PatternRewriter &rewriter, Location loc, Value qbit Value iv = forOp.getInductionVar(); Value currentQbit = forOp.getRegionIterArg(0); - Value currentGateIndex = rewriter.create(loc, gatesMemref, ValueRange{iv}); + Value currentGateIndex = memref::LoadOp::create(rewriter, loc, gatesMemref, ValueRange{iv}); // 19 cases for PPR basis: Identity + (X, Y, Z) x (2, 4, 8) x (normal, adjoint) // 10 cases for Clifford+T basis: {T, H T, S H T, I, X, Y, Z, H, S, adjS} @@ -266,8 +266,9 @@ Value buildDecompositionLoop(PatternRewriter &rewriter, Location loc, Value qbit // Create the switch operation inside the loop DenseI64ArrayAttr caseValuesAttr = rewriter.getDenseI64ArrayAttr(caseValues); - auto switchOp = rewriter.create( - loc, TypeRange{qbitType}, currentGateIndex, caseValuesAttr, caseValues.size()); + auto switchOp = + scf::IndexSwitchOp::create(rewriter, loc, TypeRange{qbitType}, currentGateIndex, + caseValuesAttr, caseValues.size()); // Populate Switch Cases if (pprBasis) { @@ -279,7 +280,7 @@ Value buildDecompositionLoop(PatternRewriter &rewriter, Location loc, Value qbit // Yield the result of the switch op from the for loop rewriter.setInsertionPointAfter(switchOp); - rewriter.create(loc, switchOp->getResults()); + scf::YieldOp::create(rewriter, loc, switchOp->getResults()); } // Return the result of the loop (the final qubit state) @@ -314,7 +315,7 @@ func::FuncOp getOrCreateDecompositionFunc(ModuleOp module, PatternRewriter &rewr OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - func = rewriter.create(module.getLoc(), funcName, funcType); + func = func::FuncOp::create(rewriter, module.getLoc(), funcName, funcType); func.setPrivate(); // Get or declare external functions (GetSize, GetGates, GetPhase) @@ -330,25 +331,25 @@ func::FuncOp getOrCreateDecompositionFunc(ModuleOp module, PatternRewriter &rewr Value angle = entryBlock->getArgument(1); // Parameters for compilation - Value epsilonVal = rewriter.create(loc, rewriter.getF64FloatAttr(epsilon)); - Value pprBasisVal = rewriter.create(loc, rewriter.getBoolAttr(pprBasis)); + Value epsilonVal = arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(epsilon)); + Value pprBasisVal = arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(pprBasis)); // Call GetSize - auto callGetSizeOp = rewriter.create(loc, extFuncs.getSize, - ValueRange{angle, epsilonVal, pprBasisVal}); + auto callGetSizeOp = func::CallOp::create(rewriter, loc, extFuncs.getSize, + ValueRange{angle, epsilonVal, pprBasisVal}); Value num_gates = callGetSizeOp->getResult(0); // Call GetGates // Use memref.alloc (Heap) instead of alloca (Stack) because num_gates is dynamic. auto gatesMemRefType = MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); - Value gatesMemref = rewriter.create(loc, gatesMemRefType, num_gates); + Value gatesMemref = memref::AllocOp::create(rewriter, loc, gatesMemRefType, num_gates); - rewriter.create(loc, extFuncs.getGates, - ValueRange{gatesMemref, angle, epsilonVal, pprBasisVal}); + func::CallOp::create(rewriter, loc, extFuncs.getGates, + ValueRange{gatesMemref, angle, epsilonVal, pprBasisVal}); // Call GetPhase - auto callGetPhaseOp = rewriter.create(loc, extFuncs.getPhase, - ValueRange{angle, epsilonVal, pprBasisVal}); + auto callGetPhaseOp = func::CallOp::create(rewriter, loc, extFuncs.getPhase, + ValueRange{angle, epsilonVal, pprBasisVal}); Value runtimePhase = callGetPhaseOp->getResult(0); // Build the Loop logic @@ -356,10 +357,10 @@ func::FuncOp getOrCreateDecompositionFunc(ModuleOp module, PatternRewriter &rewr buildDecompositionLoop(rewriter, loc, qbitIn, gatesMemref, num_gates, epsilon, pprBasis); // Clean up heap memory - rewriter.create(loc, gatesMemref); + memref::DeallocOp::create(rewriter, loc, gatesMemref); // Return the final qubit and the computed runtime phase - rewriter.create(loc, ValueRange{finalQbit, runtimePhase}); + func::ReturnOp::create(rewriter, loc, ValueRange{finalQbit, runtimePhase}); return func; } @@ -400,7 +401,7 @@ struct DecomposeCustomOpPattern : public OpRewritePattern { // Call the function using the qubit directly auto callDecompOp = - rewriter.create(loc, decompFunc, ValueRange{qbitOperand, angle}); + func::CallOp::create(rewriter, loc, decompFunc, ValueRange{qbitOperand, angle}); Value finalQbitResult = callDecompOp->getResult(0); Value runtimePhase = callDecompOp.getResult(1); @@ -409,9 +410,9 @@ struct DecomposeCustomOpPattern : public OpRewritePattern { Value finalPhase; if (isPhaseShift) { // PhaseShift(phi) = RZ(phi) * GlobalPhase(-phi/2) - Value c2 = rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); - Value halfAngle = rewriter.create(loc, angle, c2); - finalPhase = rewriter.create(loc, runtimePhase, halfAngle); + Value c2 = arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(2.0)); + Value halfAngle = arith::DivFOp::create(rewriter, loc, angle, c2); + finalPhase = arith::SubFOp::create(rewriter, loc, runtimePhase, halfAngle); } else { finalPhase = runtimePhase; @@ -420,8 +421,8 @@ struct DecomposeCustomOpPattern : public OpRewritePattern { NamedAttrList gphaseAttrs; gphaseAttrs.append( rewriter.getNamedAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 0}))); - rewriter.create(loc, TypeRange{}, ValueRange{finalPhase}, - gphaseAttrs.getAttrs()); + GlobalPhaseOp::create(rewriter, loc, TypeRange{}, ValueRange{finalPhase}, + gphaseAttrs.getAttrs()); // Replace the RZ/PhaseShift op with the resulting qubit rewriter.replaceOp(op, finalQbitResult); @@ -464,13 +465,13 @@ struct DecomposePPRArbitraryOpPattern // PPR(theta, Z) = exp(-i * theta * Z) // RZ(phi) = exp(-i * phi/2 * Z) // phi = 2 * theta - Value cMinus2 = rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); - Value rzAngle = rewriter.create(loc, angle, cMinus2); + Value cMinus2 = arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(2.0)); + Value rzAngle = arith::MulFOp::create(rewriter, loc, angle, cMinus2); func::FuncOp decompFunc = getOrCreateDecompositionFunc(mod, rewriter, epsilon, pprBasis); auto callDecompOp = - rewriter.create(loc, decompFunc, ValueRange{qbitOperand, rzAngle}); + func::CallOp::create(rewriter, loc, decompFunc, ValueRange{qbitOperand, rzAngle}); Value finalQbitResult = callDecompOp->getResult(0); Value runtimePhase = callDecompOp.getResult(1); @@ -478,8 +479,8 @@ struct DecomposePPRArbitraryOpPattern NamedAttrList gphaseAttrs; gphaseAttrs.append( rewriter.getNamedAttr("operandSegmentSizes", rewriter.getDenseI32ArrayAttr({1, 0, 0}))); - rewriter.create(loc, TypeRange{}, ValueRange{runtimePhase}, - gphaseAttrs.getAttrs()); + GlobalPhaseOp::create(rewriter, loc, TypeRange{}, ValueRange{runtimePhase}, + gphaseAttrs.getAttrs()); rewriter.replaceOp(op, finalQbitResult); return success(); diff --git a/mlir/lib/Quantum/Transforms/IonsDecompositionPatterns.cpp b/mlir/lib/Quantum/Transforms/IonsDecompositionPatterns.cpp index 96efc974d9..613a417282 100644 --- a/mlir/lib/Quantum/Transforms/IonsDecompositionPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/IonsDecompositionPatterns.cpp @@ -38,7 +38,7 @@ void oneQubitDecomp(catalyst::quantum::CustomOp op, mlir::PatternRewriter &rewri ValueRange inQubits = op.getInQubits(); TypedAttr phiAttr = rewriter.getF64FloatAttr(phi); - mlir::Value phiValue = rewriter.create(op.getLoc(), phiAttr); + mlir::Value phiValue = arith::ConstantOp::create(rewriter, op.getLoc(), phiAttr); mlir::Value thetaValue; if (std::holds_alternative(theta)) { @@ -46,19 +46,19 @@ void oneQubitDecomp(catalyst::quantum::CustomOp op, mlir::PatternRewriter &rewri } else if (std::holds_alternative(theta)) { TypedAttr thetaAttr = rewriter.getF64FloatAttr(std::get(theta)); - thetaValue = rewriter.create(op.getLoc(), thetaAttr); + thetaValue = arith::ConstantOp::create(rewriter, op.getLoc(), thetaAttr); } TypedAttr lambdaAttr = rewriter.getF64FloatAttr(lambda); - mlir::Value lambdaValue = rewriter.create(op.getLoc(), lambdaAttr); + mlir::Value lambdaValue = arith::ConstantOp::create(rewriter, op.getLoc(), lambdaAttr); - auto rxPhi = rewriter.create(op.getLoc(), outQubitsTypes, TypeRange{}, phiValue, - inQubits, "RX", false, ValueRange{}, ValueRange{}); - auto ryTheta = rewriter.create(op.getLoc(), outQubitsTypes, TypeRange{}, thetaValue, - rxPhi.getOutQubits(), "RY", false, - rxPhi.getInCtrlQubits(), rxPhi.getInCtrlValues()); + auto rxPhi = CustomOp::create(rewriter, op.getLoc(), outQubitsTypes, TypeRange{}, phiValue, + inQubits, "RX", false, ValueRange{}, ValueRange{}); + auto ryTheta = CustomOp::create(rewriter, op.getLoc(), outQubitsTypes, TypeRange{}, thetaValue, + rxPhi.getOutQubits(), "RY", false, rxPhi.getInCtrlQubits(), + rxPhi.getInCtrlValues()); auto rxLambda = - rewriter.create(op.getLoc(), outQubitsTypes, TypeRange{}, lambdaValue, - ryTheta.getOutQubits(), "RX", false, ValueRange{}, ValueRange{}); + CustomOp::create(rewriter, op.getLoc(), outQubitsTypes, TypeRange{}, lambdaValue, + ryTheta.getOutQubits(), "RX", false, ValueRange{}, ValueRange{}); op.replaceAllUsesWith(rxLambda); } @@ -110,31 +110,30 @@ void cnotDecomp(catalyst::quantum::CustomOp op, mlir::PatternRewriter &rewriter) mlir::Value inQubit1 = op.getInQubits().back(); TypedAttr piOver2Attr = rewriter.getF64FloatAttr(PI / 2); - mlir::Value piOver2 = rewriter.create(op.getLoc(), piOver2Attr); - auto ryPiOver2 = - rewriter.create(op.getLoc(), outQubitsTypes.front(), TypeRange{}, piOver2, - inQubit0, "RY", false, ValueRange{}, ValueRange{}); + mlir::Value piOver2 = arith::ConstantOp::create(rewriter, op.getLoc(), piOver2Attr); + auto ryPiOver2 = CustomOp::create(rewriter, op.getLoc(), outQubitsTypes.front(), TypeRange{}, + piOver2, inQubit0, "RY", false, ValueRange{}, ValueRange{}); SmallVector qubitsAfterRy; qubitsAfterRy.push_back(ryPiOver2.getOutQubits().front()); qubitsAfterRy.push_back(inQubit1); - auto ms = rewriter.create(op.getLoc(), outQubitsTypes, TypeRange{}, piOver2, - qubitsAfterRy, "MS", false, ValueRange{}, ValueRange{}); + auto ms = CustomOp::create(rewriter, op.getLoc(), outQubitsTypes, TypeRange{}, piOver2, + qubitsAfterRy, "MS", false, ValueRange{}, ValueRange{}); mlir::Value qubit0AfterMs = ms.getOutQubits().front(); mlir::Value qubit1AfterMs = ms.getOutQubits().back(); TypedAttr minusPiOver2Attr = rewriter.getF64FloatAttr(-PI / 2); - mlir::Value minusPiOver2 = rewriter.create(op.getLoc(), minusPiOver2Attr); + mlir::Value minusPiOver2 = arith::ConstantOp::create(rewriter, op.getLoc(), minusPiOver2Attr); auto rxMinusPiOver2 = - rewriter.create(op.getLoc(), outQubitsTypes.front(), TypeRange{}, minusPiOver2, - qubit0AfterMs, "RX", false, ValueRange{}, ValueRange{}); + CustomOp::create(rewriter, op.getLoc(), outQubitsTypes.front(), TypeRange{}, minusPiOver2, + qubit0AfterMs, "RX", false, ValueRange{}, ValueRange{}); auto firstRyMinusPiOver2 = - rewriter.create(op.getLoc(), outQubitsTypes.front(), TypeRange{}, minusPiOver2, - qubit1AfterMs, "RY", false, ValueRange{}, ValueRange{}); + CustomOp::create(rewriter, op.getLoc(), outQubitsTypes.front(), TypeRange{}, minusPiOver2, + qubit1AfterMs, "RY", false, ValueRange{}, ValueRange{}); mlir::Value qubit0AfterRY = rxMinusPiOver2.getOutQubits().front(); auto secondRyMinusPiOver2 = - rewriter.create(op.getLoc(), outQubitsTypes.front(), TypeRange{}, minusPiOver2, - qubit0AfterRY, "RY", false, ValueRange{}, ValueRange{}); + CustomOp::create(rewriter, op.getLoc(), outQubitsTypes.front(), TypeRange{}, minusPiOver2, + qubit0AfterRY, "RY", false, ValueRange{}, ValueRange{}); SmallVector qubitsEnd; qubitsEnd.push_back(firstRyMinusPiOver2.getOutQubits().front()); diff --git a/mlir/lib/Quantum/Transforms/LoopBoundaryOptimizationPatterns.cpp b/mlir/lib/Quantum/Transforms/LoopBoundaryOptimizationPatterns.cpp index 564108ead2..d4c59826ea 100644 --- a/mlir/lib/Quantum/Transforms/LoopBoundaryOptimizationPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/LoopBoundaryOptimizationPatterns.cpp @@ -191,7 +191,7 @@ quantum::ExtractOp createExtractOp(Value qreg, const QubitOrigin &qubit, Pattern auto loc = qubit.qubitOrRegister.getLoc(); auto idxAttr = rewriter.getI64IntegerAttr(qubit.position); auto type = rewriter.getType(); - return rewriter.create(loc, type, qreg, nullptr, idxAttr); + return quantum::ExtractOp::create(rewriter, loc, type, qreg, nullptr, idxAttr); } // Creates a quantum.insert operation. @@ -201,7 +201,8 @@ quantum::InsertOp createInsertOp(Value qreg, const QubitOrigin &qubit, Value ele assert(element && "InsertOp requires an element value!"); auto loc = qubit.qubitOrRegister.getLoc(); auto idxAttr = rewriter.getI64IntegerAttr(qubit.position); - return rewriter.create(loc, qreg.getType(), qreg, nullptr, idxAttr, element); + return quantum::InsertOp::create(rewriter, loc, qreg.getType(), qreg, nullptr, idxAttr, + element); } // Finds the initial value of a quantum register in the for loop. @@ -585,7 +586,8 @@ void handleParams(QuantumOpInfo topEdgeOp, QuantumOpInfo bottomEdgeOp, scf::ForO // Update the param of topEdgeOp to negative value for (auto [idx, param] : llvm::enumerate(topEdgeParams)) { - Value negParam = rewriter.create(cloneTopOp.getLoc(), param).getResult(); + Value negParam = + arith::NegFOp::create(rewriter, cloneTopOp.getLoc(), param).getResult(); rewriter.moveOpBefore(negParam.getDefiningOp(), cloneTopOp); bottomEdgeOp.op.setOperand(idx, negParam); } diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index 6c5c0f1d54..fcaf668456 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -95,12 +95,12 @@ struct MergeRotationsRewritePattern : public OpRewritePattern { auto loc = op.getLoc(); SmallVector sumParams; for (auto [param, parentParam] : llvm::zip(params, parentParams)) { - Value sumParam = rewriter.create(loc, parentParam, param).getResult(); + Value sumParam = arith::AddFOp::create(rewriter, loc, parentParam, param).getResult(); sumParams.push_back(sumParam); } - auto mergeOp = rewriter.create(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams, - parentInQubits, op.getGateName(), false, - parentInCtrlQubits, parentInCtrlValues); + auto mergeOp = CustomOp::create(rewriter, loc, outQubitsTypes, outQubitsCtrlTypes, + sumParams, parentInQubits, op.getGateName(), false, + parentInCtrlQubits, parentInCtrlValues); rewriter.replaceOp(op, mergeOp); rewriter.eraseOp(parentOp); @@ -155,79 +155,83 @@ struct MergeRotationsRewritePattern : public OpRewritePattern { // 2a. if (θ1 == 0 && θ2 == 0) { ϕF = ϕ1 + ϕ2 + ω1 + ω2; θF = 0; ωF = 0; } // 2b. if (θ1 == 0) { ϕF = ϕ1 + ϕ2 + ω1; θF = θ2; ωF = ω2; } // 2c. if (θ2 == 0) { ϕF = ϕ1; θF = θ1; ωF = ω1 + ω2 + ϕ2; } - auto zeroConst = rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + auto zeroConst = arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); if (omega1IsZero && phi2IsZero) { phiF = phi1; - thetaF = rewriter.create(loc, theta1, theta2); + thetaF = arith::AddFOp::create(rewriter, loc, theta1, theta2); omegaF = omega2; } else if (theta1IsZero && theta2IsZero) { - phiF = - rewriter.create(loc, rewriter.create(loc, phi1, phi2), - rewriter.create(loc, omega1, omega2)); + phiF = arith::AddFOp::create(rewriter, loc, + arith::AddFOp::create(rewriter, loc, phi1, phi2), + arith::AddFOp::create(rewriter, loc, omega1, omega2)); thetaF = zeroConst; omegaF = zeroConst; } else if (theta1IsZero) { - phiF = rewriter.create( - loc, rewriter.create(loc, phi1, phi2), omega1); + phiF = arith::AddFOp::create(rewriter, loc, + arith::AddFOp::create(rewriter, loc, phi1, phi2), omega1); thetaF = theta2; omegaF = omega2; } else if (theta2IsZero) { phiF = phi1; thetaF = theta1; - omegaF = rewriter.create( - loc, rewriter.create(loc, omega1, omega2), phi2); + omegaF = arith::AddFOp::create( + rewriter, loc, arith::AddFOp::create(rewriter, loc, omega1, omega2), phi2); } else { - auto halfConst = rewriter.create(loc, rewriter.getF64FloatAttr(0.5)); - auto twoConst = rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + auto halfConst = + arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.5)); + auto twoConst = arith::ConstantOp::create(rewriter, loc, rewriter.getF64FloatAttr(2.0)); // α1 = (ϕ1 + ω1)/2, α2 = (ϕ2 + ω2)/2 // β1 = (ϕ1 - ω1)/2, β2 = (ϕ2 - ω2)/2 - auto alpha1 = rewriter.create( - loc, rewriter.create(loc, phi1, omega1), halfConst); - auto alpha2 = rewriter.create( - loc, rewriter.create(loc, phi2, omega2), halfConst); - auto beta1 = rewriter.create( - loc, rewriter.create(loc, phi1, omega1), halfConst); - auto beta2 = rewriter.create( - loc, rewriter.create(loc, phi2, omega2), halfConst); + auto alpha1 = arith::MulFOp::create( + rewriter, loc, arith::AddFOp::create(rewriter, loc, phi1, omega1), halfConst); + auto alpha2 = arith::MulFOp::create( + rewriter, loc, arith::AddFOp::create(rewriter, loc, phi2, omega2), halfConst); + auto beta1 = arith::MulFOp::create( + rewriter, loc, arith::SubFOp::create(rewriter, loc, phi1, omega1), halfConst); + auto beta2 = arith::MulFOp::create( + rewriter, loc, arith::SubFOp::create(rewriter, loc, phi2, omega2), halfConst); // c1 = cos(θ1/2), c2 = cos(θ2/2) // s1 = sin(θ1/2), s2 = sin(θ2/2) - auto theta1Half = rewriter.create(loc, theta1, halfConst); - auto c1 = rewriter.create(loc, theta1Half); - auto s1 = rewriter.create(loc, theta1Half); - auto theta2Half = rewriter.create(loc, theta2, halfConst); - auto c2 = rewriter.create(loc, theta2Half); - auto s2 = rewriter.create(loc, theta2Half); + auto theta1Half = arith::MulFOp::create(rewriter, loc, theta1, halfConst); + auto c1 = math::CosOp::create(rewriter, loc, theta1Half); + auto s1 = math::SinOp::create(rewriter, loc, theta1Half); + auto theta2Half = arith::MulFOp::create(rewriter, loc, theta2, halfConst); + auto c2 = math::CosOp::create(rewriter, loc, theta2Half); + auto s2 = math::SinOp::create(rewriter, loc, theta2Half); // cF = sqrt(c1^2 * c2^2 + // s1^2 * s2^2 - // 2 * c1 * c2 * s1 * s2 * cos(ω1 + ϕ2)) - auto c1TimesC2 = rewriter.create(loc, c1, c2); - auto s1TimesS2 = rewriter.create(loc, s1, s2); + auto c1TimesC2 = arith::MulFOp::create(rewriter, loc, c1, c2); + auto s1TimesS2 = arith::MulFOp::create(rewriter, loc, s1, s2); auto firstAddend = - rewriter.create(loc, rewriter.create(loc, c1, c1), - rewriter.create(loc, c2, c2)); + arith::MulFOp::create(rewriter, loc, arith::MulFOp::create(rewriter, loc, c1, c1), + arith::MulFOp::create(rewriter, loc, c2, c2)); auto secondAddend = - rewriter.create(loc, rewriter.create(loc, s1, s1), - rewriter.create(loc, s2, s2)); - auto thirdAddend = rewriter.create( - loc, rewriter.create( - loc, twoConst, - rewriter.create( - loc, c1TimesC2, - rewriter.create( - loc, s1TimesS2, - rewriter.create( - loc, rewriter.create(loc, omega1, phi2)))))); - auto cF = rewriter.create( - loc, rewriter.create( - loc, firstAddend, - rewriter.create(loc, secondAddend, thirdAddend))); + arith::MulFOp::create(rewriter, loc, arith::MulFOp::create(rewriter, loc, s1, s1), + arith::MulFOp::create(rewriter, loc, s2, s2)); + auto thirdAddend = arith::NegFOp::create( + rewriter, loc, + arith::MulFOp::create( + rewriter, loc, twoConst, + arith::MulFOp::create( + rewriter, loc, c1TimesC2, + arith::MulFOp::create( + rewriter, loc, s1TimesS2, + math::CosOp::create( + rewriter, loc, + arith::AddFOp::create(rewriter, loc, omega1, phi2)))))); + auto cF = math::SqrtOp::create( + rewriter, loc, + arith::AddFOp::create( + rewriter, loc, firstAddend, + arith::AddFOp::create(rewriter, loc, secondAddend, thirdAddend))); // TODO: can we check these problematic scenarios for differentiability by code? // Problematic scenarios for differentiability: @@ -236,62 +240,70 @@ struct MergeRotationsRewritePattern : public OpRewritePattern { // 2. if (cF == 1) { /* acos not differentiable at 1 */ return failure(); } // θF = 2 * acos(cF) - auto acosCF = rewriter.create(loc, cF); - thetaF = rewriter.create(loc, twoConst, acosCF); + auto acosCF = math::AcosOp::create(rewriter, loc, cF); + thetaF = arith::MulFOp::create(rewriter, loc, twoConst, acosCF); // αF = - atan((- c1 * c2 * sin(α1 + α2) - s1 * s2 * sin(β2 - β1)) / // ( c1 * c2 * cos(α1 + α2) - s1 * s2 * cos(β2 - β1))) - auto alpha1PlusAlpha2 = rewriter.create(loc, alpha1, alpha2); - auto beta2MinusBeta1 = rewriter.create(loc, beta2, beta1); - auto term1 = rewriter.create( - loc, rewriter.create( - loc, c1TimesC2, rewriter.create(loc, alpha1PlusAlpha2))); - auto term2 = rewriter.create( - loc, rewriter.create( - loc, s1TimesS2, rewriter.create(loc, beta2MinusBeta1))); - auto term3 = rewriter.create( - loc, c1TimesC2, rewriter.create(loc, alpha1PlusAlpha2)); - auto term4 = rewriter.create( - loc, rewriter.create( - loc, s1TimesS2, rewriter.create(loc, beta2MinusBeta1))); - auto alphaF = rewriter.create( - loc, rewriter.create( - loc, rewriter.create( - loc, rewriter.create(loc, term1, term2), - rewriter.create(loc, term3, term4)))); + auto alpha1PlusAlpha2 = arith::AddFOp::create(rewriter, loc, alpha1, alpha2); + auto beta2MinusBeta1 = arith::SubFOp::create(rewriter, loc, beta2, beta1); + auto term1 = arith::NegFOp::create( + rewriter, loc, + arith::MulFOp::create(rewriter, loc, c1TimesC2, + math::SinOp::create(rewriter, loc, alpha1PlusAlpha2))); + auto term2 = arith::NegFOp::create( + rewriter, loc, + arith::MulFOp::create(rewriter, loc, s1TimesS2, + math::SinOp::create(rewriter, loc, beta2MinusBeta1))); + auto term3 = arith::MulFOp::create( + rewriter, loc, c1TimesC2, math::CosOp::create(rewriter, loc, alpha1PlusAlpha2)); + auto term4 = arith::NegFOp::create( + rewriter, loc, + arith::MulFOp::create(rewriter, loc, s1TimesS2, + math::CosOp::create(rewriter, loc, beta2MinusBeta1))); + auto alphaF = arith::NegFOp::create( + rewriter, loc, + math::AtanOp::create( + rewriter, loc, + arith::DivFOp::create(rewriter, loc, + arith::AddFOp::create(rewriter, loc, term1, term2), + arith::AddFOp::create(rewriter, loc, term3, term4)))); // βF = - atan((- c1 * s2 * sin(α1 + β2) + s1 * c2 * sin(α2 - β1)) / // ( c1 * s2 * cos(α1 + β2) + s1 * c2 * cos(α2 - β1))) - auto c1TimesS2 = rewriter.create(loc, c1, s2); - auto s1TimesC2 = rewriter.create(loc, s1, c2); - auto alpha1PlusBeta2 = rewriter.create(loc, alpha1, beta2); - auto alpha2MinusBeta1 = rewriter.create(loc, alpha2, beta1); - auto term5 = rewriter.create( - loc, rewriter.create( - loc, c1TimesS2, rewriter.create(loc, alpha1PlusBeta2))); - auto term6 = rewriter.create( - loc, s1TimesC2, rewriter.create(loc, alpha2MinusBeta1)); - auto term7 = rewriter.create( - loc, c1TimesS2, rewriter.create(loc, alpha1PlusBeta2)); - auto term8 = rewriter.create( - loc, s1TimesC2, rewriter.create(loc, alpha2MinusBeta1)); - auto betaF = rewriter.create( - loc, rewriter.create( - loc, rewriter.create( - loc, rewriter.create(loc, term5, term6), - rewriter.create(loc, term7, term8)))); + auto c1TimesS2 = arith::MulFOp::create(rewriter, loc, c1, s2); + auto s1TimesC2 = arith::MulFOp::create(rewriter, loc, s1, c2); + auto alpha1PlusBeta2 = arith::AddFOp::create(rewriter, loc, alpha1, beta2); + auto alpha2MinusBeta1 = arith::SubFOp::create(rewriter, loc, alpha2, beta1); + auto term5 = arith::NegFOp::create( + rewriter, loc, + arith::MulFOp::create(rewriter, loc, c1TimesS2, + math::SinOp::create(rewriter, loc, alpha1PlusBeta2))); + auto term6 = arith::MulFOp::create( + rewriter, loc, s1TimesC2, math::SinOp::create(rewriter, loc, alpha2MinusBeta1)); + auto term7 = arith::MulFOp::create(rewriter, loc, c1TimesS2, + math::CosOp::create(rewriter, loc, alpha1PlusBeta2)); + auto term8 = arith::MulFOp::create( + rewriter, loc, s1TimesC2, math::CosOp::create(rewriter, loc, alpha2MinusBeta1)); + auto betaF = arith::NegFOp::create( + rewriter, loc, + math::AtanOp::create( + rewriter, loc, + arith::DivFOp::create(rewriter, loc, + arith::AddFOp::create(rewriter, loc, term5, term6), + arith::AddFOp::create(rewriter, loc, term7, term8)))); // ϕF = αF + βF - phiF = rewriter.create(loc, alphaF, betaF); + phiF = arith::AddFOp::create(rewriter, loc, alphaF, betaF); // ωF = αF - βF - omegaF = rewriter.create(loc, alphaF, betaF); + omegaF = arith::SubFOp::create(rewriter, loc, alphaF, betaF); } auto sumParams = SmallVector{phiF, thetaF, omegaF}; - auto mergeOp = rewriter.create(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams, - parentInQubits, op.getGateName(), false, - parentInCtrlQubits, parentInCtrlValues); + auto mergeOp = CustomOp::create(rewriter, loc, outQubitsTypes, outQubitsCtrlTypes, + sumParams, parentInQubits, op.getGateName(), false, + parentInCtrlQubits, parentInCtrlValues); rewriter.replaceOp(op, mergeOp); rewriter.eraseOp(parentOp); @@ -441,12 +453,11 @@ struct MergePPRRewritePattern : public OpRewritePattern { mergeOpOutQubits = parentOpInQubits; } else { - mergeOpOutQubits = rewriter - .create(loc, - /*pauli_product=*/newPauliProduct, - /*rotationKind=*/newAngle, - /*in_qubits=*/newInQubits, - /*condition=*/opCondition) + mergeOpOutQubits = PPRotationOp::create(rewriter, loc, + /*pauli_product=*/newPauliProduct, + /*rotationKind=*/newAngle, + /*in_qubits=*/newInQubits, + /*condition=*/opCondition) .getOutQubits(); } } @@ -454,14 +465,13 @@ struct MergePPRRewritePattern : public OpRewritePattern { Value opRotation = op.getArbitraryAngle(); Value parentOpRotation = parentOp.getArbitraryAngle(); auto newAngle = - rewriter.create(loc, opRotation, parentOpRotation).getResult(); - - mergeOpOutQubits = rewriter - .create(loc, - /*pauli_product=*/newPauliProduct, - /*arbitrary_angle=*/newAngle, - /*in_qubits=*/newInQubits, - /*condition=*/opCondition) + arith::AddFOp::create(rewriter, loc, opRotation, parentOpRotation).getResult(); + + mergeOpOutQubits = PPRotationArbitraryOp::create(rewriter, loc, + /*pauli_product=*/newPauliProduct, + /*arbitrary_angle=*/newAngle, + /*in_qubits=*/newInQubits, + /*condition=*/opCondition) .getOutQubits(); } @@ -500,11 +510,11 @@ struct MergeMultiRZRewritePattern : public OpRewritePattern { auto parentTheta = parentOp.getTheta(); auto theta = op.getTheta(); - Value sumParam = rewriter.create(loc, parentTheta, theta).getResult(); + Value sumParam = arith::AddFOp::create(rewriter, loc, parentTheta, theta).getResult(); - auto mergeOp = rewriter.create(loc, outQubitsTypes, outQubitsCtrlTypes, sumParam, - parentInQubits, nullptr, parentInCtrlQubits, - parentInCtrlValues); + auto mergeOp = + MultiRZOp::create(rewriter, loc, outQubitsTypes, outQubitsCtrlTypes, sumParam, + parentInQubits, nullptr, parentInCtrlQubits, parentInCtrlValues); rewriter.replaceOp(op, mergeOp); rewriter.eraseOp(parentOp); diff --git a/mlir/lib/Quantum/Transforms/SplitMultipleTapes.cpp b/mlir/lib/Quantum/Transforms/SplitMultipleTapes.cpp index e17872f2ae..633b924cd4 100644 --- a/mlir/lib/Quantum/Transforms/SplitMultipleTapes.cpp +++ b/mlir/lib/Quantum/Transforms/SplitMultipleTapes.cpp @@ -163,7 +163,7 @@ struct SplitMultipleTapesPass : public impl::SplitMultipleTapesPassBase(loc, ArrayRef(RetTypes)); + scf::ExecuteRegionOp::create(builder, loc, ArrayRef(RetTypes)); builder.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); mlir::Block::iterator it = executeRegionOp.getRegion().front().end(); @@ -172,7 +172,7 @@ struct SplitMultipleTapesPass : public impl::SplitMultipleTapesPassBase(loc, ArrayRef(RetValues)); + scf::YieldOp y = scf::YieldOp::create(builder, loc, ArrayRef(RetValues)); return std::make_pair(executeRegionOp, y); } // wrapTapeOpsInSCFRegion() diff --git a/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp b/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp index 118c2e7fbd..1aa8eeac1f 100644 --- a/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp +++ b/mlir/lib/Quantum/Transforms/cp_global_buffers.cpp @@ -94,15 +94,15 @@ Value allocCopyMemrefDyn(Location loc, Value memref, PatternRewriter &rewriter) int64_t ndim = 0; for (auto dim : memrefType.getShape()) { if (dim < 0) { - Value dynValue = rewriter.create(loc, memref, ndim); + Value dynValue = memref::DimOp::create(rewriter, loc, memref, ndim); dynDims.push_back(dynValue); } ndim++; } } - Value newMemRef = rewriter.create(loc, memrefType, dynDims); - rewriter.create(loc, memref, newMemRef); + Value newMemRef = memref::AllocOp::create(rewriter, loc, memrefType, dynDims); + memref::CopyOp::create(rewriter, loc, memref, newMemRef); return newMemRef; } @@ -122,28 +122,28 @@ void applyCopyGlobalMemRefToReturnOp(func::ReturnOp op, PatternRewriter &rewrite Type mlirIndex = rewriter.getIndexType(); Type llvmIndex = typeConverter.convertType(mlirIndex); auto deadbeefAttr = rewriter.getIntegerAttr(mlirIndex, 0xdeadbeef); - Value deadbeef = rewriter.create(op->getLoc(), llvmIndex, deadbeefAttr); + Value deadbeef = LLVM::ConstantOp::create(rewriter, op->getLoc(), llvmIndex, deadbeefAttr); for (Value memref : memrefs) { Type ty = memref.getType(); Type llvmTy = typeConverter.convertType(ty); Value llvmMemRef = - rewriter.create(op->getLoc(), llvmTy, memref).getResult(0); + UnrealizedConversionCastOp::create(rewriter, op->getLoc(), llvmTy, memref).getResult(0); - Value allocatedPtr = rewriter.create(op->getLoc(), llvmMemRef, 0); + Value allocatedPtr = LLVM::ExtractValueOp::create(rewriter, op->getLoc(), llvmMemRef, 0); Value allocatedPtrToInt = - rewriter.create(op->getLoc(), llvmIndex, allocatedPtr); - Value comparison = rewriter.create(op->getLoc(), LLVM::ICmpPredicate::eq, - deadbeef, allocatedPtrToInt); + LLVM::PtrToIntOp::create(rewriter, op->getLoc(), llvmIndex, allocatedPtr); + Value comparison = LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::eq, + deadbeef, allocatedPtrToInt); - scf::IfOp ifOp = rewriter.create( - op->getLoc(), comparison, + scf::IfOp ifOp = scf::IfOp::create( + rewriter, op->getLoc(), comparison, [&](OpBuilder &builder, Location loc) { // then Value newMemRef = allocCopyMemrefDyn(loc, memref, rewriter); - builder.create(loc, newMemRef); + scf::YieldOp::create(builder, loc, newMemRef); }, [&](OpBuilder &builder, Location loc) { // else - builder.create(loc, memref); + scf::YieldOp::create(builder, loc, memref); }); newMemRefs.push_back(ifOp.getResult(0)); diff --git a/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp b/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp index 7cf6c04a27..58d63b290c 100644 --- a/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp +++ b/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp @@ -137,9 +137,9 @@ void wrapResultsAndArgsInTwoStructs(LLVM::LLVMFuncOp op, PatternRewriter &rewrit convertFunctionTypeCatalystWrapper(rewriter, functionType, hasReturns, hasInputs); Location loc = op.getLoc(); - auto wrapperFuncOp = rewriter.create( - loc, llvm::formatv("_catalyst_pyface_{0}", nameWithoutPrefix).str(), wrapperFuncType, - LLVM::Linkage::External, /*dsoLocal*/ false, + auto wrapperFuncOp = LLVM::LLVMFuncOp::create( + rewriter, loc, llvm::formatv("_catalyst_pyface_{0}", nameWithoutPrefix).str(), + wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C); OpBuilder::InsertionGuard guard(rewriter); @@ -158,17 +158,17 @@ void wrapResultsAndArgsInTwoStructs(LLVM::LLVMFuncOp op, PatternRewriter &rewrit if (hasInputs) { Value arg = wrapperFuncOp.getArgument(1); auto argType = inputType; - Value structOfMemrefs = rewriter.create(loc, argType, arg); + Value structOfMemrefs = LLVM::LoadOp::create(rewriter, loc, argType, arg); for (size_t idx = 0; idx < params.size(); idx++) { - Value pointer = rewriter.create(loc, structOfMemrefs, idx); + Value pointer = LLVM::ExtractValueOp::create(rewriter, loc, structOfMemrefs, idx); args.push_back(pointer); } } - auto call = rewriter.create(loc, op, args); + auto call = LLVM::CallOp::create(rewriter, loc, op, args); - rewriter.create(loc, call.getResults()); + LLVM::ReturnOp::create(rewriter, loc, call.getResults()); } struct EmitCatalystPyInterfaceTransform : public OpRewritePattern { diff --git a/mlir/lib/Quantum/Utils/QuantumSplitting.cpp b/mlir/lib/Quantum/Utils/QuantumSplitting.cpp index 50aa650a16..5296b47de5 100644 --- a/mlir/lib/Quantum/Utils/QuantumSplitting.cpp +++ b/mlir/lib/Quantum/Utils/QuantumSplitting.cpp @@ -88,14 +88,14 @@ QuantumCache QuantumCache::initialize(Region ®ion, OpBuilder &builder, Locati auto paramVectorType = ArrayListType::get(ctx, builder.getF64Type()); auto wireVectorType = ArrayListType::get(ctx, builder.getI64Type()); auto controlFlowTapeType = ArrayListType::get(ctx, builder.getIndexType()); - auto paramVector = builder.create(loc, paramVectorType); - auto wireVector = builder.create(loc, wireVectorType); + auto paramVector = ListInitOp::create(builder, loc, paramVectorType); + auto wireVector = ListInitOp::create(builder, loc, wireVectorType); // Initialize the tapes that store the structure of control flow. DenseMap> controlFlowTapes; region.walk([&](Operation *op) { if (isa(op)) { - auto tape = builder.create(loc, controlFlowTapeType); + auto tape = catalyst::ListInitOp::create(builder, loc, controlFlowTapeType); controlFlowTapes.insert({op, tape}); } }); @@ -105,10 +105,10 @@ QuantumCache QuantumCache::initialize(Region ®ion, OpBuilder &builder, Locati void QuantumCache::emitDealloc(OpBuilder &builder, Location loc) { - builder.create(loc, paramVector); - builder.create(loc, wireVector); + ListDeallocOp::create(builder, loc, paramVector); + ListDeallocOp::create(builder, loc, wireVector); for (const auto &[_key, controlFlowTape] : controlFlowTapes) { - builder.create(loc, controlFlowTape); + ListDeallocOp::create(builder, loc, controlFlowTape); } } @@ -124,7 +124,7 @@ void AugmentedCircuitGenerator::cacheGate(quantum::ParametrizedGate gate, OpBuil verifyTypeIsCacheable(paramType, op); if (paramType.isF64()) { - builder.create(loc, clonedParam, cache.paramVector); + ListPushOp::create(builder, loc, clonedParam, cache.paramVector); continue; } @@ -139,14 +139,14 @@ void AugmentedCircuitGenerator::cacheGate(quantum::ParametrizedGate gate, OpBuil auto aTensor = cast(paramType); ArrayRef shape = aTensor.getShape(); - Value c0 = builder.create(loc, 0); - Value c1 = builder.create(loc, 1); + Value c0 = index::ConstantOp::create(builder, loc, 0); + Value c1 = index::ConstantOp::create(builder, loc, 1); bool isDim0Static = ShapedType::kDynamic != shape[0]; bool isDim1Static = ShapedType::kDynamic != shape[1]; - Value dim0Length = isDim0Static ? (Value)builder.create(loc, shape[0]) - : (Value)builder.create(loc, param, c0); - Value dim1Length = isDim1Static ? (Value)builder.create(loc, shape[1]) - : (Value)builder.create(loc, param, c1); + Value dim0Length = isDim0Static ? (Value)index::ConstantOp::create(builder, loc, shape[0]) + : (Value)tensor::DimOp::create(builder, loc, param, c0); + Value dim1Length = isDim1Static ? (Value)index::ConstantOp::create(builder, loc, shape[1]) + : (Value)tensor::DimOp::create(builder, loc, param, c1); Value lowerBoundDim0 = c0; Value upperBoundDim0 = dim0Length; @@ -157,27 +157,27 @@ void AugmentedCircuitGenerator::cacheGate(quantum::ParametrizedGate gate, OpBuil Value matrix = clonedParam; scf::ForOp iForLoop = - builder.create(loc, lowerBoundDim0, upperBoundDim0, stepDim0); + scf::ForOp::create(builder, loc, lowerBoundDim0, upperBoundDim0, stepDim0); { OpBuilder::InsertionGuard afterIForLoop(builder); builder.setInsertionPointToStart(iForLoop.getBody()); Value i_index = iForLoop.getInductionVar(); scf::ForOp jForLoop = - builder.create(loc, lowerBoundDim1, upperBoundDim1, stepDim1); + scf::ForOp::create(builder, loc, lowerBoundDim1, upperBoundDim1, stepDim1); { OpBuilder::InsertionGuard afterJForLoop(builder); builder.setInsertionPointToStart(jForLoop.getBody()); Value j_index = jForLoop.getInductionVar(); SmallVector indices = {i_index, j_index}; - Value element = builder.create(loc, matrix, indices); + Value element = tensor::ExtractOp::create(builder, loc, matrix, indices); // element is complex! // So we need to convert into {f64, f64} - Value real = builder.create(loc, element); - Value imag = builder.create(loc, element); + Value real = complex::ReOp::create(builder, loc, element); + Value imag = complex::ImOp::create(builder, loc, element); // Again, take note of the order. - builder.create(loc, real, cache.paramVector); - builder.create(loc, imag, cache.paramVector); + ListPushOp::create(builder, loc, real, cache.paramVector); + ListPushOp::create(builder, loc, imag, cache.paramVector); } } } @@ -261,11 +261,11 @@ void AugmentedCircuitGenerator::visitOperation(scf::ForOp forOp, OpBuilder &buil // Store the start, stop, and step to this op's control flow tape. Value tape = cache.controlFlowTapes.at(forOp); for (Value param : {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()}) { - builder.create(forOp.getLoc(), oldToCloned.lookupOrDefault(param), tape); + ListPushOp::create(builder, forOp.getLoc(), oldToCloned.lookupOrDefault(param), tape); } - auto newForOp = builder.create( - forOp.getLoc(), oldToCloned.lookupOrDefault(forOp.getLowerBound()), + auto newForOp = scf::ForOp::create( + builder, forOp.getLoc(), oldToCloned.lookupOrDefault(forOp.getLowerBound()), oldToCloned.lookupOrDefault(forOp.getUpperBound()), oldToCloned.lookupOrDefault(forOp.getStep()), classicalInits, [&](OpBuilder &builder, Location loc, Value inductionVar, ValueRange iterArgs) { @@ -299,10 +299,10 @@ void AugmentedCircuitGenerator::visitOperation(scf::WhileOp whileOp, OpBuilder & // Augment the classical loop by counting the number of iterations. auto counterType = MemRefType::get({}, builder.getIndexType()); Location loc = whileOp.getLoc(); - Value idx0 = builder.create(loc, 0); - Value idx1 = builder.create(loc, 1); - Value counter = builder.create(loc, counterType); - builder.create(loc, idx0, counter); + Value idx0 = index::ConstantOp::create(builder, loc, 0); + Value idx1 = index::ConstantOp::create(builder, loc, 1); + Value counter = memref::AllocaOp::create(builder, loc, counterType); + memref::StoreOp::create(builder, loc, idx0, counter); auto getRegionBuilder = [&](Region &oldRegion, bool incrementCounter) { return [&, incrementCounter](OpBuilder &builder, Location loc, ValueRange newRegionArgs) { @@ -311,9 +311,9 @@ void AugmentedCircuitGenerator::visitOperation(scf::WhileOp whileOp, OpBuilder & } if (incrementCounter) { - Value countVal = builder.create(loc, counter); - countVal = builder.create(loc, countVal, idx1); - builder.create(loc, countVal, counter); + Value countVal = memref::LoadOp::create(builder, loc, counter); + countVal = index::AddOp::create(builder, loc, countVal, idx1); + memref::StoreOp::create(builder, loc, countVal, counter); } // Recursively clone the region @@ -322,19 +322,19 @@ void AugmentedCircuitGenerator::visitOperation(scf::WhileOp whileOp, OpBuilder & }; }; - auto newWhileOp = builder.create( - whileOp.getLoc(), classicalResultTypes, classicalInits, - getRegionBuilder(whileOp.getBefore(), /*incrementCounter=*/false), - // We only care about the number of times the "After" region executes. The frontend - // does not support putting quantum operations in the "Before" region, which only - // computes the iteration condition. - getRegionBuilder(whileOp.getAfter(), /*incrementCounter=*/true)); + auto newWhileOp = + scf::WhileOp::create(builder, whileOp.getLoc(), classicalResultTypes, classicalInits, + getRegionBuilder(whileOp.getBefore(), /*incrementCounter=*/false), + // We only care about the number of times the "After" region executes. + // The frontend does not support putting quantum operations in the + // "Before" region, which only computes the iteration condition. + getRegionBuilder(whileOp.getAfter(), /*incrementCounter=*/true)); mapResults(whileOp, newWhileOp, argIdxMapping); - Value numIters = builder.create(whileOp.getLoc(), counter); + Value numIters = memref::LoadOp::create(builder, whileOp.getLoc(), counter); Value tape = cache.controlFlowTapes.at(whileOp); - builder.create(whileOp.getLoc(), numIters, tape); + ListPushOp::create(builder, whileOp.getLoc(), numIters, tape); } void AugmentedCircuitGenerator::visitOperation(scf::IfOp ifOp, OpBuilder &builder) @@ -352,12 +352,12 @@ void AugmentedCircuitGenerator::visitOperation(scf::IfOp ifOp, OpBuilder &builde Value condition = oldToCloned.lookupOrDefault(ifOp.getCondition()); Value tape = cache.controlFlowTapes.at(ifOp); Value castedCondition = - builder.create(ifOp.getLoc(), builder.getIndexType(), condition); - builder.create(ifOp.getLoc(), castedCondition, tape); + index::CastSOp::create(builder, ifOp.getLoc(), builder.getIndexType(), condition); + ListPushOp::create(builder, ifOp.getLoc(), castedCondition, tape); auto newIfOp = - builder.create(ifOp.getLoc(), condition, getRegionBuilder(ifOp.getThenRegion()), - getRegionBuilder(ifOp.getElseRegion())); + scf::IfOp::create(builder, ifOp.getLoc(), condition, getRegionBuilder(ifOp.getThenRegion()), + getRegionBuilder(ifOp.getElseRegion())); mapResults(ifOp, newIfOp, argIdxMapping); } diff --git a/mlir/lib/Quantum/Utils/RemoveQuantum.cpp b/mlir/lib/Quantum/Utils/RemoveQuantum.cpp index c1e5ddb7c5..9544f8c1c5 100644 --- a/mlir/lib/Quantum/Utils/RemoveQuantum.cpp +++ b/mlir/lib/Quantum/Utils/RemoveQuantum.cpp @@ -73,18 +73,18 @@ void replaceQuantumMeasurements(func::FuncOp &function, PatternRewriter &rewrite if (auto tensorType = dyn_cast(type)) { auto shape = tensorType.getShape(); auto elemType = tensorType.getElementType(); - auto res = rewriter.create(loc, shape, elemType); + auto res = tensor::EmptyOp::create(rewriter, loc, shape, elemType); results.push_back(res); } else { if (type.isInteger()) { - auto res = rewriter.create(loc, type, - rewriter.getIntegerAttr(type, 0)); + auto res = arith::ConstantOp::create(rewriter, loc, type, + rewriter.getIntegerAttr(type, 0)); results.push_back(res); } else if (type.isIntOrFloat()) { - auto res = rewriter.create(loc, type, - rewriter.getFloatAttr(type, 0.0)); + auto res = arith::ConstantOp::create(rewriter, loc, type, + rewriter.getFloatAttr(type, 0.0)); results.push_back(res); } else { diff --git a/mlir/lib/RTIO/Transforms/ARTIQRuntimeBuilder.hpp b/mlir/lib/RTIO/Transforms/ARTIQRuntimeBuilder.hpp index 2ce9974906..7de92926fa 100644 --- a/mlir/lib/RTIO/Transforms/ARTIQRuntimeBuilder.hpp +++ b/mlir/lib/RTIO/Transforms/ARTIQRuntimeBuilder.hpp @@ -85,7 +85,7 @@ class ARTIQRuntimeBuilder { Value nowMu() { auto func = ensureFunc(ARTIQFuncNames::nowMu, LLVM::LLVMFunctionType::get(i64Ty, {})); - auto call = builder.create(getLoc(), func, ValueRange{}); + auto call = LLVM::CallOp::create(builder, getLoc(), func, ValueRange{}); call.setTailCallKind(LLVM::TailCallKind::Tail); return call.getResult(); } @@ -93,7 +93,7 @@ class ARTIQRuntimeBuilder { void atMu(Value time) { auto func = ensureFunc(ARTIQFuncNames::atMu, LLVM::LLVMFunctionType::get(voidTy, {i64Ty})); - auto call = builder.create(getLoc(), func, ValueRange{time}); + auto call = LLVM::CallOp::create(builder, getLoc(), func, ValueRange{time}); call.setTailCallKind(LLVM::TailCallKind::Tail); } @@ -101,7 +101,7 @@ class ARTIQRuntimeBuilder { { auto func = ensureFunc(ARTIQFuncNames::delayMu, LLVM::LLVMFunctionType::get(voidTy, {i64Ty})); - auto call = builder.create(getLoc(), func, ValueRange{duration}); + auto call = LLVM::CallOp::create(builder, getLoc(), func, ValueRange{duration}); call.setCConv(LLVM::CConv::Fast); call.setTailCallKind(LLVM::TailCallKind::Tail); } @@ -111,14 +111,14 @@ class ARTIQRuntimeBuilder { { auto func = ensureFunc(ARTIQFuncNames::rtioOutput, LLVM::LLVMFunctionType::get(voidTy, {i32Ty, i32Ty})); - auto call = builder.create(getLoc(), func, ValueRange{addr, val}); + auto call = LLVM::CallOp::create(builder, getLoc(), func, ValueRange{addr, val}); call.setTailCallKind(LLVM::TailCallKind::Tail); } void rtioInit() { auto func = ensureFunc(ARTIQFuncNames::rtioInit, LLVM::LLVMFunctionType::get(voidTy, {})); - auto call = builder.create(getLoc(), func, ValueRange{}); + auto call = LLVM::CallOp::create(builder, getLoc(), func, ValueRange{}); call.setCConv(LLVM::CConv::Fast); call.setTailCallKind(LLVM::TailCallKind::Tail); } @@ -127,7 +127,7 @@ class ARTIQRuntimeBuilder { { auto func = ensureFunc(ARTIQFuncNames::rtioGetCounter, LLVM::LLVMFunctionType::get(i64Ty, {})); - auto call = builder.create(getLoc(), func, ValueRange{}); + auto call = LLVM::CallOp::create(builder, getLoc(), func, ValueRange{}); call.setCConv(LLVM::CConv::Fast); call.setTailCallKind(LLVM::TailCallKind::Tail); return call.getResult(); @@ -138,7 +138,7 @@ class ARTIQRuntimeBuilder { { ensureSecToMuFunc(); auto func = getModule().lookupSymbol(ARTIQFuncNames::secToMu); - auto call = builder.create(getLoc(), func, ValueRange{durationSec}); + auto call = LLVM::CallOp::create(builder, getLoc(), func, ValueRange{durationSec}); call.setCConv(LLVM::CConv::Fast); call.setTailCallKind(LLVM::TailCallKind::Tail); return call.getResult(); @@ -149,8 +149,8 @@ class ARTIQRuntimeBuilder { { ensureConfigSpiFunc(); auto func = getModule().lookupSymbol(ARTIQFuncNames::configSpi); - auto call = - builder.create(getLoc(), func, ValueRange{baseAddr, cs, len, div, flags}); + auto call = LLVM::CallOp::create(builder, getLoc(), func, + ValueRange{baseAddr, cs, len, div, flags}); call.setCConv(LLVM::CConv::Fast); call.setTailCallKind(LLVM::TailCallKind::Tail); } @@ -169,8 +169,8 @@ class ARTIQRuntimeBuilder { { ensureSetFrequencyFunc(); auto func = getModule().lookupSymbol(ARTIQFuncNames::setFrequency); - builder.create(getLoc(), func, - ValueRange{channelId, freqHz, phaseTurns, amplitude}); + LLVM::CallOp::create(builder, getLoc(), func, + ValueRange{channelId, freqHz, phaseTurns, amplitude}); return nowMu(); } @@ -182,17 +182,17 @@ class ARTIQRuntimeBuilder { // Constant creation helpers Value constI32(int32_t val) { - return builder.create(getLoc(), builder.getI32IntegerAttr(val)); + return arith::ConstantOp::create(builder, getLoc(), builder.getI32IntegerAttr(val)); } Value constI64(int64_t val) { - return builder.create(getLoc(), builder.getI64IntegerAttr(val)); + return arith::ConstantOp::create(builder, getLoc(), builder.getI64IntegerAttr(val)); } Value constF64(double val) { - return builder.create(getLoc(), builder.getF64FloatAttr(val)); + return arith::ConstantOp::create(builder, getLoc(), builder.getF64FloatAttr(val)); } // Accessors @@ -234,8 +234,8 @@ class ARTIQRuntimeBuilder { builder.setInsertionPointToStart(module.getBody()); auto funcTy = LLVM::LLVMFunctionType::get(i64Ty, {f64Ty}); - auto func = builder.create(getLoc(), ARTIQFuncNames::secToMu, funcTy, - LLVM::Linkage::Internal); + auto func = LLVM::LLVMFuncOp::create(builder, getLoc(), ARTIQFuncNames::secToMu, funcTy, + LLVM::Linkage::Internal); func.setCConv(LLVM::CConv::Fast); Block *entry = func.addEntryBlock(builder); @@ -244,10 +244,10 @@ class ARTIQRuntimeBuilder { // duration_mu = round(duration_sec / 1e-9) Value nsPerMu = constF64(ARTIQHardwareConfig::nanosecondPeriod); - Value durationNs = builder.create(getLoc(), durationSec, nsPerMu); - Value rounded = builder.create(getLoc(), durationNs); - Value result = builder.create(getLoc(), i64Ty, rounded); - builder.create(getLoc(), result); + Value durationNs = arith::DivFOp::create(builder, getLoc(), durationSec, nsPerMu); + Value rounded = math::RoundOp::create(builder, getLoc(), durationNs); + Value result = arith::FPToSIOp::create(builder, getLoc(), i64Ty, rounded); + LLVM::ReturnOp::create(builder, getLoc(), result); } void ensureConfigSpiFunc() @@ -261,8 +261,8 @@ class ARTIQRuntimeBuilder { builder.setInsertionPointToStart(module.getBody()); auto funcTy = LLVM::LLVMFunctionType::get(voidTy, {i32Ty, i32Ty, i32Ty, i32Ty, i32Ty}); - auto func = builder.create(getLoc(), ARTIQFuncNames::configSpi, funcTy, - LLVM::Linkage::Internal); + auto func = LLVM::LLVMFuncOp::create(builder, getLoc(), ARTIQFuncNames::configSpi, funcTy, + LLVM::Linkage::Internal); Block *entry = func.addEntryBlock(builder); builder.setInsertionPointToStart(entry); @@ -274,21 +274,21 @@ class ARTIQRuntimeBuilder { Value flags = entry->getArgument(4); // Config register address = Base | 1 - Value configAddr = builder.create(getLoc(), baseAddr, constI32(1)); + Value configAddr = arith::OrIOp::create(builder, getLoc(), baseAddr, constI32(1)); // Pack: (CS << 24) | ((div - 2) << 16) | ((len - 1) << 8) | flags - Value csShifted = builder.create(getLoc(), cs, constI32(24)); - Value divOffset = builder.create(getLoc(), div, constI32(2)); - Value divShifted = builder.create(getLoc(), divOffset, constI32(16)); - Value lenOffset = builder.create(getLoc(), len, constI32(1)); - Value lenShifted = builder.create(getLoc(), lenOffset, constI32(8)); + Value csShifted = arith::ShLIOp::create(builder, getLoc(), cs, constI32(24)); + Value divOffset = arith::SubIOp::create(builder, getLoc(), div, constI32(2)); + Value divShifted = arith::ShLIOp::create(builder, getLoc(), divOffset, constI32(16)); + Value lenOffset = arith::SubIOp::create(builder, getLoc(), len, constI32(1)); + Value lenShifted = arith::ShLIOp::create(builder, getLoc(), lenOffset, constI32(8)); - Value packed = builder.create(getLoc(), csShifted, divShifted); - packed = builder.create(getLoc(), packed, lenShifted); - packed = builder.create(getLoc(), packed, flags); + Value packed = arith::OrIOp::create(builder, getLoc(), csShifted, divShifted); + packed = arith::OrIOp::create(builder, getLoc(), packed, lenShifted); + packed = arith::OrIOp::create(builder, getLoc(), packed, flags); rtioOutput(configAddr, packed); - builder.create(getLoc(), ValueRange{}); + LLVM::ReturnOp::create(builder, getLoc(), ValueRange{}); } void ensureSetFrequencyFunc() @@ -304,8 +304,8 @@ class ARTIQRuntimeBuilder { builder.setInsertionPointToStart(module.getBody()); auto funcTy = LLVM::LLVMFunctionType::get(voidTy, {i32Ty, f64Ty, f64Ty, f64Ty}); - auto func = builder.create(getLoc(), ARTIQFuncNames::setFrequency, funcTy, - LLVM::Linkage::Internal); + auto func = LLVM::LLVMFuncOp::create(builder, getLoc(), ARTIQFuncNames::setFrequency, + funcTy, LLVM::Linkage::Internal); Block *entry = func.addEntryBlock(builder); builder.setInsertionPointToStart(entry); @@ -322,21 +322,21 @@ class ARTIQRuntimeBuilder { // CS calculation: csBase is the chip_select for ch0 (typically 4) // For Urukul: ch0->CS=4, ch1->CS=5, ch2->CS=6, ch3->CS=7 // So CS = csBase + channelId - Value cs = builder.create(getLoc(), constI32(csBase), channelId); + Value cs = arith::AddIOp::create(builder, getLoc(), constI32(csBase), channelId); Value spiBase = constI32(spiBaseAddr); Value ioUpdate = constI32(ioUpdateAddr); // Calculate FTW: round(frequency * (2^32 / sys_clk)) Value ftwScale = constF64(ARTIQHardwareConfig::ftwScaleFactor); - Value ftwDouble = builder.create(getLoc(), freqHz, ftwScale); - Value ftwRounded = builder.create(getLoc(), ftwDouble); - Value ftw = builder.create(getLoc(), i32Ty, ftwRounded); + Value ftwDouble = arith::MulFOp::create(builder, getLoc(), freqHz, ftwScale); + Value ftwRounded = math::RoundOp::create(builder, getLoc(), ftwDouble); + Value ftw = arith::FPToUIOp::create(builder, getLoc(), i32Ty, ftwRounded); // Calculate POW: round(phaseTurns * 65536) Value powScale = constF64(ARTIQHardwareConfig::powScaleFactor); - Value powDouble = builder.create(getLoc(), phaseTurns, powScale); - Value powRounded = builder.create(getLoc(), powDouble); - Value pow = builder.create(getLoc(), i32Ty, powRounded); + Value powDouble = arith::MulFOp::create(builder, getLoc(), phaseTurns, powScale); + Value powRounded = math::RoundOp::create(builder, getLoc(), powDouble); + Value pow = arith::FPToUIOp::create(builder, getLoc(), i32Ty, powRounded); // SPI Transfer: Write instruction to profile 7 (0x15) configSpi(spiBase, cs, constI32(ARTIQHardwareConfig::spiLen8), @@ -354,11 +354,11 @@ class ARTIQRuntimeBuilder { constI32(ARTIQHardwareConfig::spiFlagsKeepCS)); delayMu(constI64(ARTIQHardwareConfig::refPeriodMu)); Value asfScale = constF64(static_cast(ARTIQHardwareConfig::maxAmplitude)); - Value asfDouble = builder.create(getLoc(), amplitude, asfScale); - Value asfRounded = builder.create(getLoc(), asfDouble); - Value asf = builder.create(getLoc(), i32Ty, asfRounded); - Value asfShifted = builder.create(getLoc(), asf, constI32(16)); - Value ampPhase = builder.create(getLoc(), asfShifted, pow); + Value asfDouble = arith::MulFOp::create(builder, getLoc(), amplitude, asfScale); + Value asfRounded = math::RoundOp::create(builder, getLoc(), asfDouble); + Value asf = arith::FPToUIOp::create(builder, getLoc(), i32Ty, asfRounded); + Value asfShifted = arith::ShLIOp::create(builder, getLoc(), asf, constI32(16)); + Value ampPhase = arith::OrIOp::create(builder, getLoc(), asfShifted, pow); rtioOutput(spiBase, ampPhase); // Wait for SPI transmission to complete waitForSpi(ARTIQHardwareConfig::spiLen32, ARTIQHardwareConfig::spiDiv); @@ -377,7 +377,7 @@ class ARTIQRuntimeBuilder { delayMu(constI64(ARTIQHardwareConfig::ioUpdatePulseWidth)); ttlOff(ioUpdate); - builder.create(getLoc(), ValueRange{}); + LLVM::ReturnOp::create(builder, getLoc(), ValueRange{}); } /// Returns hardware addresses: (spiBaseAddr, csBase, ioUpdateAddr) diff --git a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp index 19b0154e62..6005b9b58f 100644 --- a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp +++ b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp @@ -265,7 +265,7 @@ class PulseScheduler { auto eventType = rtio::EventType::get(builder.getContext()); Value syncEvent = - builder.create(anyPulse.getLoc(), eventType, eventsToSync); + rtio::RTIOSyncOp::create(builder, anyPulse.getLoc(), eventType, eventsToSync); // Update boundaries and consumers for (auto &[_, pulse] : channelBoundary) { @@ -410,11 +410,11 @@ void decomposeFrequencyPulses(ScheduleGroupsMap &pulseGroups) builder.setInsertionPoint(firstRoot); // Create sync - Value chainStart = - originalWaits.size() > 1 - ? builder.create( - firstRoot.getLoc(), rtio::EventType::get(builder.getContext()), originalWaits) - : originalWaits[0]; + Value chainStart = originalWaits.size() > 1 + ? rtio::RTIOSyncOp::create( + builder, firstRoot.getLoc(), + rtio::EventType::get(builder.getContext()), originalWaits) + : originalWaits[0]; // Create frequency setting chain Value lastFreqEvent = chainStart; @@ -532,7 +532,7 @@ struct RTIOEventToARTIQPass : public impl::RTIOEventToARTIQPassBase(kernelFunc.getLoc(), counter, slack); + Value initialTime = arith::AddIOp::create(builder, kernelFunc.getLoc(), counter, slack); artiq.atMu(initialTime); return success(); diff --git a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp index 755bd20b34..8c9a8e0917 100644 --- a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp +++ b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp @@ -65,9 +65,9 @@ struct PulseOpLowering : public OpConversionPattern { } Value amplitude = artiq.constF64(1.0); - rewriter.create(op.getLoc(), setFreqFunc, - ValueRange{adaptor.getChannel(), adaptor.getFrequency(), - adaptor.getPhase(), amplitude}); + LLVM::CallOp::create(rewriter, op.getLoc(), setFreqFunc, + ValueRange{adaptor.getChannel(), adaptor.getFrequency(), + adaptor.getPhase(), amplitude}); Value newTime = artiq.nowMu(); rewriter.replaceOp(op, newTime); @@ -92,7 +92,7 @@ struct PulseOpLowering : public OpConversionPattern { // Enforce minimum pulse duration to avoid 0 duratoin events Value minDuration = artiq.constI64(ARTIQHardwareConfig::minTTLPulseMu); - durationMu = rewriter.create(op.getLoc(), durationMu, minDuration); + durationMu = arith::MaxSIOp::create(rewriter, op.getLoc(), durationMu, minDuration); artiq.ttlOn(channelAddr); artiq.delayMu(durationMu); @@ -120,7 +120,7 @@ struct SyncOpLowering : public OpConversionPattern { // Compute maximum timestamp Value maxTime = events[0]; for (size_t i = 1; i < events.size(); ++i) { - maxTime = rewriter.create(op.getLoc(), maxTime, events[i]); + maxTime = arith::MaxSIOp::create(rewriter, op.getLoc(), maxTime, events[i]); } ARTIQRuntimeBuilder artiq(rewriter, op); @@ -150,8 +150,8 @@ struct ChannelOpLowering : public OpConversionPattern { { int32_t channelId = extractChannelId(op.getChannel()); Type resultType = getTypeConverter()->convertType(op.getChannel().getType()); - Value result = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr(resultType, channelId)); + Value result = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getIntegerAttr(resultType, channelId)); rewriter.replaceOp(op, result); return success(); } @@ -184,8 +184,8 @@ struct DecomposePulsePattern : public OpRewritePattern { // Sync both pulses auto eventType = EventType::get(rewriter.getContext()); - Value syncEvent = rewriter.create( - loc, eventType, ValueRange{controlPulse.getEvent(), slackPulse.getEvent()}); + Value syncEvent = RTIOSyncOp::create( + rewriter, loc, eventType, ValueRange{controlPulse.getEvent(), slackPulse.getEvent()}); rewriter.replaceOp(op, syncEvent); return success(); diff --git a/mlir/lib/RTIO/Transforms/Utils.hpp b/mlir/lib/RTIO/Transforms/Utils.hpp index b411f518c7..ce2f84a84a 100644 --- a/mlir/lib/RTIO/Transforms/Utils.hpp +++ b/mlir/lib/RTIO/Transforms/Utils.hpp @@ -57,7 +57,7 @@ inline mlir::Value computeChannelDeviceAddr(mlir::OpBuilder &builder, mlir::Oper "only static channels are supported"); int64_t channelId = channelIdAPInt.getSExtValue(); int32_t addr = static_cast((channelId + channelBase) << 8); - return builder.create(loc, builder.getI32IntegerAttr(addr)); + return mlir::arith::ConstantOp::create(builder, loc, builder.getI32IntegerAttr(addr)); } } // namespace rtio diff --git a/mlir/lib/hlo-extensions/Transforms/HloCustomCallPatterns.cpp b/mlir/lib/hlo-extensions/Transforms/HloCustomCallPatterns.cpp index bf2b5bdb3d..55365839b2 100644 --- a/mlir/lib/hlo-extensions/Transforms/HloCustomCallPatterns.cpp +++ b/mlir/lib/hlo-extensions/Transforms/HloCustomCallPatterns.cpp @@ -50,7 +50,7 @@ struct HloCustomCallOpRewritePattern : public mlir::OpRewritePattern Value { auto type = RankedTensorType::get({}, rewriter.getI32Type()); auto attr = DenseElementsAttr::get(type, APInt(32, static_cast(val))); - return rewriter.create(loc, attr); + return arith::ConstantOp::create(rewriter, loc, attr); }; if (operands.empty()) { @@ -72,18 +72,18 @@ struct HloCustomCallOpRewritePattern : public mlir::OpRewritePattern(attrValue)) { auto type = RankedTensorType::get({}, intAttr.getType()); - constVal = rewriter.create( - loc, DenseElementsAttr::get(type, intAttr.getValue())); + constVal = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(type, intAttr.getValue())); } else if (auto floatAttr = llvm::dyn_cast(attrValue)) { auto type = RankedTensorType::get({}, floatAttr.getType()); - constVal = rewriter.create( - loc, DenseElementsAttr::get(type, floatAttr.getValue())); + constVal = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(type, floatAttr.getValue())); } else if (auto boolAttr = llvm::dyn_cast(attrValue)) { auto type = RankedTensorType::get({}, rewriter.getI1Type()); - constVal = rewriter.create( - loc, DenseElementsAttr::get(type, boolAttr.getValue())); + constVal = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(type, boolAttr.getValue())); } else { LLVM_DEBUG(llvm::dbgs() << "Unsupported attribute type for: " diff --git a/mlir/lib/hlo-extensions/Transforms/ScatterPatterns.cpp b/mlir/lib/hlo-extensions/Transforms/ScatterPatterns.cpp index 906d8ba6ea..70ac53b98f 100644 --- a/mlir/lib/hlo-extensions/Transforms/ScatterPatterns.cpp +++ b/mlir/lib/hlo-extensions/Transforms/ScatterPatterns.cpp @@ -243,12 +243,12 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern(loc, scatterDimIndex); + Value scatterDimVal = index::ConstantOp::create(rewriter, loc, scatterDimIndex); auto extractOp = - rewriter.create(loc, scatterIndices, scatterDimVal) + tensor::ExtractOp::create(rewriter, loc, scatterIndices, scatterDimVal) .getResult(); auto indexCastOp = - rewriter.create(loc, rewriter.getIndexType(), extractOp) + arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), extractOp) .getResult(); dynOffsets.push_back(indexCastOp); staticOffsets.push_back(ShapedType::kDynamic); @@ -256,12 +256,12 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern(loc, scatterDimIndex); + Value scatterDimVal = index::ConstantOp::create(rewriter, loc, scatterDimIndex); auto extractOp = - rewriter.create(loc, scatterIndices, scatterDimVal) + tensor::ExtractOp::create(rewriter, loc, scatterIndices, scatterDimVal) .getResult(); auto indexCastOp = - rewriter.create(loc, rewriter.getIndexType(), extractOp) + arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), extractOp) .getResult(); dynOffsets.push_back(indexCastOp); staticOffsets.push_back(ShapedType::kDynamic); @@ -328,112 +328,110 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern(loc, 0); - Value sizeAllUpdatesIndices = rewriter.create(loc, variables.size); - Value c1 = rewriter.create(loc, 1); + Value c0 = index::ConstantOp::create(rewriter, loc, 0); + Value sizeAllUpdatesIndices = index::ConstantOp::create(rewriter, loc, variables.size); + Value c1 = index::ConstantOp::create(rewriter, loc, 1); // Create a SCF for op, the initial value for args is the results Value resultValue = - rewriter - .create( - loc, c0, sizeAllUpdatesIndices, c1, /*iterArgsInit=*/variables.resultsValue, - [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - // Get the results - auto results = iterArgs.front(); - - // Extract from the all indices tensor the right configuration - // with the value i as index: allUpdatesIndices[i] - Value updatesIndices; - if (variables.allUpdatesIndicesTensor) { - updatesIndices = extractUpdateIndices(variables.allUpdatesIndicesTensor, - i, loc, builder); + scf::ForOp::create( + rewriter, loc, c0, sizeAllUpdatesIndices, c1, + /*iterArgsInit=*/variables.resultsValue, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + // Get the results + auto results = iterArgs.front(); + + // Extract from the all indices tensor the right configuration + // with the value i as index: allUpdatesIndices[i] + Value updatesIndices; + if (variables.allUpdatesIndicesTensor) { + updatesIndices = extractUpdateIndices(variables.allUpdatesIndicesTensor, i, + loc, builder); + } + + // Scatter update + SmallVector updateScatterIndices; + if (variables.allUpdatesIndicesTensor) { + for (int64_t index : variables.updatedScatterDims) { + Value indexValue = index::ConstantOp::create(builder, loc, index); + Value updateScatterIndex = + tensor::ExtractOp::create(builder, loc, updatesIndices, indexValue); + updateScatterIndices.push_back(updateScatterIndex); } - - // Scatter update - SmallVector updateScatterIndices; - if (variables.allUpdatesIndicesTensor) { - for (int64_t index : variables.updatedScatterDims) { - Value indexValue = builder.create(loc, index); - Value updateScatterIndex = builder.create( - loc, updatesIndices, indexValue); - updateScatterIndices.push_back(updateScatterIndex); - } - } - - // Windows update - SmallVector updateWindowsIndices; - if (variables.allUpdatesIndicesTensor) { - for (int64_t index : variables.updatedWindowsDims) { - Value indexValue = builder.create(loc, index); - Value updateWindowsIndex = builder.create( - loc, updatesIndices, indexValue); - updateWindowsIndices.push_back(updateWindowsIndex); - } + } + + // Windows update + SmallVector updateWindowsIndices; + if (variables.allUpdatesIndicesTensor) { + for (int64_t index : variables.updatedWindowsDims) { + Value indexValue = index::ConstantOp::create(builder, loc, index); + Value updateWindowsIndex = + tensor::ExtractOp::create(builder, loc, updatesIndices, indexValue); + updateWindowsIndices.push_back(updateWindowsIndex); } - - // Get results indices from update indices. - // The results indices are used to store the computed update of one element. - SmallVector resultsIndicesValue = - getResultsIndices(updateScatterIndices, updateWindowsIndices, - variables.inputsShape, variables.insertedWindowsDims, - variables.scatterIndices, variables.indexVectorDim, - variables.scatterDimsToOperandDims, builder, loc); - - // Right now the indices are stored in an IR tensor. - // We need to extract them all to pass them to the tensor.extract op. - SmallVector updatesIndicesValue; - if (updatesIndices) { - if (isa(updatesIndices.getType())) { - RankedTensorType updateType = - cast(updatesIndices.getType()); - - for (int64_t index = 0; index < updateType.getShape()[0]; ++index) { - Value indexValue = - builder.create(loc, index); - Value value = builder.create( - loc, updatesIndices, indexValue); - updatesIndicesValue.push_back(value); - } + } + + // Get results indices from update indices. + // The results indices are used to store the computed update of one element. + SmallVector resultsIndicesValue = getResultsIndices( + updateScatterIndices, updateWindowsIndices, variables.inputsShape, + variables.insertedWindowsDims, variables.scatterIndices, + variables.indexVectorDim, variables.scatterDimsToOperandDims, builder, loc); + + // Right now the indices are stored in an IR tensor. + // We need to extract them all to pass them to the tensor.extract op. + SmallVector updatesIndicesValue; + if (updatesIndices) { + if (isa(updatesIndices.getType())) { + RankedTensorType updateType = + cast(updatesIndices.getType()); + + for (int64_t index = 0; index < updateType.getShape()[0]; ++index) { + Value indexValue = index::ConstantOp::create(builder, loc, index); + Value value = tensor::ExtractOp::create(builder, loc, + updatesIndices, indexValue); + updatesIndicesValue.push_back(value); } } - // Set the arguments of the update function - Value updateValue = builder.create( - loc, variables.updatesValue, updatesIndicesValue); - Value resultValue = - builder.create(loc, results, resultsIndicesValue); - // The update function from JAX always expects tensors. - // Convert f64 -> tensor if necessary - if (!isa(updateValue.getType())) { - Type resultTy = RankedTensorType::get({}, updateValue.getType()); - updateValue = - builder.create(loc, resultTy, updateValue); - } - if (!isa(resultValue.getType())) { - Type resultTy = RankedTensorType::get({}, resultValue.getType()); - resultValue = - builder.create(loc, resultTy, resultValue); - } - - // Set the arguments for the call op - std::vector args{resultValue, updateValue}; - - // Call the function that computes the update - Value updated = - builder.create(loc, updateFnOp, args).getResult(0); - // The update function from JAX always produces tensors. - // Convert tensor -> f64 if necessary - Value updatedExtracted; - if (isa(updated.getType())) { - updatedExtracted = builder.create(loc, updated); - } - else { - updatedExtracted = updated; - } - // Insert the computed update in the results and replace the previous value - Value res = builder.create(loc, updatedExtracted, results, - resultsIndicesValue); - builder.create(loc, res); - }) + } + // Set the arguments of the update function + Value updateValue = tensor::ExtractOp::create( + builder, loc, variables.updatesValue, updatesIndicesValue); + Value resultValue = + tensor::ExtractOp::create(builder, loc, results, resultsIndicesValue); + // The update function from JAX always expects tensors. + // Convert f64 -> tensor if necessary + if (!isa(updateValue.getType())) { + Type resultTy = RankedTensorType::get({}, updateValue.getType()); + updateValue = + tensor::FromElementsOp::create(builder, loc, resultTy, updateValue); + } + if (!isa(resultValue.getType())) { + Type resultTy = RankedTensorType::get({}, resultValue.getType()); + resultValue = + tensor::FromElementsOp::create(builder, loc, resultTy, resultValue); + } + + // Set the arguments for the call op + std::vector args{resultValue, updateValue}; + + // Call the function that computes the update + Value updated = + func::CallOp::create(builder, loc, updateFnOp, args).getResult(0); + // The update function from JAX always produces tensors. + // Convert tensor -> f64 if necessary + Value updatedExtracted; + if (isa(updated.getType())) { + updatedExtracted = tensor::ExtractOp::create(builder, loc, updated); + } + else { + updatedExtracted = updated; + } + // Insert the computed update in the results and replace the previous value + Value res = tensor::InsertOp::create(builder, loc, updatedExtracted, results, + resultsIndicesValue); + scf::YieldOp::create(builder, loc, res); + }) .getResult(0); // Replace the results with the updated one rewriter.replaceOp(op, resultValue); @@ -519,7 +517,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern(loc, resultTy, allUpdatesIndices); + tensor::FromElementsOp::create(rewriter, loc, resultTy, allUpdatesIndices); data.size = allUpdatesIndices.size() / updatesShapeVector.size(); } return data; @@ -549,7 +547,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePatterngetOperandTypes()); - func::FuncOp updateFn = builder.create(loc, funcName, updateFnType); + func::FuncOp updateFn = func::FuncOp::create(builder, loc, funcName, updateFnType); updateFn.setPrivate(); // Create the block of the function @@ -563,8 +561,8 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern(loc, originalTerminator->getResultTypes(), - originalTerminator->getOperands()); + func::ReturnOp::create(builder, loc, originalTerminator->getResultTypes(), + originalTerminator->getOperands()); rewriter.eraseOp(originalTerminator); return SymbolRefAttr::get(ctx, funcName); } @@ -579,7 +577,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern(loc, elem); + auto valueCurrentIndex = index::ConstantOp::create(rewriter, loc, elem); // Add to configuration configurations.push_back(valueCurrentIndex); } @@ -622,15 +620,15 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern(loc, innerIndex); + Value indexConstantOp = index::ConstantOp::create(builder, loc, innerIndex); auto indexScatter = - builder.create(loc, scatterIndices, indexConstantOp); + tensor::ExtractOp::create(builder, loc, scatterIndices, indexConstantOp); auto indexUpdateCasted = - builder.create(loc, indexScatter.getType(), indexUpdate); + index::CastSOp::create(builder, loc, indexScatter.getType(), indexUpdate); Value addValue = - builder.create(loc, indexScatter, indexUpdateCasted); + arith::AddIOp::create(builder, loc, indexScatter, indexUpdateCasted); Value addValueCasted = - builder.create(loc, builder.getIndexType(), addValue); + arith::IndexCastOp::create(builder, loc, builder.getIndexType(), addValue); results.push_back(addValueCasted); } else { @@ -647,21 +645,21 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern(loc, innerIndex); + Value indexConstantOp = index::ConstantOp::create(builder, loc, innerIndex); auto indexScatter = - builder.create(loc, scatterIndices, indexConstantOp); + tensor::ExtractOp::create(builder, loc, scatterIndices, indexConstantOp); fullStartIndex.push_back(indexScatter); } else { TypedAttr indexAttr = builder.getI32IntegerAttr(0); - Value index = builder.create(loc, indexAttr); + Value index = arith::ConstantOp::create(builder, loc, indexAttr); fullStartIndex.push_back(index); } } // Full windows indices SmallVector fullWindowIndex = updateWindowsIndices; for (auto insertedDim : insertedWindowsDims) { - auto c0 = builder.create(loc, 0); + auto c0 = index::ConstantOp::create(builder, loc, 0); fullWindowIndex.insert(fullWindowIndex.begin() + insertedDim, c0); } // Add @@ -670,11 +668,11 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern(loc, indexScatter.getType(), indexUpdate); + index::CastSOp::create(builder, loc, indexScatter.getType(), indexUpdate); Value addValue = - builder.create(loc, indexScatter, indexUpdateCasted); + arith::AddIOp::create(builder, loc, indexScatter, indexUpdateCasted); Value addValueCasted = - builder.create(loc, builder.getIndexType(), addValue); + arith::IndexCastOp::create(builder, loc, builder.getIndexType(), addValue); results.push_back(addValueCasted); } return results; @@ -717,11 +715,10 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern(loc, resultType, scatterIndices, dynOffsets, - dynSizes, dynStrides, offsets, sizes, - strides); + return tensor::ExtractSliceOp::create(builder, loc, resultType, scatterIndices, dynOffsets, + dynSizes, dynStrides, offsets, sizes, strides); } // From a index value i it extracts the update indices from the tensor of values. Value extractUpdateIndices(Value allUpdatesIndicesTensor, Value i, Location loc, @@ -747,11 +744,11 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern(loc, resultType, allUpdatesIndicesTensor, - dynOffsets, dynSizes, dynStrides, offsets, - sizes, strides); + return tensor::ExtractSliceOp::create(builder, loc, resultType, allUpdatesIndicesTensor, + dynOffsets, dynSizes, dynStrides, offsets, sizes, + strides); } }; diff --git a/mlir/lib/hlo-extensions/Transforms/stablehlo_legalize_control_flow.cpp b/mlir/lib/hlo-extensions/Transforms/stablehlo_legalize_control_flow.cpp index 88c7febae6..c039d23b68 100644 --- a/mlir/lib/hlo-extensions/Transforms/stablehlo_legalize_control_flow.cpp +++ b/mlir/lib/hlo-extensions/Transforms/stablehlo_legalize_control_flow.cpp @@ -88,9 +88,9 @@ Value extractTensorValue(OpBuilder &b, Value tensor) if (mlir::cast(tensor.getType()).hasRank() && mlir::cast(tensor.getType()).getRank() != 0) { tensor = - b.create(loc, tensor, SmallVector()); + tensor::CollapseShapeOp::create(b, loc, tensor, SmallVector()); } - return b.create(loc, tensor, ValueRange()); + return tensor::ExtractOp::create(b, loc, tensor, ValueRange()); } struct ScfForBounds { @@ -151,7 +151,7 @@ struct WhileOpPattern : public OpConversionPattern { auto loc = op.getLoc(); if (auto bounds = extractForBounds(op)) { - auto newForOp = rewriter.create( + auto newForOp = scf::ForOp::create(rewriter, loc, extractTensorValue(rewriter, bounds->lb), extractTensorValue(rewriter, bounds->ub), extractTensorValue(rewriter, bounds->step), adaptor.getOperands()); @@ -164,7 +164,7 @@ struct WhileOpPattern : public OpConversionPattern { auto oldIndexArg = newForOp.getRegion().getArgument(1 + bounds->indexArgIndex); rewriter.setInsertionPointToStart(&newForOp.getRegion().front()); auto indexArgTensor = - rewriter.create(loc, oldIndexArg.getType(), indexArg); + tensor::FromElementsOp::create(rewriter, loc, oldIndexArg.getType(), indexArg); oldIndexArg.replaceAllUsesWith(indexArgTensor); rewriter.replaceOp(op, newForOp.getResults()); @@ -172,7 +172,7 @@ struct WhileOpPattern : public OpConversionPattern { } auto newWhileOp = - rewriter.create(loc, op.getResultTypes(), adaptor.getOperands()); + scf::WhileOp::create(rewriter, loc, op.getResultTypes(), adaptor.getOperands()); // Inline while condition. The block is the same, except the boolean result // needs to be extracted and used with an scf.condition. @@ -200,7 +200,7 @@ struct IfOpPattern : public OpConversionPattern { LogicalResult matchAndRewrite(stablehlo::IfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto scfIf = rewriter.create(op.getLoc(), op.getResultTypes(), + auto scfIf = scf::IfOp::create(rewriter, op.getLoc(), op.getResultTypes(), extractTensorValue(rewriter, adaptor.getPred()), /*withElseRegion=*/true); inlineStablehloRegionIntoSCFRegion(rewriter, op.getTrueBranch(), scfIf.getThenRegion()); @@ -228,12 +228,12 @@ struct CaseOpPattern : public OpConversionPattern { auto constAttr = DenseElementsAttr::get( shapedType, {mlir::cast(outerBuilder.getI32IntegerAttr(currentIdx))}); Value currentIdxVal = - outerBuilder.create(loc, idxValue.getType(), constAttr); + stablehlo::ConstantOp::create(outerBuilder, loc, idxValue.getType(), constAttr); - auto scfIf = outerBuilder.create( + auto scfIf = scf::IfOp::create(outerBuilder, loc, op.getResultTypes(), extractTensorValue(outerBuilder, - outerBuilder.create( + stablehlo::CompareOp::create(outerBuilder, loc, idxValue, currentIdxVal, ComparisonDirection::EQ)), /*withElseRegion=*/true); inlineStablehloRegionIntoSCFRegion(outerBuilder, op.getBranches()[currentIdx], @@ -248,7 +248,7 @@ struct CaseOpPattern : public OpConversionPattern { PatternRewriter::InsertionGuard guard(outerBuilder); outerBuilder.setInsertionPointToEnd(&scfIf.getElseRegion().back()); auto innerIf = createNestedCases(nextIdx, op, adaptor, outerBuilder); - outerBuilder.create(op.getLoc(), innerIf.getResults()); + scf::YieldOp::create(outerBuilder, op.getLoc(), innerIf.getResults()); } return scfIf; } diff --git a/mlir/lib/hlo-extensions/Transforms/stablehlo_legalize_sort.cpp b/mlir/lib/hlo-extensions/Transforms/stablehlo_legalize_sort.cpp index 98a93bb21e..42fbbb114d 100644 --- a/mlir/lib/hlo-extensions/Transforms/stablehlo_legalize_sort.cpp +++ b/mlir/lib/hlo-extensions/Transforms/stablehlo_legalize_sort.cpp @@ -90,14 +90,14 @@ Value emitComparison(ImplicitLocOpBuilder &b, SmallVector &lhs, SmallVect for (auto [idx, arg] : llvm::enumerate(comparator.getArguments())) { Value value = idx % 2 == 0 ? lhs[idx / 2] : rhs[idx / 2]; Type type = RankedTensorType::get({}, value.getType()); - mapping.map(arg, b.create(type, value)); + mapping.map(arg, tensor::FromElementsOp::create(b, b.getLoc(), type, value)); } for (Operation &op : block.without_terminator()) b.clone(op, mapping); Value result = mapping.lookup(block.getTerminator()->getOperands().front()); - return b.create(result, ValueRange()); + return tensor::ExtractOp::create(b, b.getLoc(), result, ValueRange()); } // Emits a binary search of `pivots` in `arrayMemrefs` (all rank 1) in the range @@ -109,7 +109,7 @@ Value emitBinarySearch(ImplicitLocOpBuilder &b, Value leftInit, Value rightInit, ArithBuilder arith(b, b.getLoc()); // while ( - auto whileOp = b.create(types, SmallVector{leftInit, rightInit}); + auto whileOp = scf::WhileOp::create(b, b.getLoc(), types, SmallVector{leftInit, rightInit}); OpBuilder::InsertionGuard guard(b); // left < right) { @@ -118,7 +118,7 @@ Value emitBinarySearch(ImplicitLocOpBuilder &b, Value leftInit, Value rightInit, { Value left = before->getArgument(0), right = before->getArgument(1); b.setInsertionPointToEnd(before); - b.create(arith.slt(left, right), before->getArguments()); + scf::ConditionOp::create(b, b.getLoc(), arith.slt(left, right), before->getArguments()); } Block *after = @@ -127,13 +127,14 @@ Value emitBinarySearch(ImplicitLocOpBuilder &b, Value leftInit, Value rightInit, Value left = after->getArgument(0), right = after->getArgument(1); b.setInsertionPointToEnd(after); // int mid = (left + right) >> 1; - Value one = b.create(1); - Value mid = b.create(arith.add(left, right), one); - Value midPlusOne = b.create(mid, one); + Value one = arith::ConstantIndexOp::create(b, b.getLoc(), 1); + Value mid = arith::ShRUIOp::create(b, b.getLoc(), arith.add(left, right), one); + Value midPlusOne = AddIOp::create(b, b.getLoc(), mid, one); auto arraysAtMid = llvm::to_vector(llvm::map_range(arrayMemrefs, [&](Value arrayMemref) -> Value { - return b.create(arrayMemref, mid); + Type type = mlir::cast(arrayMemref.getType()).getElementType(); + return memref::LoadOp::create(b, b.getLoc(), type, arrayMemref, mid); })); Value cond = emitComparison(b, pivots, arraysAtMid, comparator); // if (comparator(pivot, array[mid])) @@ -144,7 +145,7 @@ Value emitBinarySearch(ImplicitLocOpBuilder &b, Value leftInit, Value rightInit, Value newRight = arith.select(cond, mid, right); // } - b.create(ValueRange{newLeft, newRight}); + scf::YieldOp::create(b, b.getLoc(), ValueRange{newLeft, newRight}); } return whileOp.getResult(0); @@ -153,7 +154,7 @@ Value emitBinarySearch(ImplicitLocOpBuilder &b, Value leftInit, Value rightInit, SmallVector loadTensorElements(ImplicitLocOpBuilder &b, ValueRange tensors, Value index) { return llvm::to_vector(llvm::map_range(tensors, [&](Value tensor) -> Value { - return b.create(tensor, index); + return tensor::ExtractOp::create(b, b.getLoc(), tensor, index); })); } @@ -161,7 +162,7 @@ SmallVector loadMemrefElements(ImplicitLocOpBuilder &b, ValueRange memref { return llvm::to_vector(llvm::map_range(memrefs, [&](Value memref) -> Value { Type type = mlir::cast(memref.getType()).getElementType(); - return b.create(type, memref, index); + return memref::LoadOp::create(b, b.getLoc(), type, memref, index); })); } @@ -169,7 +170,7 @@ void storeMemrefElements(ImplicitLocOpBuilder &b, ValueRange memrefs, Value inde ValueRange values) { for (auto [value, memref] : llvm::zip(values, memrefs)) { - b.create(value, memref, index); + memref::StoreOp::create(b, b.getLoc(), value, memref, index); } } @@ -180,15 +181,15 @@ void emitInsertionSort(ImplicitLocOpBuilder &b, Value lo, Value hi, ValueRange i ValueRange outputMemrefs, mlir::Region &comparator) { ArithBuilder arith(b, b.getLoc()); - Value zero = b.create(0); - Value one = b.create(1); + Value zero = arith::ConstantIndexOp::create(b, b.getLoc(), 0); + Value one = arith::ConstantIndexOp::create(b, b.getLoc(), 1); // array[lo] = tensors[lo]; storeMemrefElements(b, outputMemrefs, lo, loadTensorElements(b, inputTensors, lo)); // for (int start = lo + 1; start < hi; ++start) { - auto forOp = b.create(arith.add(lo, one), hi, one); + auto forOp = scf::ForOp::create(b, b.getLoc(), arith.add(lo, one), hi, one); OpBuilder::InsertionGuard outerGuard(b); b.setInsertionPointToStart(forOp.getBody()); Value start = forOp.getInductionVar(); @@ -208,7 +209,7 @@ void emitInsertionSort(ImplicitLocOpBuilder &b, Value lo, Value hi, ValueRange i // (strides != 1). // 2. It implements memcpy semantics, but we need memmove here. // So we go with a loop instead. - auto copyForOp = b.create(zero, n, one); + auto copyForOp = scf::ForOp::create(b, b.getLoc(), zero, n, one); { OpBuilder::InsertionGuard innerGuard(b); b.setInsertionPointToStart(copyForOp.getBody()); @@ -238,7 +239,7 @@ void emitMerge(ImplicitLocOpBuilder &b, Value lo, Value mid, Value hi, ValueRang SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); // while( - auto whileOp = b.create(whileArgTypes, whileInitArgs); + auto whileOp = scf::WhileOp::create(b, b.getLoc(), whileArgTypes, whileInitArgs); { OpBuilder::InsertionGuard guard(b); { @@ -250,7 +251,7 @@ void emitMerge(ImplicitLocOpBuilder &b, Value lo, Value mid, Value hi, ValueRang Value inbounds0 = arith.slt(i0, mid); Value inbounds1 = arith.slt(i1, hi); - b.create(arith._and(inbounds0, inbounds1), before->getArguments()); + scf::ConditionOp::create(b, b.getLoc(), arith._and(inbounds0, inbounds1), before->getArguments()); } { @@ -268,16 +269,16 @@ void emitMerge(ImplicitLocOpBuilder &b, Value lo, Value mid, Value hi, ValueRang Value cmp = emitComparison(b, vals1, vals0, comparator); SmallVector pickedVals; for (auto [val0, val1] : llvm::zip(vals0, vals1)) { - pickedVals.push_back(b.create(cmp, val1, val0)); + pickedVals.push_back(SelectOp::create(b, b.getLoc(), cmp, val1, val0)); } storeMemrefElements(b, writeBufs, iOut, pickedVals); - Value one = b.create(1); - Value nexti0 = b.create(cmp, i0, arith.add(i0, one)); - Value nexti1 = b.create(cmp, arith.add(i1, one), i1); + Value one = arith::ConstantIndexOp::create(b, b.getLoc(), 1); + Value nexti0 = SelectOp::create(b, b.getLoc(), cmp, i0, arith.add(i0, one)); + Value nexti1 = SelectOp::create(b, b.getLoc(), cmp, arith.add(i1, one), i1); // ++iOut; - Value nextIOut = b.create(iOut, one); - b.create(ValueRange{nextIOut, nexti0, nexti1}); + Value nextIOut = AddIOp::create(b, b.getLoc(), iOut, one); + scf::YieldOp::create(b, b.getLoc(), ValueRange{nextIOut, nexti0, nexti1}); } } @@ -293,9 +294,9 @@ void emitMerge(ImplicitLocOpBuilder &b, Value lo, Value mid, Value hi, ValueRang Value end = arith.select(leftoverIn0, mid, hi); Value n = arith.sub(end, start); - Value zero = b.create(0); - Value one = b.create(1); - auto forOp = b.create(zero, n, one); + Value zero = arith::ConstantIndexOp::create(b, b.getLoc(), 0); + Value one = arith::ConstantIndexOp::create(b, b.getLoc(), 1); + auto forOp = scf::ForOp::create(b, b.getLoc(), zero, n, one); b.setInsertionPointToStart(forOp.getBody()); Value copyIndex = forOp.getBody()->getArgument(0); @@ -315,21 +316,21 @@ Value emitBottomUpMergeSort(ImplicitLocOpBuilder &b, Value lo, Value hi, int64_t ArithBuilder arith(b, b.getLoc()); Value size = arith.sub(hi, lo); - Value zero = b.create(0); - Value insertionSortSize = b.create(kInsertionSortSize); + Value zero = arith::ConstantIndexOp::create(b, b.getLoc(), 0); + Value insertionSortSize = arith::ConstantIndexOp::create(b, b.getLoc(), kInsertionSortSize); // Run insertion sort on blocks of size kInsertionSortSize. // for (int start = 0; start < size; start += kInsertionSortSize) { { - auto forOp = b.create(zero, size, insertionSortSize); + auto forOp = scf::ForOp::create(b, b.getLoc(), zero, size, insertionSortSize); OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(forOp.getBody()); Value start = forOp.getBody()->getArgument(0); - Value end = arith.add(b.create(arith.add(start, insertionSortSize), size), lo); + Value end = arith.add(MinSIOp::create(b, b.getLoc(), arith.add(start, insertionSortSize), size), lo); emitInsertionSort(b, start, end, inputTensors, outputs0, comparator); } - Value initParity = b.create(0, 1); + Value initParity = arith::ConstantIntOp::create(b, b.getLoc(), 0, 1); if (staticSortDimSize >= 0 && staticSortDimSize < kInsertionSortSize) { return initParity; } @@ -354,7 +355,7 @@ Value emitBottomUpMergeSort(ImplicitLocOpBuilder &b, Value lo, Value hi, int64_t SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); // while ( - auto whileOp = b.create(whileArgTypes, whileInitArgs); + auto whileOp = scf::WhileOp::create(b, b.getLoc(), whileArgTypes, whileInitArgs); OpBuilder::InsertionGuard guard(b); // currentSize < totalSize) @@ -362,7 +363,7 @@ Value emitBottomUpMergeSort(ImplicitLocOpBuilder &b, Value lo, Value hi, int64_t Block *before = b.createBlock(&whileOp.getBefore(), {}, whileArgTypes, whileArgLocs); Value currentSize = before->getArgument(0); b.setInsertionPointToEnd(before); - b.create(arith.slt(currentSize, size), before->getArguments()); + scf::ConditionOp::create(b, b.getLoc(), arith.slt(currentSize, size), before->getArguments()); } size_t numArgs = inputTensors.size(); @@ -379,25 +380,25 @@ Value emitBottomUpMergeSort(ImplicitLocOpBuilder &b, Value lo, Value hi, int64_t // for (int start = 0; start < size; start += 2*currentSize) { { - auto forOp = b.create(zero, size, twoCurrentSize); + auto forOp = scf::ForOp::create(b, b.getLoc(), zero, size, twoCurrentSize); b.setInsertionPointToStart(forOp.getBody()); Value start = forOp.getBody()->getArgument(0); - Value mid = b.create(size, arith.add(start, currentSize)); - Value end = b.create(size, arith.add(start, twoCurrentSize)); + Value mid = MinSIOp::create(b, b.getLoc(), size, arith.add(start, currentSize)); + Value end = MinSIOp::create(b, b.getLoc(), size, arith.add(start, twoCurrentSize)); emitMerge(b, start, mid, end, readBufs, writeBufs, comparator); b.setInsertionPointAfter(forOp); } // } // parity = !parity; - Value one = b.create(1, 1); + Value one = arith::ConstantIntOp::create(b, b.getLoc(), 1, 1); Value notParity = arith.sub(one, parity); // currentSize *= 2; SmallVector nextWhileArgs{twoCurrentSize, notParity}; llvm::copy(writeBufs, std::back_inserter(nextWhileArgs)); llvm::copy(readBufs, std::back_inserter(nextWhileArgs)); - b.create(nextWhileArgs); + scf::YieldOp::create(b, b.getLoc(), nextWhileArgs); } // } @@ -425,7 +426,7 @@ struct Slicer { RankedTensorType toSlicedType(RankedTensorType sourceType) { return tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - /*resultRank=*/1, sourceType, offsets, sizes, strides); + /*resultRank=*/1, sourceType, sizes); } MemRefType toSlicedType(MemRefType sourceType) @@ -437,7 +438,7 @@ struct Slicer { template Value slice(ImplicitLocOpBuilder &b, Value input) { Ty ty = mlir::cast(input.getType()); - return b.create(toSlicedType(ty), input, offsets, sizes, strides).getResult(); + return Op::create(b, b.getLoc(), toSlicedType(ty), input, offsets, sizes, strides).getResult(); } Value apply(ImplicitLocOpBuilder &b, Value input) @@ -486,16 +487,16 @@ struct SortOpPattern : public OpRewritePattern { auto firstOperandType = mlir::cast(firstOperand.getType()); int64_t inputRank = firstOperandType.getRank(); - Value sortDimSize = b.createOrFold( - firstOperand, b.create(op.getDimension())); + Value sortDimSize = tensor::DimOp::create(b, b.getLoc(), + firstOperand, arith::ConstantIndexOp::create(b, b.getLoc(), op.getDimension())); int64_t staticSortDimSize = firstOperandType.getDimSize(op.getDimension()); SmallVector dynamicDims; for (int i = 0; i < inputRank; ++i) { if (!firstOperandType.isDynamicDim(i)) continue; - Value index = b.create(i); - Value dimOp = b.create(firstOperand, index); + Value index = arith::ConstantIndexOp::create(b, b.getLoc(), i); + Value dimOp = tensor::DimOp::create(b, b.getLoc(), firstOperand, index); dynamicDims.push_back(dimOp); } @@ -506,25 +507,25 @@ struct SortOpPattern : public OpRewritePattern { auto inputType = mlir::cast(input.getType()); auto memRefType = MemRefType::get(inputType.getShape(), inputType.getElementType()); - outputMemrefs.push_back(b.create(memRefType, dynamicDims)); - scratchMemrefs.push_back(b.create(memRefType, dynamicDims)); + outputMemrefs.push_back(memref::AllocOp::create(b, b.getLoc(), memRefType, dynamicDims)); + scratchMemrefs.push_back(memref::AllocOp::create(b, b.getLoc(), memRefType, dynamicDims)); } b.setInsertionPoint(op); - Value zero = b.create(0); - Value one = b.create(1); + Value zero = arith::ConstantIndexOp::create(b, b.getLoc(), 0); + Value one = arith::ConstantIndexOp::create(b, b.getLoc(), 1); - Value forInitArg = b.create(0, 1); + Value forInitArg = arith::ConstantIntOp::create(b, b.getLoc(), 0, 1); SmallVector forOps; SmallVector ivs; forOps.reserve(inputRank - 1); ivs.reserve(inputRank - 1); for (int64_t i = 0; i < inputRank; ++i) { if (i != static_cast(op.getDimension())) { - Value dim = b.create(i); - Value ub = b.create(firstOperand, dim); + Value dim = arith::ConstantIndexOp::create(b, b.getLoc(), i); + Value ub = tensor::DimOp::create(b, b.getLoc(), firstOperand, dim); scf::ForOp &forOp = forOps.emplace_back( - b.create(zero, ub, one, ValueRange{forInitArg})); + scf::ForOp::create(b, b.getLoc(), zero, ub, one, ValueRange{forInitArg})); ivs.push_back(forOp.getInductionVar()); b.setInsertionPointToStart(&forOp.getRegion().front()); } @@ -541,15 +542,15 @@ struct SortOpPattern : public OpRewritePattern { // Pass the parity bit through the for loops. for (auto i = static_cast(forOps.size() - 1); i >= 0; --i) { b.setInsertionPointToEnd(&forOps[i].getRegion().front()); - b.create(ValueRange{parity}); + scf::YieldOp::create(b, b.getLoc(), ValueRange{parity}); parity = forOps[i]->getResult(0); } b.setInsertionPoint(op); SmallVector outputTensors; for (auto [out0, out1] : llvm::zip(outputMemrefs, scratchMemrefs)) { - Value s = b.create(parity, out1, out0).getResult(); - outputTensors.push_back(b.create( + Value s = SelectOp::create(b, b.getLoc(), parity, out1, out0).getResult(); + outputTensors.push_back(bufferization::ToTensorOp::create(b, b.getLoc(), memref::getTensorTypeFromMemRefType(s.getType()), s, /*restrict=*/true)); } diff --git a/mlir/lib/hlo-extensions/Transforms/stablehlo_legalize_to_standard.cpp b/mlir/lib/hlo-extensions/Transforms/stablehlo_legalize_to_standard.cpp index 469bae1654..dbb186a208 100644 --- a/mlir/lib/hlo-extensions/Transforms/stablehlo_legalize_to_standard.cpp +++ b/mlir/lib/hlo-extensions/Transforms/stablehlo_legalize_to_standard.cpp @@ -196,12 +196,12 @@ class ConvertIotaOp : public OpRewritePattern { auto intShapeType = RankedTensorType::get( outputType.getShape(), IntegerType::get(rewriter.getContext(), bitwidth)); auto loc = op.getLoc(); - auto integerConst = rewriter.create( + auto integerConst = mlir::arith::ConstantOp::create(rewriter, loc, DenseIntElementsAttr::get(intShapeType, values)); auto intOrFloatShapeTy = RankedTensorType::get(outputType.getShape(), intOrFloatTy); - auto iotaConst = rewriter.create(loc, intOrFloatShapeTy, integerConst); + auto iotaConst = ConvertOp::create(rewriter, loc, intOrFloatShapeTy, integerConst); // For int/float types we are done, replace op and return. if (!complexTy) { @@ -211,9 +211,9 @@ class ConvertIotaOp : public OpRewritePattern { // For complex types, generate a constant tensor of zeroes for the imaginary // part and use iota_const for real part. - auto zeroes = rewriter.create( + auto zeroes = mlir::arith::ConstantOp::create(rewriter, loc, DenseIntElementsAttr::get(intShapeType, APInt(bitwidth, 0))); - auto imagZeroes = rewriter.create(loc, intOrFloatShapeTy, zeroes); + auto imagZeroes = ConvertOp::create(rewriter, loc, intOrFloatShapeTy, zeroes); rewriter.replaceOpWithNewOp(op, iotaConst, imagZeroes); return success(); } diff --git a/mlir/llvm-project b/mlir/llvm-project index 113f01aa82..8f264586d7 160000 --- a/mlir/llvm-project +++ b/mlir/llvm-project @@ -1 +1 @@ -Subproject commit 113f01aa82d055410f22a9d03b3468fa68600589 +Subproject commit 8f264586d7521b0e305ca7bb78825aa3382ffef7 diff --git a/mlir/patches/llvm-python-bindinggen-annotations.patch b/mlir/patches/llvm-python-bindinggen-annotations.patch new file mode 100644 index 0000000000..8f568f808d --- /dev/null +++ b/mlir/patches/llvm-python-bindinggen-annotations.patch @@ -0,0 +1,13 @@ +diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +index 2c33f4efac3..78b9989b092 100644 +--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp ++++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +@@ -34,6 +34,8 @@ using llvm::RecordKeeper; + constexpr const char *fileHeader = R"Py( + # Autogenerated by mlir-tblgen; don't manually edit. + ++from __future__ import annotations ++ + from ._ods_common import _cext as _ods_cext + from ._ods_common import ( + equally_sized_accessor as _ods_equally_sized_accessor, diff --git a/mlir/stablehlo b/mlir/stablehlo index 0a4440a5c8..d496423cdb 160000 --- a/mlir/stablehlo +++ b/mlir/stablehlo @@ -1 +1 @@ -Subproject commit 0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d +Subproject commit d496423cdb7f7d5272f14d517681202a0b9cbe41