From 6bc1bea77a89236fe6172999dab00d2258622951 Mon Sep 17 00:00:00 2001 From: Shafaq Siddiqi Date: Mon, 7 Feb 2022 15:56:47 +0100 Subject: [PATCH] [MINOR] Refactoring input and output parameters of dbscanApply and dbscan - This commit apply the consistency between the input and output parameters of dbscanApply and dbscan respectively --- scripts/builtin/dbscan.dml | 11 ++++++++--- scripts/builtin/dbscanApply.dml | 17 +++++++++-------- .../builtin/part1/BuiltinDbscanApplyTest.java | 2 +- src/test/scripts/functions/builtin/dbscan.dml | 2 +- .../scripts/functions/builtin/dbscanApply.dml | 4 ++-- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/scripts/builtin/dbscan.dml b/scripts/builtin/dbscan.dml index ecb414d088e..9e1d5101f02 100644 --- a/scripts/builtin/dbscan.dml +++ b/scripts/builtin/dbscan.dml @@ -39,11 +39,15 @@ # ---------------------------------------------------------------------------------------------------------------------- m_dbscan = function (Matrix[Double] X, Double eps = 0.5, Integer minPts = 5) - return (Matrix[Double] clusterMembers, Matrix[Double] clusterModel) + return (Matrix[Double] X, Matrix[Double] clusterModel, Double eps) { #check input parameter assertions if(minPts < 0) { stop("DBSCAN: Stopping due to invalid inputs: minPts should be greater than 0"); } - if(eps < 0) { stop("DBSCAN: Stopping due to invalid inputs: Epsilon (eps) should be greater than 0"); } + if(eps < 0) + { + print("DBSCAN: Epsilon (eps) should be greater than 0. Setting eps = 0.5"); + eps = 0.5 + } UNASSIGNED = 0; @@ -76,6 +80,7 @@ m_dbscan = function (Matrix[Double] X, Double eps = 0.5, Integer minPts = 5) clusterMembers = components(G=adjacency, verbose=FALSE); # noise to 0 clusterMembers = clusterMembers * (rowSums(adjacency) > 0); - clusterModel = removeEmpty(target=X, margin="rows", select = (clusterMembers > 0)) + clusterModel = removeEmpty(target=X, margin="rows", select = (clusterMembers > 0)) + X = clusterMembers } } diff --git a/scripts/builtin/dbscanApply.dml b/scripts/builtin/dbscanApply.dml index 08ec109e6fd..e55ee8bd99c 100644 --- a/scripts/builtin/dbscanApply.dml +++ b/scripts/builtin/dbscanApply.dml @@ -25,7 +25,7 @@ # ---------------------------------------------------------------------------- # NAME TYPE DEFAULT MEANING # ---------------------------------------------------------------------------- -# Xtest Matrix[Double] --- The input Matrix to do outlier detection on. +# X Matrix[Double] --- The input Matrix to do outlier detection on. # clusterModel Matrix[Double] --- Model of clusters to predict outliers against. # eps Double 0.5 Maximum distance between two points for one to be considered reachable for the other. @@ -36,11 +36,11 @@ # outlierPoints Matrix[Double] --- Predicted outliers -m_dbscanApply = function (Matrix[Double] Xtest, Matrix[Double] clusterModel, Double eps = 0.5) - return (Matrix[double] outlierPoints) +m_dbscanApply = function (Matrix[Double] X, Matrix[Double] clusterModel, Double eps) + return (Matrix[Double] cluster, Matrix[Double] outlierPoints) { - num_features_Xtest = ncol(Xtest); - num_rows_Xtest = nrow(Xtest); + num_features_Xtest = ncol(X); + num_rows_Xtest = nrow(X); num_features_model = ncol(clusterModel); num_rows_model = nrow(clusterModel); @@ -48,11 +48,12 @@ m_dbscanApply = function (Matrix[Double] Xtest, Matrix[Double] clusterModel, Dou if(eps < 0) { stop("DBSCAN Outlier: Stopping due to invalid inputs: Epsilon (eps) should be greater than 0"); } if(num_rows_model <= 0) { stop("DBSCAN Outlier: Stopping due to invalid inputs: Model is empty"); } - X = rbind(clusterModel, Xtest); - neighbors = dist(X); + Xall = rbind(clusterModel, X); + neighbors = dist(Xall); neighbors = replace(target = neighbors, pattern = 0, replacement = 2.225e-307); neighbors = neighbors - diag(diag(neighbors)); - Xtest_dists = neighbors[(num_rows_model+1):nrow(X), 1:num_rows_model]; + Xtest_dists = neighbors[(num_rows_model+1):nrow(Xall), 1:num_rows_model]; withinEps = ((Xtest_dists <= eps) * (0 < Xtest_dists)); outlierPoints = rowSums(withinEps) >= 1; + cluster = removeEmpty(target=outlierPoints, margin="rows", select=outlierPoints) } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDbscanApplyTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDbscanApplyTest.java index 4ec22718798..efd3e28c7fb 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDbscanApplyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDbscanApplyTest.java @@ -86,7 +86,7 @@ private void runOutlierByDBSCAN(boolean defaultProb, int seedA, int seedB, doubl String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + TEST_NAME + ".dml"; - programArgs = new String[]{"-explain","-nvargs", + programArgs = new String[]{"-nvargs", "X=" + input("A"), "Y=" + input("B"),"Z=" + output("C"), "eps=" + epsDB, "minPts=" + minPts}; fullRScriptName = HOME + TEST_NAME + ".R"; rCmd = getRCmd(inputDir(), inputDir(), Double.toString(epsDB), Integer.toString(minPts), expectedDir()); diff --git a/src/test/scripts/functions/builtin/dbscan.dml b/src/test/scripts/functions/builtin/dbscan.dml index bde15f58357..c28eb59a554 100644 --- a/src/test/scripts/functions/builtin/dbscan.dml +++ b/src/test/scripts/functions/builtin/dbscan.dml @@ -22,5 +22,5 @@ X = read($X); eps = as.double($eps); minPts = as.integer($minPts); -[Y, model] = dbscan(X, eps, minPts); +[Y, model, eps] = dbscan(X, eps, minPts); write(Y, $Y); \ No newline at end of file diff --git a/src/test/scripts/functions/builtin/dbscanApply.dml b/src/test/scripts/functions/builtin/dbscanApply.dml index a7849028d7f..07a1f0d6b16 100644 --- a/src/test/scripts/functions/builtin/dbscanApply.dml +++ b/src/test/scripts/functions/builtin/dbscanApply.dml @@ -24,6 +24,6 @@ Y = read($Y) eps = as.double($eps); minPts = as.integer($minPts); -[indices, clusterModel] = dbscan(X = X, eps = eps, minPts = minPts); -Z = dbscanApply(Xtest=Y, clusterModel = clusterModel, eps = eps); +[indices, clusterModel, eps] = dbscan(X = X, eps = eps, minPts = minPts); +[C, Z] = dbscanApply(X=Y, clusterModel = clusterModel, eps = eps); write(Z, $Z);