Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions apps/android_camera/models/prepare_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,16 @@
# specific language governing permissions and limitations
# under the License.

import logging
import pathlib
from pathlib import Path
from typing import Union
import json
import os
from os import environ
import json
from pathlib import Path
from typing import Union

import tvm
import tvm.relay as relay
from tvm.contrib import utils, ndk, graph_executor as runtime
from tvm.contrib.download import download_testdata, download
from tvm.contrib import ndk
from tvm.contrib.download import download, download_testdata

target = "llvm -mtriple=arm64-linux-android"
target_host = None
Expand All @@ -50,15 +48,18 @@ def del_dir(target: Union[Path, str], only_if_empty: bool = False):

def get_model(model_name, batch_size=1):
if model_name == "resnet18_v1":
import mxnet as mx
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
import torch
import torchvision

gluon_model = vision.get_model(model_name, pretrained=True)
img_size = 224
data_shape = (batch_size, 3, img_size, img_size)
net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
return (net, params)
weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
torch_model = torchvision.models.resnet18(weights=weights).eval()
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(torch_model, input_data)

input_infos = [("data", input_data.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, input_infos)
return (mod, params)
elif model_name == "mobilenet_v2":
import keras
from keras.applications.mobilenet_v2 import MobileNetV2
Expand Down
3 changes: 2 additions & 1 deletion apps/android_camera/models/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
keras==2.9
mxnet
scipy
tensorflow==2.9.3
torch
torchvision