Skip to content

Commit 8e7d483

Browse files
committed
[Synth] Add an operation for declarative Cut rewrite pattern
1 parent 7aa41cc commit 8e7d483

9 files changed

Lines changed: 229 additions & 12 deletions

File tree

include/circt/Dialect/Synth/SynthAttributes.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,14 @@ def MappingCostAttr : AttrDef<Synth_Dialect, "MappingCost"> {
7676
let summary = "Simplified timing and area cost for tech mapping";
7777
let parameters = (ins
7878
"::mlir::FloatAttr":$area,
79-
"::mlir::ArrayAttr":$arcs,
80-
"::mlir::DictionaryAttr":$inputCaps
79+
OptionalParameter<"::mlir::ArrayAttr">:$arcs,
80+
OptionalParameter<"::mlir::DictionaryAttr">:$inputCaps
8181
);
82+
let genVerifyDecl = 1;
8283
let assemblyFormat =
83-
"`<` `area` `=` $area `,` `arcs` `=` $arcs `,` "
84-
"`input_caps` `=` $inputCaps `>`";
84+
"`<` `area` `=` $area "
85+
"(`,` `arcs` `=` $arcs^)? "
86+
"(`,` `input_caps` `=` $inputCaps^)? `>`";
8587
}
8688

8789
#endif // CIRCT_DIALECT_SYNTH_SYNTHATTRIBUTES_TD

include/circt/Dialect/Synth/SynthOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef CIRCT_DIALECT_SYNTH_SYNTHOPS_H
1414
#define CIRCT_DIALECT_SYNTH_SYNTHOPS_H
1515

16+
#include "circt/Dialect/Synth/SynthAttributes.h"
1617
#include "circt/Dialect/Synth/SynthDialect.h"
1718
#include "circt/Dialect/Synth/SynthOpInterfaces.h"
1819
#include "circt/Support/LLVM.h"

include/circt/Dialect/Synth/SynthOps.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,29 @@ def GambleOp : SymmetricThreeInputOp<"gamble", "evaluateGambleLogic"> {
305305
}];
306306
}
307307

308+
def CutRewritePatternOp : SynthOp<"cut_rewrite_pattern", [
309+
IsolatedFromAbove,
310+
SingleBlockImplicitTerminator<"YieldOp">
311+
]> {
312+
let summary = "Declarative cut rewrite pattern";
313+
314+
let arguments = (ins
315+
TypeAttrOf<FunctionType>:$function_type,
316+
MappingCostAttr:$cost
317+
);
318+
319+
let regions = (region SizedRegion<1>:$body);
320+
let hasVerifier = 1;
321+
let hasCustomAssemblyFormat = 1;
322+
}
323+
324+
def YieldOp : SynthOp<"yield",
325+
[Pure, Terminator]> {
326+
let summary = "Yield synth operations";
327+
328+
let arguments = (ins Variadic<AnyType>:$operands);
329+
let assemblyFormat = "$operands attr-dict `:` type($operands)";
330+
}
308331

309332

310333
#endif // CIRCT_DIALECT_SYNTH_SYNTHOPS_TD

lib/Dialect/Synth/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
##===----------------------------------------------------------------------===//
66

77
add_circt_dialect_library(CIRCTSynth
8+
SynthAttributes.cpp
89
SynthDialect.cpp
910
SynthOpInterfaces.cpp
1011
SynthOps.cpp
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "circt/Dialect/Synth/SynthAttributes.h"
10+
#include "mlir/IR/Builders.h"
11+
#include "mlir/IR/DialectImplementation.h"
12+
#include "llvm/ADT/TypeSwitch.h"
13+
14+
using namespace circt;
15+
using namespace circt::synth;
16+
using namespace mlir;
17+
18+
//===----------------------------------------------------------------------===//
19+
// MappingCostAttr
20+
//===----------------------------------------------------------------------===//
21+
22+
LogicalResult
23+
MappingCostAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
24+
FloatAttr area, ArrayAttr arcs,
25+
DictionaryAttr inputCaps) {
26+
if (arcs)
27+
for (auto attr : arcs)
28+
if (!isa<LinearTimingArcAttr>(attr))
29+
return emitError()
30+
<< "expected arcs to contain synth.linear_timing_arc";
31+
32+
if (inputCaps)
33+
for (auto entry : inputCaps)
34+
if (!isa<FloatAttr>(entry.getValue()))
35+
return emitError()
36+
<< "expected input_caps values to be floating-point attributes";
37+
38+
return success();
39+
}

lib/Dialect/Synth/SynthOps.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
#include "mlir/IR/OpDefinition.h"
2020
#include "mlir/IR/PatternMatch.h"
2121
#include "mlir/IR/Value.h"
22+
#include "mlir/Interfaces/CallInterfaces.h"
23+
#include "mlir/Interfaces/FunctionImplementation.h"
2224
#include "llvm/ADT/APInt.h"
25+
#include "llvm/ADT/STLExtras.h"
2326
#include "llvm/ADT/SmallVector.h"
2427
#include "llvm/Support/Casting.h"
2528
#include "llvm/Support/LogicalResult.h"
@@ -626,3 +629,98 @@ void GambleOp::emitCNFWithoutInversion(
626629
// out = allSet | ~orSet
627630
circt::addOrClauses(outVar, {allSet, -orSet}, addClause);
628631
}
632+
633+
//===----------------------------------------------------------------------===//
634+
// CutRewritePatternOp
635+
//===----------------------------------------------------------------------===//
636+
637+
ParseResult CutRewritePatternOp::parse(OpAsmParser &parser,
638+
OperationState &result) {
639+
640+
SmallVector<OpAsmParser::Argument> entryArgs;
641+
SmallVector<Type> resultTypes;
642+
SmallVector<DictionaryAttr> resultAttrs;
643+
bool isVariadic = false;
644+
645+
if (function_interface_impl::parseFunctionSignatureWithArguments(
646+
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
647+
resultAttrs))
648+
return failure();
649+
650+
auto inputTypes = llvm::map_to_vector(
651+
entryArgs, [](auto &arg) -> Type { return arg.type; });
652+
auto functionType =
653+
parser.getBuilder().getFunctionType(inputTypes, resultTypes);
654+
655+
result.addAttribute(getFunctionTypeAttrName(result.name),
656+
TypeAttr::get(functionType));
657+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
658+
return failure();
659+
660+
return parser.parseRegion(*result.addRegion(), entryArgs,
661+
/*enableNameShadowing=*/false);
662+
}
663+
664+
void CutRewritePatternOp::print(OpAsmPrinter &p) {
665+
auto functionType = getFunctionType();
666+
call_interface_impl::printFunctionSignature(
667+
p, functionType.getInputs(), /*argAttrs=*/{}, /*isVariadic=*/false,
668+
functionType.getResults(), /*resultAttrs=*/{}, &getBody(),
669+
/*printEmptyResult=*/false);
670+
671+
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
672+
{getFunctionTypeAttrName()});
673+
674+
p << ' ';
675+
p.printRegion(getBody(), /*printEntryBlockArgs=*/false,
676+
/*printBlockTerminators=*/true);
677+
}
678+
679+
LogicalResult CutRewritePatternOp::verify() {
680+
auto functionType = getFunctionType();
681+
682+
if (functionType.getNumResults() != 1)
683+
return emitError() << "requires exactly one result";
684+
685+
for (auto type : functionType.getInputs())
686+
if (!type.isInteger(1))
687+
return emitError() << "argument type must be i1, but got " << type;
688+
689+
for (auto type : functionType.getResults())
690+
if (!type.isInteger(1))
691+
return emitError() << "result type must be i1, but got " << type;
692+
693+
// Check outputs.
694+
auto *terminator = this->getBody().front().getTerminator();
695+
if (terminator->getOperands().size() != functionType.getNumResults())
696+
return emitError() << "result type doesn't match with the terminator";
697+
698+
for (auto [lhs, rhs] : llvm::zip(terminator->getOperands().getTypes(),
699+
functionType.getResults()))
700+
if (rhs != lhs)
701+
return emitError() << rhs << " is expected but got " << lhs;
702+
703+
auto blockArgs = this->getBody().front().getArguments();
704+
if (blockArgs.size() != functionType.getNumInputs())
705+
return emitError() << "operand type doesn't match with the block arg";
706+
707+
for (auto [blockArg, inputType] :
708+
llvm::zip(blockArgs, functionType.getInputs()))
709+
if (blockArg.getType() != inputType)
710+
return emitError() << inputType << " is expected but got "
711+
<< blockArg.getType();
712+
713+
auto cost = getCost();
714+
if (auto arcs = cost.getArcs())
715+
if (!arcs.empty())
716+
return emitError()
717+
<< "mapping cost arcs for cut rewrite patterns must not use "
718+
"input/output names";
719+
720+
if (auto inputCaps = cost.getInputCaps())
721+
if (inputCaps.size() != functionType.getNumInputs())
722+
return emitError()
723+
<< "input_caps size must match the number of arguments";
724+
725+
return success();
726+
}

lib/Dialect/Synth/Transforms/TechMapper.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,15 @@ struct TechMapperPass : public impl::TechMapperBase<TechMapperPass> {
222222
}
223223

224224
llvm::DenseMap<StringAttr, DelayType> delayByInput;
225-
for (auto attr : mappingCost.getArcs()) {
226-
auto arc = cast<LinearTimingArcAttr>(attr);
227-
if (!arc) {
228-
hwModule.emitError(
229-
"expected synth.linear_timing_arc in synth.mapping_cost arcs");
230-
signalPassFailure();
231-
return;
232-
}
225+
auto arcs = mappingCost.getArcs();
226+
if (!arcs) {
227+
hwModule.emitError(
228+
"expected synth.linear_timing_arc in synth.mapping_cost arcs");
229+
signalPassFailure();
230+
return;
231+
}
232+
for (auto attr : arcs) {
233+
auto arc = dyn_cast<LinearTimingArcAttr>(attr);
233234

234235
if (arc.getPin() != outputName) {
235236
hwModule.emitError("mapping cost arc output '")

test/Dialect/Synth/errors.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,41 @@ hw.module @test(out result : i1) {
55
%0 = synth.choice : i1
66
hw.output %0 : i1
77
}
8+
9+
// -----
10+
11+
// expected-error @below {{argument type must be i1, but got 'i2'}}
12+
synth.cut_rewrite_pattern (%a: i2) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64>} {
13+
%0 = comb.extract %a from 0 : (i2) -> i1
14+
synth.yield %0 : i1
15+
}
16+
17+
// -----
18+
19+
// expected-error @below {{result type must be i1, but got 'i2'}}
20+
synth.cut_rewrite_pattern (%a: i1) -> i2 attributes {cost = #synth.mapping_cost<area = 1.0 : f64>} {
21+
%0 = hw.constant 0 : i2
22+
synth.yield %0 : i2
23+
}
24+
25+
// -----
26+
27+
// expected-error @below {{requires exactly one result}}
28+
synth.cut_rewrite_pattern (%a: i1) -> (i1, i1) attributes {cost = #synth.mapping_cost<area = 1.0 : f64>} {
29+
synth.yield %a, %a : i1, i1
30+
}
31+
32+
// -----
33+
34+
// expected-error @below {{result type doesn't match with the terminator}}
35+
synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64>} {
36+
"synth.yield"() : () -> ()
37+
}
38+
39+
// -----
40+
41+
// expected-error @below {{'i1' is expected but got 'i2'}}
42+
synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64>} {
43+
%0 = hw.constant 0 : i2
44+
synth.yield %0 : i2
45+
}

test/Dialect/Synth/round-trip.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,17 @@ hw.module @mux_inv(in %c: i4, in %a: i4, in %b: i4) {
5454
hw.module @gamble(in %x: i1, in %y: i1, in %z: i1) {
5555
%0 = synth.gamble %x, not %y, %z : i1
5656
}
57+
58+
// CHECK-LABEL: synth.cut_rewrite_pattern
59+
// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1, %{{.*}}: i1) -> i1
60+
synth.cut_rewrite_pattern (%a: i1, %b: i1, %c: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64>} {
61+
%0 = synth.aig.and_inv %a, not %b, %c : i1
62+
synth.yield %0 : i1
63+
}
64+
65+
// CHECK-LABEL: synth.cut_rewrite_pattern
66+
// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1) -> i1 attributes {cost = #synth.mapping_cost<area =
67+
synth.cut_rewrite_pattern (%a: i1, %b: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64>} {
68+
%0 = synth.aig.and_inv %a, %b : i1
69+
synth.yield %0 : i1
70+
}

0 commit comments

Comments
 (0)