Skip to content

Commit f88898f

Browse files
committed
Pipelining
Better Error Handling Streams hold underlying CacheableData<?> ? Generalized primitives for many to many joins Primitive for (multi-)aggregations Various bugfixes Added cache features to merge/prioritize deferred requests
1 parent 3f841b7 commit f88898f

56 files changed

Lines changed: 3857 additions & 786 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/main/java/org/apache/sysds/hops/AggBinaryOp.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
5151
import org.apache.sysds.runtime.util.UtilFunctions;
5252

53-
5453
/* Aggregate binary (cell operations): Sum (aij + bij)
5554
* Properties:
5655
* Inner Symbol: *, -, +, ...
@@ -515,14 +514,14 @@ private void constructCPLopsMMChain(ChainType chain) {
515514
if (chain == ChainType.XtXv) {
516515
Hop hX = getInput().get(0).getInput().get(0);
517516
Hop hv = getInput().get(1).getInput().get(1);
518-
mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), getDataType(), getValueType(), ExecType.CP);
517+
mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), getDataType(), getValueType(), DMLScript.USE_OOC ? ExecType.OOC : ExecType.CP);
519518
} else { //ChainType.XtwXv / ChainType.XtwXvy
520519
int wix = (chain == ChainType.XtwXv) ? 0 : 1;
521520
int vix = (chain == ChainType.XtwXv) ? 1 : 0;
522521
Hop hX = getInput().get(0).getInput().get(0);
523522
Hop hw = getInput().get(1).getInput().get(wix);
524523
Hop hv = getInput().get(1).getInput().get(vix).getInput().get(1);
525-
mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(), ExecType.CP);
524+
mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(), DMLScript.USE_OOC ? ExecType.OOC : ExecType.CP);
526525
}
527526

528527
//set degree of parallelism

src/main/java/org/apache/sysds/runtime/DMLRuntimeException.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@ public class DMLRuntimeException extends DMLException
2828
{
2929
private static final long serialVersionUID = 1L;
3030

31+
public static DMLRuntimeException of(Throwable t) {
32+
return t instanceof DMLRuntimeException ? (DMLRuntimeException) t : new DMLRuntimeException(t);
33+
}
34+
3135
public DMLRuntimeException(String string) {
3236
super(string);
3337
}
3438

35-
public DMLRuntimeException(Exception e) {
39+
public DMLRuntimeException(Throwable e) {
3640
super(e);
3741
}
3842

src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,11 @@ public BroadcastObject<T> getBroadcastHandle() {
469469
public boolean hasBroadcastHandle() {
470470
return _bcHandle != null && _bcHandle.hasBackReference();
471471
}
472-
472+
473473
public synchronized OOCStream<IndexedMatrixValue> getStreamHandle() {
474474
if( !hasStreamHandle() ) {
475475
final SubscribableTaskQueue<IndexedMatrixValue> _mStream = new SubscribableTaskQueue<>();
476+
_mStream.setData(this);
476477
DataCharacteristics dc = getDataCharacteristics();
477478
MatrixBlock src = (MatrixBlock)acquireReadAndRelease();
478479
_streamHandle = _mStream;
@@ -489,7 +490,7 @@ public synchronized OOCStream<IndexedMatrixValue> getStreamHandle() {
489490
}
490491

491492
OOCStream<IndexedMatrixValue> stream = _streamHandle.getReadStream();
492-
if (!stream.hasStreamCache())
493+
if(!stream.hasStreamCache())
493494
_streamHandle = null; // To ensure read once
494495
return stream;
495496
}
@@ -539,6 +540,7 @@ public synchronized void removeGPUObject(GPUContext gCtx) {
539540
}
540541

541542
public synchronized void setStreamHandle(OOCStreamable<IndexedMatrixValue> q) {
543+
q.setData(this);
542544
_streamHandle = q;
543545
}
544546

src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.commons.logging.Log;
2323
import org.apache.commons.logging.LogFactory;
2424
import org.apache.sysds.common.InstructionType;
25+
import org.apache.sysds.common.Opcodes;
2526
import org.apache.sysds.runtime.DMLRuntimeException;
2627
import org.apache.sysds.runtime.instructions.ooc.AggregateTernaryOOCInstruction;
2728
import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
@@ -37,7 +38,8 @@
3738
import org.apache.sysds.runtime.instructions.ooc.TernaryOOCInstruction;
3839
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
3940
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
40-
import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction;
41+
import org.apache.sysds.runtime.instructions.ooc.MMultOOCInstruction;
42+
import org.apache.sysds.runtime.instructions.ooc.MapMMChainOOCInstruction;
4143
import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
4244
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
4345

@@ -72,11 +74,22 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
7274
return UnaryOOCInstruction.parseInstruction(str);
7375
case Binary:
7476
return BinaryOOCInstruction.parseInstruction(str);
77+
case Builtin:
78+
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
79+
if(parts[0].equals(Opcodes.LOG.toString()) || parts[0].equals(Opcodes.LOGNZ.toString())) {
80+
if(parts.length == 3)
81+
return UnaryOOCInstruction.parseInstruction(str);
82+
else if(parts.length == 4)
83+
return BinaryOOCInstruction.parseInstruction(str);
84+
}
85+
throw new DMLRuntimeException("Invalid Builtin Instruction: " + str);
7586
case Ternary:
7687
return TernaryOOCInstruction.parseInstruction(str);
7788
case AggregateBinary:
7889
case MAPMM:
79-
return MatrixVectorBinaryOOCInstruction.parseInstruction(str);
90+
return MMultOOCInstruction.parseInstruction(str);
91+
case MAPMMCHAIN:
92+
return MapMMChainOOCInstruction.parseInstruction(str);
8093
case MMTSJ:
8194
return TSMMOOCInstruction.parseInstruction(str);
8295
case Reorg:

src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,10 @@ private void processMoveInstruction(ExecutionContext ec) {
794794
// cleanup matrix/frame/list data if necessary
795795
if( srcData.getDataType().isMatrix() || srcData.getDataType().isFrame() ) {
796796
Data tgtData = ec.removeVariable(getInput2().getName());
797+
798+
if (DMLScript.USE_OOC && tgtData instanceof MatrixObject)
799+
TeeOOCInstruction.incrRef(((MatrixObject) tgtData).getStreamable(), -1);
800+
797801
if( tgtData != null && srcData != tgtData )
798802
ec.cleanupDataObject(tgtData);
799803
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateTernaryOOCInstruction.java

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@
3939
import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
4040
import org.apache.sysds.runtime.matrix.operators.Operator;
4141
import org.apache.sysds.runtime.meta.DataCharacteristics;
42+
import org.apache.sysds.runtime.util.IndexRange;
4243

4344
import java.util.ArrayList;
4445
import java.util.HashMap;
4546
import java.util.List;
4647
import java.util.Map;
4748
import java.util.concurrent.CompletableFuture;
48-
import java.util.function.Function;
4949

5050
public class AggregateTernaryOOCInstruction extends ComputationOOCInstruction {
5151

@@ -111,17 +111,13 @@ private void processReduceAll(ExecutionContext ec, AggregateTernaryOperator abOp
111111
if(qIn3 != null)
112112
streams.add(qIn3);
113113

114-
List<Function<IndexedMatrixValue, MatrixIndexes>> keyFns = new ArrayList<>();
115-
for(int i = 0; i < streams.size(); i++)
116-
keyFns.add(IndexedMatrixValue::getIndexes);
117-
118114
CompletableFuture<Void> fut = joinOOC(streams, qMid, blocks -> {
119115
MatrixBlock b1 = (MatrixBlock) blocks.get(0).getValue();
120116
MatrixBlock b2 = (MatrixBlock) blocks.get(1).getValue();
121117
MatrixBlock b3 = blocks.size() == 3 ? (MatrixBlock) blocks.get(2).getValue() : null;
122118
MatrixBlock partial = MatrixBlock.aggregateTernaryOperations(b1, b2, b3, new MatrixBlock(), abOp, false);
123119
return new IndexedMatrixValue(blocks.get(0).getIndexes(), partial);
124-
}, keyFns);
120+
}, IndexedMatrixValue::getIndexes);
125121

126122
try {
127123
IndexedMatrixValue imv;
@@ -159,17 +155,26 @@ private void processReduceRow(ExecutionContext ec, AggregateTernaryOperator abOp
159155
if(qIn3 != null)
160156
streams.add(qIn3);
161157

162-
List<Function<IndexedMatrixValue, MatrixIndexes>> keyFns = new ArrayList<>();
163-
for(int i = 0; i < streams.size(); i++)
164-
keyFns.add(IndexedMatrixValue::getIndexes);
158+
for (OOCStream<IndexedMatrixValue> stream : streams)
159+
stream.setDownstreamMessageRelay(qOut::messageDownstream);
160+
161+
qOut.setUpstreamMessageRelay(msg ->
162+
streams.forEach(stream -> stream.messageUpstream(streams.size() > 1 ? msg.split() : msg)));
163+
164+
qOut.setIXTransform((downstream, range) -> {
165+
if (downstream)
166+
return new IndexRange(1, 1, range.colStart, range.colEnd);
167+
else
168+
return new IndexRange(1, dc.getRows(), range.colStart, range.colEnd);
169+
});
165170

166171
CompletableFuture<Void> fut = joinOOC(streams, qMid, blocks -> {
167172
MatrixBlock b1 = (MatrixBlock) blocks.get(0).getValue();
168173
MatrixBlock b2 = (MatrixBlock) blocks.get(1).getValue();
169174
MatrixBlock b3 = blocks.size() == 3 ? (MatrixBlock) blocks.get(2).getValue() : null;
170175
MatrixBlock partial = MatrixBlock.aggregateTernaryOperations(b1, b2, b3, new MatrixBlock(), abOp, false);
171176
return new IndexedMatrixValue(blocks.get(0).getIndexes(), partial);
172-
}, keyFns);
177+
}, IndexedMatrixValue::getIndexes);
173178

174179
final Map<Long, MatrixBlock> aggMap = new HashMap<>();
175180
final Map<Long, MatrixBlock> corrMap = new HashMap<>();

src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
3636
import org.apache.sysds.runtime.matrix.operators.Operator;
3737
import org.apache.sysds.runtime.meta.DataCharacteristics;
38+
import org.apache.sysds.runtime.ooc.stream.StreamContext;
39+
import org.apache.sysds.runtime.util.IndexRange;
3840

3941
import java.util.HashMap;
4042

@@ -90,6 +92,21 @@ public void processInstruction( ExecutionContext ec ) {
9092

9193
ec.getMatrixObject(output).setStreamHandle(qOut);
9294

95+
qIn.setDownstreamMessageRelay(qOut::messageDownstream);
96+
qOut.setUpstreamMessageRelay(qIn::messageUpstream);
97+
qOut.setIXTransform((downstream, range) -> {
98+
if (downstream) {
99+
if (aggun.isRowAggregate())
100+
return new IndexRange(range.rowStart, range.rowEnd, 1, 1);
101+
else
102+
return new IndexRange(1, 1, range.colStart, range.colEnd);
103+
}
104+
if (aggun.isRowAggregate())
105+
return new IndexRange(range.rowStart, range.rowEnd, 1, min.getNumColumns() - 1);
106+
else
107+
return new IndexRange(1, min.getNumRows() - 1, range.colStart, range.colEnd);
108+
});
109+
93110
// per-block aggregation (parallel map)
94111
mapOOC(qIn, qLocal, tmp -> {
95112
MatrixIndexes midx = aggun.isRowAggregate() ?
@@ -134,7 +151,7 @@ public void processInstruction( ExecutionContext ec ) {
134151
}
135152
}
136153
qOut.closeInput();
137-
});
154+
}, new StreamContext().addOutStream(qOut));
138155
}
139156
// full aggregation
140157
else {

src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) {
6868
OOCStream<IndexedMatrixValue> qIn2 = m2.getStreamHandle();
6969
OOCStream<IndexedMatrixValue> qOut = new SubscribableTaskQueue<>();
7070
ec.getMatrixObject(output).setStreamHandle(qOut);
71+
qIn1.setDownstreamMessageRelay(qOut::messageDownstream);
72+
qIn2.setDownstreamMessageRelay(qOut::messageDownstream);
73+
qOut.setUpstreamMessageRelay(msg -> {
74+
qIn1.messageUpstream(msg.split());
75+
qIn2.messageUpstream(msg.split());
76+
});
7177

7278
if (m1.getNumRows() < 0 || m1.getNumColumns() < 0 || m2.getNumRows() < 0 || m2.getNumColumns() < 0)
7379
throw new DMLRuntimeException("Cannot process (matrix, matrix) BinaryOOCInstruction with unknown dimensions.");
@@ -116,8 +122,6 @@ else if (isRowBroadcast && !isColBroadcast) {
116122
return tmpOut;
117123
}, IndexedMatrixValue::getIndexes);
118124
}
119-
120-
121125
}
122126

123127
protected void processScalarMatrixInstruction(ExecutionContext ec) {
@@ -131,6 +135,8 @@ protected void processScalarMatrixInstruction(ExecutionContext ec) {
131135
OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
132136
OOCStream<IndexedMatrixValue> qOut = createWritableStream();
133137
ec.getMatrixObject(output).setStreamHandle(qOut);
138+
qIn.setDownstreamMessageRelay(qOut::messageDownstream);
139+
qOut.setUpstreamMessageRelay(qIn::messageUpstream);
134140

135141
mapOOC(qIn, qOut, tmp -> {
136142
IndexedMatrixValue tmpOut = new IndexedMatrixValue();

src/main/java/org/apache/sysds/runtime/instructions/ooc/CSVReblockOOCInstruction.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.apache.sysds.runtime.io.ReaderTextCSVParallel;
3232
import org.apache.sysds.runtime.matrix.operators.Operator;
3333
import org.apache.sysds.runtime.meta.DataCharacteristics;
34+
import org.apache.sysds.runtime.ooc.stream.StreamContext;
3435

3536
public class CSVReblockOOCInstruction extends ComputationOOCInstruction {
3637
private final int blen;
@@ -80,7 +81,7 @@ public void processInstruction(ExecutionContext ec) {
8081
catch(Exception ex) {
8182
throw (ex instanceof DMLRuntimeException) ? (DMLRuntimeException) ex : new DMLRuntimeException(ex);
8283
}
83-
}, qOut);
84+
}, new StreamContext().addOutStream(qOut));
8485

8586
MatrixObject mout = ec.getMatrixObject(output);
8687
mout.setStreamHandle(qOut);

0 commit comments

Comments
 (0)