From 97079b070d947c38032bedaf637cb7b15d2ff49e Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 30 Jul 2018 10:05:18 -0700 Subject: [PATCH] [tvm4j] support kNDArrayContainer --- jvm/core/src/main/java/ml/dmlc/tvm/Function.java | 7 +++++-- jvm/core/src/main/java/ml/dmlc/tvm/TypeCode.java | 2 +- jvm/native/src/main/native/jni_helper_func.h | 10 ++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/Function.java b/jvm/core/src/main/java/ml/dmlc/tvm/Function.java index 5b2008a757ed..2e21f439300e 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/Function.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/Function.java @@ -187,7 +187,8 @@ public Function pushArg(String arg) { * @return this */ public Function pushArg(NDArrayBase arg) { - Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.ARRAY_HANDLE.id); + int id = arg.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id; + Base._LIB.tvmFuncPushArgHandle(arg.handle, id); return this; } @@ -247,7 +248,9 @@ private static void pushArgToStack(Object arg) { } else if (arg instanceof byte[]) { Base._LIB.tvmFuncPushArgBytes((byte[]) arg); } else if (arg instanceof NDArrayBase) { - Base._LIB.tvmFuncPushArgHandle(((NDArrayBase) arg).handle, TypeCode.ARRAY_HANDLE.id); + NDArrayBase nd = (NDArrayBase) arg; + int id = nd.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id; + Base._LIB.tvmFuncPushArgHandle(nd.handle, id); } else if (arg instanceof Module) { Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id); } else if (arg instanceof Function) { diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/TypeCode.java b/jvm/core/src/main/java/ml/dmlc/tvm/TypeCode.java index 0b28746f9555..1f01fde6d307 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/TypeCode.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/TypeCode.java @@ -21,7 +21,7 @@ public enum TypeCode { INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5), TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9), - FUNC_HANDLE(10), STR(11), BYTES(12); + FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13); public final int id; diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index d4435bdaaba8..181d9de040f1 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -134,10 +134,10 @@ jobject newFunction(JNIEnv *env, jlong value) { return object; } -jobject newNDArray(JNIEnv *env, jlong value) { +jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) { jclass cls = env->FindClass("ml/dmlc/tvm/NDArrayBase"); - jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); - jobject object = env->NewObject(cls, constructor, value); + jmethodID constructor = env->GetMethodID(cls, "", "(JZ)V"); + jobject object = env->NewObject(cls, constructor, handle, isview); env->DeleteLocalRef(cls); return object; } @@ -181,7 +181,9 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { case kFuncHandle: return newFunction(env, reinterpret_cast(value.v_handle)); case kArrayHandle: - return newNDArray(env, reinterpret_cast(value.v_handle)); + return newNDArray(env, reinterpret_cast(value.v_handle), true); + case kNDArrayContainer: + return newNDArray(env, reinterpret_cast(value.v_handle), false); case kStr: return newTVMValueString(env, value.v_str); case kBytes: