diff --git a/apps/android_camera/models/prepare_model.py b/apps/android_camera/models/prepare_model.py index 9f2cbbdd6d1f..5fd99967aea3 100644 --- a/apps/android_camera/models/prepare_model.py +++ b/apps/android_camera/models/prepare_model.py @@ -15,18 +15,16 @@ # specific language governing permissions and limitations # under the License. -import logging -import pathlib -from pathlib import Path -from typing import Union +import json import os from os import environ -import json +from pathlib import Path +from typing import Union import tvm import tvm.relay as relay -from tvm.contrib import utils, ndk, graph_executor as runtime -from tvm.contrib.download import download_testdata, download +from tvm.contrib import ndk +from tvm.contrib.download import download, download_testdata target = "llvm -mtriple=arm64-linux-android" target_host = None @@ -50,15 +48,18 @@ def del_dir(target: Union[Path, str], only_if_empty: bool = False): def get_model(model_name, batch_size=1): if model_name == "resnet18_v1": - import mxnet as mx - from mxnet import gluon - from mxnet.gluon.model_zoo import vision + import torch + import torchvision - gluon_model = vision.get_model(model_name, pretrained=True) - img_size = 224 - data_shape = (batch_size, 3, img_size, img_size) - net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) - return (net, params) + weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1 + torch_model = torchvision.models.resnet18(weights=weights).eval() + input_shape = [1, 3, 224, 224] + input_data = torch.randn(input_shape) + scripted_model = torch.jit.trace(torch_model, input_data) + + input_infos = [("data", input_data.shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, input_infos) + return (mod, params) elif model_name == "mobilenet_v2": import keras from keras.applications.mobilenet_v2 import MobileNetV2 diff --git a/apps/android_camera/models/requirements.txt b/apps/android_camera/models/requirements.txt index dbf496b2d968..3e35efdeb66e 100644 --- a/apps/android_camera/models/requirements.txt +++ b/apps/android_camera/models/requirements.txt @@ -1,4 +1,5 @@ keras==2.9 -mxnet scipy tensorflow==2.9.3 +torch +torchvision