|
19 | 19 | #include "mlir/IR/OpDefinition.h" |
20 | 20 | #include "mlir/IR/PatternMatch.h" |
21 | 21 | #include "mlir/IR/Value.h" |
| 22 | +#include "mlir/Interfaces/FunctionImplementation.h" |
22 | 23 | #include "llvm/ADT/APInt.h" |
23 | 24 | #include "llvm/ADT/SmallVector.h" |
24 | 25 | #include "llvm/Support/Casting.h" |
@@ -626,3 +627,121 @@ void GambleOp::emitCNFWithoutInversion( |
626 | 627 | // out = allSet | ~orSet |
627 | 628 | circt::addOrClauses(outVar, {allSet, -orSet}, addClause); |
628 | 629 | } |
| 630 | + |
| 631 | +//===----------------------------------------------------------------------===// |
| 632 | +// CutRewritePatternOp |
| 633 | +//===----------------------------------------------------------------------===// |
| 634 | + |
| 635 | +ParseResult CutRewritePatternOp::parse(OpAsmParser &parser, |
| 636 | + OperationState &result) { |
| 637 | + |
| 638 | + SmallVector<OpAsmParser::Argument> entryArgs; |
| 639 | + SmallVector<Type> inputTypes; |
| 640 | + SmallVector<Type> resultTypes; |
| 641 | + SmallVector<DictionaryAttr> resultAttrs; |
| 642 | + bool isVariadic = false; |
| 643 | + |
| 644 | + if (function_interface_impl::parseFunctionSignatureWithArguments( |
| 645 | + parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, |
| 646 | + resultAttrs)) |
| 647 | + return failure(); |
| 648 | + |
| 649 | + inputTypes.reserve(entryArgs.size()); |
| 650 | + for (auto &arg : entryArgs) |
| 651 | + inputTypes.push_back(arg.type); |
| 652 | + |
| 653 | + auto functionType = |
| 654 | + parser.getBuilder().getFunctionType(inputTypes, resultTypes); |
| 655 | + |
| 656 | + NamedAttrList parsedAttributes; |
| 657 | + auto attrDictLoc = parser.getCurrentLocation(); |
| 658 | + if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) |
| 659 | + return failure(); |
| 660 | + |
| 661 | + if (parsedAttributes.get(getFunctionTypeAttrName(result.name))) |
| 662 | + return parser.emitError(attrDictLoc, "'function_type' is an inferred " |
| 663 | + "attribute and should not be " |
| 664 | + "specified in the explicit attribute " |
| 665 | + "dictionary"); |
| 666 | + |
| 667 | + result.attributes.append(parsedAttributes); |
| 668 | + result.addAttribute(getFunctionTypeAttrName(result.name), |
| 669 | + TypeAttr::get(functionType)); |
| 670 | + |
| 671 | + auto *body = result.addRegion(); |
| 672 | + if (parser.parseRegion(*body, entryArgs, |
| 673 | + /*enableNameShadowing=*/false)) |
| 674 | + return failure(); |
| 675 | + |
| 676 | + return success(); |
| 677 | +} |
| 678 | + |
| 679 | +void CutRewritePatternOp::print(OpAsmPrinter &p) { |
| 680 | + |
| 681 | + ArrayRef<Type> resultTypes = getFunctionType().getResults(); |
| 682 | + |
| 683 | + p << " ("; |
| 684 | + llvm::interleaveComma(getBody().front().getArguments(), p, [&](auto arg) { |
| 685 | + p.printRegionArgument(arg); |
| 686 | + p << ": "; |
| 687 | + p.printType(arg.getType()); |
| 688 | + }); |
| 689 | + p << ") -> "; |
| 690 | + |
| 691 | + if (resultTypes.size() == 1) { |
| 692 | + p.printType(resultTypes.front()); |
| 693 | + } else { |
| 694 | + p << '('; |
| 695 | + llvm::interleaveComma(resultTypes, p, |
| 696 | + [&](auto type) { p.printType(type); }); |
| 697 | + p << ')'; |
| 698 | + } |
| 699 | + |
| 700 | + p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), |
| 701 | + {getFunctionTypeAttrName()}); |
| 702 | + |
| 703 | + p << ' '; |
| 704 | + p.printRegion(getBody(), /*printEntryBlockArgs=*/false, |
| 705 | + /*printBlockTerminators=*/true); |
| 706 | +} |
| 707 | + |
| 708 | +LogicalResult CutRewritePatternOp::verify() { |
| 709 | + auto functionType = getFunctionType(); |
| 710 | + |
| 711 | + if (functionType.getNumResults() != 1) |
| 712 | + return emitError() << "requires exactly one result"; |
| 713 | + |
| 714 | + for (auto type : functionType.getInputs()) |
| 715 | + if (!type.isInteger(1)) |
| 716 | + return emitError() << "expected i1 input type but got " << type; |
| 717 | + |
| 718 | + for (auto type : functionType.getResults()) |
| 719 | + if (!type.isInteger(1)) |
| 720 | + return emitError() << "expected i1 result type but got " << type; |
| 721 | + |
| 722 | + // Check outputs. |
| 723 | + auto *terminator = this->getBody().front().getTerminator(); |
| 724 | + auto yield = dyn_cast<YieldOp>(terminator); |
| 725 | + if (!yield) |
| 726 | + return emitError() << "body must terminate with synth.yield"; |
| 727 | + |
| 728 | + if (terminator->getOperands().size() != functionType.getNumResults()) |
| 729 | + return emitError() << "result type doesn't match with the terminator"; |
| 730 | + |
| 731 | + for (auto [lhs, rhs] : llvm::zip(terminator->getOperands().getTypes(), |
| 732 | + functionType.getResults())) |
| 733 | + if (rhs != lhs) |
| 734 | + return emitError() << rhs << " is expected but got " << lhs; |
| 735 | + |
| 736 | + auto blockArgs = this->getBody().front().getArguments(); |
| 737 | + if (blockArgs.size() != functionType.getNumInputs()) |
| 738 | + return emitError() << "operand type doesn't match with the block arg"; |
| 739 | + |
| 740 | + for (auto [blockArg, inputType] : |
| 741 | + llvm::zip(blockArgs, functionType.getInputs())) |
| 742 | + if (blockArg.getType() != inputType) |
| 743 | + return emitError() << inputType << " is expected but got " |
| 744 | + << blockArg.getType(); |
| 745 | + |
| 746 | + return success(); |
| 747 | +} |
0 commit comments