diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index ec092260b40..62fcb1f86da 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -41,18 +41,18 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; -import org.apache.sysds.runtime.compress.colgroup.AColGroupValue; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.ColGroupIO; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.compress.lib.CLALibAppend; import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp; +import org.apache.sysds.runtime.compress.lib.CLALibCMOps; import org.apache.sysds.runtime.compress.lib.CLALibCompAgg; import org.apache.sysds.runtime.compress.lib.CLALibDecompress; import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy; import org.apache.sysds.runtime.compress.lib.CLALibMMChain; import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult; -import org.apache.sysds.runtime.compress.lib.CLALibReExpand; +import org.apache.sysds.runtime.compress.lib.CLALibRexpand; import org.apache.sysds.runtime.compress.lib.CLALibScalar; import org.apache.sysds.runtime.compress.lib.CLALibSlice; import org.apache.sysds.runtime.compress.lib.CLALibSquash; @@ -275,10 +275,8 @@ public long recomputeNonZeros() { nonZeros = nnz; } - if(nonZeros == 0) { - ColGroupEmpty cg = ColGroupEmpty.generate(getNumColumns()); - allocateColGroup(cg); - } + if(nonZeros == 0) // If there is no nonzeros then reallocate into single empty column group. + allocateColGroup(ColGroupEmpty.create(getNumColumns())); return nonZeros; } @@ -468,7 +466,8 @@ public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock w, Matri } @Override - public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) { + public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, + AggregateBinaryOperator op) { checkAggregateBinaryOperations(m1, m2, op); return CLALibMatrixMult.matrixMultiply(m1, m2, ret, op.getNumThreads(), false, false); } @@ -636,13 +635,7 @@ public double mean() { @Override public MatrixBlock rexpandOperations(MatrixBlock ret, double max, boolean rows, boolean cast, boolean ignore, int k) { - if(rows) { - printDecompressWarning("rexpandOperations"); - MatrixBlock tmp = getUncompressed(); - return tmp.rexpandOperations(ret, max, rows, cast, ignore, k); - } - else - return CLALibReExpand.reExpand(this, ret, max, cast, ignore, k); + return CLALibRexpand.rexpand(this, ret, max, rows, cast, ignore, k); } @Override @@ -713,29 +706,7 @@ public MatrixBlock zeroOutOperations(MatrixValue result, IndexRange range, boole @Override public CM_COV_Object cmOperations(CMOperator op) { - if(isEmpty()) - return super.cmOperations(op); - else if(_colGroups.size() == 1 && _colGroups.get(0) instanceof AColGroupValue) { - AColGroupValue g = (AColGroupValue) _colGroups.get(0); - MatrixBlock vals = g.getValuesAsBlock(); - MatrixBlock counts = getCountsAsBlock(g.getCounts()); - if(counts.isEmpty()) - return vals.cmOperations(op); - return vals.cmOperations(op, counts); - } - else - return getUncompressed("cmOperations").cmOperations(op); - } - - private static MatrixBlock getCountsAsBlock(int[] counts) { - if(counts != null) { - MatrixBlock ret = new MatrixBlock(counts.length, 1, false); - for(int i = 0; i < counts.length; i++) - ret.quickSetValue(i, 0, counts[i]); - return ret; - } - else - return new MatrixBlock(1, 1, false); + return CLALibCMOps.centralMoment(this, op); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java index 97f6f0975d1..cb8eabb3be4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java @@ -29,10 +29,11 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.cocode.CoCoderFactory; import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.AColGroupValue; +import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; -import org.apache.sysds.runtime.compress.colgroup.AColGroupValue; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.cost.CostEstimatorBuilder; import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory; @@ -205,7 +206,7 @@ public static CompressedMatrixBlock genUncompressedCompressedMatrixBlock(MatrixB */ public static CompressedMatrixBlock createConstant(int numRows, int numCols, double value) { CompressedMatrixBlock block = new CompressedMatrixBlock(numRows, numCols); - AColGroup cg = ColGroupFactory.genColGroupConst(numCols, value); + AColGroup cg = ColGroupConst.create(numCols, value); block.allocateColGroup(cg); block.recomputeNonZeros(); if(block.getNumRows() == 0 || block.getNumColumns() == 0) { @@ -223,7 +224,7 @@ private Pair compressMatrix() { else if(mb.isEmpty()) { LOG.info("Empty input to compress, returning a compressed Matrix block with empty column group"); CompressedMatrixBlock ret = new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns()); - ColGroupEmpty cg = ColGroupEmpty.generate(mb.getNumColumns()); + ColGroupEmpty cg = ColGroupEmpty.create(mb.getNumColumns()); ret.allocateColGroup(cg); ret.setNonZeros(0); return new ImmutablePair<>(ret, null); @@ -440,7 +441,7 @@ private void logPhase() { DMLCompressionStatistics.addCompressionTime(getLastTimePhase(), phase); if(LOG.isDebugEnabled()) { if(compSettings.isInSparkInstruction) { - if(phase == 5) + if(phase == 4) LOG.debug(_stats); } else { diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java index 3b59fef14b4..ec3c2c667ba 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java @@ -20,13 +20,14 @@ package org.apache.sysds.runtime.compress.cocode; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.cost.ICostEstimate; import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator; @@ -132,7 +133,6 @@ private static List coCodeBruteForce(List workSet, Memorizer mem, ICostEstimate cEst, int k) { try { - ExecutorService pool = CommonThreadPool.get(k); List tasks = new ArrayList<>(); for(int i = 0; i < workSet.size(); i++) @@ -154,7 +154,7 @@ protected static void parallelFirstJoin(List workSet, Memorizer mem, pool.shutdown(); } catch(Exception e) { - throw new DMLRuntimeException("failed to join column groups", e); + throw new DMLCompressionException("Failed parallelize first level all join all", e); } } @@ -170,8 +170,14 @@ protected JoinTask(ColIndexes c1, ColIndexes c2, Memorizer m) { @Override public Object call() { - _m.getOrCreate(_c1, _c2); - return null; + try { + _m.getOrCreate(_c1, _c2); + return null; + } + catch(Exception e) { + throw new DMLCompressionException( + "Failed to join columns : " + Arrays.toString(_c1._indexes) + " + " + Arrays.toString(_c2._indexes), e); + } } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java index 23674b1d709..1697ebdf71b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java @@ -30,9 +30,11 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.utils.MemoryEstimates; @@ -501,8 +503,8 @@ public double get(int r, int c) { public abstract double getMax(); /** - * Get a copy of this column group. Depending on which column group is copied it is a deep or shallow copy. If the - * primitives for the underlying column groups is Immutable then only shallow copies is performed. + * Get a copy of this column group note this is only a shallow copy. Meaning only the object wrapping index + * structures, column indexes and dictionaries are copied. * * @return Get a copy of this column group. */ @@ -542,6 +544,26 @@ public double get(int r, int c) { */ public abstract void computeColSums(double[] c, int nRows); + /** + * Central Moment instruction executed on a column group. + * + * @param op The Operator to use. + * @param nRows The number of rows contained in the ColumnGroup. + * @return A Central Moment object. + */ + public abstract CM_COV_Object centralMoment(CMOperator op, int nRows); + + /** + * Expand the column group to multiple columns. (one hot encode the column group) + * + * @param max The number of columns to expand to and cutoff values at. + * @param ignore If zero and negative values should be ignored. + * @param cast If the double values contained should be cast to whole numbers. + * @param nRows The number of rows in the column group. + * @return A new column group containing max number of columns. + */ + public abstract AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows); + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java index 1fbb8431d17..2464bbcfd67 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java @@ -189,7 +189,7 @@ protected static void tsmmDense(double[] result, int numColumns, double[] values if(values == null) return; final int nCol = colIndexes.length; - final int nRow = values.length / colIndexes.length; + final int nRow = counts.length; for(int k = 0; k < nRow; k++) { final int offTmp = nCol * k; final int scale = counts[k]; @@ -204,7 +204,7 @@ protected static void tsmmDense(double[] result, int numColumns, double[] values } protected static void tsmmSparse(double[] result, int numColumns, SparseBlock sb, int[] counts, int[] colIndexes) { - for(int row = 0; row < sb.numRows(); row++) { + for(int row = 0; row < counts.length; row++) { if(sb.isEmpty(row)) continue; final int apos = sb.pos(row); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java index 2eaeb998425..064a62b8eb7 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java @@ -36,7 +36,9 @@ import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.CMOperator; /** * Base class for column groups encoded with value dictionary. This include column groups such as DDC OLE and RLE. @@ -51,6 +53,8 @@ public abstract class AColGroupValue extends AColGroupCompressed implements Clon /** * ColGroup Implementation Contains zero tuple. Note this is not if it contains a zero value. If false then the * stored values are filling the ColGroup making it a dense representation, that can be leveraged in operations. + * + * TODO remove */ protected boolean _zeros = false; @@ -173,21 +177,10 @@ public int getNumValues() { return _dict.getNumberOfValues(_colIndexes.length); } - public final ADictionary getDictionary() { + public ADictionary getDictionary() { return _dict; } - public final MatrixBlock getValuesAsBlock() { - _dict = _dict.getMBDict(_colIndexes.length); - MatrixBlock ret = ((MatrixBlockDictionary) _dict).getMatrixBlock(); - if(_zeros) { - MatrixBlock tmp = new MatrixBlock(); - ret.append(new MatrixBlock(1, _colIndexes.length, 0), tmp, false); - return tmp; - } - return ret; - } - /** * Returns the counts of values inside the dictionary. If already calculated it will return the previous counts. This * produce an overhead in cases where the count is calculated, but the overhead will be limited to number of distinct @@ -202,7 +195,7 @@ public final int[] getCounts() { int[] ret = getCachedCounts(); if(ret == null) { - ret = getCounts(new int[getNumValues() + (_zeros ? 1 : 0)]); + ret = getCounts(new int[getNumValues()]); counts = new SoftReference<>(ret); } @@ -318,7 +311,6 @@ public long getExactSizeOnDisk() { long ret = super.getExactSizeOnDisk(); ret += 1; // zeros boolean ret += _dict.getExactSizeOnDisk(); - return ret; } @@ -349,6 +341,11 @@ protected void computeProduct(double[] c, int nRows) { c[0] *= _dict.product(getCounts(), _colIndexes.length); } + @Override + protected void computeColProduct(double[] c, int nRows) { + _dict.colProduct(c, getCounts(), _colIndexes); + } + @Override protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) { throw new NotImplementedException(); @@ -375,10 +372,6 @@ protected double[] preAggBuiltinRows(Builtin builtin) { } @Override - protected void computeColProduct(double[] c, int nRows) { - _dict.colProduct(c, getCounts(), _colIndexes); - } - protected Object clone() { try { return super.clone(); @@ -388,21 +381,17 @@ protected Object clone() { } } - public AColGroup copyAndSet(double[] newDictionary) { - return copyAndSet(new Dictionary(newDictionary)); - } - - public AColGroup copyAndSet(ADictionary newDictionary) { + protected AColGroup copyAndSet(ADictionary newDictionary) { AColGroupValue clone = (AColGroupValue) this.clone(); clone._dict = newDictionary; return clone; } - public AColGroup copyAndSet(int[] colIndexes, double[] newDictionary) { + private AColGroup copyAndSet(int[] colIndexes, double[] newDictionary) { return copyAndSet(colIndexes, new Dictionary(newDictionary)); } - public AColGroup copyAndSet(int[] colIndexes, ADictionary newDictionary) { + private AColGroup copyAndSet(int[] colIndexes, ADictionary newDictionary) { AColGroupValue clone = (AColGroupValue) this.clone(); clone._dict = newDictionary; clone.setColIndices(colIndexes); @@ -416,7 +405,7 @@ public AColGroupValue copy() { @Override protected AColGroup sliceSingleColumn(int idx) { - final AColGroupValue ret = (AColGroupValue) copy(); + final AColGroupValue ret = (AColGroupValue) this.clone(); ret._colIndexes = new int[] {0}; if(_colIndexes.length == 1) ret._dict = ret._dict.clone(); @@ -428,7 +417,7 @@ protected AColGroup sliceSingleColumn(int idx) { @Override protected AColGroup sliceMultiColumns(int idStart, int idEnd, int[] outputCols) { - final AColGroupValue ret = (AColGroupValue) copy(); + final AColGroupValue ret = (AColGroupValue) this.clone(); ret._dict = ret._dict.sliceOutColumnRange(idStart, idEnd, _colIndexes.length); ret._colIndexes = outputCols; return ret; @@ -504,6 +493,20 @@ public AColGroup replace(double pattern, double replace) { return copyAndSet(replaced); } + @Override + public CM_COV_Object centralMoment(CMOperator op, int nRows) { + return _dict.centralMoment(op.fn, getCounts(), nRows); + } + + @Override + public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows){ + ADictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.length); + if(d == null) + return ColGroupEmpty.create(max); + else + return copyAndSet(d); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java index 336ce08137b..922d9c9ffc5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java @@ -29,12 +29,15 @@ import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; +import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; public class ColGroupConst extends AColGroupCompressed { @@ -74,6 +77,80 @@ protected static AColGroup create(int[] colIndices, ADictionary dict) { return new ColGroupConst(colIndices, dict); } + /** + * Generate a constant column group. + * + * @param values The value vector that contains all the unique values for each column in the matrix. + * @return A Constant column group. + */ + public static AColGroup create(double[] values) { + final int[] colIndices = Util.genColsIndices(values.length); + return create(colIndices, values); + } + + /** + * Generate a constant column group. + * + * It is assumed that the column group is intended for use, therefore zero value is allowed. + * + * @param cols The specific column indexes that is contained in this constant group. + * @param value The value contained in all cells. + * @return A Constant column group. + */ + public static AColGroup create(int[] cols, double value) { + final int numCols = cols.length; + double[] values = new double[numCols]; + for(int i = 0; i < numCols; i++) + values[i] = value; + return create(cols, values); + } + + /** + * Generate a constant column group. + * + * @param cols The specific column indexes that is contained in this constant group. + * @param values The value vector that contains all the unique values for each column in the matrix. + * @return A Constant column group. + */ + public static AColGroup create(int[] cols, double[] values) { + if(cols.length != values.length) + throw new DMLCompressionException("Invalid size of values compared to columns"); + ADictionary dict = new Dictionary(values); + return ColGroupConst.create(cols, dict); + } + + /** + * Generate a constant column group. + * + * @param numCols The number of columns. + * @param dict The dictionary to contain int the Constant group. + * @return A Constant column group. + */ + public static AColGroup create(int numCols, ADictionary dict) { + if(numCols != dict.getValues().length) + throw new DMLCompressionException( + "Invalid construction of const column group with different number of columns in arguments"); + final int[] colIndices = Util.genColsIndices(numCols); + return ColGroupConst.create(colIndices, dict); + } + + /** + * Generate a constant column group. + * + * @param numCols The number of columns + * @param value The value contained in all cells. + * @return A Constant column group. + */ + public static AColGroup create(int numCols, double value) { + if(numCols <= 0) + throw new DMLCompressionException("Invalid construction of constant column group with cols: " + numCols); + final int[] colIndices = Util.genColsIndices(numCols); + + if(value == 0) + return new ColGroupEmpty(colIndices); + return ColGroupConst.create(colIndices, value); + } + @Override protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] preAgg) { double v = preAgg[0]; @@ -167,7 +244,6 @@ public void computeColSums(double[] c, int nRows) { @Override protected void computeSumSq(double[] c, int nRows) { - c[0] += _dict.sumSq(new int[] {nRows}, _colIndexes.length); } @@ -178,7 +254,7 @@ protected void computeColSumsSq(double[] c, int nRows) { @Override protected void computeRowSums(double[] c, int rl, int ru, double[] preAgg) { - double vals = preAgg[0]; + final double vals = preAgg[0]; for(int rix = rl; rix < ru; rix++) c[rix] += vals; } @@ -209,7 +285,7 @@ public AColGroup rightMultByMatrix(MatrixBlock right) { if(ret.isEmpty()) return null; ADictionary d = new MatrixBlockDictionary(ret, cr); - return ColGroupFactory.genColGroupConst(cr, d); + return create(cr, d); } else { throw new NotImplementedException(); @@ -323,7 +399,6 @@ protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) { @Override protected void computeColProduct(double[] c, int nRows) { throw new NotImplementedException(); - } @Override @@ -334,7 +409,6 @@ protected double[] preAggSumRows() { @Override protected double[] preAggSumSqRows() { return _dict.sumAllRowsToDoubleSq(_colIndexes.length); - } @Override @@ -354,4 +428,20 @@ public long estimateInMemorySize() { size += 8; // dict reference return size; } + + @Override + public CM_COV_Object centralMoment(CMOperator op, int nRows) { + CM_COV_Object ret = new CM_COV_Object(); + op.fn.execute(ret, _dict.getValue(0), nRows); + return ret; + } + + @Override + public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { + ADictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.length); + if(d == null) + return ColGroupEmpty.create(max); + else + return create(max, d); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index b8a0b880c1c..bdf78f32a4a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -146,9 +146,7 @@ protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double @Override public int[] getCounts(int[] counts) { - for(int i = 0; i < _numRows; i++) - counts[_data.getIndex(i)]++; - return counts; + return _data.getCounts(counts); } @Override @@ -182,9 +180,7 @@ private void preAggregateSparse(SparseBlock sb, MatrixBlock preAgg, int rl, int @Override public void preAggregateThatDDCStructure(ColGroupDDC that, Dictionary ret) { - final int nCol = that._colIndexes.length; - for(int r = 0; r < _numRows; r++) - that._dict.addToEntry(ret, that._data.getIndex(r), this._data.getIndex(r), nCol); + _data.preAggregateDDC(that._data, that._dict, ret, that._colIndexes.length); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java index bcd139ab0e6..ff5f6f374f7 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java @@ -21,13 +21,17 @@ import java.util.Arrays; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; +import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.ValueFunction; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; public class ColGroupEmpty extends AColGroupCompressed { @@ -47,12 +51,8 @@ public ColGroupEmpty(int[] colIndices) { super(colIndices); } - public static ColGroupEmpty generate(int nCol) { - int[] cols = new int[nCol]; - for(int i = 0; i < nCol; i++) { - cols[i] = i; - } - return new ColGroupEmpty(cols); + public static ColGroupEmpty create(int nCol) { + return new ColGroupEmpty(Util.genColsIndices(nCol)); } @Override @@ -174,7 +174,7 @@ public AColGroup rightMultByMatrix(MatrixBlock right) { @Override public AColGroup replace(double pattern, double replace) { if(pattern == 0) - return ColGroupFactory.genColGroupConst(_colIndexes, replace); + return ColGroupConst.create(_colIndexes, replace); else return new ColGroupEmpty(_colIndexes); } @@ -270,4 +270,20 @@ protected double[] preAggProductRows() { protected double[] preAggBuiltinRows(Builtin builtin) { return null; } + + @Override + public CM_COV_Object centralMoment(CMOperator op, int nRows) { + CM_COV_Object ret = new CM_COV_Object(); + op.fn.execute(ret, 0.0, nRows); + return ret; + } + + @Override + public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { + if(!ignore) + throw new DMLRuntimeException( + "Invalid input to rexpand since it contains zero use ignore flag to encode anyway"); + else + return create(max); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java index dbc08a81081..3544bb04b83 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java @@ -55,7 +55,6 @@ import org.apache.sysds.runtime.compress.utils.DblArrayCountHashMap; import org.apache.sysds.runtime.compress.utils.DoubleCountHashMap; import org.apache.sysds.runtime.compress.utils.IntArrayList; -import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; @@ -85,94 +84,16 @@ public static List compressColGroups(MatrixBlock in, CompressedSizeIn for(CompressedSizeInfoColGroup g : csi.getInfo()) g.clearMap(); - if(in.isEmpty()) - return genEmpty(in, compSettings); + if(in.isEmpty()) { + AColGroup empty = ColGroupEmpty.create(compSettings.transposed ? in.getNumRows() : in.getNumColumns()); + return Collections.singletonList(empty); + } else if(k <= 1) return compressColGroupsSingleThreaded(in, csi, compSettings); else return compressColGroupsParallel(in, csi, compSettings, k); } - /** - * Generate a constant column group. - * - * @param numCols The number of columns - * @param value The value contained in all cells. - * @return A Constant column group. - */ - public static AColGroup genColGroupConst(int numCols, double value) { - if(numCols <= 0) - throw new DMLCompressionException("Invalid construction of constant column group with cols: " + numCols); - final int[] colIndices = Util.genColsIndices(numCols); - - if(value == 0) - return new ColGroupEmpty(colIndices); - return genColGroupConst(colIndices, value); - } - - /** - * Generate a constant column group. - * - * @param values The value vector that contains all the unique values for each column in the matrix. - * @return A Constant column group. - */ - public static AColGroup genColGroupConst(double[] values) { - final int[] colIndices = Util.genColsIndices(values.length); - return genColGroupConst(colIndices, values); - } - - /** - * Generate a constant column group. - * - * It is assumed that the column group is intended for use, therefore zero value is allowed. - * - * @param cols The specific column indexes that is contained in this constant group. - * @param value The value contained in all cells. - * @return A Constant column group. - */ - public static AColGroup genColGroupConst(int[] cols, double value) { - final int numCols = cols.length; - double[] values = new double[numCols]; - for(int i = 0; i < numCols; i++) - values[i] = value; - return genColGroupConst(cols, values); - } - - /** - * Generate a constant column group. - * - * @param cols The specific column indexes that is contained in this constant group. - * @param values The value vector that contains all the unique values for each column in the matrix. - * @return A Constant column group. - */ - public static AColGroup genColGroupConst(int[] cols, double[] values) { - if(cols.length != values.length) - throw new DMLCompressionException("Invalid size of values compared to columns"); - ADictionary dict = new Dictionary(values); - return ColGroupConst.create(cols, dict); - } - - /** - * Generate a constant column group. - * - * @param numCols The number of columns. - * @param dict The dictionary to contain int the Constant group. - * @return A Constant column group. - */ - public static AColGroup genColGroupConst(int numCols, ADictionary dict) { - if(numCols != dict.getValues().length) - throw new DMLCompressionException( - "Invalid construction of const column group with different number of columns in arguments"); - final int[] colIndices = Util.genColsIndices(numCols); - return ColGroupConst.create(colIndices, dict); - } - - private static List genEmpty(MatrixBlock in, CompressionSettings compSettings) { - List ret = new ArrayList<>(1); - ret.add(genColGroupConst(compSettings.transposed ? in.getNumRows() : in.getNumColumns(), 0)); - return ret; - } - private static List compressColGroupsSingleThreaded(MatrixBlock in, CompressedSizeInfo csi, CompressionSettings compSettings) { List ret = new ArrayList<>(csi.getNumberColGroups()); @@ -220,36 +141,6 @@ private static List> makeGroups(List> { - private final MatrixBlock _in; - private final List _groups; - private final CompressionSettings _compSettings; - private final int _k; - - protected CompressTask(MatrixBlock in, List groups, CompressionSettings compSettings, - int k) { - _in = in; - _groups = groups; - _compSettings = compSettings; - _k = k; - } - - @Override - public Collection call() { - try { - ArrayList res = new ArrayList<>(); - Tmp tmpMap = new Tmp(); - for(CompressedSizeInfoColGroup g : _groups) - res.addAll(compressColGroup(_in, _compSettings, tmpMap, g, _k)); - return res; - } - catch(Exception e) { - e.printStackTrace(); - throw e; - } - } - } - private static Collection compressColGroup(MatrixBlock in, CompressionSettings compSettings, Tmp tmpMap, CompressedSizeInfoColGroup cg, int k) { final int inCols = compSettings.transposed ? in.getNumRows() : in.getNumColumns(); @@ -330,7 +221,7 @@ private static AColGroup compressColGroupForced(MatrixBlock in, CompressionSetti CompressionType estimatedBestCompressionType = cg.getBestCompressionType(); if(estimatedBestCompressionType == CompressionType.UNCOMPRESSED) // don't construct mapping if uncompressed - return new ColGroupUncompressed(colIndexes, in, cs.transposed); + return ColGroupUncompressed.create(colIndexes, in, cs.transposed); else if(estimatedBestCompressionType == CompressionType.SDC && colIndexes.length == 1 && in.isInSparseFormat() && cs.transposed) // Leverage the Sparse matrix, to construct SDC group return compressSDCFromSparseTransposedBlock(in, colIndexes, in.getNumColumns(), @@ -372,10 +263,12 @@ private static AColGroup compress(int[] colIndexes, int rlen, ABitmap ubm, Compr return compressRLE(colIndexes, rlen, ubm, cs, tupleSparsity); case OLE: return compressOLE(colIndexes, rlen, ubm, cs, tupleSparsity); + case CONST: // in case somehow one requested const, but it was not const fall back to SDC. + LOG.warn("Requested const on non constant column, fallback to SDC"); case SDC: return compressSDC(colIndexes, rlen, ubm, cs, tupleSparsity); default: - throw new DMLCompressionException("Not implemented compression of " + compType + "in factory."); + throw new DMLCompressionException("Not implemented compression of " + compType + " in factory."); } } @@ -501,34 +394,6 @@ private static MatrixBlock deltaEncodeMatrixBlock(MatrixBlock mb) { return DataConverter.convertToMatrixBlock(ret); } - static class readToMapDDCTask implements Callable { - private final int[] _colIndexes; - private final MatrixBlock _raw; - private final DblArrayCountHashMap _map; - private final CompressionSettings _cs; - private final AMapToData _data; - private final int _rl; - private final int _ru; - private final int _fill; - - protected readToMapDDCTask(int[] colIndexes, MatrixBlock raw, DblArrayCountHashMap map, CompressionSettings cs, - AMapToData data, int rl, int ru, int fill) { - _colIndexes = colIndexes; - _raw = raw; - _map = map; - _cs = cs; - _data = data; - _rl = rl; - _ru = ru; - _fill = fill; - } - - @Override - public Boolean call() { - return Boolean.valueOf(readToMapDDC(_colIndexes, _raw, _map, _cs, _data, _rl, _ru, _fill)); - } - } - private static AColGroup compressSDC(int[] colIndexes, int rlen, ABitmap ubm, CompressionSettings cs, double tupleSparsity) { @@ -555,55 +420,43 @@ private static AColGroup compressSDC(int[] colIndexes, int rlen, ABitmap ubm, Co return new ColGroupSDCSingleZeros(colIndexes, rlen, dict, off, null); } else { - LOG.warn("fix three dictionary allocations"); - ADictionary dict = DictionaryFactory.create(ubm, 1.0); - dict = DictionaryFactory.moveFrequentToLastDictionaryEntry(dict, ubm, rlen, largestIndex); - if(tupleSparsity < 0.4) - dict = dict.getMBDict(colIndexes.length); - return setupSingleValueSDCColGroup(colIndexes, rlen, ubm, dict); + double[] defaultTuple = new double[colIndexes.length]; + ADictionary dict = DictionaryFactory.create(ubm, largestIndex, defaultTuple, 1.0, numZeros > 0); + return compressSDCSingle(colIndexes, rlen, ubm, dict, defaultTuple); } } else if(numZeros >= largestOffset) { ADictionary dict = DictionaryFactory.create(ubm, tupleSparsity); - return setupMultiValueZeroColGroup(colIndexes, rlen, ubm, dict, cs); + return compressSDCZero(colIndexes, rlen, ubm, dict, cs); } else { - LOG.warn("fix three dictionary allocations"); - ADictionary dict = DictionaryFactory.create(ubm, 1.0); - dict = DictionaryFactory.moveFrequentToLastDictionaryEntry(dict, ubm, rlen, largestIndex); - if(tupleSparsity < 0.4 && colIndexes.length > 4) - dict = dict.getMBDict(colIndexes.length); - return setupMultiValueColGroup(colIndexes, numZeros, rlen, ubm, largestIndex, dict, cs); + double[] defaultTuple = new double[colIndexes.length]; + ADictionary dict = DictionaryFactory.create(ubm, largestIndex, defaultTuple, 1.0, numZeros > 0); + return compressSDCNormal(colIndexes, numZeros, rlen, ubm, largestIndex, dict, defaultTuple, cs); } } - private static AColGroup setupMultiValueZeroColGroup(int[] colIndexes, int rlen, ABitmap ubm, ADictionary dict, + private static AColGroup compressSDCZero(int[] colIndexes, int rlen, ABitmap ubm, ADictionary dict, CompressionSettings cs) { IntArrayList[] offsets = ubm.getOffsetList(); AInsertionSorter s = InsertionSorterFactory.create(rlen, offsets, cs.sdcSortType); AOffset indexes = OffsetFactory.createOffset(s.getIndexes()); AMapToData data = s.getData(); - int[] counts = new int[offsets.length + 1]; - int sum = 0; - for(int i = 0; i < offsets.length; i++) { - counts[i] = offsets[i].size(); - sum += counts[i]; - } - counts[offsets.length] = rlen - sum; - return ColGroupSDCZeros.create(colIndexes, rlen, dict, indexes, data, counts); + return ColGroupSDCZeros.create(colIndexes, rlen, dict, indexes, data, null); } - private static AColGroup setupMultiValueColGroup(int[] colIndexes, int numZeros, int rlen, ABitmap ubm, - int largestIndex, ADictionary dict, CompressionSettings cs) { + private static AColGroup compressSDCNormal(int[] colIndexes, int numZeros, int rlen, ABitmap ubm, int largestIndex, + ADictionary dict, double[] defaultTuple, CompressionSettings cs) { IntArrayList[] offsets = ubm.getOffsetList(); AInsertionSorter s = InsertionSorterFactory.createNegative(rlen, offsets, largestIndex, cs.sdcSortType); AOffset indexes = OffsetFactory.createOffset(s.getIndexes()); AMapToData _data = s.getData(); _data = MapToFactory.resize(_data, _data.getUnique() - 1); - return ColGroupSDC.create(colIndexes, rlen, dict, indexes, _data, null); + return ColGroupSDC.create(colIndexes, rlen, dict, defaultTuple, indexes, _data, null); } - private static AColGroup setupSingleValueSDCColGroup(int[] colIndexes, int rlen, ABitmap ubm, ADictionary dict) { + private static AColGroup compressSDCSingle(int[] colIndexes, int rlen, ABitmap ubm, ADictionary dict, + double[] defaultTuple) { IntArrayList inv = ubm.getOffsetsList(0); int[] indexes = new int[rlen - inv.size()]; int p = 0; @@ -620,7 +473,7 @@ private static AColGroup setupSingleValueSDCColGroup(int[] colIndexes, int rlen, indexes[p++] = v++; AOffset off = OffsetFactory.createOffset(indexes); - return new ColGroupSDCSingle(colIndexes, rlen, dict, off, null); + return new ColGroupSDCSingle(colIndexes, rlen, dict, defaultTuple, off, null); } private static AColGroup compressDDC(int[] colIndexes, int rlen, ABitmap ubm, CompressionSettings cs, @@ -714,18 +567,15 @@ private static AColGroup compressSDCFromSparseTransposedBlock(MatrixBlock mb, in if(entries[0].count < rlen - sb.size(sbRow)) { // If the zero is the default value. - final int[] counts = new int[entries.length + 1]; + final int[] counts = new int[entries.length]; final double[] dict = new double[entries.length]; - int sum = 0; for(int i = 0; i < entries.length; i++) { final DCounts x = entries[i]; counts[i] = x.count; - sum += x.count; dict[i] = x.key; x.count = i; } - counts[entries.length] = rlen - sum; final AOffset offsets = OffsetFactory.createOffset(sb.indexes(sbRow), apos, alen); if(entries.length <= 1) return new ColGroupSDCSingleZeros(cols, rlen, new Dictionary(dict), offsets, counts); @@ -743,6 +593,65 @@ private static AColGroup compressSDCFromSparseTransposedBlock(MatrixBlock mb, in } } + static class CompressTask implements Callable> { + private final MatrixBlock _in; + private final List _groups; + private final CompressionSettings _compSettings; + private final int _k; + + protected CompressTask(MatrixBlock in, List groups, CompressionSettings compSettings, + int k) { + _in = in; + _groups = groups; + _compSettings = compSettings; + _k = k; + } + + @Override + public Collection call() { + try { + ArrayList res = new ArrayList<>(); + Tmp tmpMap = new Tmp(); + for(CompressedSizeInfoColGroup g : _groups) + res.addAll(compressColGroup(_in, _compSettings, tmpMap, g, _k)); + return res; + } + catch(Exception e) { + e.printStackTrace(); + throw e; + } + } + } + + static class readToMapDDCTask implements Callable { + private final int[] _colIndexes; + private final MatrixBlock _raw; + private final DblArrayCountHashMap _map; + private final CompressionSettings _cs; + private final AMapToData _data; + private final int _rl; + private final int _ru; + private final int _fill; + + protected readToMapDDCTask(int[] colIndexes, MatrixBlock raw, DblArrayCountHashMap map, CompressionSettings cs, + AMapToData data, int rl, int ru, int fill) { + _colIndexes = colIndexes; + _raw = raw; + _map = map; + _cs = cs; + _data = data; + _rl = rl; + _ru = ru; + _fill = fill; + } + + @Override + public Boolean call() { + return Boolean.valueOf(readToMapDDC(_colIndexes, _raw, _map, _cs, _data, _rl, _ru, _fill)); + } + } + + /** * Temp reuse object, to contain intermediates for compressing column groups that can be used by the same thread * again for subsequent compressions. diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java index 67b364ac575..78011cf68fb 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java @@ -160,14 +160,18 @@ public AColGroup scalarOperation(ScalarOperator op) { return new ColGroupOLE(_colIndexes, _numRows, false, _dict.applyScalarOp(op), _data, _ptr, getCachedCounts()); } - ADictionary rvalues = _dict.applyScalarOp(op, val0, getNumCols()); - char[] lbitmap = genOffsetBitmap(loff, loff.length); - char[] rbitmaps = Arrays.copyOf(_data, _data.length + lbitmap.length); - System.arraycopy(lbitmap, 0, rbitmaps, _data.length, lbitmap.length); - int[] rbitmapOffs = Arrays.copyOf(_ptr, _ptr.length + 1); - rbitmapOffs[rbitmapOffs.length - 1] = rbitmaps.length; - - return new ColGroupOLE(_colIndexes, _numRows, false, rvalues, rbitmaps, rbitmapOffs, getCachedCounts()); + throw new NotImplementedException( + "Not implemented because dictionaries no longer should support extending by a tuple" + + " Ideally implement a modification such that OLE becomes SDC group when materializing Zero tuples"); + + // ADictionary rvalues = _dict.applyScalarOp(op, val0, getNumCols()); + // char[] lbitmap = genOffsetBitmap(loff, loff.length); + // char[] rbitmaps = Arrays.copyOf(_data, _data.length + lbitmap.length); + // System.arraycopy(lbitmap, 0, rbitmaps, _data.length, lbitmap.length); + // int[] rbitmapOffs = Arrays.copyOf(_ptr, _ptr.length + 1); + // rbitmapOffs[rbitmapOffs.length - 1] = rbitmaps.length; + + // return new ColGroupOLE(_colIndexes, _numRows, false, rvalues, rbitmaps, rbitmapOffs, getCachedCounts()); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupPFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupPFOR.java index a39b17f2de1..3a38514becb 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupPFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupPFOR.java @@ -36,7 +36,9 @@ import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; /** @@ -57,7 +59,7 @@ public class ColGroupPFOR extends AMorphingMMColGroup { protected AOffset _indexes; /** Pointers to row indexes in the dictionary. */ - protected transient AMapToData _data; + protected AMapToData _data; /** Reference values in this column group */ protected double[] _reference; @@ -94,7 +96,7 @@ protected static AColGroup create(int[] colIndices, int numRows, ADictionary dic if(allZero) return new ColGroupEmpty(colIndices); else - return ColGroupFactory.genColGroupConst(colIndices, reference); + return ColGroupConst.create(colIndices, reference); } return new ColGroupPFOR(colIndices, numRows, dict, indexes, data, cachedCounts, reference); } @@ -118,7 +120,7 @@ public ColGroupType getColGroupType() { @Override public int[] getCounts(int[] counts) { - return _data.getCounts(counts, _numRows); + return _data.getCounts(counts); } private final double refSum() { @@ -128,6 +130,13 @@ private final double refSum() { return ret; } + private final double refSumSq() { + double ret = 0; + for(double d : _reference) + ret += d * d; + return ret; + } + @Override protected void computeRowSums(double[] c, int rl, int ru, double[] preAgg) { ColGroupSDC.computeRowSums(c, rl, ru, preAgg, _data, _indexes, _numRows); @@ -161,7 +170,7 @@ else if(op.fn instanceof Multiply || op.fn instanceof Divide) { return create(_colIndexes, _numRows, newDict, _indexes, _data, getCachedCounts(), newRef); } else { - final ADictionary newDict = _dict.applyScalarOp(op, _reference, newRef); + final ADictionary newDict = _dict.applyScalarOpWithReference(op, _reference, newRef); return create(_colIndexes, _numRows, newDict, _indexes, _data, getCachedCounts(), newRef); } } @@ -172,14 +181,15 @@ public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSaf for(int i = 0; i < _reference.length; i++) newRef[i] = op.fn.execute(v[_colIndexes[i]], _reference[i]); - if(op.fn instanceof Plus || op.fn instanceof Minus) + if(op.fn instanceof Plus || op.fn instanceof Minus) // only edit reference return create(_colIndexes, _numRows, _dict, _indexes, _data, getCachedCounts(), newRef); else if(op.fn instanceof Multiply || op.fn instanceof Divide) { + // possible to simply process on dict and keep reference final ADictionary newDict = _dict.binOpLeft(op, v, _colIndexes); return create(_colIndexes, _numRows, newDict, _indexes, _data, getCachedCounts(), newRef); } - else { - final ADictionary newDict = _dict.binOpLeft(op, v, _colIndexes, _reference, newRef); + else { // have to apply reference while processing + final ADictionary newDict = _dict.binOpLeftWithReference(op, v, _colIndexes, _reference, newRef); return create(_colIndexes, _numRows, newDict, _indexes, _data, getCachedCounts(), newRef); } } @@ -189,14 +199,16 @@ public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSa final double[] newRef = new double[_reference.length]; for(int i = 0; i < _reference.length; i++) newRef[i] = op.fn.execute(_reference[i], v[_colIndexes[i]]); - if(op.fn instanceof Plus || op.fn instanceof Minus) + + if(op.fn instanceof Plus || op.fn instanceof Minus)// only edit reference return new ColGroupPFOR(_colIndexes, _numRows, _dict, _indexes, _data, getCachedCounts(), newRef); else if(op.fn instanceof Multiply || op.fn instanceof Divide) { + // possible to simply process on dict and keep reference final ADictionary newDict = _dict.binOpRight(op, v, _colIndexes); return new ColGroupPFOR(_colIndexes, _numRows, newDict, _indexes, _data, getCachedCounts(), newRef); } - else { - final ADictionary newDict = _dict.binOpRight(op, v, _colIndexes, _reference, newRef); + else { // have to apply reference while processing + final ADictionary newDict = _dict.binOpRightWithReference(op, v, _colIndexes, _reference, newRef); return new ColGroupPFOR(_colIndexes, _numRows, newDict, _indexes, _data, getCachedCounts(), newRef); } } @@ -229,6 +241,15 @@ public long getExactSizeOnDisk() { return ret; } + @Override + public long estimateInMemorySize() { + long size = super.estimateInMemorySize(); + size += _indexes.getInMemorySize(); + size += _data.getInMemorySize(); + size += 8 * _colIndexes.length; + return size; + } + @Override public AColGroup replace(double pattern, double replace) { boolean patternInReference = false; @@ -243,7 +264,7 @@ public AColGroup replace(double pattern, double replace) { // _dict.replace(pattern, replace, _reference, _newReplace); } else { - final ADictionary newDict = _dict.replace(pattern, replace, _reference); + final ADictionary newDict = _dict.replaceWithReference(pattern, replace, _reference); return create(_colIndexes, _numRows, newDict, _indexes, _data, getCachedCounts(), _reference); } @@ -264,46 +285,58 @@ public String toString() { @Override protected double computeMxx(double c, Builtin builtin) { - return _dict.aggregate(c, builtin, _reference); + return _dict.aggregateWithReference(c, builtin, _reference); } @Override protected void computeColMxx(double[] c, Builtin builtin) { - _dict.aggregateCols(c, builtin, _colIndexes, _reference); + _dict.aggregateColsWithReference(c, builtin, _colIndexes, _reference); } @Override protected void computeSum(double[] c, int nRows) { + // trick,use normal sum super.computeSum(c, nRows); + // and add all sum of reference multiplied with nrows. final double refSum = refSum(); c[0] += refSum * nRows; } @Override public void computeColSums(double[] c, int nRows) { + // trick, use the normal sum super.computeColSums(c, nRows); + // and add reference multiplied with number of rows. for(int i = 0; i < _colIndexes.length; i++) c[_colIndexes[i]] += _reference[i] * nRows; } @Override protected void computeSumSq(double[] c, int nRows) { - c[0] += _dict.sumSq(getCounts(), _reference); + // square sum the dictionary. + c[0] += _dict.sumSqWithReference(getCounts(), _reference); + final double refSum = refSumSq(); + // Square sum of the reference values only for the rows that is not represented in the Offsets. + c[0] += refSum * (_numRows - _data.size()); } @Override protected void computeColSumsSq(double[] c, int nRows) { - _dict.colSumSq(c, getCounts(), _colIndexes, _reference); + // square sum the dictionary + _dict.colSumSqWithReference(c, getCounts(), _colIndexes, _reference); + // Square sum of the reference values only for the rows that is not represented in the Offsets. + for(int i = 0; i < _colIndexes.length; i++) // correct for the reference sum. + c[_colIndexes[i]] += _reference[i] * _reference[i] * (_numRows - _data.size()); } @Override protected double[] preAggSumRows() { - return _dict.sumAllRowsToDouble(_reference); + return _dict.sumAllRowsToDoubleWithReference(_reference); } @Override protected double[] preAggSumSqRows() { - return _dict.sumAllRowsToDoubleSq(_reference); + return _dict.sumAllRowsToDoubleSqWithReference(_reference); } @Override @@ -313,7 +346,7 @@ protected double[] preAggProductRows() { @Override protected double[] preAggBuiltinRows(Builtin builtin) { - return _dict.aggregateRows(builtin, _reference); + return _dict.aggregateRowsWithReference(builtin, _reference); } @Override @@ -358,7 +391,7 @@ public boolean containsValue(double pattern) { else if(Double.isNaN(pattern) || Double.isInfinite(pattern)) return containsInfOrNan(pattern) || _dict.containsValue(pattern); else - return _dict.containsValue(pattern, _reference); + return _dict.containsValueWithReference(pattern, _reference); } private boolean containsInfOrNan(double pattern) { @@ -378,8 +411,12 @@ private boolean containsInfOrNan(double pattern) { @Override public long getNumberNonZeros(int nRows) { - int[] counts = getCounts(); - return (long) _dict.getNumberNonZeros(counts, _reference, nRows); + final int[] counts = getCounts(); + final int count = _numRows - _data.size(); + long c = _dict.getNumberNonZerosWithReference(counts, _reference, nRows); + for(int x = 0; x < _colIndexes.length; x++) + c += _reference[x] != 0 ? count : 0; + return c; } @Override @@ -388,4 +425,19 @@ public AColGroup extractCommon(double[] constV) { constV[_colIndexes[i]] += _reference[i]; return ColGroupSDCZeros.create(_colIndexes, _numRows, _dict, _indexes, _data, getCounts()); } + + @Override + public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { + ADictionary d = _dict.rexpandColsWithReference(max, ignore, cast, _reference[0]); + return ColGroupSDC.rexpandCols(max, ignore, cast, nRows, d, _indexes, _data, getCachedCounts(), _reference[0]); + } + + @Override + public CM_COV_Object centralMoment(CMOperator op, int nRows) { + // should be guaranteed to be one column therefore only one reference value. + CM_COV_Object ret = _dict.centralMomentWithReference(op.fn, getCounts(), _reference[0], nRows); + int count = _numRows - _data.size(); + op.fn.execute(ret, _reference[0], count); + return ret; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java index ed2896813e7..1a6254233b4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java @@ -156,14 +156,18 @@ public AColGroup scalarOperation(ScalarOperator op) { return new ColGroupRLE(_colIndexes, _numRows, false, _dict.applyScalarOp(op), _data, _ptr, getCachedCounts()); } - ADictionary rvalues = _dict.applyScalarOp(op, val0, getNumCols()); - char[] lbitmap = genRLEBitmap(loff, loff.length); - - char[] rbitmaps = Arrays.copyOf(_data, _data.length + lbitmap.length); - System.arraycopy(lbitmap, 0, rbitmaps, _data.length, lbitmap.length); - int[] rbitmapOffs = Arrays.copyOf(_ptr, _ptr.length + 1); - rbitmapOffs[rbitmapOffs.length - 1] = rbitmaps.length; - return new ColGroupRLE(_colIndexes, _numRows, false, rvalues, rbitmaps, rbitmapOffs, getCachedCounts()); + throw new NotImplementedException( + "Not implemented because dictionaries no longer should support extending by a tuple" + + " Ideally implement a modification such that RLE becomes SDC group when materializing Zero tuples"); + + // ADictionary rvalues = _dict.applyScalarOp(op, val0, getNumCols()); + // char[] lbitmap = genRLEBitmap(loff, loff.length); + + // char[] rbitmaps = Arrays.copyOf(_data, _data.length + lbitmap.length); + // System.arraycopy(lbitmap, 0, rbitmaps, _data.length, lbitmap.length); + // int[] rbitmapOffs = Arrays.copyOf(_ptr, _ptr.length + 1); + // rbitmapOffs[rbitmapOffs.length - 1] = rbitmaps.length; + // return new ColGroupRLE(_colIndexes, _numRows, false, rvalues, rbitmaps, rbitmapOffs, getCachedCounts()); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index e825756daf3..539f9e71689 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -22,21 +22,27 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.util.Arrays; +import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; +import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; /** * Column group that sparsely encodes the dictionary values. The idea is that all values is encoded with indexes except - * the most common one. the most common one can be inferred by not being included in the indexes. If the values are very - * sparse then the most common one is zero. + * the most common one. the most common one can be inferred by not being included in the indexes. * * This column group is handy in cases where sparse unsafe operations is executed on very sparse columns. Then the zeros * would be materialized in the group without any overhead. @@ -45,9 +51,11 @@ public class ColGroupSDC extends AMorphingMMColGroup { private static final long serialVersionUID = 769993538831949086L; /** Sparse row indexes for the data */ - protected transient AOffset _indexes; - /** Pointers to row indexes in the dictionary. Note the dictionary has one extra entry. */ - protected transient AMapToData _data; + protected AOffset _indexes; + /** Pointers to row indexes in the dictionary. */ + protected AMapToData _data; + /** The default value stored in this column group */ + protected double[] _defaultTuple; /** * Constructor for serialization @@ -58,20 +66,28 @@ protected ColGroupSDC(int numRows) { super(numRows); } - private ColGroupSDC(int[] colIndices, int numRows, ADictionary dict, AOffset offsets, AMapToData data, - int[] cachedCounts) { + private ColGroupSDC(int[] colIndices, int numRows, ADictionary dict, double[] defaultTuple, AOffset offsets, + AMapToData data, int[] cachedCounts) { super(colIndices, numRows, dict, cachedCounts); _indexes = offsets; _data = data; _zeros = false; + _defaultTuple = defaultTuple; } - protected static AColGroup create(int[] colIndices, int numRows, ADictionary dict, AOffset offsets, AMapToData data, - int[] cachedCounts) { + protected static AColGroup create(int[] colIndices, int numRows, ADictionary dict, double[] defaultTuple, + AOffset offsets, AMapToData data, int[] cachedCounts) { if(dict == null) return new ColGroupEmpty(colIndices); - else - return new ColGroupSDC(colIndices, numRows, dict, offsets, data, cachedCounts); + else { + boolean allZero = true; + for(double d : defaultTuple) + allZero &= d == 0; + if(allZero) + return ColGroupSDCZeros.create(colIndices, numRows, dict, offsets, data, cachedCounts); + else + return new ColGroupSDC(colIndices, numRows, dict, defaultTuple, offsets, data, cachedCounts); + } } @Override @@ -87,9 +103,55 @@ public ColGroupType getColGroupType() { @Override public double getIdx(int r, int colIdx) { final AIterator it = _indexes.getIterator(r); - final int rowOff = it == null || it.value() != r ? getNumValues() - 1 : _data.getIndex(it.getDataIndex()); - final int nCol = _colIndexes.length; - return _dict.getValue(rowOff * nCol + colIdx); + if(it == null || it.value() != r) + return _defaultTuple[colIdx]; + + else { + final int rowOff = _data.getIndex(it.getDataIndex()); + final int nCol = _colIndexes.length; + return _dict.getValue(rowOff * nCol + colIdx); + } + } + + @Override + public ADictionary getDictionary() { + throw new NotImplementedException( + "Not implemented getting the dictionary out, and i think we should consider removing the option"); + } + + @Override + protected double[] preAggSumRows() { + return _dict.sumAllRowsToDoubleWithDefault(_defaultTuple); + } + + @Override + protected double[] preAggSumSqRows() { + return _dict.sumAllRowsToDoubleSqWithDefault(_defaultTuple); + } + + @Override + protected double[] preAggProductRows() { + throw new NotImplementedException("Should implement preAgg with extra cell"); + } + + @Override + protected double[] preAggBuiltinRows(Builtin builtin) { + return _dict.aggregateRowsWithDefault(builtin, _defaultTuple); + } + + @Override + protected double computeMxx(double c, Builtin builtin) { + double ret = _dict.aggregate(c, builtin); + for(int i = 0; i < _defaultTuple.length; i++) + ret = builtin.execute(ret, _defaultTuple[i]); + return ret; + } + + @Override + protected void computeColMxx(double[] c, Builtin builtin) { + _dict.aggregateCols(c, builtin, _colIndexes); + for(int x = 0; x < _colIndexes.length; x++) + c[_colIndexes[x]] = builtin.execute(c[_colIndexes[x]], _defaultTuple[x]); } @Override @@ -123,7 +185,7 @@ else if(it != null && ru >= indexes.getOffsetToLast()) { } else if(it != null) { while(r < ru) { - if(it.value() == r){ + if(it.value() == r) { c[r] += preAgg[data.getIndex(it.getDataIndex())]; it.next(); } @@ -145,7 +207,7 @@ protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double computeRowMxx(c, builtin, rl, ru, preAgg, _data, _indexes, _numRows, preAgg[preAgg.length - 1]); } - protected static final void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] vals, + protected static final void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] preAgg, AMapToData data, AOffset indexes, int nRows, double def) { int r = rl; final AIterator it = indexes.getIterator(rl); @@ -155,7 +217,7 @@ else if(it != null && ru >= indexes.getOffsetToLast()) { final int maxId = data.size() - 1; while(true) { if(it.value() == r) { - c[r] = builtin.execute(c[r], vals[data.getIndex(it.getDataIndex())]); + c[r] = builtin.execute(c[r], preAgg[data.getIndex(it.getDataIndex())]); if(it.getDataIndex() < maxId) it.next(); else { @@ -170,8 +232,8 @@ else if(it != null && ru >= indexes.getOffsetToLast()) { } else if(it != null) { while(r < ru) { - if(it.value() == r){ - c[r] = builtin.execute(c[r], vals[data.getIndex(it.getDataIndex())]); + if(it.value() == r) { + c[r] = builtin.execute(c[r], preAgg[data.getIndex(it.getDataIndex())]); it.next(); } else @@ -187,9 +249,69 @@ else if(it != null) { } } + @Override + protected void computeSum(double[] c, int nRows) { + super.computeSum(c, nRows); + int count = _numRows - _data.size(); + for(int x = 0; x < _defaultTuple.length; x++) + c[0] += _defaultTuple[x] * count; + } + + @Override + public void computeColSums(double[] c, int nRows) { + super.computeColSums(c, nRows); + int count = _numRows - _data.size(); + for(int x = 0; x < _colIndexes.length; x++) + c[_colIndexes[x]] += _defaultTuple[x] * count; + } + + @Override + protected void computeSumSq(double[] c, int nRows) { + super.computeSumSq(c, nRows); + int count = _numRows - _data.size(); + for(int x = 0; x < _colIndexes.length; x++) + c[0] += _defaultTuple[x] * _defaultTuple[x] * count; + } + + @Override + protected void computeColSumsSq(double[] c, int nRows) { + super.computeColSumsSq(c, nRows); + int count = _numRows - _data.size(); + for(int x = 0; x < _colIndexes.length; x++) + c[_colIndexes[x]] += _defaultTuple[x] * _defaultTuple[x] * count; + } + + @Override + protected void computeProduct(double[] c, int nRows) { + super.computeProduct(c, nRows); + for(int x = 0; x < _colIndexes.length; x++) + c[0] *= _defaultTuple[x]; + } + + @Override + protected void computeColProduct(double[] c, int nRows) { + super.computeColProduct(c, nRows); + for(int x = 0; x < _colIndexes.length; x++) + c[_colIndexes[x]] *= _defaultTuple[x]; + } + + @Override + protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) { + throw new NotImplementedException(); + } + @Override public int[] getCounts(int[] counts) { - return _data.getCounts(counts, _numRows); + return _data.getCounts(counts); + } + + @Override + public long getNumberNonZeros(int nRows) { + long c = super.getNumberNonZeros(nRows); + int count = _numRows - _data.size(); + for(int x = 0; x < _colIndexes.length; x++) + c += _defaultTuple[x] != 0 ? count : 0; + return c; } @Override @@ -197,24 +319,36 @@ public long estimateInMemorySize() { long size = super.estimateInMemorySize(); size += _indexes.getInMemorySize(); size += _data.getInMemorySize(); + size += 8 * _colIndexes.length; return size; } @Override public AColGroup scalarOperation(ScalarOperator op) { - return create(_colIndexes, _numRows, _dict.applyScalarOp(op), _indexes, _data, getCachedCounts()); + final double[] newDefaultTuple = new double[_defaultTuple.length]; + for(int i = 0; i < _defaultTuple.length; i++) + newDefaultTuple[i] = op.executeScalar(_defaultTuple[i]); + + return create(_colIndexes, _numRows, _dict.applyScalarOp(op), newDefaultTuple, _indexes, _data, + getCachedCounts()); } @Override public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) { - ADictionary ret = _dict.binOpLeft(op, v, _colIndexes); - return create(_colIndexes, _numRows, ret, _indexes, _data, getCachedCounts()); + final double[] newDefaultTuple = new double[_defaultTuple.length]; + for(int i = 0; i < _defaultTuple.length; i++) + newDefaultTuple[i] = op.fn.execute(v[_colIndexes[i]], _defaultTuple[i]); + final ADictionary newDict = _dict.binOpLeft(op, v, _colIndexes); + return create(_colIndexes, _numRows, newDict, newDefaultTuple, _indexes, _data, getCachedCounts()); } @Override public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSafe) { - ADictionary ret = _dict.binOpRight(op, v, _colIndexes); - return create(_colIndexes, _numRows, ret, _indexes, _data, getCachedCounts()); + final double[] newDefaultTuple = new double[_defaultTuple.length]; + for(int i = 0; i < _defaultTuple.length; i++) + newDefaultTuple[i] = op.fn.execute(_defaultTuple[i], v[_colIndexes[i]]); + final ADictionary newDict = _dict.binOpRight(op, v, _colIndexes); + return create(_colIndexes, _numRows, newDict, newDefaultTuple, _indexes, _data, getCachedCounts()); } @Override @@ -222,6 +356,8 @@ public void write(DataOutput out) throws IOException { super.write(out); _indexes.write(out); _data.write(out); + for(double d : _defaultTuple) + out.writeDouble(d); } @Override @@ -229,6 +365,9 @@ public void readFields(DataInput in) throws IOException { super.readFields(in); _indexes = OffsetFactory.readIn(in); _data = MapToFactory.readIn(in); + _defaultTuple = new double[_colIndexes.length]; + for(int i = 0; i < _colIndexes.length; i++) + _defaultTuple[i] = in.readDouble(); } @Override @@ -236,26 +375,82 @@ public long getExactSizeOnDisk() { long ret = super.getExactSizeOnDisk(); ret += _data.getExactSizeOnDisk(); ret += _indexes.getExactSizeOnDisk(); + ret += 8 * _colIndexes.length; // _default tuple values. return ret; } @Override - public AColGroup extractCommon(double[] constV) { - double[] commonV = _dict.getTuple(getNumValues() - 1, _colIndexes.length); - if(commonV == null) // The common tuple was all zero. Therefore this column group should never have been SDC. - return ColGroupSDCZeros.create(_colIndexes, _numRows, _dict, _indexes, _data, getCounts()); + public AColGroup replace(double pattern, double replace) { + ADictionary replaced = _dict.replace(pattern, replace, _colIndexes.length); + double[] newDefaultTuple = new double[_defaultTuple.length]; + for(int i = 9; i < _defaultTuple.length; i++) + if(_defaultTuple[i] == pattern) + newDefaultTuple[i] = replace; + else + newDefaultTuple[i] = _defaultTuple[i]; + + return create(_colIndexes, _numRows, replaced, newDefaultTuple, _indexes, _data, getCachedCounts()); + } + @Override + public AColGroup extractCommon(double[] constV) { for(int i = 0; i < _colIndexes.length; i++) - constV[_colIndexes[i]] += commonV[i]; + constV[_colIndexes[i]] += _defaultTuple[i]; - ADictionary subtractedDict = _dict.subtractTuple(commonV); + ADictionary subtractedDict = _dict.subtractTuple(_defaultTuple); return ColGroupSDCZeros.create(_colIndexes, _numRows, subtractedDict, _indexes, _data, getCounts()); } + @Override + public CM_COV_Object centralMoment(CMOperator op, int nRows) { + CM_COV_Object ret = super.centralMoment(op, nRows); + int count = _numRows - _data.size(); + op.fn.execute(ret, _defaultTuple[0], count); + return ret; + } + + @Override + public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { + ADictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.length); + return rexpandCols(max, ignore, cast, nRows, d, _indexes, _data, getCachedCounts(), _defaultTuple[0]); + } + + protected static AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows, ADictionary d, + AOffset indexes, AMapToData data, int[] counts, double def) { + // final double def = _defaultTuple[0]; + if(d == null) { + if(def <= 0 || def > max) + return ColGroupEmpty.create(max); + else { + double[] retDef = new double[max]; + retDef[((int) def) - 1] = 1; + return new ColGroupSDCSingle(Util.genColsIndices(max), nRows, new Dictionary(new double[max]), retDef, + indexes, null); + } + } + else { + if(def <= 0) { + if(ignore) + return ColGroupSDCZeros.create(Util.genColsIndices(max), nRows, d, indexes, data, counts); + else + throw new DMLRuntimeException("Invalid content of zero in rexpand"); + } + else if(def > max) + return ColGroupSDCZeros.create(Util.genColsIndices(max), nRows, d, indexes, data, counts); + else { + double[] retDef = new double[max]; + retDef[((int) def) - 1] = 1; + return ColGroupSDC.create(Util.genColsIndices(max), nRows, d, retDef, indexes, data, counts); + } + } + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append(super.toString()); + sb.append(String.format("\n%15s", "Default: ")); + sb.append(Arrays.toString(_defaultTuple)); sb.append(String.format("\n%15s", "Indexes: ")); sb.append(_indexes.toString()); sb.append(String.format("\n%15s", "Data: ")); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index a00841dfefd..ee31b71dfe2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -23,12 +23,18 @@ import java.io.DataOutput; import java.io.IOException; +import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; +import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; /** @@ -42,7 +48,9 @@ public class ColGroupSDCSingle extends AMorphingMMColGroup { private static final long serialVersionUID = 3883228464052204200L; /** Sparse row indexes for the data */ - protected transient AOffset _indexes; + protected AOffset _indexes; + /** The default value stored in this column group */ + protected double[] _defaultTuple; /** * Constructor for serialization @@ -53,10 +61,12 @@ protected ColGroupSDCSingle(int numRows) { super(numRows); } - protected ColGroupSDCSingle(int[] colIndices, int numRows, ADictionary dict, AOffset offsets, int[] cachedCounts) { + protected ColGroupSDCSingle(int[] colIndices, int numRows, ADictionary dict, double[] defaultTuple, AOffset offsets, + int[] cachedCounts) { super(colIndices, numRows, dict, cachedCounts); _indexes = offsets; _zeros = false; + _defaultTuple = defaultTuple; } @Override @@ -73,20 +83,66 @@ public ColGroupType getColGroupType() { public double getIdx(int r, int colIdx) { final AIterator it = _indexes.getIterator(r); if(it == null || it.value() != r) - return _dict.getValue(_colIndexes.length + colIdx); - return _dict.getValue(colIdx); + return _defaultTuple[colIdx]; + else + return _dict.getValue(colIdx); } @Override - protected void computeRowSums(double[] c, int rl, int ru, double[] vals) { + public ADictionary getDictionary() { + throw new NotImplementedException( + "Not implemented getting the dictionary out, and i think we should consider removing the option"); + } + + @Override + protected double[] preAggSumRows() { + return _dict.sumAllRowsToDoubleWithDefault(_defaultTuple); + } + + @Override + protected double[] preAggSumSqRows() { + return _dict.sumAllRowsToDoubleSqWithDefault(_defaultTuple); + } + + @Override + protected double[] preAggProductRows() { + throw new NotImplementedException("Should implement preAgg with extra cell"); + } + + @Override + protected double[] preAggBuiltinRows(Builtin builtin) { + return _dict.aggregateRowsWithDefault(builtin, _defaultTuple); + } + + @Override + protected double computeMxx(double c, Builtin builtin) { + double ret = _dict.aggregate(c, builtin); + for(int i = 0; i < _defaultTuple.length; i++) + ret = builtin.execute(ret, _defaultTuple[i]); + return ret; + } + + @Override + protected void computeColMxx(double[] c, Builtin builtin) { + _dict.aggregateCols(c, builtin, _colIndexes); + for(int x = 0; x < _colIndexes.length; x++) + c[_colIndexes[x]] = builtin.execute(c[_colIndexes[x]], _defaultTuple[x]); + } + + @Override + protected void computeRowSums(double[] c, int rl, int ru, double[] preAgg) { + computeRowSums(c, rl, ru, preAgg, _indexes, _numRows); + } + + protected static final void computeRowSums(double[] c, int rl, int ru, double[] preAgg, AOffset indexes, int nRows) { int r = rl; - final AIterator it = _indexes.getIterator(rl); - final double def = vals[1]; - final double norm = vals[0]; + final AIterator it = indexes.getIterator(rl); + final double def = preAgg[1]; + final double norm = preAgg[0]; if(it != null && it.value() > ru) - _indexes.cacheIterator(it, ru); - else if(it != null && ru >= _indexes.getOffsetToLast()) { - final int maxOff = _indexes.getOffsetToLast(); + indexes.cacheIterator(it, ru); + else if(it != null && ru >= indexes.getOffsetToLast()) { + final int maxOff = indexes.getOffsetToLast(); while(true) { if(it.value() == r) { c[r] += norm; @@ -110,7 +166,7 @@ else if(it != null) { c[r] += def; r++; } - _indexes.cacheIterator(it, ru); + indexes.cacheIterator(it, ru); } while(r < ru) { @@ -166,10 +222,60 @@ else if(it != null) { } } + @Override + protected void computeSum(double[] c, int nRows) { + super.computeSum(c, nRows); + int count = _numRows - getCounts()[0]; + for(int x = 0; x < _defaultTuple.length; x++) + c[0] += _defaultTuple[x] * count; + } + + @Override + public void computeColSums(double[] c, int nRows) { + super.computeColSums(c, nRows); + int count = _numRows - getCounts()[0]; + for(int x = 0; x < _colIndexes.length; x++) + c[_colIndexes[x]] += _defaultTuple[x] * count; + } + + @Override + protected void computeSumSq(double[] c, int nRows) { + super.computeSumSq(c, nRows); + int count = _numRows - getCounts()[0]; + for(int x = 0; x < _colIndexes.length; x++) + c[0] += _defaultTuple[x] * _defaultTuple[x] * count; + } + + @Override + protected void computeColSumsSq(double[] c, int nRows) { + super.computeColSumsSq(c, nRows); + int count = _numRows - getCounts()[0]; + for(int x = 0; x < _colIndexes.length; x++) + c[_colIndexes[x]] += _defaultTuple[x] * _defaultTuple[x] * count; + } + + @Override + protected void computeProduct(double[] c, int nRows) { + super.computeProduct(c, nRows); + for(int x = 0; x < _colIndexes.length; x++) + c[0] *= _defaultTuple[x]; + } + + @Override + protected void computeColProduct(double[] c, int nRows) { + super.computeColProduct(c, nRows); + for(int x = 0; x < _colIndexes.length; x++) + c[_colIndexes[x]] *= _defaultTuple[x]; + } + + @Override + protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) { + throw new NotImplementedException(); + } + @Override public int[] getCounts(int[] counts) { counts[0] = _indexes.getSize(); - counts[1] = _numRows - counts[0]; return counts; } @@ -177,59 +283,117 @@ public int[] getCounts(int[] counts) { public long estimateInMemorySize() { long size = super.estimateInMemorySize(); size += _indexes.getInMemorySize(); + size += 8 * _colIndexes.length; return size; } @Override public AColGroup scalarOperation(ScalarOperator op) { - return new ColGroupSDCSingle(_colIndexes, _numRows, _dict.applyScalarOp(op), _indexes, getCachedCounts()); + final double[] newDefaultTuple = new double[_defaultTuple.length]; + for(int i = 0; i < _defaultTuple.length; i++) + newDefaultTuple[i] = op.executeScalar(_defaultTuple[i]); + return new ColGroupSDCSingle(_colIndexes, _numRows, _dict.applyScalarOp(op), newDefaultTuple, _indexes, + getCachedCounts()); } @Override public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) { - ADictionary ret = _dict.binOpLeft(op, v, _colIndexes); - return new ColGroupSDCSingle(_colIndexes, _numRows, ret, _indexes, getCachedCounts()); + final double[] newDefaultTuple = new double[_defaultTuple.length]; + for(int i = 0; i < _defaultTuple.length; i++) + newDefaultTuple[i] = op.fn.execute(v[_colIndexes[i]], _defaultTuple[i]); + final ADictionary newDict = _dict.binOpLeft(op, v, _colIndexes); + return new ColGroupSDCSingle(_colIndexes, _numRows, newDict, newDefaultTuple, _indexes, getCachedCounts()); } @Override public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSafe) { - ADictionary ret = _dict.binOpRight(op, v, _colIndexes); - return new ColGroupSDCSingle(_colIndexes, _numRows, ret, _indexes, getCachedCounts()); + final double[] newDefaultTuple = new double[_defaultTuple.length]; + for(int i = 0; i < _defaultTuple.length; i++) + newDefaultTuple[i] = op.fn.execute(_defaultTuple[i], v[_colIndexes[i]]); + final ADictionary newDict = _dict.binOpRight(op, v, _colIndexes); + return new ColGroupSDCSingle(_colIndexes, _numRows, newDict, newDefaultTuple, _indexes, getCachedCounts()); } @Override public void write(DataOutput out) throws IOException { super.write(out); _indexes.write(out); + for(double d : _defaultTuple) + out.writeDouble(d); + } @Override public void readFields(DataInput in) throws IOException { super.readFields(in); _indexes = OffsetFactory.readIn(in); + _defaultTuple = new double[_colIndexes.length]; + for(int i = 0; i < _colIndexes.length; i++) + _defaultTuple[i] = in.readDouble(); } @Override public long getExactSizeOnDisk() { long ret = super.getExactSizeOnDisk(); ret += _indexes.getExactSizeOnDisk(); + ret += 8 * _colIndexes.length; // _default tuple values. return ret; } @Override public ColGroupSDCSingleZeros extractCommon(double[] constV) { - double[] commonV = _dict.getTuple(getNumValues() - 1, _colIndexes.length); - - if(commonV == null) // The common tuple was all zero. Therefore this column group should never have been SDC. - return new ColGroupSDCSingleZeros(_colIndexes, _numRows, _dict, _indexes, getCachedCounts()); - for(int i = 0; i < _colIndexes.length; i++) - constV[_colIndexes[i]] += commonV[i]; + constV[_colIndexes[i]] += _defaultTuple[i]; - ADictionary subtractedDict = _dict.subtractTuple(commonV); + ADictionary subtractedDict = _dict.subtractTuple(_defaultTuple); return new ColGroupSDCSingleZeros(_colIndexes, _numRows, subtractedDict, _indexes, getCachedCounts()); } + @Override + public long getNumberNonZeros(int nRows) { + long nnz = super.getNumberNonZeros(nRows); + nnz += _numRows - getCounts()[0]; + return nnz; + } + + @Override + public CM_COV_Object centralMoment(CMOperator op, int nRows) { + CM_COV_Object ret = super.centralMoment(op, nRows); + int count = _numRows - getCounts()[0]; + op.fn.execute(ret, _defaultTuple[0], count); + return ret; + } + + @Override + public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { + ADictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.length); + final double def = _defaultTuple[0]; + if(d == null) { + if(def <= 0 || def > max) + return ColGroupEmpty.create(max); + else { + double[] retDef = new double[max]; + retDef[((int) _defaultTuple[0]) - 1] = 1; + return new ColGroupSDCSingle(Util.genColsIndices(max), nRows, new Dictionary(new double[max]), retDef, _indexes, null); + } + } + else { + if(def <= 0) { + if(ignore) + return new ColGroupSDCSingleZeros(Util.genColsIndices(max), nRows, d, _indexes, getCachedCounts()); + else + throw new DMLRuntimeException("Invalid content of zero in rexpand"); + } + else if(def > max) + return new ColGroupSDCSingleZeros(Util.genColsIndices(max), nRows, d, _indexes, getCachedCounts()); + else { + double[] retDef = new double[max]; + retDef[((int) _defaultTuple[0]) - 1] = 1; + return new ColGroupSDCSingle(Util.genColsIndices(max), nRows, d, retDef, _indexes, getCachedCounts()); + } + } + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java index dd419e540c9..82e200129a0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java @@ -46,9 +46,9 @@ */ public class ColGroupSDCSingleZeros extends APreAgg { private static final long serialVersionUID = 8033235615964315078L; - + /** Sparse row indexes for the data */ - protected transient AOffset _indexes; + protected AOffset _indexes; /** * Constructor for serialization @@ -249,7 +249,6 @@ protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double @Override public int[] getCounts(int[] counts) { counts[0] = _indexes.getSize(); - counts[1] = _numRows - counts[0]; return counts; } @@ -275,7 +274,7 @@ else if(cu < _indexes.getOffsetToLast() + 1) { while(it.value() < cu) { final int start = it.value() + nCol * rl; final int end = it.value() + nCol * ru; - for(int offOut = 0, off = start; off < end; offOut ++, off += nCol) + for(int offOut = 0, off = start; off < end; offOut++, off += nCol) preAV[offOut] += vals[off]; it.next(); } @@ -285,14 +284,14 @@ else if(cu < _indexes.getOffsetToLast() + 1) { int of = it.value(); int start = of + nCol * rl; int end = of + nCol * ru; - for(int offOut = 0, off = start; off < end; offOut ++, off += nCol) + for(int offOut = 0, off = start; off < end; offOut++, off += nCol) preAV[offOut] += vals[off]; while(of < _indexes.getOffsetToLast()) { it.next(); of = it.value(); start = of + nCol * rl; end = of + nCol * ru; - for(int offOut = 0, off = start; off < end; offOut ++, off += nCol) + for(int offOut = 0, off = start; off < end; offOut++, off += nCol) preAV[offOut] += vals[off]; } } @@ -351,9 +350,11 @@ public AColGroup scalarOperation(ScalarOperator op) { if(isSparseSafeOp) return new ColGroupSDCSingleZeros(_colIndexes, _numRows, _dict.applyScalarOp(op), _indexes, getCachedCounts()); else { - ADictionary aDictionary = _dict.applyScalarOp(op, val0, getNumCols());// swapEntries(); - // ADictionary aDictionary = applyScalarOp(op, val0, getNumCols()); - return new ColGroupSDCSingle(_colIndexes, _numRows, aDictionary, _indexes, null); + ADictionary aDictionary = _dict.applyScalarOp(op); + double[] defaultTuple = new double[_colIndexes.length]; + for(int i = 0; i < _colIndexes.length; i++) + defaultTuple[i] = val0; + return new ColGroupSDCSingle(_colIndexes, _numRows, aDictionary, defaultTuple, _indexes, null); } } @@ -364,8 +365,11 @@ public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSaf return new ColGroupSDCSingleZeros(_colIndexes, _numRows, ret, _indexes, getCachedCounts()); } else { - ADictionary ret = _dict.applyBinaryRowOpLeftAppendNewEntry(op, v, _colIndexes); - return new ColGroupSDCSingle(_colIndexes, _numRows, ret, _indexes, getCachedCounts()); + ADictionary newDict = _dict.binOpLeft(op, v, _colIndexes); + double[] defaultTuple = new double[_colIndexes.length]; + for(int i = 0; i < _colIndexes.length; i++) + defaultTuple[i] = op.fn.execute(v[_colIndexes[i]], 0); + return new ColGroupSDCSingle(_colIndexes, _numRows, newDict, defaultTuple, _indexes, getCachedCounts()); } } @@ -376,8 +380,11 @@ public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSa return new ColGroupSDCSingleZeros(_colIndexes, _numRows, ret, _indexes, getCachedCounts()); } else { - ADictionary ret = _dict.applyBinaryRowOpRightAppendNewEntry(op, v, _colIndexes); - return new ColGroupSDCSingle(_colIndexes, _numRows, ret, _indexes, getCachedCounts()); + ADictionary newDict = _dict.binOpRight(op, v, _colIndexes); + double[] defaultTuple = new double[_colIndexes.length]; + for(int i = 0; i < _colIndexes.length; i++) + defaultTuple[i] = op.fn.execute(0, v[_colIndexes[i]]); + return new ColGroupSDCSingle(_colIndexes, _numRows, newDict, defaultTuple, _indexes, getCachedCounts()); } } @@ -493,28 +500,26 @@ else if(itThat.value() < itThis.value()) { else itThis.next(); } - } } @Override - public int getPreAggregateSize(){ + public int getPreAggregateSize() { return 1; } @Override public AColGroup replace(double pattern, double replace) { - if(pattern == 0) - return replaceZero(replace); ADictionary replaced = _dict.replace(pattern, replace, _colIndexes.length); + if(pattern == 0) { + double[] defaultTuple = new double[_colIndexes.length]; + for(int i = 0; i < _colIndexes.length; i++) + defaultTuple[i] = replace; + return new ColGroupSDCSingle(_colIndexes, _numRows, replaced, defaultTuple, _indexes, getCachedCounts()); + } return copyAndSet(replaced); } - private AColGroup replaceZero(double replace) { - ADictionary replaced = _dict.replaceZeroAndExtend(replace, _colIndexes.length); - return new ColGroupSDCSingle(_colIndexes, _numRows, replaced, _indexes, getCachedCounts()); - } - @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index f1dcf2ed842..abb2c888ed8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -51,10 +51,10 @@ public class ColGroupSDCZeros extends APreAgg { private static final long serialVersionUID = -3703199743391937991L; /** Sparse row indexes for the data */ - protected transient AOffset _indexes; + protected AOffset _indexes; /** Pointers to row indexes in the dictionary. Note the dictionary has one extra entry. */ - protected transient AMapToData _data; + protected AMapToData _data; /** * Constructor for serialization @@ -393,7 +393,7 @@ protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double @Override public int[] getCounts(int[] counts) { - return _data.getCounts(counts, _numRows); + return _data.getCounts(counts); } @Override @@ -427,25 +427,37 @@ public AColGroup scalarOperation(ScalarOperator op) { boolean isSparseSafeOp = op.sparseSafe || val0 == 0; if(isSparseSafeOp) return create(_colIndexes, _numRows, _dict.applyScalarOp(op), _indexes, _data, getCachedCounts()); + else if(op.fn instanceof Plus) { + double[] reference = new double[_colIndexes.length]; + for(int i = 0; i < _colIndexes.length; i++) + reference[i] = val0; + return ColGroupPFOR.create(_colIndexes, _numRows, _dict, _indexes, _data, getCachedCounts(), reference); + } else { - ADictionary rValues = _dict.applyScalarOp(op, val0, getNumCols()); - return ColGroupSDC.create(_colIndexes, _numRows, rValues, _indexes, _data, getCachedCounts()); + ADictionary newDict = _dict.applyScalarOp(op); + double[] defaultTuple = new double[_colIndexes.length]; + for(int i = 0; i < _colIndexes.length; i++) + defaultTuple[i] = val0; + return ColGroupSDC.create(_colIndexes, _numRows, newDict, defaultTuple, _indexes, _data, getCachedCounts()); } } @Override public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) { if(isRowSafe) { - ADictionary ret = _dict.binOpLeft(op, v, _colIndexes); - return create(_colIndexes, _numRows, ret, _indexes, _data, getCachedCounts()); + ADictionary newDict = _dict.binOpLeft(op, v, _colIndexes); + return create(_colIndexes, _numRows, newDict, _indexes, _data, getCachedCounts()); } else if(op.fn instanceof Plus) { - double[] def = ColGroupUtils.binaryDefRowLeft(op, v, _colIndexes); - return ColGroupPFOR.create(_colIndexes, _numRows, _dict, _indexes, _data, getCachedCounts(), def); + double[] reference = ColGroupUtils.binaryDefRowLeft(op, v, _colIndexes); + return ColGroupPFOR.create(_colIndexes, _numRows, _dict, _indexes, _data, getCachedCounts(), reference); } else { - ADictionary ret = _dict.applyBinaryRowOpLeftAppendNewEntry(op, v, _colIndexes); - return ColGroupSDC.create(_colIndexes, _numRows, ret, _indexes, _data, getCachedCounts()); + ADictionary newDict = _dict.binOpLeft(op, v, _colIndexes); + double[] defaultTuple = new double[_colIndexes.length]; + for(int i = 0; i < _colIndexes.length; i++) + defaultTuple[i] = op.fn.execute(v[_colIndexes[i]], 0); + return ColGroupSDC.create(_colIndexes, _numRows, newDict, defaultTuple, _indexes, _data, getCachedCounts()); } } @@ -460,8 +472,11 @@ else if(op.fn instanceof Plus) { return ColGroupPFOR.create(_colIndexes, _numRows, _dict, _indexes, _data, getCachedCounts(), def); } else { - ADictionary ret = _dict.applyBinaryRowOpRightAppendNewEntry(op, v, _colIndexes); - return ColGroupSDC.create(_colIndexes, _numRows, ret, _indexes, _data, getCachedCounts()); + ADictionary newDict = _dict.binOpRight(op, v, _colIndexes); + double[] defaultTuple = new double[_colIndexes.length]; + for(int i = 0; i < _colIndexes.length; i++) + defaultTuple[i] = op.fn.execute(0, v[_colIndexes[i]]); + return ColGroupSDC.create(_colIndexes, _numRows, newDict, defaultTuple, _indexes, _data, getCachedCounts()); } } @@ -591,15 +606,15 @@ else if(itThat.value() < itThis.value()) { @Override public AColGroup replace(double pattern, double replace) { - if(pattern == 0) - return replaceZero(replace); ADictionary replaced = _dict.replace(pattern, replace, _colIndexes.length); - return copyAndSet(replaced); - } - - private AColGroup replaceZero(double replace) { - ADictionary replaced = _dict.replaceZeroAndExtend(replace, _colIndexes.length); - return ColGroupSDC.create(_colIndexes, _numRows, replaced, _indexes, _data, getCachedCounts()); + if(pattern == 0) { + double[] defaultTuple = new double[_colIndexes.length]; + for(int i = 0; i < _colIndexes.length; i++) + defaultTuple[i] = replace; + return ColGroupSDC.create(_colIndexes, _numRows, replaced, defaultTuple, _indexes, _data, getCachedCounts()); + } + else + return copyAndSet(replaced); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index ad5bdbe7b9a..4452a9d0dc1 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -32,14 +32,15 @@ import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.ReduceAll; import org.apache.sysds.runtime.functionobjects.ReduceRow; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; -import org.apache.sysds.runtime.util.SortUtils; /** * Column group type for columns that are stored as dense arrays of doubles. Uses a MatrixBlock internally to store the @@ -58,65 +59,71 @@ protected ColGroupUncompressed() { super(); } + private ColGroupUncompressed(MatrixBlock mb, int[] colIndexes) { + super(colIndexes); + _data = mb; + } + + protected static AColGroup create(MatrixBlock mb, int[] colIndexes) { + if(mb.isEmpty()) + return new ColGroupEmpty(colIndexes); + else + return new ColGroupUncompressed(mb, colIndexes); + } + /** * Main constructor for Uncompressed ColGroup. * - * @param colIndicesList Indices (relative to the current block) of the columns that this column group represents. - * @param rawBlock The uncompressed block; uncompressed data must be present at the time that the constructor - * is called - * @param transposed Says if the input matrix raw block have been transposed. + * @param colIndexes Indices (relative to the current block) of the columns that this column group represents. + * @param rawBlock The uncompressed block; uncompressed data must be present at the time that the constructor is + * called + * @param transposed Says if the input matrix raw block have been transposed. + * @return AColGroup. */ - public ColGroupUncompressed(int[] colIndicesList, MatrixBlock rawBlock, boolean transposed) { - super(colIndicesList); + public static AColGroup create(int[] colIndexes, MatrixBlock rawBlock, boolean transposed) { + + // special cases + if(rawBlock.isEmptyBlock(false)) // empty input + return new ColGroupEmpty(colIndexes); + else if(!transposed && colIndexes.length == rawBlock.getNumColumns()) + // full input to uncompressedColumnGroup + return new ColGroupUncompressed(rawBlock, colIndexes); + + MatrixBlock mb; final int _numRows = transposed ? rawBlock.getNumColumns() : rawBlock.getNumRows(); - if(colIndicesList.length == 1) { - final int col = colIndicesList[0]; + + if(colIndexes.length == 1) { + final int col = colIndexes[0]; if(transposed) { - _data = rawBlock.slice(col, col, 0, rawBlock.getNumColumns() - 1); - _data = LibMatrixReorg.transposeInPlace(_data, InfrastructureAnalyzer.getLocalParallelism()); + mb = rawBlock.slice(col, col, 0, rawBlock.getNumColumns() - 1); + mb = LibMatrixReorg.transposeInPlace(mb, InfrastructureAnalyzer.getLocalParallelism()); } else - _data = rawBlock.slice(0, rawBlock.getNumRows() - 1, col, col); + mb = rawBlock.slice(0, rawBlock.getNumRows() - 1, col, col); - return; - } - - if(rawBlock.isInSparseFormat() && transposed) { - _data = new MatrixBlock(); - _data.setNumRows(_numRows); - _data.setNumColumns(colIndicesList.length); + return create(mb, colIndexes); } // Create a matrix with just the requested rows of the original block - _data = new MatrixBlock(_numRows, _colIndexes.length, rawBlock.isInSparseFormat()); + mb = new MatrixBlock(_numRows, colIndexes.length, rawBlock.isInSparseFormat()); - // ensure sorted col indices - if(!SortUtils.isSorted(0, _colIndexes.length, _colIndexes)) - Arrays.sort(_colIndexes); + final int m = _numRows; + final int n = colIndexes.length; - // special cases empty blocks - if(rawBlock.isEmptyBlock(false)) - return; + if(transposed) + for(int i = 0; i < m; i++) + for(int j = 0; j < n; j++) + mb.appendValue(i, j, rawBlock.quickGetValue(colIndexes[j], i)); + else + for(int i = 0; i < m; i++) + for(int j = 0; j < n; j++) + mb.appendValue(i, j, rawBlock.quickGetValue(i, colIndexes[j])); - // special cases full blocks - if(!transposed && _data.getNumColumns() == rawBlock.getNumColumns()) { - _data.copy(rawBlock); - _data.recomputeNonZeros(); - return; - } + mb.recomputeNonZeros(); + mb.examSparsity(); + + return create(mb, colIndexes); - // dense implementation for dense and sparse matrices to avoid linear search - int m = _numRows; - int n = _colIndexes.length; - for(int i = 0; i < m; i++) { - for(int j = 0; j < n; j++) { - double val = transposed ? rawBlock.quickGetValue(_colIndexes[j], i) : rawBlock.quickGetValue(i, - _colIndexes[j]); - _data.appendValue(i, j, val); - } - } - _data.recomputeNonZeros(); - _data.examSparsity(); } /** @@ -137,18 +144,11 @@ protected ColGroupUncompressed(int[] colIndices, MatrixBlock data) { * @param data matrix block */ public ColGroupUncompressed(MatrixBlock data) { - super(generateColumnList(data.getNumColumns())); + super(Util.genColsIndices(data.getNumColumns())); _data = data; _data.recomputeNonZeros(); } - private static int[] generateColumnList(int nCol) { - int[] cols = new int[nCol]; - for(int i = 0; i < nCol; i++) - cols[i] = i; - return cols; - } - @Override public CompressionType getCompType() { return CompressionType.UNCOMPRESSED; @@ -434,9 +434,9 @@ public final void tsmm(MatrixBlock ret, int nRows) { @Override public AColGroup copy() { - MatrixBlock newData = new MatrixBlock(_data.getNumRows(), _data.getNumColumns(), _data.isInSparseFormat()); - newData.copy(_data); - return new ColGroupUncompressed(_colIndexes, newData); + // MatrixBlock newData = new MatrixBlock(_data.getNumRows(), _data.getNumColumns(), _data.isInSparseFormat()); + // newData.copy(_data); + return new ColGroupUncompressed(_colIndexes, _data); } @Override @@ -506,11 +506,11 @@ else if(lhs instanceof APreAgg) { + "\nCurrently solved by t(t(Uncompressed) %*% AColGroup)"); final MatrixBlock ucCGT = LibMatrixReorg.transpose(getData(), InfrastructureAnalyzer.getLocalParallelism()); - + final APreAgg paCG = (APreAgg) lhs; final MatrixBlock preAgg = new MatrixBlock(1, lhs.getNumValues(), false); final MatrixBlock tmpRes = new MatrixBlock(1, this.getNumCols(), false); - final MatrixBlock dictM = paCG._dict.getMBDict(paCG.getNumCols()).getMatrixBlock(); + final MatrixBlock dictM = paCG._dict.getMBDict(paCG.getNumCols()).getMatrixBlock(); preAgg.allocateDenseBlock(); tmpRes.allocateDenseBlock(); final int nRows = ucCGT.getNumRows(); @@ -609,7 +609,7 @@ protected AColGroup sliceMultiColumns(int idStart, int idEnd, int[] outputCols) @Override public AColGroup rightMultByMatrix(MatrixBlock right) { final int nColR = right.getNumColumns(); - final int[] outputCols = generateColumnList(nColR); + final int[] outputCols = Util.genColsIndices(nColR); if(_data.isEmpty() || right.isEmpty()) return new ColGroupEmpty(outputCols); @@ -661,9 +661,15 @@ public AColGroup replace(double pattern, double replace) { @Override public void computeColSums(double[] c, int nRows) { - MatrixBlock colSum = _data.colSum(); - if(colSum.isInSparseFormat()) { - throw new NotImplementedException(); + final MatrixBlock colSum = _data.colSum(); + if(colSum.isEmpty()) + return; + else if(colSum.isInSparseFormat()) { + SparseBlock sb = colSum.getSparseBlock(); + double[] rv = sb.values(0); + int[] idx = sb.indexes(0); + for(int i = 0; i < idx.length; i++) + c[_colIndexes[idx[i]]] += rv[i]; } else { double[] dv = colSum.getDenseBlockValues(); @@ -671,4 +677,18 @@ public void computeColSums(double[] c, int nRows) { c[_colIndexes[i]] += dv[i]; } } + + @Override + public CM_COV_Object centralMoment(CMOperator op, int nRows) { + return _data.cmOperations(op); + } + + @Override + public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { + MatrixBlock nd = LibMatrixReorg.rexpand(_data, new MatrixBlock(), max, false, cast, ignore, 1); + if(nd.isEmpty()) + return ColGroupEmpty.create(max); + else + return new ColGroupUncompressed(nd); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java index cca2a89452a..39fcd7e83f8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java @@ -26,6 +26,8 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.functionobjects.ValueFunction; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; @@ -78,7 +80,7 @@ public abstract class ADictionary implements Serializable { * @param reference The reference offset to each value in the dictionary * @return The aggregated value as a double. */ - public abstract double aggregate(double init, Builtin fn, double[] reference); + public abstract double aggregateWithReference(double init, Builtin fn, double[] reference); /** * Aggregate all entries in the rows. @@ -89,6 +91,16 @@ public abstract class ADictionary implements Serializable { */ public abstract double[] aggregateRows(Builtin fn, int nCol); + /** + * Aggregate all entries in the rows of the dictionary with a extra cell in the end that contains the aggregate of + * the given defaultTuple. + * + * @param fn The aggregate function + * @param defaultTuple The default tuple to aggregate in last cell + * @return Aggregates for this dictionary tuples. + */ + public abstract double[] aggregateRowsWithDefault(Builtin fn, double[] defaultTuple); + /** * Aggregate all entries in the rows with an offset value reference added. * @@ -96,7 +108,7 @@ public abstract class ADictionary implements Serializable { * @param reference The reference offset to each value in the dictionary * @return Aggregates for this dictionary tuples. */ - public abstract double[] aggregateRows(Builtin fn, double[] reference); + public abstract double[] aggregateRowsWithReference(Builtin fn, double[] reference); /** * Aggregates the columns into the target double array provided. @@ -117,10 +129,10 @@ public abstract class ADictionary implements Serializable { * @param reference The reference offset values to add to each cell. * @param colIndexes The mapping to the target columns from the individual columns */ - public abstract void aggregateCols(double[] c, Builtin fn, int[] colIndexes, double[] reference); + public abstract void aggregateColsWithReference(double[] c, Builtin fn, int[] colIndexes, double[] reference); /** - * Allocate a new dictionary and applies the scalar operation on each cell of the to then return the new. + * Allocate a new dictionary and applies the scalar operation on each cell of the to then return the new dictionary. * * @param op The operator. * @return The new dictionary to return. @@ -137,7 +149,7 @@ public abstract class ADictionary implements Serializable { * @param newReference The reference value to subtract after the operator. * @return A New Dictionary. */ - public abstract ADictionary applyScalarOp(ScalarOperator op, double[] reference, double[] newReference); + public abstract ADictionary applyScalarOpWithReference(ScalarOperator op, double[] reference, double[] newReference); /** * Applies the scalar operation on the dictionary. Note that this operation modifies the underlying data, and @@ -148,17 +160,6 @@ public abstract class ADictionary implements Serializable { */ public abstract ADictionary inplaceScalarOp(ScalarOperator op); - /** - * Applies the scalar operation on the dictionary. The returned dictionary should contain a new instance of the - * underlying data. Therefore it will not modify the previous object. - * - * @param op The operator to apply to the dictionary values. - * @param newVal The value to append to the dictionary. - * @param numCols The number of columns stored in the dictionary. - * @return Another dictionary with modified values. - */ - public abstract ADictionary applyScalarOp(ScalarOperator op, double newVal, int numCols); - /** * Apply binary row operation on the left side in place * @@ -183,8 +184,8 @@ public abstract class ADictionary implements Serializable { * @param newReference The reference value to subtract after operator. * @return A new dictionary. */ - public abstract ADictionary binOpLeft(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, - double[] newReference); + public abstract ADictionary binOpLeftWithReference(BinaryOperator op, double[] v, int[] colIndexes, + double[] reference, double[] newReference); /** * Apply binary row operation on the right side. @@ -209,44 +210,14 @@ public abstract ADictionary binOpLeft(BinaryOperator op, double[] v, int[] colIn * @param newReference The reference value to subtract after operator. * @return A new dictionary. */ - public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, - double[] newReference); - - /** - * Apply binary row operation on the left side and allocate a new dictionary. - * - * While adding a new tuple, where the operation is applied with zero values. - * - * @param op The operation to this dictionary - * @param v The values to use on the left hand side. - * @param colIndexes The column indexes to consider inside v. - * @return A new dictionary containing the updated values. - */ - public abstract ADictionary applyBinaryRowOpLeftAppendNewEntry(BinaryOperator op, double[] v, int[] colIndexes); - - /** - * Apply binary row operation on this dictionary on the right side. - * - * @param op The operation to this dictionary - * @param v The values to use on the right hand side. - * @param colIndexes The column indexes to consider inside v. - * @return A new dictionary containing the updated values. - */ - public abstract ADictionary applyBinaryRowOpRightAppendNewEntry(BinaryOperator op, double[] v, int[] colIndexes); + public abstract ADictionary binOpRightWithReference(BinaryOperator op, double[] v, int[] colIndexes, + double[] reference, double[] newReference); /** * Returns a deep clone of the dictionary. */ public abstract ADictionary clone(); - /** - * Clone the dictionary, and extend size of the dictionary by a given length - * - * @param len The length to extend the dictionary, it is assumed this value is positive. - * @return a clone of the dictionary, extended by len. - */ - public abstract ADictionary cloneAndExtend(int len); - /** * Write the dictionary to a DataOutput. * @@ -288,13 +259,22 @@ public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colI */ public abstract double[] sumAllRowsToDouble(int nrColumns); + /** + * Do exactly the same as the sumAllRowsToDouble but also sum the array given to a extra index in the end of the + * array. + * + * @param defaultTuple The default row to sum in the end index returned. + * @return a double array containing the row sums from this dictionary. + */ + public abstract double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple); + /** * Method used as a pre-aggregate of each tuple in the dictionary, to single double values with a reference. * * @param reference The reference values to add to each cell. * @return a double array containing the row sums from this dictionary. */ - public abstract double[] sumAllRowsToDouble(double[] reference); + public abstract double[] sumAllRowsToDoubleWithReference(double[] reference); /** * Method used as a pre-aggregate of each tuple in the dictionary, to single double values. @@ -306,13 +286,22 @@ public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colI */ public abstract double[] sumAllRowsToDoubleSq(int nrColumns); + /** + * Method used as a pre-aggregate of each tuple in the dictionary, to single double values. But adds another cell to + * the return with an extra value that is the sum of the given defaultTuple. + * + * @param defaultTuple The default row to sum in the end index returned. + * @return a double array containing the row sums from this dictionary. + */ + public abstract double[] sumAllRowsToDoubleSqWithDefault(double[] defaultTuple); + /** * Method used as a pre-aggregate of each tuple in the dictionary, to single double values. * * @param reference The reference values to add to each cell. * @return a double array containing the row sums from this dictionary. */ - public abstract double[] sumAllRowsToDoubleSq(double[] reference); + public abstract double[] sumAllRowsToDoubleSqWithReference(double[] reference); /** * Sum the values at a specific row. @@ -340,10 +329,10 @@ public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colI * @param reference The reference vector to add to each cell processed. * @return The sum of the row. */ - public abstract double sumRowSq(int k, int nrColumns, double[] reference); + public abstract double sumRowSqWithReference(int k, int nrColumns, double[] reference); /** - * get the column sum of this dictionary only. + * Get the column sum of this dictionary only. * * @param counts the counts of the values contained * @param nCol The number of columns contained in each tuple. @@ -380,7 +369,7 @@ public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colI * the c output. * @param reference The reference values to add to each cell. */ - public abstract void colSumSq(double[] c, int[] counts, int[] colIndexes, double[] reference); + public abstract void colSumSqWithReference(double[] c, int[] counts, int[] colIndexes, double[] reference); /** * Get the sum of the values contained in the dictionary @@ -407,7 +396,7 @@ public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colI * @param reference The reference value * @return The square sum scaled by the counts and reference. */ - public abstract double sumSq(int[] counts, double[] reference); + public abstract double sumSqWithReference(int[] counts, double[] reference); /** * Get a string representation of the dictionary, that considers the layout of the data. @@ -417,16 +406,6 @@ public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colI */ public abstract String getString(int colIndexes); - /** - * This method adds the max and min values contained in the dictionary to corresponding cells in the ret variable. - * - * One use case for this method is the squash operation, to go from an overlapping state to normal compression. - * - * @param ret The double array that contains all columns min and max. - * @param colIndexes The column indexes contained in this dictionary. - */ - public abstract void addMaxAndMin(double[] ret, int[] colIndexes); - /** * Modify the dictionary by removing columns not within the index range. * @@ -437,14 +416,6 @@ public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colI */ public abstract ADictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns); - /** - * return a new Dictionary that have re expanded all values, based on the entries already contained. - * - * @param max The number of output columns possible. - * @return The re expanded Dictionary. - */ - public abstract ADictionary reExpandColumns(int max); - /** * Detect if the dictionary contains a specific value. * @@ -460,7 +431,7 @@ public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colI * @param reference The reference double array. * @return true if the value is contained else false. */ - public abstract boolean containsValue(double pattern, double[] reference); + public abstract boolean containsValueWithReference(double pattern, double[] reference); /** * Calculate the number of non zeros in the dictionary. The number of non zeros should be scaled with the counts @@ -484,7 +455,16 @@ public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colI * @param nRows The number of rows in the input. * @return The NonZero Count. */ - public abstract long getNumberNonZeros(int[] counts, double[] reference, int nRows); + public abstract long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows); + + /** + * Single column version of copy add to entry. + * + * @param d target dictionary to add to + * @param fr Take from this index + * @param to put into index in d. + */ + public abstract void addToEntry(Dictionary d, int fr, int to); /** * Copies and adds the dictionary entry from this dictionary to the d dictionary @@ -496,17 +476,6 @@ public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colI */ public abstract void addToEntry(Dictionary d, int fr, int to, int nCol); - /** - * Get the values contained in a specific tuple of the dictionary. - * - * If the entire row is zero return null. - * - * @param index The index where the values are located - * @param nCol The number of columns contained in this dictionary - * @return a materialized double array containing the tuple. - */ - public abstract double[] getTuple(int index, int nCol); - /** * Allocate a new dictionary where the tuple given is subtracted from all tuples in the previous dictionary. * @@ -516,7 +485,7 @@ public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colI public abstract ADictionary subtractTuple(double[] tuple); /** - * Get this dictionary as a matrixBlock dictionary. This allows us to use optimized kernels coded elsewhere in the + * Get this dictionary as a MatrixBlock dictionary. This allows us to use optimized kernels coded elsewhere in the * system, such as matrix multiplication. * * @param nCol The number of columns contained in this column group. @@ -527,14 +496,14 @@ public abstract ADictionary binOpRight(BinaryOperator op, double[] v, int[] colI /** * Scale all tuples contained in the dictionary by the scaling factor given in the int list. * - * @param scaling The ammout to multiply the given tuples with + * @param scaling The amount to multiply the given tuples with * @param nCol The number of columns contained in this column group. * @return A New dictionary (since we don't want to modify the underlying dictionary) */ public abstract ADictionary scaleTuples(int[] scaling, int nCol); /** - * Pre Aggregate values for right Matrix Multiplication. + * Pre Aggregate values for Right Matrix Multiplication. * * @param numVals The number of values contained in this dictionary * @param colIndexes The column indexes that is associated with the parent column group @@ -558,11 +527,110 @@ public abstract ADictionary preaggValuesFromDense(final int numVals, final int[] */ public abstract ADictionary replace(double pattern, double replace, int nCol); - public abstract ADictionary replace(double pattern, double replace, double[] reference); + /** + * Make a copy of the values, and replace all values that match pattern with replacement value. If needed add a new + * column index. With reference such that each value in the dict is considered offset by the values contained in the + * reference. + * + * @param pattern The value to look for + * @param replace The value to replace the other value with + * @param reference The reference tuple to add to all entries when replacing + * @return A new Column Group, reusing the index structure but with new values. + */ + public abstract ADictionary replaceWithReference(double pattern, double replace, double[] reference); - public abstract ADictionary replaceZeroAndExtend(double replace, int nCol); + // public abstract ADictionary replaceZeroAndExtend(double replace, int nCol); + /** + * Calculate the product of the dictionary weighted by counts. + * + * @param counts The count of individual tuples + * @param nCol Number of columns in the dictionary. + * @return The product of the dictionary + */ public abstract double product(int[] counts, int nCol); + /** + * Calculate the column product of the dictionary weighted by counts. + * + * @param res The result vector to put the result into + * @param counts The weighted count of individual tuples + * @param colIndexes The column indexes. + */ public abstract void colProduct(double[] res, int[] counts, int[] colIndexes); + + /** + * Central moment function to calculate the central moment of this column group. MUST be on a single column + * dictionary. + * + * @param fn The value function to apply + * @param counts The weight of individual tuples + * @param nRows The number of rows in total of the column group + * @return The central moment Object + */ + public CM_COV_Object centralMoment(ValueFunction fn, int[] counts, int nRows) { + return centralMoment(new CM_COV_Object(), fn, counts, nRows); + } + + /** + * Central moment function to calculate the central moment of this column group. MUST be on a single column + * dictionary. + * + * @param ret The Central Moment object to be modified and returned + * @param fn The value function to apply + * @param counts The weight of individual tuples + * @param nRows The number of rows in total of the column group + * @return The central moment Object + */ + public abstract CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows); + + /** + * Central moment function to calculate the central moment of this column group with a reference offset on each + * tuple. MUST be on a single column dictionary. + * + * @param fn The value function to apply + * @param counts The weight of individual tuples + * @param reference The reference values to offset the tuples with + * @param nRows The number of rows in total of the column group + * @return The central moment Object + */ + public CM_COV_Object centralMomentWithReference(ValueFunction fn, int[] counts, double reference, int nRows) { + return centralMomentWithReference(new CM_COV_Object(), fn, counts, reference, nRows); + } + + /** + * Central moment function to calculate the central moment of this column group with a reference offset on each + * tuple. MUST be on a single column dictionary. + * + * @param ret The Central Moment object to be modified and returned + * @param fn The value function to apply + * @param counts The weight of individual tuples + * @param reference The reference values to offset the tuples with + * @param nRows The number of rows in total of the column group + * @return The central moment Object + */ + public abstract CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, + double reference, int nRows); + + /** + * Rexpand the dictionary (one hot encode) + * + * @param max the tuple width of the output + * @param ignore If we should ignore zero and negative values + * @param cast If we should cast all double values to whole integer values + * @param nCol The number of columns in the dictionary already (should be 1) + * @return A new dictionary + */ + public abstract ADictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol); + + /** + * Rexpand the dictionary (one hot encode) + * + * @param max the tuple width of the output + * @param ignore If we should ignore zero and negative values + * @param cast If we should cast all double values to whole integer values + * @param reference A reference value to add to all tuples before expanding + * @return A new dictionary + */ + public abstract ADictionary rexpandColsWithReference(int max, boolean ignore, boolean cast, double reference); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index 363e22d3fbb..4db4d5d249f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -27,8 +27,11 @@ import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.functionobjects.ValueFunction; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.utils.MemoryEstimates; @@ -80,7 +83,7 @@ public double aggregate(double init, Builtin fn) { } @Override - public double aggregate(double init, Builtin fn, double[] reference) { + public double aggregateWithReference(double init, Builtin fn, double[] reference) { final int nCol = reference.length; double ret = init; for(int i = 0; i < _values.length; i++) @@ -107,7 +110,26 @@ public double[] aggregateRows(Builtin fn, int nCol) { } @Override - public double[] aggregateRows(Builtin fn, double[] reference) { + public double[] aggregateRowsWithDefault(Builtin fn, double[] defaultTuple) { + final int nCol = defaultTuple.length; + final int nRows = _values.length / nCol; + double[] res = new double[nRows + 1]; + for(int i = 0; i < nRows; i++) { + final int off = i * nCol; + res[i] = _values[off]; + for(int j = off + 1; j < off + nCol; j++) + res[i] = fn.execute(res[i], _values[j]); + } + final int def = res.length - 1; + res[def] = defaultTuple[0]; + for(int i = 1; i < nCol; i++) + res[def] = fn.execute(res[def], defaultTuple[i]); + + return res; + } + + @Override + public double[] aggregateRowsWithReference(Builtin fn, double[] reference) { final int nCol = reference.length; final int nRows = _values.length / nCol; double[] res = new double[nRows + 1]; @@ -132,7 +154,7 @@ public Dictionary applyScalarOp(ScalarOperator op) { } @Override - public Dictionary applyScalarOp(ScalarOperator op, double[] reference, double[] newReference) { + public Dictionary applyScalarOpWithReference(ScalarOperator op, double[] reference, double[] newReference) { final double[] retV = new double[_values.length]; final int nCol = reference.length; final int nRow = _values.length / nCol; @@ -154,17 +176,6 @@ public Dictionary inplaceScalarOp(ScalarOperator op) { return this; } - @Override - public Dictionary applyScalarOp(ScalarOperator op, double newVal, int numCols) { - // allocate new array just once because we need to add the newVal. - double[] values = new double[_values.length + numCols]; - for(int i = 0; i < _values.length; i++) - values[i] = op.executeScalar(_values[i]); - - Arrays.fill(values, _values.length, _values.length + numCols, newVal); - return new Dictionary(values); - } - @Override public Dictionary binOpRight(BinaryOperator op, double[] v, int[] colIndexes) { final ValueFunction fn = op.fn; @@ -177,7 +188,7 @@ public Dictionary binOpRight(BinaryOperator op, double[] v, int[] colIndexes) { } @Override - public Dictionary binOpRight(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, + public Dictionary binOpRightWithReference(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, double[] newReference) { final ValueFunction fn = op.fn; final double[] retV = new double[_values.length]; @@ -205,7 +216,7 @@ public final Dictionary binOpLeft(BinaryOperator op, double[] v, int[] colIndexe } @Override - public Dictionary binOpLeft(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, + public Dictionary binOpLeftWithReference(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, double[] newReference) { final ValueFunction fn = op.fn; final double[] retV = new double[_values.length]; @@ -221,45 +232,11 @@ public Dictionary binOpLeft(BinaryOperator op, double[] v, int[] colIndexes, dou return new Dictionary(retV); } - @Override - public Dictionary applyBinaryRowOpRightAppendNewEntry(BinaryOperator op, double[] v, int[] colIndexes) { - final ValueFunction fn = op.fn; - final int len = size(); - final int lenV = colIndexes.length; - final double[] values = new double[len + lenV]; - int i = 0; - for(; i < len; i++) - values[i] = fn.execute(_values[i], v[colIndexes[i % lenV]]); - for(; i < len + lenV; i++) - values[i] = fn.execute(0, v[colIndexes[i % lenV]]); - return new Dictionary(values); - } - - @Override - public final Dictionary applyBinaryRowOpLeftAppendNewEntry(BinaryOperator op, double[] v, int[] colIndexes) { - final ValueFunction fn = op.fn; - final int len = size(); - final int lenV = colIndexes.length; - final double[] values = new double[len + lenV]; - int i = 0; - for(; i < len; i++) - values[i] = fn.execute(v[colIndexes[i % lenV]], _values[i]); - for(; i < len + lenV; i++) - values[i] = fn.execute(v[colIndexes[i % lenV]], 0); - return new Dictionary(values); - } - @Override public Dictionary clone() { return new Dictionary(_values.clone()); } - @Override - public Dictionary cloneAndExtend(int len) { - double[] ret = Arrays.copyOf(_values, _values.length + len); - return new Dictionary(ret); - } - public static Dictionary read(DataInput in) throws IOException { int numVals = in.readInt(); // read distinct values @@ -306,12 +283,24 @@ public double[] sumAllRowsToDouble(int nrColumns) { } @Override - public double[] sumAllRowsToDouble(double[] reference) { + public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) { + final int nCol = defaultTuple.length; + final int numVals = getNumberOfValues(nCol); + final double[] ret = new double[numVals + 1]; + for(int k = 0; k < numVals; k++) + ret[k] = sumRow(k, nCol); + for(int i = 0; i < nCol; i++) + ret[ret.length - 1] += defaultTuple[i]; + return ret; + } + + @Override + public double[] sumAllRowsToDoubleWithReference(double[] reference) { final int nCol = reference.length; final int numVals = getNumberOfValues(nCol); double[] ret = new double[numVals + 1]; for(int k = 0; k < numVals; k++) - ret[k] = sumRow(k, nCol, reference); + ret[k] = sumRowWithReference(k, nCol, reference); for(int i = 0; i < nCol; i++) ret[numVals] += reference[i]; return ret; @@ -321,7 +310,7 @@ public double[] sumAllRowsToDouble(double[] reference) { public double[] sumAllRowsToDoubleSq(int nrColumns) { // pre-aggregate value tuple final int numVals = getNumberOfValues(nrColumns); - double[] ret = new double[numVals]; + final double[] ret = new double[numVals]; for(int k = 0; k < numVals; k++) ret[k] = sumRowSq(k, nrColumns); @@ -329,12 +318,24 @@ public double[] sumAllRowsToDoubleSq(int nrColumns) { } @Override - public double[] sumAllRowsToDoubleSq(double[] reference) { + public double[] sumAllRowsToDoubleSqWithDefault(double[] defaultTuple) { + final int nCol = defaultTuple.length; + final int numVals = getNumberOfValues(nCol); + final double[] ret = new double[numVals + 1]; + for(int k = 0; k < numVals; k++) + ret[k] = sumRowSq(k, nCol); + for(int i = 0; i < nCol; i++) + ret[ret.length - 1] += defaultTuple[i] * defaultTuple[i]; + return ret; + } + + @Override + public double[] sumAllRowsToDoubleSqWithReference(double[] reference) { final int nCol = reference.length; final int numVals = getNumberOfValues(nCol); double[] ret = new double[numVals + 1]; for(int k = 0; k < numVals; k++) - ret[k] = sumRowSq(k, nCol, reference); + ret[k] = sumRowSqWithReference(k, nCol, reference); for(int i = 0; i < nCol; i++) ret[numVals] += reference[i] * reference[i]; return ret; @@ -349,7 +350,7 @@ public double sumRow(int k, int nrColumns) { return res; } - public double sumRow(int k, int nrColumns, double[] reference) { + public double sumRowWithReference(int k, int nrColumns, double[] reference) { final int valOff = k * nrColumns; double res = 0.0; for(int i = 0; i < nrColumns; i++) @@ -367,7 +368,7 @@ public double sumRowSq(int k, int nrColumns) { } @Override - public double sumRowSq(int k, int nrColumns, double[] reference) { + public double sumRowSqWithReference(int k, int nrColumns, double[] reference) { final int valOff = k * nrColumns; double res = 0.0; for(int i = 0; i < nrColumns; i++) { @@ -381,7 +382,7 @@ public double sumRowSq(int k, int nrColumns, double[] reference) { public double[] colSum(int[] counts, int nCol) { final double[] res = new double[nCol]; int idx = 0; - for(int k = 0; k < _values.length / nCol; k++) { + for(int k = 0; k < counts.length; k++) { final int cntk = counts[k]; for(int j = 0; j < nCol; j++) res[j] += _values[idx++] * cntk; @@ -392,7 +393,7 @@ public double[] colSum(int[] counts, int nCol) { @Override public void colSum(double[] c, int[] counts, int[] colIndexes) { final int nCol = colIndexes.length; - for(int k = 0; k < _values.length / nCol; k++) { + for(int k = 0; k < counts.length; k++) { final int cntk = counts[k]; final int off = k * nCol; for(int j = 0; j < nCol; j++) @@ -403,7 +404,7 @@ public void colSum(double[] c, int[] counts, int[] colIndexes) { @Override public void colSumSq(double[] c, int[] counts, int[] colIndexes) { final int nCol = colIndexes.length; - final int nRow = _values.length / nCol; + final int nRow = counts.length; int off = 0; for(int k = 0; k < nRow; k++) { final int cntk = counts[k]; @@ -415,9 +416,9 @@ public void colSumSq(double[] c, int[] counts, int[] colIndexes) { } @Override - public void colSumSq(double[] c, int[] counts, int[] colIndexes, double[] reference) { + public void colSumSqWithReference(double[] c, int[] counts, int[] colIndexes, double[] reference) { final int nCol = colIndexes.length; - final int nRow = _values.length / nCol; + final int nRow = counts.length; int off = 0; for(int k = 0; k < nRow; k++) { final int cntk = counts[k]; @@ -426,15 +427,13 @@ public void colSumSq(double[] c, int[] counts, int[] colIndexes, double[] refere c[colIndexes[j]] += v * v * cntk; } } - for(int i = 0; i < nCol; i++) - c[colIndexes[i]] += reference[i] * reference[i] * counts[nRow]; } @Override public double sum(int[] counts, int nCol) { double out = 0; int valOff = 0; - for(int k = 0; k < _values.length / nCol; k++) { + for(int k = 0; k < counts.length; k++) { int countK = counts[k]; for(int j = 0; j < nCol; j++) { out += _values[valOff++] * countK; @@ -447,7 +446,7 @@ public double sum(int[] counts, int nCol) { public double sumSq(int[] counts, int nCol) { double out = 0; int valOff = 0; - for(int k = 0; k < _values.length / nCol; k++) { + for(int k = 0; k < counts.length; k++) { final int countK = counts[k]; for(int j = 0; j < nCol; j++) { final double val = _values[valOff++]; @@ -458,9 +457,9 @@ public double sumSq(int[] counts, int nCol) { } @Override - public double sumSq(int[] counts, double[] reference) { + public double sumSqWithReference(int[] counts, double[] reference) { final int nCol = reference.length; - final int nRow = _values.length / nCol; + final int nRow = counts.length; double out = 0; int valOff = 0; for(int k = 0; k < nRow; k++) { @@ -470,9 +469,6 @@ public double sumSq(int[] counts, double[] reference) { out += val * val * countK; } } - for(int i = 0; i < nCol; i++) - out += reference[i] * reference[i] * counts[nRow]; - return out; } @@ -485,27 +481,6 @@ public String toString() { return sb.toString(); } - @Override - public void addMaxAndMin(double[] ret, int[] colIndexes) { - - double[] mins = new double[colIndexes.length]; - double[] maxs = new double[colIndexes.length]; - for(int i = 0; i < colIndexes.length; i++) { - mins[i] = _values[i]; - maxs[i] = _values[i]; - } - for(int i = colIndexes.length; i < _values.length; i++) { - int idx = i % colIndexes.length; - mins[idx] = Math.min(_values[i], mins[idx]); - maxs[idx] = Math.max(_values[i], maxs[idx]); - } - for(int i = 0; i < colIndexes.length; i++) { - int idy = colIndexes[i] * 2; - ret[idy] += mins[i]; - ret[idy + 1] += maxs[i]; - } - } - public String getString(int colIndexes) { StringBuilder sb = new StringBuilder(); if(colIndexes == 1) @@ -537,17 +512,6 @@ public ADictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNum return new Dictionary(newDictValues); } - public ADictionary reExpandColumns(int max) { - double[] newDictValues = new double[_values.length * max]; - - for(int i = 0, offset = 0; i < _values.length; i++, offset += max) { - int val = (int) Math.floor(_values[i]) - 1; - newDictValues[offset + val] = 1; - } - - return new Dictionary(newDictValues); - } - @Override public boolean containsValue(double pattern) { boolean NaNpattern = Double.isNaN(pattern); @@ -567,7 +531,7 @@ public boolean containsValue(double pattern) { } @Override - public boolean containsValue(double pattern, double[] reference) { + public boolean containsValueWithReference(double pattern, double[] reference) { final int nCol = reference.length; for(int i = 0; i < _values.length; i++) if(_values[i] + reference[i % nCol] == pattern) @@ -578,7 +542,7 @@ public boolean containsValue(double pattern, double[] reference) { @Override public long getNumberNonZeros(int[] counts, int nCol) { long nnz = 0; - final int nRow = _values.length / nCol; + final int nRow = counts.length; for(int i = 0; i < nRow; i++) { long rowCount = 0; final int off = i * nCol; @@ -592,10 +556,11 @@ public long getNumberNonZeros(int[] counts, int nCol) { } @Override - public long getNumberNonZeros(int[] counts, double[] reference, int nRows) { + public long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows) { long nnz = 0; final int nCol = reference.length; - final int nRow = _values.length / nCol; + final int nRow = counts.length; + for(int i = 0; i < nRow; i++) { long rowCount = 0; final int off = i * nCol; @@ -605,21 +570,23 @@ public long getNumberNonZeros(int[] counts, double[] reference, int nRows) { } nnz += rowCount * counts[i]; } - for(int i = 0; i < nCol; i++) - if(reference[i] != 0) - nnz += counts[nRow]; return nnz; } + @Override + public void addToEntry(Dictionary d, int fr, int to){ + d.getValues()[to] += _values[fr]; + } + @Override public void addToEntry(Dictionary d, int fr, int to, int nCol) { - final int sf = nCol * fr; // start from + final int sf = fr * nCol; // start from final int ef = sf + nCol; // end from - double[] v = d.getValues(); - for(int i = sf, j = nCol * to; i < ef; i++, j++) { + final int st = to * nCol; // start to + final double[] v = d.getValues(); + for(int i = sf, j = st; i < ef; i++, j++) v[j] += _values[i]; - } } @Override @@ -627,27 +594,12 @@ public boolean isLossy() { return false; } - @Override - public double[] getTuple(int index, int nCol) { - - final double[] tuple = new double[nCol]; - boolean allZero = true; - for(int i = index * nCol, off = 0; i < (index + 1) * nCol && i < _values.length; i++, off++) { - final double v = _values[i]; - if(v != 0) { - tuple[off] = v; - allZero = false; - } - } - - return allZero ? null : tuple; - } - @Override public ADictionary subtractTuple(double[] tuple) { - double[] newValues = new double[_values.length - tuple.length]; - for(int i = 0; i < _values.length - tuple.length; i++) - newValues[i] = _values[i] - tuple[i % tuple.length]; + double[] newValues = new double[_values.length]; + for(int i = 0; i < _values.length;) + for(int j = 0; j < tuple.length; i++, j++) + newValues[i] = _values[i] - tuple[j]; return new Dictionary(newValues); } @@ -667,7 +619,7 @@ public void aggregateCols(double[] c, Builtin fn, int[] colIndexes) { } @Override - public void aggregateCols(double[] c, Builtin fn, int[] colIndexes, double[] reference) { + public void aggregateColsWithReference(double[] c, Builtin fn, int[] colIndexes, double[] reference) { final int nCol = reference.length; final int rlen = _values.length / nCol; for(int k = 0; k < rlen; k++) @@ -717,7 +669,7 @@ public ADictionary replace(double pattern, double replace, int nCol) { } @Override - public ADictionary replace(double pattern, double replace, double[] reference) { + public ADictionary replaceWithReference(double pattern, double replace, double[] reference) { final double[] retV = new double[_values.length]; final int nCol = reference.length; final int nRow = _values.length / nCol; @@ -732,26 +684,10 @@ public ADictionary replace(double pattern, double replace, double[] reference) { return new Dictionary(retV); } - @Override - public ADictionary replaceZeroAndExtend(double replace, int nCol) { - double[] retV = new double[_values.length + nCol]; - for(int i = 0; i < _values.length; i++) { - final double v = _values[i]; - if(v == 0) - retV[i] = replace; - else - retV[i] = v; - } - for(int i = _values.length; i < _values.length + nCol; i++) - retV[i] = replace; - - return new Dictionary(retV); - } - @Override public double product(int[] counts, int nCol) { double ret = 1; - final int len = _values.length / nCol; + final int len = counts.length; for(int i = 0; i < len; i++) { for(int j = i * nCol; j < (i + 1) * nCol; j++) { double v = _values[j]; @@ -768,4 +704,32 @@ public double product(int[] counts, int nCol) { public void colProduct(double[] res, int[] counts, int[] colIndexes) { throw new NotImplementedException(); } + + @Override + public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) { + // should be guaranteed to only contain one value per tuple in dictionary. + for(int i = 0; i < _values.length; i++) + fn.execute(ret, _values[i], counts[i]); + return ret; + } + + @Override + public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference, + int nRows) { + // should be guaranteed to only contain one value per tuple in dictionary. + for(int i = 0; i < _values.length; i++) + fn.execute(ret, _values[i] + reference, counts[i]); + return ret; + } + + @Override + public ADictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol) { + return getMBDict(nCol).rexpandCols(max, ignore, cast, nCol); + } + + @Override + public ADictionary rexpandColsWithReference(int max, boolean ignore, boolean cast, double reference) { + return getMBDict(1).applyScalarOp(new LeftScalarOperator(Plus.getPlusFnObject(), reference)).rexpandCols(max, + ignore, cast, 1); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java index 610458e723f..fd8be9c98f4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java @@ -86,28 +86,27 @@ public static ADictionary create(ABitmap ubm, double sparsity, boolean withZeroT } public static ADictionary create(ABitmap ubm, double sparsity) { + final int nCol = ubm.getNumColumns(); if(ubm instanceof Bitmap) return new Dictionary(((Bitmap) ubm).getValues()); - else if(sparsity < 0.4 && ubm instanceof MultiColBitmap) { - final int nCols = ubm.getNumColumns(); + else if(sparsity < 0.4 && nCol > 4 && ubm instanceof MultiColBitmap) { final MultiColBitmap mcbm = (MultiColBitmap) ubm; - final MatrixBlock m = new MatrixBlock(ubm.getNumValues(), nCols, true); + final MatrixBlock m = new MatrixBlock(ubm.getNumValues(), nCol, true); m.allocateSparseRowsBlock(); final SparseBlock sb = m.getSparseBlock(); final int nVals = ubm.getNumValues(); for(int i = 0; i < nVals; i++) { final double[] tuple = mcbm.getValues(i); - for(int col = 0; col < nCols; col++) + for(int col = 0; col < nCol; col++) sb.append(i, col, tuple[col]); } m.recomputeNonZeros(); - return new MatrixBlockDictionary(m, nCols); + return new MatrixBlockDictionary(m, nCol); } else if(ubm instanceof MultiColBitmap) { MultiColBitmap mcbm = (MultiColBitmap) ubm; - final int nCol = ubm.getNumColumns(); final int nVals = ubm.getNumValues(); double[] resValues = new double[nVals * nCol]; for(int i = 0; i < nVals; i++) @@ -118,6 +117,37 @@ else if(ubm instanceof MultiColBitmap) { throw new NotImplementedException("Not implemented creation of bitmap type : " + ubm.getClass().getSimpleName()); } + public static ADictionary create(ABitmap ubm, int defaultIndex, double[] defaultTuple, double sparsity, + boolean addZero) { + final int nCol = ubm.getNumColumns(); + final int nVal = ubm.getNumValues() - (addZero ? 0 : 1); + if(nCol > 4 && sparsity < 0.4) { + // return sparse + throw new NotImplementedException("Not supported sparse allocation yet"); + } + else { + double[] dict = new double[nCol * nVal]; + if(ubm instanceof Bitmap) { + final double[] bmv = ((Bitmap) ubm).getValues(); + System.arraycopy(bmv, 0, dict, 0, defaultIndex); + defaultTuple[0] = bmv[defaultIndex]; + System.arraycopy(bmv, defaultIndex + 1, dict, defaultIndex, bmv.length - defaultIndex - 1); + } + else if(ubm instanceof MultiColBitmap) { + MultiColBitmap mcbm = (MultiColBitmap) ubm; + for(int i = 0; i < defaultIndex; i++) + System.arraycopy(mcbm.getValues(i), 0, dict, i * nCol, nCol); + System.arraycopy(mcbm.getValues(defaultIndex), 0, defaultTuple, 0, nCol); + for(int i = defaultIndex; i < ubm.getNumValues() - 1; i++) + System.arraycopy(mcbm.getValues(i + 1), 0, dict, i * nCol, nCol); + } + else + throw new NotImplementedException("not supported ABitmap of type:" + ubm.getClass().getSimpleName()); + + return new Dictionary(dict); + } + } + public static ADictionary createWithAppendedZeroTuple(ABitmap ubm, double sparsity) { final int nVals = ubm.getNumValues(); final int nRows = nVals + 1; @@ -131,7 +161,7 @@ public static ADictionary createWithAppendedZeroTuple(ABitmap ubm, double sparsi } final MultiColBitmap mcbm = (MultiColBitmap) ubm; - if(sparsity < 0.4) { + if(sparsity < 0.4 && nCols > 4) { final MatrixBlock m = new MatrixBlock(nRows, nCols, true); m.allocateSparseRowsBlock(); final SparseBlock sb = m.getSparseBlock(); @@ -150,56 +180,5 @@ public static ADictionary createWithAppendedZeroTuple(ABitmap ubm, double sparsi System.arraycopy(mcbm.getValues(i), 0, resValues, i * nCols, nCols); return new Dictionary(resValues); - - } - - public static ADictionary moveFrequentToLastDictionaryEntry(ADictionary dict, ABitmap ubm, int nRow, - int largestIndex) { - LOG.warn("Inefficient moving of tuples."); - final int zeros = nRow - (int) ubm.getNumOffsets(); - final int nCol = ubm.getNumColumns(); - final int largestIndexSize = ubm.getOffsetsList(largestIndex).size(); - if(dict instanceof MatrixBlockDictionary) { - MatrixBlockDictionary mbd = (MatrixBlockDictionary) dict; - MatrixBlock mb = mbd.getMatrixBlock(); - if(mb.isEmpty()) - throw new DMLCompressionException("Should not construct or use a empty dictionary ever."); - else if(mb.isInSparseFormat()) { - throw new NotImplementedException(); // and should not be - } - else - return moveToLastDictionaryEntryDense(mb.getDenseBlockValues(), largestIndex, zeros, nCol, - largestIndexSize); - } - else - return moveToLastDictionaryEntryDense(dict.getValues(), largestIndex, zeros, nCol, largestIndexSize); - - } - - private static ADictionary moveToLastDictionaryEntryDense(double[] values, int indexToMove, int zeros, int nCol, - int largestIndexSize) { - final int offsetToLargest = indexToMove * nCol; - - if(zeros == 0) { - final double[] swap = new double[nCol]; - System.arraycopy(values, offsetToLargest, swap, 0, nCol); - for(int i = offsetToLargest; i < values.length - nCol; i++) - values[i] = values[i + nCol]; - - System.arraycopy(swap, 0, values, values.length - nCol, nCol); - return new Dictionary(values); - } - - final double[] newDict = new double[values.length + nCol]; - - if(zeros > largestIndexSize) - System.arraycopy(values, 0, newDict, 0, values.length); - else { - System.arraycopy(values, 0, newDict, 0, offsetToLargest); - System.arraycopy(values, offsetToLargest + nCol, newDict, offsetToLargest, - values.length - offsetToLargest - nCol); - System.arraycopy(values, offsetToLargest, newDict, newDict.length - nCol, nCol); - } - return new Dictionary(newDict); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 6c56ca68472..55b74261415 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -32,9 +32,14 @@ import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; +import org.apache.sysds.runtime.functionobjects.Minus; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.functionobjects.ValueFunction; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; public class MatrixBlockDictionary extends ADictionary { @@ -56,7 +61,8 @@ public MatrixBlockDictionary(MatrixBlock data, int nCol) { throw new DMLCompressionException("Invalid construction of empty dictionary"); if(_data.getNumColumns() != nCol) - throw new DMLCompressionException("Invalid construction expected nCol: "+ nCol + " but matrix block contains: " + _data.getNumColumns()); + throw new DMLCompressionException( + "Invalid construction expected nCol: " + nCol + " but matrix block contains: " + _data.getNumColumns()); } public MatrixBlock getMatrixBlock() { @@ -65,9 +71,10 @@ public MatrixBlock getMatrixBlock() { @Override public double[] getValues() { - LOG.warn("Inefficient call to getValues for a MatrixBlockDictionary"); - if(_data.isInSparseFormat()) + if(_data.isInSparseFormat()) { + LOG.warn("Inefficient call to getValues for a MatrixBlockDictionary because it was sparse"); _data.sparseToDense(); + } return _data.getDenseBlockValues(); } @@ -103,7 +110,7 @@ else if(fn.getBuiltinCode() == BuiltinCode.MIN) } @Override - public double aggregate(double init, Builtin fn, double[] reference) { + public double aggregateWithReference(double init, Builtin fn, double[] reference) { final int nCol = reference.length; final int nRows = _data.getNumRows(); double ret = init; @@ -178,7 +185,12 @@ else if(nCol == 1) } @Override - public double[] aggregateRows(Builtin fn, double[] reference) { + public double[] aggregateRowsWithDefault(Builtin fn, double[] defaultTuple) { + throw new NotImplementedException(); + } + + @Override + public double[] aggregateRowsWithReference(Builtin fn, double[] reference) { final int nCol = reference.length; final int nRows = _data.getNumRows(); final double[] ret = new double[nRows + 1]; @@ -267,7 +279,7 @@ else if(_data.isInSparseFormat()) { } @Override - public void aggregateCols(double[] c, Builtin fn, int[] colIndexes, double[] reference) { + public void aggregateColsWithReference(double[] c, Builtin fn, int[] colIndexes, double[] reference) { final int nCol = _data.getNumColumns(); final int nRow = _data.getNumRows(); @@ -313,7 +325,7 @@ public ADictionary applyScalarOp(ScalarOperator op) { } @Override - public ADictionary applyScalarOp(ScalarOperator op, double[] reference, double[] newReference) { + public ADictionary applyScalarOpWithReference(ScalarOperator op, double[] reference, double[] newReference) { final int nCol = _data.getNumColumns(); final int nRow = _data.getNumRows(); final MatrixBlock ret = new MatrixBlock(nRow, nCol, false); @@ -365,44 +377,18 @@ public ADictionary inplaceScalarOp(ScalarOperator op) { throw new NotImplementedException(); } - @Override - public ADictionary applyScalarOp(ScalarOperator op, double newVal, int numCols) { - MatrixBlock res = _data.scalarOperations(op, new MatrixBlock()); - final int lastRow = res.getNumRows(); - MatrixBlock res2 = new MatrixBlock(lastRow + 1, res.getNumColumns(), true); - if(res.isEmpty()) - for(int i = 0; i < numCols; i++) - res2.appendValue(lastRow, i, newVal); - else - res.append(new MatrixBlock(1, numCols, newVal), res2, false); - - if(res2.isEmpty()) - return null; - else - return new MatrixBlockDictionary(res2, _data.getNumColumns()); - } - @Override public ADictionary binOpLeft(BinaryOperator op, double[] v, int[] colIndexes) { throw new NotImplementedException("Binary row op left is not supported for Uncompressed Matrix, " + "Implement support for VMr in MatrixBLock Binary Cell operations"); - // MatrixBlock rowVector = Util.extractValues(v, colIndexes); - // return new MatrixBlockDictionary(rowVector.binaryOperations(op, _data, null)); } @Override - public Dictionary binOpLeft(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, + public Dictionary binOpLeftWithReference(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, double[] newReference) { throw new NotImplementedException(); } - @Override - public ADictionary applyBinaryRowOpLeftAppendNewEntry(BinaryOperator op, double[] v, int[] colIndexes) { - MatrixBlock rowVector = Util.extractValues(v, colIndexes); - MatrixBlock tmp = _data.append(new MatrixBlock(1, _data.getNumColumns(), 0), null, false); - return new MatrixBlockDictionary(rowVector.binaryOperations(op, tmp, null), _data.getNumColumns()); - } - @Override public ADictionary binOpRight(BinaryOperator op, double[] v, int[] colIndexes) { MatrixBlock rowVector = Util.extractValues(v, colIndexes); @@ -410,18 +396,11 @@ public ADictionary binOpRight(BinaryOperator op, double[] v, int[] colIndexes) { } @Override - public Dictionary binOpRight(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, + public Dictionary binOpRightWithReference(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, double[] newReference) { throw new NotImplementedException(); } - @Override - public ADictionary applyBinaryRowOpRightAppendNewEntry(BinaryOperator op, double[] v, int[] colIndexes) { - MatrixBlock rowVector = Util.extractValues(v, colIndexes); - MatrixBlock tmp = _data.append(new MatrixBlock(1, _data.getNumColumns(), 0), null, false); - return new MatrixBlockDictionary(tmp.binaryOperations(op, rowVector, null), _data.getNumColumns()); - } - @Override public ADictionary clone() { MatrixBlock ret = new MatrixBlock(); @@ -429,11 +408,6 @@ public ADictionary clone() { return new MatrixBlockDictionary(ret, _data.getNumColumns()); } - @Override - public ADictionary cloneAndExtend(int len) { - throw new NotImplementedException(); - } - @Override public boolean isLossy() { return false; @@ -477,7 +451,12 @@ else if(_data.isInSparseFormat()) { } @Override - public double[] sumAllRowsToDouble(double[] reference){ + public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) { + throw new NotImplementedException(); + } + + @Override + public double[] sumAllRowsToDoubleWithReference(double[] reference) { final int nCol = reference.length; final int numVals = _data.getNumRows(); final double[] ret = new double[numVals + 1]; @@ -500,7 +479,7 @@ public double[] sumAllRowsToDouble(double[] reference){ int j = 0; for(; j < _data.getNumColumns() && k < alen; j++) { final double v = aix[k] == j ? avals[k++] + reference[j] : reference[j]; - ret[i] += v ; + ret[i] += v; } for(; j < _data.getNumColumns(); j++) ret[i] += reference[j]; @@ -514,7 +493,7 @@ else if(!_data.isEmpty()) { for(int k = 0; k < numVals; k++) { for(int j = 0; j < _data.getNumColumns(); j++) { final double v = values[off++] + reference[j]; - ret[k] += v ; + ret[k] += v; } } } @@ -555,7 +534,12 @@ else if(_data.isInSparseFormat()) { } @Override - public double[] sumAllRowsToDoubleSq(double[] reference) { + public double[] sumAllRowsToDoubleSqWithDefault(double[] defaultTuple) { + throw new NotImplementedException(); + } + + @Override + public double[] sumAllRowsToDoubleSqWithReference(double[] reference) { final int nCol = reference.length; final int numVals = _data.getNumRows(); final double[] ret = new double[numVals + 1]; @@ -611,7 +595,7 @@ public double sumRowSq(int k, int nrColumns) { } @Override - public double sumRowSq(int k, int nrColumns, double[] reference) { + public double sumRowSqWithReference(int k, int nrColumns, double[] reference) { throw new NotImplementedException(); } @@ -622,7 +606,7 @@ public double[] colSum(int[] counts, int nCol) { double[] ret = new double[nCol]; if(_data.isInSparseFormat()) { SparseBlock sb = _data.getSparseBlock(); - for(int i = 0; i < _data.getNumRows(); i++) { + for(int i = 0; i < counts.length; i++) { if(!sb.isEmpty(i)) { // double tmpSum = 0; final int count = counts[i]; @@ -639,7 +623,7 @@ public double[] colSum(int[] counts, int nCol) { else { double[] values = _data.getDenseBlockValues(); int off = 0; - for(int k = 0; k < _data.getNumRows(); k++) { + for(int k = 0; k < counts.length; k++) { final int countK = counts[k]; for(int j = 0; j < _data.getNumColumns(); j++) { final double v = values[off++]; @@ -656,7 +640,7 @@ public void colSum(double[] c, int[] counts, int[] colIndexes) { return; if(_data.isInSparseFormat()) { SparseBlock sb = _data.getSparseBlock(); - for(int i = 0; i < _data.getNumRows(); i++) { + for(int i = 0; i < counts.length; i++) { if(!sb.isEmpty(i)) { // double tmpSum = 0; final int count = counts[i]; @@ -673,7 +657,7 @@ public void colSum(double[] c, int[] counts, int[] colIndexes) { else { double[] values = _data.getDenseBlockValues(); int off = 0; - for(int k = 0; k < _data.getNumRows(); k++) { + for(int k = 0; k < counts.length; k++) { final int countK = counts[k]; for(int j = 0; j < _data.getNumColumns(); j++) { final double v = values[off++]; @@ -689,7 +673,7 @@ public void colSumSq(double[] c, int[] counts, int[] colIndexes) { return; if(_data.isInSparseFormat()) { SparseBlock sb = _data.getSparseBlock(); - for(int i = 0; i < _data.getNumRows(); i++) { + for(int i = 0; i < counts.length; i++) { if(!sb.isEmpty(i)) { // double tmpSum = 0; final int count = counts[i]; @@ -706,7 +690,7 @@ public void colSumSq(double[] c, int[] counts, int[] colIndexes) { else { double[] values = _data.getDenseBlockValues(); int off = 0; - for(int k = 0; k < _data.getNumRows(); k++) { + for(int k = 0; k < counts.length; k++) { final int countK = counts[k]; for(int j = 0; j < _data.getNumColumns(); j++) { final double v = values[off++]; @@ -717,13 +701,12 @@ public void colSumSq(double[] c, int[] counts, int[] colIndexes) { } @Override - public void colSumSq(double[] c, int[] counts, int[] colIndexes, double[] reference) { + public void colSumSqWithReference(double[] c, int[] counts, int[] colIndexes, double[] reference) { final int nCol = reference.length; - final int nRow = _data.getNumRows(); - for(int i = 0; i < nCol; i++) - c[colIndexes[i]] += reference[i] * reference[i] * counts[nRow]; - - if(!_data.isEmpty() && _data.isInSparseFormat()) { + final int nRow = counts.length; + if(_data.isEmpty()) + return; + else if(_data.isInSparseFormat()) { final SparseBlock sb = _data.getSparseBlock(); for(int i = 0; i < nRow; i++) { final int countK = counts[i]; @@ -746,7 +729,7 @@ public void colSumSq(double[] c, int[] counts, int[] colIndexes, double[] refere } } } - else if(!_data.isEmpty()) { + else { double[] values = _data.getDenseBlockValues(); int off = 0; for(int k = 0; k < nRow; k++) { @@ -766,7 +749,7 @@ public double sum(int[] counts, int ncol) { return tmpSum; if(_data.isInSparseFormat()) { SparseBlock sb = _data.getSparseBlock(); - for(int i = 0; i < _data.getNumRows(); i++) { + for(int i = 0; i < counts.length; i++) { if(!sb.isEmpty(i)) { final int count = counts[i]; final int apos = sb.pos(i); @@ -781,7 +764,7 @@ public double sum(int[] counts, int ncol) { else { double[] values = _data.getDenseBlockValues(); int off = 0; - for(int k = 0; k < _data.getNumRows(); k++) { + for(int k = 0; k < counts.length; k++) { final int countK = counts[k]; for(int j = 0; j < _data.getNumColumns(); j++) { final double v = values[off++]; @@ -797,9 +780,9 @@ public double sumSq(int[] counts, int ncol) { double tmpSum = 0; if(_data.isEmpty()) return tmpSum; - if(_data.isInSparseFormat()) { + else if(_data.isInSparseFormat()) { SparseBlock sb = _data.getSparseBlock(); - for(int i = 0; i < _data.getNumRows(); i++) { + for(int i = 0; i < counts.length; i++) { if(!sb.isEmpty(i)) { final int count = counts[i]; final int apos = sb.pos(i); @@ -814,7 +797,7 @@ public double sumSq(int[] counts, int ncol) { else { double[] values = _data.getDenseBlockValues(); int off = 0; - for(int k = 0; k < _data.getNumRows(); k++) { + for(int k = 0; k < counts.length; k++) { final int countK = counts[k]; for(int j = 0; j < _data.getNumColumns(); j++) { final double v = values[off++]; @@ -826,16 +809,17 @@ public double sumSq(int[] counts, int ncol) { } @Override - public double sumSq(int[] counts, double[] reference) { + public double sumSqWithReference(int[] counts, double[] reference) { + if(_data.isEmpty()) + return 0; final int nCol = reference.length; - final int numVals = _data.getNumRows(); + final int numVals = counts.length; double ret = 0; - for(int i = 0; i < nCol; i++) - ret += reference[i] * reference[i]; - final double ref = ret; - ret *= counts[numVals]; - if(!_data.isEmpty() && _data.isInSparseFormat()) { + if(_data.isInSparseFormat()) { + double ref = 0; + for(int i = 0; i < nCol; i++) + ref += reference[i] * reference[i]; final SparseBlock sb = _data.getSparseBlock(); for(int i = 0; i < numVals; i++) { final int countK = counts[i]; @@ -858,7 +842,7 @@ public double sumSq(int[] counts, double[] reference) { } } - else if(!_data.isEmpty()) { + else { double[] values = _data.getDenseBlockValues(); int off = 0; for(int k = 0; k < numVals; k++) { @@ -873,20 +857,10 @@ else if(!_data.isEmpty()) { return ret; } - @Override - public void addMaxAndMin(double[] ret, int[] colIndexes) { - throw new NotImplementedException(); - } - @Override public ADictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns) { MatrixBlock retBlock = _data.slice(0, _data.getNumRows() - 1, idxStart, idxEnd - 1); - return new MatrixBlockDictionary(retBlock, idxEnd - idxStart ); - } - - @Override - public ADictionary reExpandColumns(int max) { - throw new NotImplementedException(); + return new MatrixBlockDictionary(retBlock, idxEnd - idxStart); } @Override @@ -895,7 +869,7 @@ public boolean containsValue(double pattern) { } @Override - public boolean containsValue(double pattern, double[] reference) { + public boolean containsValueWithReference(double pattern, double[] reference) { if(_data.isEmpty()) { for(double d : reference) @@ -949,14 +923,14 @@ public long getNumberNonZeros(int[] counts, int nCol) { long nnz = 0; if(_data.isInSparseFormat()) { SparseBlock sb = _data.getSparseBlock(); - for(int i = 0; i < _data.getNumRows(); i++) + for(int i = 0; i < counts.length; i++) if(!sb.isEmpty(i)) nnz += sb.size(i) * counts[i]; } else { double[] values = _data.getDenseBlockValues(); int off = 0; - for(int i = 0; i < _data.getNumRows(); i++) { + for(int i = 0; i < counts.length; i++) { int countThisTuple = 0; for(int j = 0; j < _data.getNumColumns(); j++) { double v = values[off++]; @@ -970,20 +944,14 @@ public long getNumberNonZeros(int[] counts, int nCol) { } @Override - public long getNumberNonZeros(int[] counts, double[] reference, int nRows) { + public long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows) { long nnz = 0; - for(double d : reference) - if(d != 0) - nnz++; - if(_data.isEmpty()) { - // sum counts - return nnz * nRows; - } + if(_data.isEmpty()) + return nnz; else if(_data.isInSparseFormat()) { SparseBlock sb = _data.getSparseBlock(); long emptyRowNNZ = nnz; - nnz *= counts[counts.length - 1]; // multiply count with the common value count in reference. - for(int i = 0; i < _data.getNumRows(); i++) { + for(int i = 0; i < counts.length; i++) { if(sb.isEmpty(i)) nnz += emptyRowNNZ * counts[i]; else { @@ -1013,10 +981,9 @@ else if(_data.isInSparseFormat()) { } } else { - nnz *= counts[counts.length - 1]; // multiply count with the common value count in reference. final double[] values = _data.getDenseBlockValues(); int off = 0; - for(int i = 0; i < _data.getNumRows(); i++) { + for(int i = 0; i < counts.length; i++) { int countThisTuple = 0; for(int j = 0; j < _data.getNumColumns(); j++) if(values[off++] + reference[j] != 0) @@ -1027,6 +994,12 @@ else if(_data.isInSparseFormat()) { return nnz; } + @Override + public void addToEntry(Dictionary d, int fr, int to) { + double[] v = d.getValues(); + v[to] += _data.getDouble(fr, 1); + } + @Override public void addToEntry(Dictionary d, int fr, int to, int nCol) { double[] v = d.getValues(); @@ -1055,57 +1028,14 @@ else if(_data.isInSparseFormat()) { } } - @Override - public double[] getTuple(int index, int nCol) { - if(_data.isEmpty() || index >= _data.getNumRows()) - return null; - - final double[] tuple = new double[nCol]; - if(_data.isInSparseFormat()) { - SparseBlock sb = _data.getSparseBlock(); - if(sb.isEmpty(index)) - return null; - final int apos = sb.pos(index); - final int alen = sb.size(index) + apos; - final int[] aix = sb.indexes(index); - final double[] avals = sb.values(index); - for(int j = apos; j < alen; j++) - tuple[aix[j]] = avals[j]; - - return tuple; - } - else { - double[] values = _data.getDenseBlockValues(); - int offset = index * nCol; - for(int i = 0; i < nCol; i++, offset++) - tuple[i] = values[offset]; - return tuple; - } - } - @Override public ADictionary subtractTuple(double[] tuple) { - if(_data.isEmpty()) - throw new NotImplementedException("Should not extract from empty matrix"); - else if(_data.isInSparseFormat()) { - throw new NotImplementedException("Not supporting extracting from sparse matrix yet"); - } - else { - final int nRow = _data.getNumRows() - 1; - final int nCol = _data.getNumColumns(); - double[] values = _data.getDenseBlockValues(); - MatrixBlock res = new MatrixBlock(nRow, nCol, false); - res.allocateBlock(); - double[] resVals = res.getDenseBlockValues(); - for(int i = 0, off = 0; i < nRow; i++) - for(int j = 0; j < nCol; j++, off++) - resVals[off] = values[off] - tuple[j]; - - res.examSparsity(); - if(res.isEmpty()) - return null; - return new MatrixBlockDictionary(res, nCol); - } + MatrixBlock v = new MatrixBlock(1, tuple.length, tuple); + BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject()); + MatrixBlock ret = _data.binaryOperations(op, v, null); + if(ret.isEmpty()) + return null; + return new MatrixBlockDictionary(ret, _data.getNumColumns()); } @Override @@ -1247,7 +1177,7 @@ public ADictionary replace(double pattern, double replace, int nCol) { } @Override - public ADictionary replace(double pattern, double replace, double[] reference) { + public ADictionary replaceWithReference(double pattern, double replace, double[] reference) { final int nRow = _data.getNumRows(); final int nCol = _data.getNumColumns(); final MatrixBlock ret = new MatrixBlock(nRow, nCol, false); @@ -1294,50 +1224,6 @@ public ADictionary replace(double pattern, double replace, double[] reference) { } - @Override - public ADictionary replaceZeroAndExtend(double replace, int nCol) { - final int nRows = _data.getNumRows(); - final int nCols = _data.getNumColumns(); - final long nonZerosOut = (nRows + 1) * nCols; - final MatrixBlock ret = new MatrixBlock(nRows + 1, nCols, false); - ret.allocateBlock(); - ret.setNonZeros(nonZerosOut); - final double[] retValues = ret.getDenseBlockValues(); - if(_data.isEmpty()) - Arrays.fill(retValues, replace); - else if(_data.isInSparseFormat()) { - final SparseBlock sb = _data.getSparseBlock(); - for(int i = 0; i < nRows; i++) { - for(int h = i * nCols; h < i * nCols + nCols; h++) - retValues[h] = replace; - if(sb.isEmpty(i)) - continue; - final int off = nCol * i; - final int apos = sb.pos(i); - final int alen = sb.size(i) + apos; - final double[] avals = sb.values(i); - final int[] aix = sb.indexes(i); - for(int j = apos; j < alen; j++) { - final int idb = aix[j]; - final double v = avals[j]; - retValues[off + idb] = v; - } - } - for(int h = nRows * nCols; h < nonZerosOut; h++) - retValues[h] = replace; - } - else { - final double[] values = _data.getDenseBlockValues(); - for(int k = 0; k < nRows; k++) - for(int h = k * nCols; h < k * nCols + nCols; h++) - retValues[h] = values[h] == 0 ? replace : values[h]; - - for(int h = nRows * nCols; h < nonZerosOut; h++) - retValues[h] = replace; - } - return new MatrixBlockDictionary(ret, _data.getNumColumns()); - } - @Override public double product(int[] counts, int nCol) { throw new NotImplementedException(); @@ -1347,4 +1233,41 @@ public double product(int[] counts, int nCol) { public void colProduct(double[] res, int[] counts, int[] colIndexes) { throw new NotImplementedException(); } + + @Override + public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) { + // should be guaranteed to only contain one value per tuple in dictionary. + if(_data.isInSparseFormat()) + throw new DMLCompressionException("The dictionary should not be sparse with one column"); + double[] vals = _data.getDenseBlockValues(); + for(int i = 0; i < vals.length; i++) + fn.execute(ret, vals[i], counts[i]); + return ret; + } + + @Override + public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference, + int nRows) { + // should be guaranteed to only contain one value per tuple in dictionary. + if(_data.isInSparseFormat()) + throw new DMLCompressionException("The dictionary should not be sparse with one column"); + double[] vals = _data.getDenseBlockValues(); + for(int i = 0; i < vals.length; i++) + fn.execute(ret, vals[i] + reference, counts[i]); + return ret; + } + + @Override + public ADictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol) { + MatrixBlock ex = LibMatrixReorg.rexpand(_data, new MatrixBlock(), max, false, cast, ignore, 1); + if(ex.isEmpty()) + return null; + else + return new MatrixBlockDictionary(ex, max); + } + + @Override + public ADictionary rexpandColsWithReference(int max, boolean ignore, boolean cast, double reference) { + return applyScalarOp(new LeftScalarOperator(Plus.getPlusFnObject(), reference)).rexpandCols(max, ignore, cast, 1); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java index 32ed014717a..7199903ac13 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java @@ -22,13 +22,14 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; -import java.util.Arrays; import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.functionobjects.ValueFunction; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.utils.MemoryEstimates; @@ -101,7 +102,7 @@ public double aggregate(double init, Builtin fn) { } @Override - public double aggregate(double init, Builtin fn, double[] reference) { + public double aggregateWithReference(double init, Builtin fn, double[] reference) { throw new NotImplementedException(); } @@ -121,7 +122,12 @@ public double[] aggregateRows(Builtin fn, final int nCol) { } @Override - public double[] aggregateRows(Builtin fn, double[] reference) { + public double[] aggregateRowsWithDefault(Builtin fn, double[] defaultTuple) { + throw new NotImplementedException(); + } + + @Override + public double[] aggregateRowsWithReference(Builtin fn, double[] reference) { throw new NotImplementedException(); } @@ -169,26 +175,6 @@ public QDictionary applyScalarOp(ScalarOperator op) { throw new NotImplementedException(); } - @Override - public QDictionary applyScalarOp(ScalarOperator op, double newVal, int numCols) { - double[] temp = getValues(); - double max = Math.abs(newVal); - for(int i = 0; i < size(); i++) { - temp[i] = op.executeScalar(temp[i]); - double absTemp = Math.abs(temp[i]); - if(absTemp > max) { - max = absTemp; - } - } - double scale = max / (double) (Byte.MAX_VALUE); - byte[] res = new byte[size() + numCols]; - for(int i = 0; i < size(); i++) { - res[i] = (byte) Math.round(temp[i] / scale); - } - Arrays.fill(res, size(), size() + numCols, (byte) Math.round(newVal / scale)); - return new QDictionary(res, scale); - } - private int size() { return _values.length; } @@ -198,12 +184,6 @@ public QDictionary clone() { return new QDictionary(_values.clone(), _scale); } - @Override - public QDictionary cloneAndExtend(int len) { - byte[] ret = Arrays.copyOf(_values, _values.length + len); - return new QDictionary(ret, _scale); - } - @Override public void write(DataOutput out) throws IOException { out.writeByte(DictionaryFactory.Type.INT8_DICT.ordinal()); @@ -246,6 +226,16 @@ public double[] sumAllRowsToDouble(int nrColumns) { return ret; } + @Override + public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) { + throw new NotImplementedException(); + } + + @Override + public double[] sumAllRowsToDoubleWithReference(double[] reference) { + throw new NotImplementedException(); + } + @Override public double[] sumAllRowsToDoubleSq(int nrColumns) { final int numVals = getNumberOfValues(nrColumns); @@ -256,7 +246,12 @@ public double[] sumAllRowsToDoubleSq(int nrColumns) { } @Override - public double[] sumAllRowsToDoubleSq(double[] reference) { + public double[] sumAllRowsToDoubleSqWithDefault(double[] defaultTuple) { + throw new NotImplementedException(); + } + + @Override + public double[] sumAllRowsToDoubleSqWithReference(double[] reference) { throw new NotImplementedException(); } @@ -286,7 +281,7 @@ public double sumRowSq(int k, int nrColumns) { } @Override - public double sumRowSq(int k, int nrColumns, double[] reference) { + public double sumRowSqWithReference(int k, int nrColumns, double[] reference) { throw new NotImplementedException(); } @@ -306,7 +301,7 @@ public void colSumSq(double[] c, int[] counts, int[] colIndexes) { } @Override - public void colSumSq(double[] c, int[] counts, int[] colIndexes, double[] reference) { + public void colSumSqWithReference(double[] c, int[] counts, int[] colIndexes, double[] reference) { throw new NotImplementedException(); } @@ -321,30 +316,10 @@ public double sumSq(int[] counts, int ncol) { } @Override - public double sumSq(int[] counts, double[] reference) { + public double sumSqWithReference(int[] counts, double[] reference) { throw new NotImplementedException("Not Implemented"); } - @Override - public void addMaxAndMin(double[] ret, int[] colIndexes) { - byte[] mins = new byte[colIndexes.length]; - byte[] maxs = new byte[colIndexes.length]; - for(int i = 0; i < colIndexes.length; i++) { - mins[i] = _values[i]; - maxs[i] = _values[i]; - } - for(int i = colIndexes.length; i < _values.length; i++) { - int idx = i % colIndexes.length; - mins[idx] = (byte) Math.min(_values[i], mins[idx]); - maxs[idx] = (byte) Math.max(_values[i], maxs[idx]); - } - for(int i = 0; i < colIndexes.length; i++) { - int idy = colIndexes[i] * 2; - ret[idy] += mins[i] * _scale; - ret[idy + 1] += maxs[i] * _scale; - } - } - public String getString(int colIndexes) { StringBuilder sb = new StringBuilder(); for(int i = 0; i < size(); i++) { @@ -374,17 +349,6 @@ public ADictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNum return new QDictionary(newDictValues, _scale); } - public ADictionary reExpandColumns(int max) { - byte[] newDictValues = new byte[_values.length * max]; - - for(int i = 0, offset = 0; i < _values.length; i++, offset += max) { - int val = _values[i] - 1; - newDictValues[offset + val] = 1; - } - - return new QDictionary(newDictValues, 1.0); - } - @Override public boolean containsValue(double pattern) { if(Double.isNaN(pattern) || Double.isInfinite(pattern)) @@ -393,7 +357,7 @@ public boolean containsValue(double pattern) { } @Override - public boolean containsValue(double pattern, double[] reference) { + public boolean containsValueWithReference(double pattern, double[] reference) { throw new NotImplementedException(); } @@ -414,23 +378,23 @@ public long getNumberNonZeros(int[] counts, int nCol) { } @Override - public long getNumberNonZeros(int[] counts, double[] reference, int nRows) { + public long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows) { throw new NotImplementedException("not implemented yet"); } @Override - public void addToEntry(Dictionary d, int fr, int to, int nCol) { + public void addToEntry(Dictionary d, int fr, int to) { throw new NotImplementedException("Not implemented yet"); } @Override - public boolean isLossy() { - return false; + public void addToEntry(Dictionary d, int fr, int to, int nCol) { + throw new NotImplementedException("Not implemented yet"); } @Override - public double[] getTuple(int index, int nCol) { - return null; + public boolean isLossy() { + return false; } @Override @@ -449,7 +413,7 @@ public void aggregateCols(double[] c, Builtin fn, int[] colIndexes) { } @Override - public void aggregateCols(double[] c, Builtin fn, int[] colIndexes, double[] reference) { + public void aggregateColsWithReference(double[] c, Builtin fn, int[] colIndexes, double[] reference) { throw new NotImplementedException(); } @@ -470,64 +434,72 @@ public ADictionary replace(double pattern, double replace, int nCol) { } @Override - public ADictionary replace(double pattern, double replace, double[] reference) { + public ADictionary replaceWithReference(double pattern, double replace, double[] reference) { throw new NotImplementedException(); } @Override - public ADictionary replaceZeroAndExtend(double replace, int nCol) { + public double product(int[] counts, int nCol) { throw new NotImplementedException(); } @Override - public double product(int[] counts, int nCol) { + public void colProduct(double[] res, int[] counts, int[] colIndexes) { throw new NotImplementedException(); } @Override - public void colProduct(double[] res, int[] counts, int[] colIndexes) { + public ADictionary binOpLeft(BinaryOperator op, double[] v, int[] colIndexes) { throw new NotImplementedException(); } @Override - public ADictionary applyBinaryRowOpLeftAppendNewEntry(BinaryOperator op, double[] v, int[] colIndexes) { + public ADictionary binOpRight(BinaryOperator op, double[] v, int[] colIndexes) { throw new NotImplementedException(); } @Override - public ADictionary binOpLeft(BinaryOperator op, double[] v, int[] colIndexes) { + public ADictionary applyScalarOpWithReference(ScalarOperator op, double[] reference, double[] newReference) { throw new NotImplementedException(); } @Override - public ADictionary binOpRight(BinaryOperator op, double[] v, int[] colIndexes) { + public ADictionary binOpLeftWithReference(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, + double[] newReference) { throw new NotImplementedException(); } @Override - public ADictionary applyBinaryRowOpRightAppendNewEntry(BinaryOperator op, double[] v, int[] colIndexes) { + public ADictionary binOpRightWithReference(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, + double[] newReference) { throw new NotImplementedException(); } @Override - public ADictionary applyScalarOp(ScalarOperator op, double[] reference, double[] newReference) { + public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) { throw new NotImplementedException(); } @Override - public ADictionary binOpLeft(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, - double[] newReference) { + public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference, + int nRows) { throw new NotImplementedException(); } @Override - public ADictionary binOpRight(BinaryOperator op, double[] v, int[] colIndexes, double[] reference, - double[] newReference) { + public ADictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol) { throw new NotImplementedException(); + // byte[] newDictValues = new byte[_values.length * max]; + // for(int i = 0, offset = 0; i < _values.length; i++, offset += max) { + // int val = _values[i] - 1; + // newDictValues[offset + val] = 1; + // } + + // return new QDictionary(newDictValues, 1.0); } @Override - public double[] sumAllRowsToDouble(double[] reference) { + public ADictionary rexpandColsWithReference(int max, boolean ignore, boolean cast, double reference) { throw new NotImplementedException(); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/insertionsort/MaterializeSort.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/insertionsort/MaterializeSort.java index e3d533d130c..f37c03aab35 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/insertionsort/MaterializeSort.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/insertionsort/MaterializeSort.java @@ -19,44 +19,58 @@ package org.apache.sysds.runtime.compress.colgroup.insertionsort; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.utils.IntArrayList; public class MaterializeSort extends AInsertionSorter { - public static int CACHE_BLOCK = 50000; + + /** The block size to materialize at a time */ + public static int CACHE_BLOCK = 16000; /** a dense mapToData, that have a value for each row in the input. */ private final AMapToData md; private final int[] skip; + + private final int placeholder; private int off = 0; protected MaterializeSort(int endLength, int numRows, IntArrayList[] offsets) { super(endLength, numRows, offsets); - - md = MapToFactory.create(Math.min(_numRows, CACHE_BLOCK), Math.max(_numLabels, 3)); + placeholder = _numLabels + 1; + // + 1 to ensure that the _numLabels is exceeded. + md = MapToFactory.create(Math.min(_numRows, CACHE_BLOCK), Math.max(placeholder, 3)); skip = new int[offsets.length]; - for(int block = 0; block < _numRows; block += CACHE_BLOCK) { - md.fill(_numLabels); + for(int block = 0; block < _numRows; block += CACHE_BLOCK) insert(block, Math.min(block + CACHE_BLOCK, _numRows)); - } + } protected MaterializeSort(int endLength, int numRows, IntArrayList[] offsets, int negativeIndex) { super(endLength, numRows, offsets, negativeIndex); - md = MapToFactory.create(Math.min(_numRows, CACHE_BLOCK), Math.max(_numLabels, 3)); + placeholder = _numLabels; + md = MapToFactory.create(Math.min(_numRows, CACHE_BLOCK), Math.max(placeholder, 3)); skip = new int[offsets.length]; - for(int block = 0; block < _numRows; block += CACHE_BLOCK) { - md.fill(_numLabels); + for(int block = 0; block < _numRows; block += CACHE_BLOCK) insertWithNegative(block, Math.min(block + CACHE_BLOCK, _numRows)); - } + } private void insert(int rl, int ru) { - materializeInsert(rl, ru); - filterInsert(rl, ru); + try { + md.fill(placeholder); + materializeInsert(rl, ru); + filterInsert(rl, ru); + } + catch(Exception e) { + int sum = 0; + for(IntArrayList o : _offsets) + sum += o.size(); + throw new DMLCompressionException("Failed normal materialize sorting with list of " + _offsets.length + " with sum (aka output size): " + sum + " requested Size: " + _indexes.length + " range: " + rl + " " + ru , e); + } } private void materializeInsert(int rl, int ru) { @@ -71,14 +85,17 @@ private void materializeInsert(int rl, int ru) { } private void filterInsert(int rl, int ru) { - for(int i = rl; i < ru; i++) { - final int idx = md.getIndex(i - rl); - if(idx != _numLabels) - set(off++, i, idx); + final int len = ru - rl; + for(int i = 0; i < len; i++) { + final int idx = md.getIndex(i); + if(idx != placeholder) + set(off++, i + rl, idx); } } private void insertWithNegative(int rl, int ru) { + md.fill(placeholder); + for(int i = 0; i < _offsets.length; i++) { IntArrayList of = _offsets[i]; int k = skip[i]; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java index a6d048bb2e6..e804abcd2a9 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java @@ -25,6 +25,8 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; @@ -219,20 +221,48 @@ public abstract void preAggregateDense(MatrixBlock m, double[] preAV, int rl, in public abstract void preAggregateSparse(SparseBlock sb, double[] preAV, int rl, int ru, AOffset indexes); /** - * Get the number of counts of each unique value contained in this map. + * Get the number of counts of each unique value contained in this map. Note that in the case the mapping is shorter + * than number of rows the counts sum to the number of mapped values not the number of rows. * * @param counts The object to return. - * @param nRows The number of rows in the calling column group. * @return the Counts */ - public int[] getCounts(int[] counts, int nRows) { - final int nonDefaultLength = size(); - for(int i = 0; i < nonDefaultLength; i++) - counts[getIndex(i)]++; - counts[counts.length - 1] += nRows - nonDefaultLength; - return counts; + public abstract int[] getCounts(int[] counts); + + /** + * PreAggregate into dictionary. + * + * @param tm Map of other side + * @param td Dictionary to take values from (other side dictionary) + * @param ret The output dictionary to aggregate into + * @param nCol The number of columns + */ + public final void preAggregateDDC(AMapToData tm, ADictionary td, Dictionary ret, int nCol) { + if(nCol == 1) + preAggregateDDCSingleCol(tm, td, ret); + else + preAggregateDDCMultiCol(tm, td, ret, nCol); } + /** + * PreAggregate into dictionary guaranteed to only have one column tuples. + * + * @param tm Map of other side + * @param td Dictionary to take values from (other side dictionary) + * @param ret The output dictionary to aggregate into + */ + protected abstract void preAggregateDDCSingleCol(AMapToData tm, ADictionary td, Dictionary ret); + + /** + * PreAggregate into dictionary guaranteed to multiple column tuples. + * + * @param tm Map of other side + * @param td Dictionary to take values from (other side dictionary) + * @param ret The output dictionary to aggregate into + * @param nCol The number of columns + */ + protected abstract void preAggregateDDCMultiCol(AMapToData tm, ADictionary td, Dictionary ret, int nCol); + /** * Copy the values in this map into another mapping object. * diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java index be742f55aa4..74145fedc26 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java @@ -25,6 +25,8 @@ import java.util.BitSet; import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.data.DenseBlock; @@ -163,4 +165,113 @@ public int getUpperBoundValue() { return 1; } + @Override + public int[] getCounts(int[] counts) { + final int sz = size(); + + if(counts.length == 1) + counts[0] = sz; + else { + counts[1] = _data.cardinality(); + counts[0] = sz - counts[1]; + } + + return counts; + } + + @Override + public void preAggregateDDCSingleCol(AMapToData tm, ADictionary td, Dictionary ret) { + if(tm instanceof MapToBit) + preAggregateDDCSingleColBitBit((MapToBit) tm, td, ret); + else { + final int nRows = size(); + for(int r = 0; r < nRows; r++) + td.addToEntry(ret, tm.getIndex(r), getIndex(r)); + } + } + + private void preAggregateDDCSingleColBitBit(MapToBit tmb, ADictionary td, Dictionary ret) { + + JoinBitSets j = new JoinBitSets(tmb._data, _data, _size); + + final double[] tv = td.getValues(); + final double[] rv = ret.getValues(); + + // multiply and scale with actual values + rv[1] += tv[1] * j.tt; + rv[0] += tv[1] * j.ft; + rv[1] += tv[0] * j.tf; + rv[0] += tv[0] * j.ff; + } + + @Override + public void preAggregateDDCMultiCol(AMapToData tm, ADictionary td, Dictionary ret, int nCol) { + if(tm instanceof MapToBit) + preAggregateDDCMultiColBitBit((MapToBit) tm, td, ret, nCol); + else { + final int nRows = size(); + for(int r = 0; r < nRows; r++) + td.addToEntry(ret, tm.getIndex(r), getIndex(r), nCol); + } + } + + private void preAggregateDDCMultiColBitBit(MapToBit tmb, ADictionary td, Dictionary ret, int nCol) { + + JoinBitSets j = new JoinBitSets(tmb._data, _data, _size); + + final double[] tv = td.getValues(); + final double[] rv = ret.getValues(); + + // multiply and scale with actual values + for(int i = 0; i < nCol; i++) { + final int off = nCol + i; + rv[i] += tv[i] * j.ff; + rv[off] += tv[i] * j.tf; + rv[off] += tv[off] * j.tt; + rv[i] += tv[off] * j.ft; + } + } + + private static class JoinBitSets { + int tt = 0; + int ft = 0; + int tf = 0; + int ff = 0; + + protected JoinBitSets(BitSet t_data, BitSet o_data, int size) { + + // This naively rely on JDK implementation using long arrays to encode bit Arrays. + final long[] t_longs = t_data.toLongArray(); + final long[] _longs = o_data.toLongArray(); + + final int common = Math.min(t_longs.length, _longs.length); + + for(int i = 0; i < common; i++) { + long t = t_longs[i]; + long v = _longs[i]; + tt += Long.bitCount(t & v); + ft += Long.bitCount(t & ~v); + tf += Long.bitCount(~t & v); + ff += Long.bitCount(~t & ~v); + } + + if(t_longs.length > common) { + for(int i = common; i < t_longs.length; i++) { + int v = Long.bitCount(t_longs[i]); + ft += v; + ff += 64 - v; + } + } + else if(_longs.length > common) { + for(int i = common; i < _longs.length; i++) { + int v = Long.bitCount(_longs[i]); + tf += v; + ff += 64 - v; + } + } + + final int longest = Math.max(t_longs.length, _longs.length); + ff += size - (longest * 64); // remainder + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java index 993cd32070c..0894c866cea 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java @@ -25,6 +25,8 @@ import java.util.Arrays; import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.data.DenseBlock; @@ -225,4 +227,32 @@ public void preAggregateSparse(SparseBlock sb, double[] preAV, int rl, int ru, A public int getUpperBoundValue() { return 255; } + + @Override + public int[] getCounts(int[] counts){ + final int sz = size(); + if(getUnique() < 127){ + for(int i = 0; i < sz; i++) + counts[_data[i]]++; + } + else{ + for(int i = 0; i < sz; i++) + counts[getIndex(i)]++; + } + return counts; + } + + @Override + public void preAggregateDDCSingleCol(AMapToData tm, ADictionary td, Dictionary ret) { + final int nRows = size(); + for(int r = 0; r < nRows; r++) + td.addToEntry(ret, tm.getIndex(r), getIndex(r)); + } + + @Override + public void preAggregateDDCMultiCol(AMapToData tm, ADictionary td, Dictionary ret, int nCol) { + final int nRows = size(); + for(int r = 0; r < nRows; r++) + td.addToEntry(ret, tm.getIndex(r), getIndex(r), nCol); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java index 1c8c8e5d027..38c274eb451 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java @@ -25,6 +25,8 @@ import java.util.Arrays; import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.data.DenseBlock; @@ -180,4 +182,25 @@ public int getUpperBoundValue() { return Character.MAX_VALUE; } + @Override + public int[] getCounts(int[] counts){ + final int sz = size(); + for(int i = 0; i < sz; i++) + counts[_data[i]]++; + return counts; + } + + @Override + public void preAggregateDDCSingleCol(AMapToData tm, ADictionary td, Dictionary ret) { + final int nRows = size(); + for(int r = 0; r < nRows; r++) + td.addToEntry(ret, tm.getIndex(r), getIndex(r)); + } + + @Override + public void preAggregateDDCMultiCol(AMapToData tm, ADictionary td, Dictionary ret, int nCol) { + final int nRows = size(); + for(int r = 0; r < nRows; r++) + td.addToEntry(ret, tm.getIndex(r), getIndex(r), nCol); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java index 9712d112efa..8f113e2e9f9 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java @@ -25,6 +25,8 @@ import java.util.Arrays; import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.data.DenseBlock; @@ -151,4 +153,26 @@ public void preAggregateSparse(SparseBlock sb, double[] preAV, int rl, int ru, A public int getUpperBoundValue() { return Integer.MAX_VALUE; } + + @Override + public int[] getCounts(int[] counts){ + final int sz = size(); + for(int i = 0; i < sz; i++) + counts[_data[i]]++; + return counts; + } + + @Override + public void preAggregateDDCSingleCol(AMapToData tm, ADictionary td, Dictionary ret) { + final int nRows = size(); + for(int r = 0; r < nRows; r++) + td.addToEntry(ret, tm.getIndex(r), getIndex(r)); + } + + @Override + public void preAggregateDDCMultiCol(AMapToData tm, ADictionary td, Dictionary ret, int nCol) { + final int nRows = size(); + for(int r = 0; r < nRows; r++) + td.addToEntry(ret, tm.getIndex(r), getIndex(r), nCol); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimator.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimator.java index 109d8b37057..e47bb49ffdb 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimator.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimator.java @@ -137,16 +137,6 @@ private CompressedSizeInfoColGroup[] estimateIndividualColumnGroupSizes(int k) { } - /** - * Method used for compressing into one type of colGroup - * - * @return CompressedSizeInfo on a compressed colGroup compressing the entire matrix into a single colGroup type. - */ - public CompressedSizeInfoColGroup estimateCompressedColGroupSize() { - int[] colIndexes = makeColIndexes(); - return estimateCompressedColGroupSize(colIndexes); - } - /** * Method for extracting Compressed Size Info of specified columns, together in a single ColGroup * @@ -300,8 +290,4 @@ public CompressedSizeInfoColGroup call() { return _estimator.estimateCompressedColGroupSize(_cols); } } - - private int[] makeColIndexes() { - return Util.genColsIndices(getNumColumns()); - } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorSample.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorSample.java index 07bd84485c8..5bf2838d684 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorSample.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorSample.java @@ -25,6 +25,7 @@ import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.bitmap.ABitmap; import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; @@ -119,16 +120,25 @@ protected CompressedSizeInfoColGroup estimateJoinCompressedSize(int[] joined, Co CompressedSizeInfoColGroup g2, int joinedMaxDistinct) { if((long) g1.getNumVals() * g2.getNumVals() > (long) Integer.MAX_VALUE) return null; + try { - final IEncode map = g1.getMap().join(g2.getMap()); - final EstimationFactors sampleFacts = map.computeSizeEstimation(joined, _sampleSize, _data.getSparsity(), - _data.getSparsity()); - // EstimationFactors.computeSizeEstimation(joined, map, - // _cs.validCompressions.contains(CompressionType.RLE), map.size(), false); + final IEncode map = g1.getMap().join(g2.getMap()); + final EstimationFactors sampleFacts = map.computeSizeEstimation(joined, _sampleSize, _data.getSparsity(), + _data.getSparsity()); + try { + final EstimationFactors em = estimateCompressionFactors(sampleFacts, map, joined, joinedMaxDistinct); + return new CompressedSizeInfoColGroup(joined, em, _cs.validCompressions, map); + } + catch(Exception e) { + throw new DMLCompressionException("failed to scale Estimation factors with :\n" + map + "\n" + sampleFacts, + e); + } + } + catch(Exception e) { + throw new DMLCompressionException("failed to join compression estimation groups", e); + } // result facts - final EstimationFactors em = estimateCompressionFactors(sampleFacts, map, joined, joinedMaxDistinct); - return new CompressedSizeInfoColGroup(joined, em, _cs.validCompressions, map); } private EstimationFactors estimateCompressionFactors(EstimationFactors sampleFacts, IEncode map, int[] colIndexes, @@ -171,7 +181,7 @@ private int calculateOffs(EstimationFactors sampleFacts, int sampleSize, int num final int numCols = getNumColumns(); if(numCols == 1) return (int) _data.getNonZeros(); - else + else return numRows - (int) Math.floor(numZerosInSample * scalingFactor); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java index 3f183507a43..dcb2a02df26 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java @@ -62,5 +62,4 @@ public EstimationFactors computeSizeEstimation(int[] cols, int nRows, double tup return new EstimationFactors(cols.length, 1, nRows, nRows, counts, 0, 0, nRows, false, false, matrixSparsity, tupleSparsity); } - } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java index 45d28791a6c..5693609cbd2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java @@ -35,23 +35,16 @@ public class DenseEncoding implements IEncode { protected DenseEncoding(AMapToData map, int[] counts) { this.map = map; this.counts = counts; - if(map.getUnique() == 0) - throw new DMLCompressionException("Invalid Dense Encoding"); - } - - /** - * Protected constructor that also counts the frequencies of the values. - * - * @param map The Map. - */ - protected DenseEncoding(AMapToData map) { - this.map = map; - final int nUnique = map.getUnique(); - if(nUnique == 0) - throw new DMLCompressionException("Invalid Dense Encoding"); - this.counts = new int[nUnique]; - for(int i = 0; i < map.size(); i++) - counts[map.getIndex(i)]++; + // for debugging correctness and efficiency but should be guaranteed by implementations creating the Dense encoding: + // if(map.getUnique() == 0) + // throw new DMLCompressionException("Invalid Dense Encoding"); + // if(map.getUnique() != counts.length) + // throw new DMLCompressionException( + // "Invalid number of counts not matching map:" + map.getUnique() + " " + counts.length); + // int u = map.getUnique(); + // for(int i = 0; i < map.size(); i++) + // if(map.getIndex(i) >= u) + // throw new DMLCompressionException("Invalid values contained in map:" + map.getUnique() + " " + map); } @Override @@ -78,53 +71,60 @@ protected DenseEncoding joinSparse(SparseEncoding e) { final int[] m = new int[(int) maxUnique]; final AMapToData d = MapToFactory.create(nRows, (int) maxUnique); + // iterate through indexes that are in the sparse encoding final AIterator itr = e.off.getIterator(); final int fr = e.off.getOffsetToLast(); int newUID = 1; int r = 0; - for(; r < fr; r++) { + for(; r <= fr; r++) { final int ir = itr.value(); if(ir == r) { - - final int nv = map.getIndex(r) + e.map.getIndex(itr.getDataIndex()) * nVl; - itr.next(); + final int nv = map.getIndex(ir) + e.map.getIndex(itr.getDataIndex()) * nVl; newUID = addVal(nv, r, m, newUID, d); + if(ir >= fr) { + r++; + break; + } + else { + itr.next(); + } } else { final int nv = map.getIndex(r) + defR; newUID = addVal(nv, r, m, newUID, d); } } - // add last offset - newUID = addVal(map.getIndex(r) + e.map.getIndex(itr.getDataIndex()) * nVl, r++, m, newUID, d); - // add remaining rows. for(; r < nRows; r++) { final int nv = map.getIndex(r) + defR; newUID = addVal(nv, r, m, newUID, d); } // set unique. - d.setUnique(newUID - 1); - return new DenseEncoding(d); + d.setUnique(newUID-1); + return joinDenseCount(d); } - protected static int addVal(int nv, int r, int[] m, int newUID, AMapToData d) { - final int mapV = m[nv]; - if(mapV == 0) { - d.set(r, newUID - 1); - m[nv] = newUID++; - } + private static int addVal(int nv, int r, int[] m, int newId, AMapToData d) { + if(m[nv] == 0) + d.set(r, (m[nv] = newId++) - 1); else - d.set(r, mapV - 1); - return newUID; + d.set(r, m[nv] - 1); + return newId; } protected DenseEncoding joinDense(DenseEncoding e) { if(map == e.map) return this; // unlikely to happen but cheap to compute - return new DenseEncoding(MapToFactory.join(map, e.map)); + final AMapToData d = MapToFactory.join(map, e.map); + return joinDenseCount(d); + } + + protected static DenseEncoding joinDenseCount(AMapToData d) { + int[] counts = new int[d.getUnique()]; + d.getCounts(counts); + return new DenseEncoding(d, counts); } @Override @@ -152,7 +152,6 @@ public EstimationFactors computeSizeEstimation(int[] cols, int nRows, double tup return new EstimationFactors(cols.length, counts.length, nRows, largestOffs, counts, 0, 0, nRows, false, false, matrixSparsity, tupleSparsity); - } @Override @@ -167,5 +166,4 @@ public String toString() { sb.append(Arrays.toString(counts)); return sb.toString(); } - } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java index 6fa7863a336..1aa41d47884 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java @@ -29,7 +29,6 @@ public class EmptyEncoding implements IEncode { // empty constructor protected EmptyEncoding() { - } @Override @@ -63,5 +62,4 @@ public String toString() { public EstimationFactors computeSizeEstimation(int[] cols, int nRows, double tupleSparsity, double matrixSparsity) { return new EstimationFactors(cols.length, 0, 0, nRows, counts, 0, 0, nRows, false, true, 0, 0); } - } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java index dfcda0c208e..a7eba644206 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java @@ -158,7 +158,6 @@ public static IEncode createFromSparseTransposed(MatrixBlock m, int row) { AOffset o = OffsetFactory.createOffset(sb.indexes(row), apos, alen); final int zero = m.getNumColumns() - o.getSize(); - return new SparseEncoding(d, o, zero, counts, m.getNumColumns()); } @@ -208,7 +207,6 @@ public static IEncode createFromDense(MatrixBlock m, int col) { // Iteration 2, make final map for(int i = off, r = 0; i < end; i += nCol, r++) d.set(r, map.get(vals[i])); - return new DenseEncoding(d, counts); } } @@ -263,7 +261,6 @@ public static IEncode createFromSparse(MatrixBlock m, int col) { AOffset o = OffsetFactory.createOffset(offsets); final int zero = m.getNumRows() - sumCounts; - return new SparseEncoding(d, o, zero, counts, m.getNumRows()); } @@ -306,6 +303,7 @@ public static IEncode createWithReader(MatrixBlock m, int[] rowCols, boolean tra else { // TODO add Common group, that allows to allocate with one of the map entries as the common value. // the input was fully dense. + final int[] counts = map.getUnorderedCountsAndReplaceWithUIDs(); return createWithReaderDense(m, map, counts, rowCols, nRows, transposed); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java index 51520b37d39..8d0bea2819d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java @@ -30,11 +30,16 @@ import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.utils.IntArrayList; -/** Most common is zero */ +/** Most common is zero encoding */ public class SparseEncoding implements IEncode { + /** A map to the distinct values contained */ protected final AMapToData map; + + /** A Offset index structure to indicate space of zero values */ protected final AOffset off; + + /** Total number of rows encoded */ protected final int nRows; /** Count of Zero tuples in this encoding */ @@ -49,13 +54,6 @@ protected SparseEncoding(AMapToData map, AOffset off, int zeroCount, int[] count this.zeroCount = zeroCount; this.counts = counts; this.nRows = nRows; - - // final int u = getUnique(); - // for(int i = 0; i < map.size();i ++){ - // if(map.getIndex(i) > u){ - // throw new DMLCompressionException("Invalid allocation"); - // } - // } } @Override @@ -90,10 +88,32 @@ protected IEncode joinSparse(SparseEncoding e) { final int nVl = getUnique(); final int nVr = e.getUnique(); + + final int unique = joinSparse(map, e.map, itl, itr, retOff, tmpVals, fl, fr, nVl, nVr, d); + + if(retOff.size() < nRows * 0.3) { + final AOffset o = OffsetFactory.createOffset(retOff); + final AMapToData retMap = MapToFactory.create(tmpVals.size(), tmpVals.extractValues(), unique); + return new SparseEncoding(retMap, o, nRows - retOff.size(), retMap.getCounts(new int[unique - 1]), nRows); + } + else { + final AMapToData retMap = MapToFactory.create(nRows, unique); + retMap.fill(unique - 1); + for(int i = 0; i < retOff.size(); i++) + retMap.set(retOff.get(i), tmpVals.get(i)); + + // add values. + IEncode ret = DenseEncoding.joinDenseCount(retMap); + return ret; + } + } + + private static int joinSparse(AMapToData lMap, AMapToData rMap, AIterator itl, AIterator itr, IntArrayList retOff, + IntArrayList tmpVals, int fl, int fr, int nVl, int nVr, int[] d) { + final int defR = (nVr - 1) * nVl; final int defL = nVl - 1; - boolean doneL = false; boolean doneR = false; int newUID = 1; @@ -102,7 +122,7 @@ protected IEncode joinSparse(SparseEncoding e) { final int ir = itr.value(); if(il == ir) { // Both sides have a value. - final int nv = map.getIndex(itl.getDataIndex()) + e.map.getIndex(itr.getDataIndex()) * nVl; + final int nv = lMap.getIndex(itl.getDataIndex()) + rMap.getIndex(itr.getDataIndex()) * nVl; newUID = addVal(nv, il, d, newUID, tmpVals, retOff); if(il >= fl || ir >= fr) { if(il < fl) @@ -121,7 +141,7 @@ protected IEncode joinSparse(SparseEncoding e) { } else if(il < ir) { // left side have a value before right - final int nv = map.getIndex(itl.getDataIndex()) + defR; + final int nv = lMap.getIndex(itl.getDataIndex()) + defR; newUID = addVal(nv, il, d, newUID, tmpVals, retOff); if(il >= fl) { doneL = true; @@ -131,7 +151,7 @@ else if(il < ir) { } else { // right side have a value before left - final int nv = e.map.getIndex(itr.getDataIndex()) * nVl + defL; + final int nv = rMap.getIndex(itr.getDataIndex()) * nVl + defL; newUID = addVal(nv, ir, d, newUID, tmpVals, retOff); if(ir >= fr) { doneR = true; @@ -148,9 +168,9 @@ else if(il < ir) { final int ir = itr.value(); int nv; if(ir == il) - nv = map.getIndex(itl.getDataIndex()) + e.map.getIndex(itr.getDataIndex()) * nVl; + nv = lMap.getIndex(itl.getDataIndex()) + rMap.getIndex(itr.getDataIndex()) * nVl; else - nv = map.getIndex(itl.getDataIndex()) + defR; + nv = lMap.getIndex(itl.getDataIndex()) + defR; newUID = addVal(nv, il, d, newUID, tmpVals, retOff); if(il >= fl) break; @@ -163,9 +183,9 @@ else if(!doneR) {// If there is stragglers in the right side final int ir = itr.value(); int nv; if(ir == il) - nv = map.getIndex(itl.getDataIndex()) + e.map.getIndex(itr.getDataIndex()) * nVl; + nv = lMap.getIndex(itl.getDataIndex()) + rMap.getIndex(itr.getDataIndex()) * nVl; else - nv = e.map.getIndex(itr.getDataIndex()) * nVl + defL; + nv = rMap.getIndex(itr.getDataIndex()) * nVl + defL; newUID = addVal(nv, ir, d, newUID, tmpVals, retOff); if(ir >= fr) @@ -174,22 +194,7 @@ else if(!doneR) {// If there is stragglers in the right side } } - if(retOff.size() < nRows * 0.4) { - final AOffset o = OffsetFactory.createOffset(retOff); - final AMapToData retMap = MapToFactory.create(tmpVals.size(), tmpVals.extractValues(), newUID); - return new SparseEncoding(retMap, o, nRows - retOff.size(), - retMap.getCounts(new int[newUID - 1], retOff.size()), nRows); - } - else { - final AMapToData retMap = MapToFactory.create(nRows, newUID); - retMap.fill(newUID - 1); - for(int i = 0; i < retOff.size(); i++) - retMap.set(retOff.get(i), tmpVals.get(i)); - - // add values. - IEncode ret = new DenseEncoding(retMap); - return ret; - } + return newUID; } private static int addVal(int nv, int offset, int[] d, int newUID, IntArrayList tmpVals, IntArrayList offsets) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/sample/HassAndStokes.java b/src/main/java/org/apache/sysds/runtime/compress/estim/sample/HassAndStokes.java index 614e3325a68..937dec2ca00 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/sample/HassAndStokes.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/sample/HassAndStokes.java @@ -159,6 +159,7 @@ private static double getMethodOfMomentsEstimate(int nj, double q, double min, d // NOTE the cache does not work currently since the number of rows considered each call can change now // This happens because the sampled estimator now considers nonzeros or offsets calculated and therefore know upper // bounds of the number of offsets that are lower than the maximum number of rows. + // if(solveCache.containsKey(nj)) // synchronized(solveCache) { // return solveCache.get(nj); diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/sample/SampleEstimatorFactory.java b/src/main/java/org/apache/sysds/runtime/compress/estim/sample/SampleEstimatorFactory.java index 78c0bc21e75..83934373933 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/sample/SampleEstimatorFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/sample/SampleEstimatorFactory.java @@ -96,7 +96,8 @@ private static int[] getInvertedFrequencyHistogram(int[] frequencies) { // create frequency histogram int[] freqCounts = new int[maxCount]; for(int i = 0; i < numVals; i++) - freqCounts[frequencies[i] - 1]++; + if(frequencies[i] != 0) + freqCounts[frequencies[i] - 1]++; return freqCounts; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/sample/SmoothedJackknifeEstimator.java b/src/main/java/org/apache/sysds/runtime/compress/estim/sample/SmoothedJackknifeEstimator.java index ea40a70ad1b..c1aa5344b3e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/sample/SmoothedJackknifeEstimator.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/sample/SmoothedJackknifeEstimator.java @@ -54,8 +54,8 @@ public static int distinctCount(int numVals, int[] freqCounts, int nRows, int sa * However, for large values of nRows, Gamma.gamma returns NAN * (factorial of a very large number). * - * The following implementation solves this problem by levaraging the - * cancelations that show up when expanding the factorials in the + * The following implementation solves this problem by leveraging the + * cancellations that show up when expanding the factorials in the * numerator and the denominator. * * diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAppend.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAppend.java index 68eca8045af..f9986eb141f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAppend.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAppend.java @@ -88,7 +88,7 @@ public static MatrixBlock append(CompressedMatrixBlock left, CompressedMatrixBlo private static MatrixBlock appendRightEmpty(CompressedMatrixBlock left, MatrixBlock right, int m, int n) { CompressedMatrixBlock ret = new CompressedMatrixBlock(m, n); List newGroup = new ArrayList<>(1); - newGroup.add(ColGroupEmpty.generate(right.getNumColumns())); + newGroup.add(ColGroupEmpty.create(right.getNumColumns())); ret = appendColGroups(ret, left.getColGroups(), newGroup, left.getNumColumns()); ret.setOverlapping(left.isOverlapping()); return ret; @@ -97,7 +97,7 @@ private static MatrixBlock appendRightEmpty(CompressedMatrixBlock left, MatrixBl private static MatrixBlock appendLeftEmpty(MatrixBlock left, CompressedMatrixBlock right, int m, int n) { CompressedMatrixBlock ret = new CompressedMatrixBlock(m, n); List newGroup = new ArrayList<>(1); - newGroup.add(ColGroupEmpty.generate(left.getNumColumns())); + newGroup.add(ColGroupEmpty.create(left.getNumColumns())); ret = appendColGroups(ret, newGroup, right.getColGroups(), left.getNumColumns()); ret.setOverlapping(right.isOverlapping()); return ret; @@ -134,5 +134,4 @@ private static MatrixBlock uc(MatrixBlock mb) { // get uncompressed return CompressedMatrixBlock.getUncompressed(mb); } - } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java index 7fb0118602c..cd4360de606 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java @@ -34,7 +34,7 @@ import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.colgroup.AColGroup; -import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; +import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; import org.apache.sysds.runtime.data.SparseBlock; @@ -300,7 +300,7 @@ protected static CompressedMatrixBlock binaryMVPlusStack(CompressedMatrixBlock m if(smallestSize == Integer.MAX_VALUE) { // if there was no smallest colgroup ADictionary newDict = new MatrixBlockDictionary(m2, nCol); - newColGroups.add(ColGroupFactory.genColGroupConst(nCol, newDict)); + newColGroups.add(ColGroupConst.create(nCol, newDict)); } else { // apply to the found group @@ -331,7 +331,7 @@ private static MatrixBlock binaryMVCol(CompressedMatrixBlock m1, MatrixBlock m2, CompressedMatrixBlock mf1 = new CompressedMatrixBlock(m1); double[] constV = new double[nCols]; final List filteredGroups = CLALibUtils.filterGroups(groups, constV); - filteredGroups.add(ColGroupFactory.genColGroupConst(constV)); + filteredGroups.add(ColGroupConst.create(constV)); mf1.allocateColGroupList(filteredGroups); m1 = mf1; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCMOps.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCMOps.java new file mode 100644 index 00000000000..a400e6ad96b --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCMOps.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.lib; + +import java.util.List; + +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; +import org.apache.sysds.runtime.matrix.data.LibMatrixAgg; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.CMOperator; + +public class CLALibCMOps { + public static CM_COV_Object centralMoment(CompressedMatrixBlock cmb, CMOperator op) { + MatrixBlock.checkCMOperations(cmb, op); + if(cmb.isEmpty()) + return LibMatrixAgg.aggregateCmCov(cmb, null, null, op.fn); + else if(cmb.isOverlapping()) + return cmb.getUncompressed("cmOperations on overlapping state").cmOperations(op); + else { + final List groups = cmb.getColGroups(); + if(groups.size() == 1) + return groups.get(0).centralMoment(op, cmb.getNumRows()); + else + throw new DMLCompressionException("Unsupported case for cmOperations"); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java index 9e01b8488bd..db99dc5ef08 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java @@ -35,7 +35,7 @@ import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed; -import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; +import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; @@ -434,7 +434,7 @@ private static List> generateUnaryAggregateOverlappingFuture if(shouldFilter) { final double[] constV = new double[nCol]; final List filteredGroups = CLALibUtils.filterGroups(groups, constV); - final AColGroup cRet = ColGroupFactory.genColGroupConst(constV); + final AColGroup cRet = ColGroupConst.create(constV); filteredGroups.add(cRet); for(int i = 0; i < nRow; i += blklen) tasks.add(new UAOverlappingTask(filteredGroups, ret, i, Math.min(i + blklen, nRow), op, nCol)); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java index a1ac8ffcbf8..55bc0fc8e01 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java @@ -32,7 +32,7 @@ import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; -import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; +import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.data.DenseBlock; @@ -106,7 +106,7 @@ private static void decompressToDenseBlock(CompressedMatrixBlock cmb, DenseBlock final List filteredGroups = CLALibUtils.filterGroups(groups, constV); for(AColGroup g : filteredGroups) g.decompressToDenseBlock(ret, 0, nRows, rowOffset, colOffset); - AColGroup cRet = ColGroupFactory.genColGroupConst(constV); + AColGroup cRet = ColGroupConst.create(constV); cRet.decompressToDenseBlock(ret, 0, nRows, rowOffset, colOffset); } else { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java index 9d6b889d4d1..750dd332395 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java @@ -115,7 +115,7 @@ public static MatrixBlock leftMultByMatrix(CompressedMatrixBlock right, MatrixBl } public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { - // final boolean overlapping = cmb.isOverlapping(); + final List groups = cmb.getColGroups(); final int numColumns = cmb.getNumColumns(); final int numRows = cmb.getNumRows(); @@ -518,7 +518,8 @@ private static List preFilterAndMultiply(List colGroups, Mat if(a instanceof APreAgg) { APreAgg g = (APreAgg) a; - g.forceMatrixBlockDictionary(); + //TODO remove call to force matrix block + g.forceMatrixBlockDictionary(); ColGroupValues.add(g); } else @@ -535,8 +536,9 @@ private static double[] getColSum(List groups, int nCols, int nRows) private static void MMPreaggregate(APreAgg cg, MatrixBlock preAgg, MatrixBlock tmpRes, MatrixBlock ret, int rl, int ru) { preAgg.recomputeNonZeros(); + // TODO remove call to getDictionary(). final MatrixBlock dict = ((MatrixBlockDictionary) cg.getDictionary()).getMatrixBlock(); - tmpRes.reset(ru - rl, dict.getNumColumns(), false); + tmpRes.reset(ru - rl, cg.getNumCols(), false); try { LibMatrixMult.matrixMult(preAgg, dict, tmpRes); addMatrixToResult(tmpRes, ret, cg.getColIndices(), rl, ru); @@ -544,7 +546,7 @@ private static void MMPreaggregate(APreAgg cg, MatrixBlock preAgg, MatrixBlock t } catch(Exception e) { throw new DMLCompressionException("Failed MM with preAggregate:\n" + preAgg.getNumRows() + " " - + preAgg.getNumColumns() + "\n dict:\n" + dict.getNumRows() + " " + dict.getNumColumns(), e); + + preAgg.getNumColumns() + "\n dict:\n" + dict.getNumRows() + " " + dict.getNumColumns() + "\n" + cg, e); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java index 5dfe25f8d02..c8752b2e693 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java @@ -26,7 +26,7 @@ import org.apache.sysds.lops.MapMultChain.ChainType; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; -import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; +import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; @@ -101,7 +101,7 @@ private static CompressedMatrixBlock filterColGroups(CompressedMatrixBlock x) { final double[] constV = new double[nCol]; final List filteredGroups = CLALibUtils.filterGroups(groups, constV); - AColGroup c = ColGroupFactory.genColGroupConst(constV); + AColGroup c = ColGroupConst.create(constV); filteredGroups.add(c); x.allocateColGroupList(filteredGroups); return x; diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReExpand.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReExpand.java deleted file mode 100644 index 3ef5629885a..00000000000 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReExpand.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.runtime.compress.lib; - -import java.util.ArrayList; -import java.util.List; - -import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.compress.CompressedMatrixBlock; -import org.apache.sysds.runtime.compress.colgroup.AColGroup; -import org.apache.sysds.runtime.compress.colgroup.AColGroupValue; -import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.util.UtilFunctions; - -public class CLALibReExpand { - - // private static final Log LOG = LogFactory.getLog(CLALibReExpand.class.getName()); - - public static MatrixBlock reExpand(CompressedMatrixBlock in, MatrixBlock ret, double max, boolean cast, - boolean ignore, int k) { - int iMax = UtilFunctions.toInt(max); - if (!ignore && in.min() == 0.0) - throw new DMLRuntimeException("Invalid input w/ zeros for rexpand ignore=false " + "(rlen=" - + in.getNumRows() + ", nnz=" + in.getNonZeros() + ")."); - - if (in.isOverlapping() || in.getColGroups().size() > 1) { - throw new DMLRuntimeException( - "Invalid input for re expand operations, currently not supporting overlapping or multi column groups"); - } - - // check for empty inputs (for ignore=true) - if (in.isEmptyBlock(false)) { - ret.reset(in.getNumRows(), iMax, true); - return ret; - } - CompressedMatrixBlock retC = ret instanceof CompressedMatrixBlock ? (CompressedMatrixBlock) ret - : new CompressedMatrixBlock(in.getNumRows(), iMax); - - return reExpandRows(in, retC, iMax, cast, k); - } - - private static CompressedMatrixBlock reExpandRows(CompressedMatrixBlock in, CompressedMatrixBlock ret, int max, - boolean cast, int k) { - AColGroupValue oldGroup = ((AColGroupValue) in.getColGroups().get(0)); - - ADictionary newDictionary = oldGroup.getDictionary().reExpandColumns(max); - AColGroup newGroup = oldGroup.copyAndSet(getColIndexes(max), newDictionary); - List newColGroups = new ArrayList<>(1); - newColGroups.add(newGroup); - - ret.allocateColGroupList(newColGroups); - ret.setOverlapping(true); - - ret.recomputeNonZeros(); - return ret; - } - - private static int[] getColIndexes(int max) { - int[] ret = new int[max]; - for (int i = 0; i < max; i++) { - ret[i] = i; - } - return ret; - } -} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java new file mode 100644 index 00000000000..c9f8b496adb --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.lib; + +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.UtilFunctions; + +public class CLALibRexpand { + + // private static final Log LOG = LogFactory.getLog(CLALibReExpand.class.getName()); + + public static MatrixBlock rexpand(CompressedMatrixBlock in, MatrixBlock ret, double max, boolean rows, boolean cast, + boolean ignore, int k) { + if(rows) + return in.getUncompressed("Rexpand in rows direction (one hot encode)").rexpandOperations(ret, max, rows, cast, + ignore, k); + else + return rexpandCols(in, max, cast, ignore, k); + } + + private static MatrixBlock rexpandCols(CompressedMatrixBlock in, double max, boolean cast, boolean ignore, int k) { + return rexpandCols(in, UtilFunctions.toInt(max), cast, ignore, k); + } + + private static MatrixBlock rexpandCols(CompressedMatrixBlock in, int max, boolean cast, boolean ignore, int k) { + LibMatrixReorg.checkRexpand(in, ignore); + + final int nRows = in.getNumRows(); + if(in.isEmptyBlock(false)) + return new MatrixBlock(nRows, max, true); + else if(in.isOverlapping() || in.getColGroups().size() > 1) + return LibMatrixReorg.rexpand(in.getUncompressed("Rexpand (one hot encode)"), new MatrixBlock(), max, false, + cast, ignore, k); + else { + CompressedMatrixBlock retC = new CompressedMatrixBlock(nRows, max); + retC.allocateColGroup(in.getColGroups().get(0).rexpandCols(max, ignore, cast, nRows)); + retC.recomputeNonZeros(); + return retC; + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java index be43e585510..7f80b8e0651 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java @@ -38,7 +38,6 @@ import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; -import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -116,7 +115,7 @@ private static CompressedMatrixBlock RMMOverlapping(CompressedMatrixBlock m1, Ma containsNull = RMMParallel(filteredGroups, that, retCg, k); if(constV != null) { - AColGroup cRet = ColGroupFactory.genColGroupConst(constV).rightMultByMatrix(that); + AColGroup cRet = ColGroupConst.create(constV).rightMultByMatrix(that); if(cRet != null) retCg.add(cRet); } @@ -156,7 +155,7 @@ private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock that, int k RMMParallel(filteredGroups, that, retCg, k); if(constV != null) { - ColGroupConst cRet = (ColGroupConst) ColGroupFactory.genColGroupConst(constV).rightMultByMatrix(that); + ColGroupConst cRet = (ColGroupConst) ColGroupConst.create(constV).rightMultByMatrix(that); constV = cRet.getValues(); // overwrite constV } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java index d8a384e2307..9bacffaf3f6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java @@ -34,7 +34,6 @@ import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; -import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; import org.apache.sysds.runtime.compress.colgroup.ColGroupOLE; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.functionobjects.Divide; @@ -104,7 +103,7 @@ private static CompressedMatrixBlock setupRet(CompressedMatrixBlock m1, MatrixVa } private static ColGroupConst constOverlap(CompressedMatrixBlock m1, ScalarOperator sop) { - return (ColGroupConst) ColGroupFactory.genColGroupConst(m1.getNumColumns(), sop.executeScalar(0)); + return (ColGroupConst) ColGroupConst.create(m1.getNumColumns(), sop.executeScalar(0)); } private static List copyGroups(CompressedMatrixBlock m1, ScalarOperator sop, ColGroupConst c, diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSlice.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSlice.java index 4f06fd48552..51077d427c6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSlice.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSlice.java @@ -26,7 +26,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; -import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; +import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -66,7 +66,7 @@ private static MatrixBlock sliceRows(CompressedMatrixBlock cmb, int rl, int ru) final List filteredGroups = CLALibUtils.filterGroups(groups, constV); for(AColGroup g : filteredGroups) g.decompressToDenseBlock(db, rl, rue, -rl, 0); - AColGroup cRet = ColGroupFactory.genColGroupConst(constV); + AColGroup cRet = ColGroupConst.create(constV); cRet.decompressToDenseBlock(db, rl, rue, -rl, 0); } else diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java index b1c5274137b..17cd05fc01d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java @@ -31,7 +31,6 @@ import org.apache.sysds.runtime.compress.colgroup.AMorphingMMColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; -import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; public final class CLALibUtils { protected static final Log LOG = LogFactory.getLog(CLALibUtils.class.getName()); @@ -134,7 +133,7 @@ private static AColGroup combineConst(List c) { values[outId] = colVals[i]; } } - return ColGroupFactory.genColGroupConst(resCols, values); + return ColGroupConst.create(resCols, values); } private static int[] combineColIndexes(List gs) { diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java index 864c7f2e5f6..8ae1ad2d290 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java @@ -419,10 +419,23 @@ public static MatrixBlock cumaggregateUnaryMatrix(MatrixBlock in, MatrixBlock ou return out; } + /** + * Single threaded Covariance and Central Moment operations + * + * CM = Central Moment + * + * COV = Covariance + * + * @param in1 Main input matrix + * @param in2 Second input matrix + * @param in3 Third input matrix (not output since output is returned) + * @param fn Value function to apply + * @return Central Moment or Covariance object + */ public static CM_COV_Object aggregateCmCov(MatrixBlock in1, MatrixBlock in2, MatrixBlock in3, ValueFunction fn) { CM_COV_Object cmobj = new CM_COV_Object(); - // empty block handling (important for result corretness, otherwise + // empty block handling (important for result correctness, otherwise // we get a NaN due to 0/0 on reading out the required result) if( in1.isEmptyBlock(false) && fn instanceof CM ) { fn.execute(cmobj, 0.0, in1.getNumRows()); @@ -432,6 +445,20 @@ public static CM_COV_Object aggregateCmCov(MatrixBlock in1, MatrixBlock in2, Mat return aggregateCmCov(in1, in2, in3, fn, 0, in1.getNumRows()); } + /** + * Multi threaded Covariance and Central Moment operations + * + * CM = Central Moment + * + * COV = Covariance + * + * @param in1 Main input matrix + * @param in2 Second input matrix + * @param in3 Third input matrix (not output since output is returned) + * @param fn Value function to apply + * @param k Parallelization degree + * @return Central Moment or Covariance object + */ public static CM_COV_Object aggregateCmCov(MatrixBlock in1, MatrixBlock in2, MatrixBlock in3, ValueFunction fn, int k) { if( in1.isEmptyBlock(false) || !satisfiesMultiThreadingConstraints(in1, k) ) return aggregateCmCov(in1, in2, in3, fn); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index 596beda055e..a288e9eab94 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -719,52 +719,77 @@ public static void rmempty(IndexedMatrixValue data, IndexedMatrixValue offset, b } /** - * CP rexpand operation (single input, single output) + * CP rexpand operation (single input, single output), the classic example of this operation is one hot encoding of a + * column to multiple columns. * - * @param in input matrix - * @param ret output matrix - * @param max ? - * @param rows ? - * @param cast ? - * @param ignore ? - * @param k degree of parallelism - * @return output matrix + * @param in Input matrix + * @param ret Output matrix + * @param max Number of rows/cols of the output + * @param rows If the expansion is in rows direction + * @param cast If the values contained should be cast to double (rounded up and down) + * @param ignore Ignore if the input contain values below zero that technically is incorrect input. + * @param k Degree of parallelism + * @return Output matrix rexpanded */ public static MatrixBlock rexpand(MatrixBlock in, MatrixBlock ret, double max, boolean rows, boolean cast, boolean ignore, int k) { - //prepare parameters - int lmax = (int)UtilFunctions.toLong(max); - + return rexpand(in, ret, UtilFunctions.toInt(max), rows, cast, ignore, k); + } + + /** + * CP rexpand operation (single input, single output), the classic example of this operation is one hot encoding of a + * column to multiple columns. + * + * @param in Input matrix + * @param ret Output matrix + * @param max Number of rows/cols of the output + * @param rows If the expansion is in rows direction + * @param cast If the values contained should be cast to double (rounded up and down) + * @param ignore Ignore if the input contain values below zero that technically is incorrect input. + * @param k Degree of parallelism + * @return Output matrix rexpanded + */ + public static MatrixBlock rexpand(MatrixBlock in, MatrixBlock ret, int max, boolean rows, boolean cast, boolean ignore, int k){ //sanity check for input nnz (incl implicit handling of empty blocks) - if( !ignore && in.getNonZeros() outList) { //prepare parameters diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 683cd14edb6..2a8d74da8a0 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -52,6 +52,7 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.DenseBlockFP64; import org.apache.sysds.runtime.data.DenseBlockFactory; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockCOO; @@ -221,6 +222,13 @@ public MatrixBlock(int rl, int cl, DenseBlock dBlock){ denseBlock = dBlock; } + public MatrixBlock(int rl, int cl, double[] vals){ + rlen = rl; + clen = cl; + sparse = false; + denseBlock = new DenseBlockFP64(new int[] {rl,cl}, vals); + nonZeros = vals.length; + } protected MatrixBlock(boolean empty){ // do nothing @@ -4708,13 +4716,16 @@ else if( !sparse && denseBlock!=null ) { //DENSE } public CM_COV_Object cmOperations(CMOperator op) { + checkCMOperations(this, op); + return LibMatrixAgg.aggregateCmCov(this, null, null, op.fn, op.getNumThreads()); + } + + public static void checkCMOperations(MatrixBlock mb, CMOperator op){ // dimension check for input column vectors - if ( this.getNumColumns() != 1) { + if ( mb.getNumColumns() != 1) { throw new DMLRuntimeException("Central Moment cannot be computed on [" - + this.getNumRows() + "," + this.getNumColumns() + "] matrix."); + + mb.getNumRows() + "," + mb.getNumColumns() + "] matrix."); } - - return LibMatrixAgg.aggregateCmCov(this, null, null, op.fn, op.getNumThreads()); } public CM_COV_Object cmOperations(CMOperator op, MatrixBlock weights) { diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java index 88acf523249..4312d103d53 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java @@ -19,7 +19,6 @@ package org.apache.sysds.test.component.compress; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -45,14 +44,11 @@ import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.functionobjects.PlusMultiply; import org.apache.sysds.runtime.functionobjects.ReduceAll; -import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; -import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator; -import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator.CountDistinctTypes; import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.matrix.operators.TernaryOperator; @@ -114,7 +110,7 @@ public void testQuickGetValue() { Random r = new Random(); final int min = r.nextInt(rows); final int max = Math.min(r.nextInt(rows - min) + min, min + 1000); - + for(int i = min; i < max; i++) for(int j = 0; j < cols; j++) { final double ulaVal = mb.quickGetValue(i, j); @@ -133,39 +129,6 @@ else if(OverLapping.effectOnOutput(overlappingType)) } } - @Test - @Ignore - public void testCountDistinct() { - try { - // Counting distinct is potentially wrong in cases with overlapping, resulting in a few to many or few - // elements. - if(!(cmb instanceof CompressedMatrixBlock) || (overlappingType == OverLapping.MATRIX_MULT_NEGATIVE)) - return; // Input was not compressed then just pass test - - CountDistinctOperator op = new CountDistinctOperator(CountDistinctTypes.COUNT); - int ret1 = LibMatrixCountDistinct.estimateDistinctValues(mb, op); - int ret2 = LibMatrixCountDistinct.estimateDistinctValues(cmb, op); - - String base = bufferedToString + "\n"; - if(_cs != null && _cs.lossy) { - // The number of distinct values should be same or lower in lossy mode. - // assertTrue(base + "lossy distinct count " +ret2+ "is less than full " + ret1, ret1 >= ret2); - - // above assumption is false, since the distinct count when using multiple different scales becomes - // larger due to differences in scale. - assertTrue(base + "lossy distinct count " + ret2 + "is greater than 0", 0 < ret2); - } - else { - assertEquals(base, ret1, ret2); - } - - } - catch(Exception e) { - e.printStackTrace(); - throw new RuntimeException(bufferedToString + "\n" + e.getMessage(), e); - } - } - @Override public void testUnaryOperators(AggType aggType, boolean inCP) { AggregateUnaryOperator auop = super.getUnaryOperator(aggType, 1); diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedSingleTests.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedSingleTests.java index cca005c02b4..ab966526e66 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedSingleTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedSingleTests.java @@ -139,9 +139,9 @@ public void test_settingsBuilder() { CompressionSettingsBuilder b = new CompressionSettingsBuilder(); b = b.addValidCompression(CompressionType.CONST).setLossy(true).setLossy(false).setSortValuesByLength(true) .setAllowSharedDictionary(true).setColumnPartitioner(CoCoderFactory.PartitionerType.BIN_PACKING) - .setMaxColGroupCoCode(3).setEstimationType(EstimationType.ShlosserJackknifeEstimator) - .clearValidCompression().setSamplingRatio(0.2).setSeed(1342).setCoCodePercentage(0.22) - .setMinimumSampleSize(1342).setCostType(CostEstimatorFactory.CostType.MEMORY); + .setMaxColGroupCoCode(3).setEstimationType(EstimationType.ShlosserJackknifeEstimator).clearValidCompression() + .setSamplingRatio(0.2).setSeed(1342).setCoCodePercentage(0.22).setMinimumSampleSize(1342) + .setCostType(CostEstimatorFactory.CostType.MEMORY); CompressionSettings s = b.create(); b = b.copySettings(s); } diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java index aef68e620d1..acd58fba1e3 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java @@ -19,7 +19,6 @@ package org.apache.sysds.test.component.compress; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; import java.util.ArrayList; @@ -192,7 +191,7 @@ else if(_cs == null) { for(int i = 0; i < colIndexes.length; i++) colIndexes[i] = i; cmb = new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns()); - ((CompressedMatrixBlock) cmb).allocateColGroup(new ColGroupUncompressed(colIndexes, mb, false)); + ((CompressedMatrixBlock) cmb).allocateColGroup(ColGroupUncompressed.create(colIndexes, mb, false)); } } else { @@ -390,7 +389,9 @@ public void testDecompress() { ((CompressedMatrixBlock) cmb).clearSoftReferenceToDecompressed(); MatrixBlock decompressedMatrixBlock = ((CompressedMatrixBlock) cmb).decompress(_k); compareResultMatrices(mb, decompressedMatrixBlock, 1); - assertEquals(bufferedToString, mb.getNonZeros(), decompressedMatrixBlock.getNonZeros()); + if(mb.getNonZeros() != decompressedMatrixBlock.getNonZeros()) + fail(bufferedToString + "\n NonZeros not equivalent: expected:" + mb.getNonZeros() + " was: " + + decompressedMatrixBlock.getNonZeros()); } catch(Exception e) { e.printStackTrace(); @@ -450,7 +451,6 @@ public void testLeftMatrixMatrixMultSmall() { testLeftMatrixMatrix(matrix); } - @Test public void testLeftMatrixMatrixMultConst() { MatrixBlock matrix = TestUtils.generateTestMatrixBlock(3, rows, 1.0, 1.0, 1.0, 3); @@ -634,7 +634,7 @@ public void testLeftMatrixMatrixMultiplicationTransposed(MatrixBlock matrix, boo ucRet = right.aggregateBinaryOperations(left, right, ucRet, abopSingle); MatrixBlock ret2 = ((CompressedMatrixBlock) cmb).aggregateBinaryOperations(compMatrix, cmb, new MatrixBlock(), - abopSingle, transposeLeft, transposeRight); + abopSingle, transposeLeft, transposeRight); compareResultMatrices(ucRet, ret2, 100); } diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java index a3bfccc296e..7ffe3bb705f 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedVectorTest.java @@ -138,4 +138,30 @@ public void testSortOperations() { throw new RuntimeException(bufferedToString + "\n" + e.getMessage(), e); } } + + @Test + public void testReExpandRow() { + // does not make much sense since it would entail the compression was on a matrix with one row. + // but here is a test. + testReExpand(false); + } + + @Test + public void testReExpandCol() { + testReExpand(true); + } + + public void testReExpand(boolean col) { + try { + if(cmb instanceof CompressedMatrixBlock) { + MatrixBlock ret1 = cmb.rexpandOperations(new MatrixBlock(), 50, !col, true, true, _k); + MatrixBlock ret2 = mb.rexpandOperations(new MatrixBlock(), 50, !col, true, true, _k); + compareResultMatrices(ret2, ret1, 0); + } + } + catch(Exception e) { + e.printStackTrace(); + throw e; + } + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressibleInputGenerator.java b/src/test/java/org/apache/sysds/test/component/compress/CompressibleInputGenerator.java index 4a4aab961a4..89fa9405784 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressibleInputGenerator.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressibleInputGenerator.java @@ -155,31 +155,30 @@ private static void rle(MatrixBlock output, int nrUnique, int max, int min, doub } } } - + private static void ole(MatrixBlock output, int nrUnique, int max, int min, double sparsity, int seed, - boolean transpose) { - + boolean transpose) { + // chose some random values - final Random r = new Random(seed); - final List values = getNRandomValues(nrUnique, r, max, min); - if(transpose && output.isInSparseFormat() && output.getNumRows() == 1){ - int nV = (int)Math.round((double)output.getNumColumns() * sparsity); + final Random r = new Random(seed); + final List values = getNRandomValues(nrUnique, r, max, min); + if(transpose && output.isInSparseFormat() && output.getNumRows() == 1) { + int nV = (int) Math.round((double) output.getNumColumns() * sparsity); - for(int i = 0 ; i < nV; i ++){ - double d = values.get(r.nextInt(nrUnique)); + for(int i = 0; i < nV; i++) { + double d = values.get(r.nextInt(nrUnique)); output.appendValue(0, r.nextInt(output.getNumColumns()), d); } output.getSparseBlock().sort(); return; } - final int cols = transpose ? output.getNumRows() : output.getNumColumns(); final int rows = transpose ? output.getNumColumns() : output.getNumRows(); // Generate the first column. for(int x = 0; x < rows; x++) { - double d = values.get(r.nextInt(nrUnique)); + double d = values.get(r.nextInt(nrUnique)); if(transpose && output.isInSparseFormat()) output.appendValue(0, x, d); else if(transpose) @@ -194,7 +193,7 @@ else if(transpose) for(int y = 1; y < cols; y++) { for(int x = 0; x < rows; x++) { if(r.nextDouble() < sparsity) { - if(transpose && output.isInSparseFormat()){ + if(transpose && output.isInSparseFormat()) { int v = (int) (output.getValue(0, x) * (double) y); double d = Math.abs(v % ((int) (diff))) + min; output.appendValue(y, x, d); @@ -213,17 +212,17 @@ else if(transpose) { } } - if(transpose && output.isInSparseFormat()){ + if(transpose && output.isInSparseFormat()) { SparseBlock sb = output.getSparseBlock(); double[] r0 = sb.values(0); - for(int i = 0; i < r0.length; i++){ + for(int i = 0; i < r0.length; i++) { if(r.nextDouble() > sparsity) { r0[i] = 0; } } sb.get(0).compact(); } - else{ + else { for(int x = 0; x < rows; x++) { if(r.nextDouble() > sparsity) { if(transpose) @@ -257,7 +256,6 @@ private static List getNRandomValues(int nrUnique, Random r, int max, in double v = Math.round(((r.nextDouble() * (max - min)) + min) * 100) / 100; values.add(Math.floor(v)); } - // LOG.debug(values); return values; } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/ExtendedMatrixTests.java b/src/test/java/org/apache/sysds/test/component/compress/ExtendedMatrixTests.java index 36da6485b8d..55b303077dd 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/ExtendedMatrixTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/ExtendedMatrixTests.java @@ -81,7 +81,7 @@ public static Collection data() { OverLapping ov = OverLapping.NONE; // empty matrix compression ... (technically not a compressed matrix.) - tests.add(new Object[]{SparsityType.EMPTY, ValueType.RAND, vr, csb(), mt, ov, 1, null }); + tests.add(new Object[] {SparsityType.EMPTY, ValueType.RAND, vr, csb(), mt, ov, 1, null}); for(CompressionSettingsBuilder cs : usedCompressionSettings) tests.add(new Object[] {st, vt, vr, cs, mt, ov, 1, null}); @@ -89,7 +89,7 @@ public static Collection data() { ov = OverLapping.PLUS_ROW_VECTOR; for(CompressionSettingsBuilder cs : usedCompressionSettings) tests.add(new Object[] {st, vt, vr, cs, mt, ov, 1, null}); - + return tests; } @@ -323,7 +323,7 @@ public void testScalarLeftOpEqual() { } @Test - @Ignore + @Ignore // Currently ignored because of division with zero. public void testScalarLeftOpDivide() { double addValue = 14.0; @@ -465,12 +465,12 @@ public void testLeftMatrixMatrixMultMedium() { MatrixBlock matrix = TestUtils.generateTestMatrixBlock(50, rows, 0.9, 1.5, 1.0, 3); testLeftMatrixMatrix(matrix); } - + @Test - public void testCompactEmptyBlock(){ - if(cmb instanceof CompressedMatrixBlock){ + public void testCompactEmptyBlock() { + if(cmb instanceof CompressedMatrixBlock) { cmb.compactEmptyBlock(); - if(cmb.isEmpty()){ + if(cmb.isEmpty()) { CompressedMatrixBlock cm = (CompressedMatrixBlock) cmb; assertTrue(null == cm.getSoftReferenceToDecompressed()); } diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/JolEstimateTest.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/JolEstimateTest.java index c5219bf832f..27ee991e36b 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/JolEstimateTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/JolEstimateTest.java @@ -151,7 +151,7 @@ public void compressedSizeInfoEstimatorSample(double ratio, double tolerance) { if(cg == null) return; try { - if( mbt.getNumColumns() > 10000 ) + if(mbt.getNumColumns() > 10000) tolerance = tolerance * 0.95; final CompressionSettings cs = csb.setSamplingRatio(ratio).setMinimumSampleSize(10) .setValidCompressions(EnumSet.of(getCT())).create(); @@ -161,7 +161,7 @@ public void compressedSizeInfoEstimatorSample(double ratio, double tolerance) { Math.max(10, (int) (mbt.getNumColumns() * ratio)), 1); final int sampleSize = est.getSampleSize(); - final CompressedSizeInfoColGroup cInfo = est.estimateCompressedColGroupSize(); + final CompressedSizeInfoColGroup cInfo = est.estimateCompressedColGroupSize(colIndexes); // LOG.error(cg); final int estimateNUniques = cInfo.getNumVals(); final long estimateCSI = cInfo.getCompressionSize(cg.getCompType()); @@ -177,9 +177,7 @@ public void compressedSizeInfoEstimatorSample(double ratio, double tolerance) { fail("CSI Sampled estimate size is not in tolerance range \n" + rangeString + "\nActual number uniques:" + actualNumberUnique + " estimated Uniques: " + estimateNUniques + "\nSampleSize of total rows:: " - + sampleSize + " " + mbt.getNumColumns() + "\n" + cInfo - // + "\n" + mbt + "\n" + cg - ); + + sampleSize + " " + mbt.getNumColumns() + "\n" + cInfo + "\n" + mbt + "\n" + cg); } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/NegativeConstTests.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/NegativeConstTests.java index c1a94bf8ceb..b2e252b09c7 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/NegativeConstTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/NegativeConstTests.java @@ -20,7 +20,7 @@ package org.apache.sysds.test.component.compress.colgroup; import org.apache.sysds.runtime.compress.DMLCompressionException; -import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; +import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.junit.Test; @@ -28,71 +28,71 @@ public class NegativeConstTests { @Test(expected = DMLCompressionException.class) public void testConstConstruction_01() { - ColGroupFactory.genColGroupConst(-1, 14); + ColGroupConst.create(-1, 14); } @Test(expected = DMLCompressionException.class) public void testConstConstruction_02() { - ColGroupFactory.genColGroupConst(0, 14); + ColGroupConst.create(0, 14); } @Test(expected = DMLCompressionException.class) public void testConstConstruction_03() { - ColGroupFactory.genColGroupConst(new int[] {}, 0); + ColGroupConst.create(new int[] {}, 0); } @Test(expected = DMLCompressionException.class) public void testConstConstruction_05() { - ColGroupFactory.genColGroupConst(new int[] {0, 1, 2}, new double[] {1, 2}); + ColGroupConst.create(new int[] {0, 1, 2}, new double[] {1, 2}); } @Test(expected = DMLCompressionException.class) public void testConstConstruction_06() { - ColGroupFactory.genColGroupConst(new int[] {0, 1}, new double[] {1, 2, 4}); + ColGroupConst.create(new int[] {0, 1}, new double[] {1, 2, 4}); } @Test(expected = DMLCompressionException.class) public void testConstConstruction_07() { - ColGroupFactory.genColGroupConst(2, new Dictionary(new double[] {1, 2, 4})); + ColGroupConst.create(2, new Dictionary(new double[] {1, 2, 4})); } @Test(expected = DMLCompressionException.class) public void testConstConstruction_08() { - ColGroupFactory.genColGroupConst(4, new Dictionary(new double[] {1, 2, 4})); + ColGroupConst.create(4, new Dictionary(new double[] {1, 2, 4})); } @Test public void testConstConstruction_allowed_01() { - ColGroupFactory.genColGroupConst(new int[] {0, 1, 2, 3}, 0); + ColGroupConst.create(new int[] {0, 1, 2, 3}, 0); } @Test public void testConstConstruction_allowed_02() { - ColGroupFactory.genColGroupConst(3, new Dictionary(new double[] {1, 2, 4})); + ColGroupConst.create(3, new Dictionary(new double[] {1, 2, 4})); } @Test public void testConstConstruction_allowed_03() { - ColGroupFactory.genColGroupConst(new double[] {1, 2, 4}); + ColGroupConst.create(new double[] {1, 2, 4}); } @Test(expected = NullPointerException.class) public void testConstConstruction_null_01() { - ColGroupFactory.genColGroupConst(null, 0); + ColGroupConst.create(null, 0); } @Test(expected = NullPointerException.class) public void testConstConstruction_null_02() { - ColGroupFactory.genColGroupConst(null, null); + ColGroupConst.create(null, null); } @Test(expected = NullPointerException.class) public void testConstConstruction_null_03() { - ColGroupFactory.genColGroupConst(0, null); + ColGroupConst.create(0, null); } @Test(expected = NullPointerException.class) public void testConstConstruction_null_04() { - ColGroupFactory.genColGroupConst(null); + ColGroupConst.create(null); } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java index ef4fdfed15f..6f31e987352 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java @@ -127,15 +127,24 @@ public void testJoinWithSecondSubpartLeft() { } private void partJoinVerification(IEncode er) { - if(e.getUnique() != er.getUnique() || e.size() != er.size()) { + boolean sameClass = e.getClass() == er.getClass(); + boolean incorrectSize = sameClass && e.size() != er.size(); + + boolean incorrectUnique = e.getUnique() != er.getUnique(); + + if(incorrectUnique || incorrectSize) { StringBuilder sb = new StringBuilder(); - sb.append("\nFailed joining sub parts to recreate whole.\nRead:"); + sb.append("\nFailed joining sub parts to recreate whole."); + sb.append("\nexpected unique:" + e.getUnique() + " got:" + er.getUnique()); + sb.append("\nexpected Size:" + e.size() + " got:" + er.size()); + + sb.append("\n\nRead:"); sb.append(e); sb.append("\nJoined:"); sb.append(er); sb.append("\n"); sb.append(m); - sb.append("\nsubParts:\n"); + sb.append("\n\nsubParts:\n"); sb.append(sh); sb.append("\n"); sb.append(fh); diff --git a/src/test/java/org/apache/sysds/test/component/compress/insertionsort/TestInsertionSorters.java b/src/test/java/org/apache/sysds/test/component/compress/insertionsort/TestInsertionSorters.java index 756aff828d3..746f5ed21f0 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/insertionsort/TestInsertionSorters.java +++ b/src/test/java/org/apache/sysds/test/component/compress/insertionsort/TestInsertionSorters.java @@ -41,6 +41,8 @@ @RunWith(value = Parameterized.class) public class TestInsertionSorters { + private static final int materializeSizeDef = MaterializeSort.CACHE_BLOCK; + public final int[][] data; public final SORT_TYPE st; public final int numRows; @@ -185,6 +187,6 @@ public static void setCacheSize() { @AfterClass public static void setCacheAfter() { - MaterializeSort.CACHE_BLOCK = 1000; + MaterializeSort.CACHE_BLOCK = materializeSizeDef; } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTests.java b/src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTests.java index 7edc48536ce..ee47f18ad44 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTests.java @@ -206,19 +206,15 @@ public void replaceMax() { } } - @Test - public void getCountsWithDefault() { - int nVal = m.getUnique(); - int[] counts = m.getCounts(new int[nVal + 1], size + 10); - if(10 != counts[nVal]) - fail("Incorrect number of unique values:" + m + "\n" + Arrays.toString(counts)); - - } - @Test public void getCountsNoDefault() { int nVal = m.getUnique(); - m.getCounts(new int[nVal], size); + int[] counts = m.getCounts(new int[nVal]); + int sum = 0; + for(int v : counts) + sum += v; + if(sum != size) + fail("Incorrect number of unique values."); } @Test