Skip to content

Commit 833947f

Browse files
committed
[Synth] Add an operation for declarative Cut rewrite pattern
1 parent 7aa41cc commit 833947f

4 files changed

Lines changed: 192 additions & 0 deletions

File tree

include/circt/Dialect/Synth/SynthOps.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,27 @@ def GambleOp : SymmetricThreeInputOp<"gamble", "evaluateGambleLogic"> {
305305
}];
306306
}
307307

308+
def CutRewritePatternOp : SynthOp<"cut_rewrite_pattern", [
309+
IsolatedFromAbove,
310+
SingleBlockImplicitTerminator<"YieldOp">
311+
]> {
312+
let summary = "Declarative cut rewrite pattern";
313+
314+
let arguments = (ins
315+
TypeAttrOf<FunctionType>:$function_type);
316+
317+
let regions = (region SizedRegion<1>:$body);
318+
let hasVerifier = 1;
319+
let hasCustomAssemblyFormat = 1;
320+
}
321+
322+
def YieldOp : SynthOp<"yield",
323+
[Pure, Terminator]> {
324+
let summary = "Yield synth operations";
325+
326+
let arguments = (ins Variadic<AnyType>:$operands);
327+
let assemblyFormat = "$operands attr-dict `:` type($operands)";
328+
}
308329

309330

310331
#endif // CIRCT_DIALECT_SYNTH_SYNTHOPS_TD

lib/Dialect/Synth/SynthOps.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/OpDefinition.h"
2020
#include "mlir/IR/PatternMatch.h"
2121
#include "mlir/IR/Value.h"
22+
#include "mlir/Interfaces/FunctionImplementation.h"
2223
#include "llvm/ADT/APInt.h"
2324
#include "llvm/ADT/SmallVector.h"
2425
#include "llvm/Support/Casting.h"
@@ -626,3 +627,121 @@ void GambleOp::emitCNFWithoutInversion(
626627
// out = allSet | ~orSet
627628
circt::addOrClauses(outVar, {allSet, -orSet}, addClause);
628629
}
630+
631+
//===----------------------------------------------------------------------===//
632+
// CutRewritePatternOp
633+
//===----------------------------------------------------------------------===//
634+
635+
ParseResult CutRewritePatternOp::parse(OpAsmParser &parser,
636+
OperationState &result) {
637+
638+
SmallVector<OpAsmParser::Argument> entryArgs;
639+
SmallVector<Type> inputTypes;
640+
SmallVector<Type> resultTypes;
641+
SmallVector<DictionaryAttr> resultAttrs;
642+
bool isVariadic = false;
643+
644+
if (function_interface_impl::parseFunctionSignatureWithArguments(
645+
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
646+
resultAttrs))
647+
return failure();
648+
649+
inputTypes.reserve(entryArgs.size());
650+
for (auto &arg : entryArgs)
651+
inputTypes.push_back(arg.type);
652+
653+
auto functionType =
654+
parser.getBuilder().getFunctionType(inputTypes, resultTypes);
655+
656+
NamedAttrList parsedAttributes;
657+
auto attrDictLoc = parser.getCurrentLocation();
658+
if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes))
659+
return failure();
660+
661+
if (parsedAttributes.get(getFunctionTypeAttrName(result.name)))
662+
return parser.emitError(attrDictLoc, "'function_type' is an inferred "
663+
"attribute and should not be "
664+
"specified in the explicit attribute "
665+
"dictionary");
666+
667+
result.attributes.append(parsedAttributes);
668+
result.addAttribute(getFunctionTypeAttrName(result.name),
669+
TypeAttr::get(functionType));
670+
671+
auto *body = result.addRegion();
672+
if (parser.parseRegion(*body, entryArgs,
673+
/*enableNameShadowing=*/false))
674+
return failure();
675+
676+
return success();
677+
}
678+
679+
void CutRewritePatternOp::print(OpAsmPrinter &p) {
680+
681+
ArrayRef<Type> resultTypes = getFunctionType().getResults();
682+
683+
p << " (";
684+
llvm::interleaveComma(getBody().front().getArguments(), p, [&](auto arg) {
685+
p.printRegionArgument(arg);
686+
p << ": ";
687+
p.printType(arg.getType());
688+
});
689+
p << ") -> ";
690+
691+
if (resultTypes.size() == 1) {
692+
p.printType(resultTypes.front());
693+
} else {
694+
p << '(';
695+
llvm::interleaveComma(resultTypes, p,
696+
[&](auto type) { p.printType(type); });
697+
p << ')';
698+
}
699+
700+
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
701+
{getFunctionTypeAttrName()});
702+
703+
p << ' ';
704+
p.printRegion(getBody(), /*printEntryBlockArgs=*/false,
705+
/*printBlockTerminators=*/true);
706+
}
707+
708+
LogicalResult CutRewritePatternOp::verify() {
709+
auto functionType = getFunctionType();
710+
711+
if (functionType.getNumResults() != 1)
712+
return emitError() << "requires exactly one result";
713+
714+
for (auto type : functionType.getInputs())
715+
if (!type.isInteger(1))
716+
return emitError() << "expected i1 input type but got " << type;
717+
718+
for (auto type : functionType.getResults())
719+
if (!type.isInteger(1))
720+
return emitError() << "expected i1 result type but got " << type;
721+
722+
// Check outputs.
723+
auto *terminator = this->getBody().front().getTerminator();
724+
auto yield = dyn_cast<YieldOp>(terminator);
725+
if (!yield)
726+
return emitError() << "body must terminate with synth.yield";
727+
728+
if (terminator->getOperands().size() != functionType.getNumResults())
729+
return emitError() << "result type doesn't match with the terminator";
730+
731+
for (auto [lhs, rhs] : llvm::zip(terminator->getOperands().getTypes(),
732+
functionType.getResults()))
733+
if (rhs != lhs)
734+
return emitError() << rhs << " is expected but got " << lhs;
735+
736+
auto blockArgs = this->getBody().front().getArguments();
737+
if (blockArgs.size() != functionType.getNumInputs())
738+
return emitError() << "operand type doesn't match with the block arg";
739+
740+
for (auto [blockArg, inputType] :
741+
llvm::zip(blockArgs, functionType.getInputs()))
742+
if (blockArg.getType() != inputType)
743+
return emitError() << inputType << " is expected but got "
744+
<< blockArg.getType();
745+
746+
return success();
747+
}

test/Dialect/Synth/errors.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,41 @@ hw.module @test(out result : i1) {
55
%0 = synth.choice : i1
66
hw.output %0 : i1
77
}
8+
9+
// -----
10+
11+
// expected-error @below {{expected i1 input type but got 'i2'}}
12+
synth.cut_rewrite_pattern (%a: i2) -> i1 {
13+
%0 = comb.extract %a from 0 : (i2) -> i1
14+
synth.yield %0 : i1
15+
}
16+
17+
// -----
18+
19+
// expected-error @below {{expected i1 result type but got 'i2'}}
20+
synth.cut_rewrite_pattern (%a: i1) -> i2 {
21+
%0 = hw.constant 0 : i2
22+
synth.yield %0 : i2
23+
}
24+
25+
// -----
26+
27+
// expected-error @below {{requires exactly one result}}
28+
synth.cut_rewrite_pattern (%a: i1) -> (i1, i1) {
29+
synth.yield %a, %a : i1, i1
30+
}
31+
32+
// -----
33+
34+
// expected-error @below {{result type doesn't match with the terminator}}
35+
synth.cut_rewrite_pattern (%a: i1) -> i1 {
36+
"synth.yield"() : () -> ()
37+
}
38+
39+
// -----
40+
41+
// expected-error @below {{'i1' is expected but got 'i2'}}
42+
synth.cut_rewrite_pattern (%a: i1) -> i1 {
43+
%0 = hw.constant 0 : i2
44+
synth.yield %0 : i2
45+
}

test/Dialect/Synth/round-trip.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,17 @@ hw.module @mux_inv(in %c: i4, in %a: i4, in %b: i4) {
5454
hw.module @gamble(in %x: i1, in %y: i1, in %z: i1) {
5555
%0 = synth.gamble %x, not %y, %z : i1
5656
}
57+
58+
// CHECK-LABEL: synth.cut_rewrite_pattern
59+
// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1, %{{.*}}: i1) -> i1
60+
synth.cut_rewrite_pattern (%a: i1, %b: i1, %c: i1) -> i1 {
61+
%0 = synth.aig.and_inv %a, not %b, %c : i1
62+
synth.yield %0 : i1
63+
}
64+
65+
// CHECK-LABEL: synth.cut_rewrite_pattern
66+
// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1) -> i1 attributes {synth.mapping_cost = #synth.mapping_cost<area =
67+
synth.cut_rewrite_pattern (%a: i1, %b: i1) -> i1 attributes {synth.mapping_cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
68+
%0 = synth.aig.and_inv %a, %b : i1
69+
synth.yield %0 : i1
70+
}

0 commit comments

Comments
 (0)