diff --git a/compiler/packages/babel-plugin-react-compiler/src/ReactiveIR/BuildReactiveGraph.ts b/compiler/packages/babel-plugin-react-compiler/src/ReactiveIR/BuildReactiveGraph.ts index 61c5de321675..92c95323cd56 100644 --- a/compiler/packages/babel-plugin-react-compiler/src/ReactiveIR/BuildReactiveGraph.ts +++ b/compiler/packages/babel-plugin-react-compiler/src/ReactiveIR/BuildReactiveGraph.ts @@ -9,6 +9,7 @@ import {CompilerError, SourceLocation} from '..'; import { BlockId, DeclarationId, + GotoVariant, HIRFunction, Identifier, IdentifierId, @@ -23,9 +24,10 @@ import { eachInstructionValueOperand, terminalFallthrough, } from '../HIR/visitors'; -import {getOrInsertWith} from '../Utils/utils'; +import {assertExhaustive, getOrInsertWith} from '../Utils/utils'; import { BranchNode, + GotoNode, ControlNode, EntryNode, InstructionNode, @@ -42,18 +44,19 @@ import { ReturnNode, reversePostorderReactiveGraph, StoreNode, + eachNodeDependency, } from './ReactiveIR'; export function buildReactiveGraph(fn: HIRFunction): ReactiveGraph { const builder = new Builder(); - const context = new ControlContext(); + const context = new ControlContext(builder, {kind: 'Function'}); const entryNode: EntryNode = { kind: 'Entry', id: builder.nextReactiveId, loc: fn.loc, outputs: [], }; - builder.putNode(entryNode); + context.putNode(entryNode); for (const param of fn.params) { const place = param.kind === 'Identifier' ? param : param.place; const node: LoadArgumentNode = { @@ -64,17 +67,11 @@ export function buildReactiveGraph(fn: HIRFunction): ReactiveGraph { place: {...place}, control: entryNode.id, }; - builder.putNode(node); + context.putNode(node); context.recordDeclarationWrite(place.identifier.declarationId, node.id); } - const exitNode = buildBlockScope( - fn, - builder, - context, - fn.body.entry, - entryNode.id, - ); + const exitNode = buildBlockScope(fn, context, fn.body.entry, entryNode.id); const graph = builder.build(fn, exitNode); @@ -109,19 +106,6 @@ class Builder { return makeReactiveId(this.#nextNodeId++); } - controlNode(control: ReactiveId, loc: SourceLocation): ReactiveId { - const node: ControlNode = { - kind: 'Control', - id: this.nextReactiveId, - loc, - outputs: [], - control, - dependencies: [], - }; - this.putNode(node); - return node.id; - } - putNode(node: ReactiveNode): void { this.#nodes.set(node.id, node); } @@ -140,18 +124,27 @@ class Builder { } } +type Fallthrough = + | {kind: 'Function'} + | {kind: 'If'; block: BlockId; fallthrough: ReactiveId}; + class ControlContext { constructor( + private builder: Builder, + private fallthrough: Fallthrough, private declarations: Map< DeclarationId, {write: ReactiveId | null; read: ReactiveId | null} > = new Map(), private scopes: Map = new Map(), + private nodes: Set = new Set(), private parent: ControlContext | null = null, ) {} - fork(): ControlContext { + fork(fallthrough: Fallthrough): ControlContext { return new ControlContext( + this.builder, + fallthrough, /* * We reset these maps so that the first reference of each declaration/scope within the branch * will depend on the branch's control. subsequent references can then refer to branch-local @@ -159,10 +152,72 @@ class ControlContext { */ new Map(), new Map(), + new Set(), this, ); } + controlNode(control: ReactiveId, loc: SourceLocation): ReactiveId { + const node: ControlNode = { + kind: 'Control', + id: this.nextReactiveId, + loc, + outputs: [], + control, + dependencies: [], + }; + this.putNode(node); + return node.id; + } + + putNode(node: ReactiveNode): void { + for (const dep of eachNodeDependency(node)) { + this.nodes.delete(dep); + } + this.nodes.add(node.id); + this.builder.putNode(node); + } + + /* + * Returns the nodes added to this context which are not yet depended on + * by other nodes + */ + uncontolledNodes(): Array { + return Array.from(this.nodes); + } + + get nextReactiveId(): ReactiveId { + return this.builder.nextReactiveId; + } + + storeTemporary(place: Place, node: ReactiveId): void { + this.builder.storeTemporary(place, node); + } + + lookupTemporary(identifier: Identifier, loc: SourceLocation): ReactiveId { + return this.builder.lookupTemporary(identifier, loc); + } + + loadBreakTarget(target: BlockId, loc: SourceLocation): ReactiveId { + if (this.fallthrough.kind === 'If' && this.fallthrough.block === target) { + return this.fallthrough.fallthrough; + } + if (this.parent != null) { + return this.parent.loadBreakTarget(target, loc); + } + CompilerError.invariant(false, { + reason: `Cannot find break target for bb${target}`, + loc, + }); + } + + loadContinueTarget(target: BlockId, loc: SourceLocation): ReactiveId { + CompilerError.invariant(false, { + reason: `Cannot find continue target for bb${target}`, + loc, + }); + } + // Scopes *eachScope(): Iterable<[ScopeId, ReactiveId]> { @@ -266,7 +321,6 @@ class ControlContext { function buildBlockScope( fn: HIRFunction, - builder: Builder, context: ControlContext, entry: BlockId, control: ReactiveId, @@ -308,7 +362,7 @@ function buildBlockScope( const node: LoadNode = { kind: 'Load', control: instructionControl, - id: builder.nextReactiveId, + id: context.nextReactiveId, loc: value.loc, outputs: [], value: { @@ -317,9 +371,9 @@ function buildBlockScope( as: lvalue, }, }; - builder.putNode(node); + context.putNode(node); lastNode = node.id; - builder.storeTemporary(lvalue, node.id); + context.storeTemporary(lvalue, node.id); // Record that we read so that subsequent writes can be sequenced after context.recordDeclarationRead( value.place.identifier.declarationId, @@ -337,7 +391,7 @@ function buildBlockScope( ) ?? control; // Lookup the node that defines the temporary we're storing - const valueNode = builder.lookupTemporary( + const valueNode = context.lookupTemporary( value.value.identifier, value.value.loc, ); @@ -345,7 +399,7 @@ function buildBlockScope( const node: StoreNode = { kind: 'Store', control: instructionControl, - id: builder.nextReactiveId, + id: context.nextReactiveId, instructionKind: value.lvalue.kind, loc: value.loc, lvalue: value.lvalue.place, @@ -356,9 +410,9 @@ function buildBlockScope( as: value.value, }, }; - builder.putNode(node); + context.putNode(node); lastNode = node.id; - builder.storeTemporary(lvalue, node.id); + context.storeTemporary(lvalue, node.id); // Record that the value was written so subsequent reads/writes can be sequenced after context.recordDeclarationWrite( value.lvalue.place.identifier.declarationId, @@ -403,7 +457,7 @@ function buildBlockScope( const dependencies: NodeDependencies = new Map(); for (const operand of eachInstructionValueOperand(instr.value)) { - const dep = builder.lookupTemporary(operand.identifier, operand.loc); + const dep = context.lookupTemporary(operand.identifier, operand.loc); dependencies.set(dep, { from: {...operand}, as: {...operand}, @@ -413,14 +467,14 @@ function buildBlockScope( kind: 'Value', control: instructionControl, dependencies, - id: builder.nextReactiveId, + id: context.nextReactiveId, loc: instr.loc, outputs: [], value: instr, }; - builder.putNode(node); + context.putNode(node); lastNode = node.id; - builder.storeTemporary(lvalue, node.id); + context.storeTemporary(lvalue, node.id); if (instructionScope != null) { context.recordScopeMutation(instructionScope.id, node.id); } @@ -431,7 +485,7 @@ function buildBlockScope( const terminal = block.terminal; switch (terminal.kind) { case 'if': { - const testDep = builder.lookupTemporary( + const testDep = context.lookupTemporary( terminal.test.identifier, terminal.test.loc, ); @@ -444,27 +498,37 @@ function buildBlockScope( kind: 'Branch', control, dependencies: [], - id: builder.nextReactiveId, + id: context.nextReactiveId, loc: terminal.loc, outputs: [], }; - builder.putNode(branch); - const consequentContext = context.fork(); - const consequentControl = builder.controlNode(branch.id, terminal.loc); + context.putNode(branch); + const joinNodeId = context.nextReactiveId; + const joinFallthrough = { + kind: 'If', + block: terminal.fallthrough, + fallthrough: joinNodeId, + } as const; + const consequentContext = context.fork(joinFallthrough); + const consequentControl = consequentContext.controlNode( + branch.id, + terminal.loc, + ); const consequent = buildBlockScope( fn, - builder, consequentContext, terminal.consequent, consequentControl, ); - const alternateContext = context.fork(); - const alternateControl = builder.controlNode(branch.id, terminal.loc); + const alternateContext = context.fork(joinFallthrough); + const alternateControl = alternateContext.controlNode( + branch.id, + terminal.loc, + ); const alternate = terminal.alternate !== terminal.fallthrough ? buildBlockScope( fn, - builder, alternateContext, terminal.alternate, alternateControl, @@ -473,7 +537,7 @@ function buildBlockScope( const ifNode: JoinNode = { kind: 'Join', control: branch.id, - id: builder.nextReactiveId, + id: joinNodeId, loc: terminal.loc, outputs: [], terminal: { @@ -568,12 +632,12 @@ function buildBlockScope( } branch.dependencies = Array.from(controlDependencies); - builder.putNode(ifNode); + context.putNode(ifNode); lastNode = ifNode.id; break; } case 'return': { - const valueDep = builder.lookupTemporary( + const valueDep = context.lookupTemporary( terminal.value.identifier, terminal.value.loc, ); @@ -584,17 +648,55 @@ function buildBlockScope( }; const returnNode: ReturnNode = { kind: 'Return', - id: builder.nextReactiveId, + id: context.nextReactiveId, loc: terminal.loc, outputs: [], value, + dependencies: context + .uncontolledNodes() + .filter(id => id !== valueDep), control, }; - builder.putNode(returnNode); + context.putNode(returnNode); lastNode = returnNode.id; break; } case 'goto': { + let target: ReactiveId; + switch (terminal.variant) { + case GotoVariant.Break: { + target = context.loadBreakTarget(terminal.block, terminal.loc); + break; + } + case GotoVariant.Continue: { + target = context.loadContinueTarget(terminal.block, terminal.loc); + break; + } + case GotoVariant.Try: { + CompilerError.throwTodo({ + reason: 'Support break with try variant', + loc: terminal.loc, + }); + } + default: { + assertExhaustive( + terminal.variant, + `Unexpected goto variant ${terminal.variant}`, + ); + } + } + const node: GotoNode = { + kind: 'Goto', + id: context.nextReactiveId, + loc: terminal.loc, + outputs: [], + target, + variant: terminal.variant, + dependencies: context.uncontolledNodes(), + control, + }; + context.putNode(node); + lastNode = node.id; break; } default: { diff --git a/compiler/packages/babel-plugin-react-compiler/src/ReactiveIR/ReactiveIR.ts b/compiler/packages/babel-plugin-react-compiler/src/ReactiveIR/ReactiveIR.ts index 6e5881639a39..1bc41ee52c39 100644 --- a/compiler/packages/babel-plugin-react-compiler/src/ReactiveIR/ReactiveIR.ts +++ b/compiler/packages/babel-plugin-react-compiler/src/ReactiveIR/ReactiveIR.ts @@ -8,6 +8,7 @@ import {CompilerError} from '..'; import { Environment, + GotoVariant, Instruction, InstructionKind, Place, @@ -58,7 +59,8 @@ export type ReactiveNode = | BranchNode | JoinNode | ControlNode - | ReturnNode; + | ReturnNode + | GotoNode; export type NodeReference = { node: ReactiveId; @@ -122,7 +124,19 @@ export type ReturnNode = { loc: SourceLocation; value: NodeReference; outputs: Array; + dependencies: Array; + control: ReactiveId; +}; + +export type GotoNode = { + kind: 'Goto'; + id: ReactiveId; + loc: SourceLocation; + outputs: Array; + dependencies: Array; control: ReactiveId; + target: ReactiveId; + variant: GotoVariant; }; export type BranchNode = { @@ -230,6 +244,7 @@ export function* eachNodeDependency(node: ReactiveNode): Iterable { case 'LoadArgument': { break; } + case 'Goto': case 'Control': case 'Branch': { yield* node.dependencies; @@ -250,6 +265,7 @@ export function* eachNodeDependency(node: ReactiveNode): Iterable { break; } case 'Return': { + yield* node.dependencies; yield node.value.node; break; } @@ -270,6 +286,7 @@ export function* eachNodeReference( node: ReactiveNode, ): Iterable { switch (node.kind) { + case 'Goto': case 'Entry': case 'Control': case 'LoadArgument': { @@ -353,9 +370,6 @@ function writeReactiveNodes( nodes: Map, ): void { for (const [id, node] of nodes) { - const deps = [...eachNodeReference(node)] - .map(id => printNodeReference(id)) - .join(' '); const control = node.kind !== 'Entry' && node.control != null ? ` control=£${node.control}` @@ -365,12 +379,20 @@ function writeReactiveNodes( buffer.push(`£${id} Entry`); break; } + case 'Goto': { + buffer.push( + `£${id} Goto(${node.variant}) target=£${node.target} deps=[${node.dependencies.map(id => `£${id}`).join(', ')}]${control}`, + ); + break; + } case 'LoadArgument': { buffer.push(`£${id} LoadArgument ${printPlace(node.place)}${control}`); break; } case 'Control': { - buffer.push(`£${id} Control${control}`); + buffer.push( + `£${id} Control${control} deps=[${node.dependencies.map(id => `£${id}`).join(', ')}]`, + ); break; } case 'Load': { @@ -385,7 +407,7 @@ function writeReactiveNodes( } case 'Return': { buffer.push( - `£${id} Return ${printNodeReference(node.value)}${control}`, + `£${id} Return ${printNodeReference(node.value)} deps=[${node.dependencies.map(id => `£${id}`).join(', ')}]${control}`, ); break; } @@ -411,6 +433,9 @@ function writeReactiveNodes( break; } case 'Value': { + const deps = [...eachNodeReference(node)] + .map(id => printNodeReference(id)) + .join(' '); buffer.push(`£${id} Value deps=[${deps}]${control}`); buffer.push(' ' + printInstruction(node.value)); break;