diff --git a/scripts/builtin/splitBalanced.dml b/scripts/builtin/splitBalanced.dml index bb1d86bce87..3caad9e491a 100644 --- a/scripts/builtin/splitBalanced.dml +++ b/scripts/builtin/splitBalanced.dml @@ -43,9 +43,9 @@ return (Matrix[Double] X_train, Matrix[Double] y_train, Matrix[Double] X_test, Matrix[Double] y_test) { + classes = table(Y, 1) XY = order(target = cbind(Y, X), by = 1, decreasing=FALSE, index.return=FALSE) # get the class count - classes = table(XY[, 1], 1) split = floor(nrow(X) * splitRatio) start_class = 1 train_row_s = 1 @@ -70,13 +70,14 @@ return (Matrix[Double] X_train, Matrix[Double] y_train, Matrix[Double] X_test, { end_class = end_class + as.scalar(classes[i]) class_t = XY[start_class:end_class, ] + ratio = as.scalar(classes_ratio_train[i]) - train_row_e = train_row_e + as.scalar(classes_ratio_train[i]) + train_row_e = train_row_e + ratio test_row_e = test_row_e + as.scalar(classes_ratio_test[i]) - outTrain[train_row_s:train_row_e, ] = class_t[1:as.scalar(classes_ratio_train[i]), ] + outTrain[train_row_s:train_row_e, ] = class_t[1:ratio, ] - outTest[test_row_s:test_row_e, ] = class_t[as.scalar(classes_ratio_train[i])+1:nrow(class_t), ] + outTest[test_row_s:test_row_e, ] = class_t[ratio+1:nrow(class_t), ] train_row_s = train_row_e + 1 test_row_s = test_row_e + 1