Skip to content
1 change: 1 addition & 0 deletions python/tvm/autotvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,6 @@
FallbackContext,
ApplyHistoryBest as apply_history_best,
ApplyGraphBest as apply_graph_best,
ApplyFixedConfig as apply_fixed_config,
)
from .env import GLOBAL_SCOPE
1 change: 1 addition & 0 deletions python/tvm/autotvm/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .dispatcher import (
DispatchContext,
ApplyConfig,
ApplyFixedConfig,
ApplyHistoryBest,
FallbackContext,
clear_fallback_cache,
Expand Down
53 changes: 53 additions & 0 deletions python/tvm/autotvm/task/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from __future__ import absolute_import as _abs

import logging
import typing
from typing import Union
from collections.abc import Iterable

import numpy as np
Expand Down Expand Up @@ -179,6 +181,57 @@ def update(self, target, workload, cfg):
self._config = cfg


class ApplyFixedConfig(DispatchContext):
"""Apply a config of a deterministic schedule.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain why this is different from ApplyConfig here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added more details.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the comment no longer makes sense now that we are accepting an array of schedule names, but can fix that in a follow-on.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix that in a follow up PR.

This is used for building a single Relay operator with deterministic schedule
for testing schedules at Relay level.

Parameters
----------
tasks : list[tvm.autotvm.task.task.Task]
List of autoTVM tasks.
schedule_names : str, List[str]
Name of schedules to use.
"""

def __init__(self, tasks, schedule_names: Union[str, typing.List[str]]):
super(ApplyFixedConfig, self).__init__()
if isinstance(schedule_names, str):
self._schedule_names = list(schedule_names)
elif isinstance(schedule_names, list):
self._schedule_names = schedule_names
else:
raise RuntimeError("Incorrect type: " + schedule_names)
self._tasks = tasks
self.workload = None

def _query_inside(self, target, workload):
"""Override query"""
self.workload = workload

# Create a config from correct task
for task in self._tasks:
if task.name == workload[0]:
config = task.config_space.get(0)
break

if not config:
raise RuntimeError(
"workload: %s does not exist in %s" % (str(workload), str(self._tasks))
)
# Add low cost to the target schedule and high cost to others.
if workload[0] in self._schedule_names:
config.cost = 1e-6
else:
config.cost = 100000
return config

def update(self, target, workload, cfg):
"""Override update"""
self.workload = workload
self._config = cfg


class ApplyHistoryBest(DispatchContext):
"""
Apply the history best config
Expand Down
105 changes: 105 additions & 0 deletions python/tvm/micro/testing/aot_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import logging
import itertools
import shutil

import pytest

pytest.importorskip("tvm.micro")

import tvm
from tvm.testing.aot import AOTTestRunner

_LOG = logging.getLogger(__name__)


AOT_DEFAULT_RUNNER = AOTTestRunner()

# AOT Test Runner using the Arm® Corstone™-300 Reference Systems
# see: https://developer.arm.com/ip-products/subsystem/corstone/corstone-300
AOT_CORSTONE300_RUNNER = AOTTestRunner(
makefile="corstone300",
prologue="""
uart_init();
""",
includes=["uart.h"],
pass_config={
"relay.ext.cmsisnn.options": {
"mcpu": "cortex-m55",
}
},
)

AOT_USMP_CORSTONE300_RUNNER = AOTTestRunner(
makefile="corstone300",
prologue="""
uart_init();
""",
includes=["uart.h"],
pass_config={
"relay.ext.cmsisnn.options": {
"mcpu": "cortex-m55",
},
"tir.usmp.enable": True,
},
)


def parametrize_aot_options(test):
"""Parametrize over valid option combinations"""

requires_arm_eabi = pytest.mark.skipif(
shutil.which("arm-none-eabi-gcc") is None, reason="ARM embedded toolchain unavailable"
)

interface_api = ["packed", "c"]
use_unpacked_api = [True, False]
test_runner = [AOT_DEFAULT_RUNNER, AOT_CORSTONE300_RUNNER]

all_combinations = itertools.product(interface_api, use_unpacked_api, test_runner)

# Filter out packed operators with c interface
valid_combinations = filter(
lambda parameters: not (parameters[0] == "c" and not parameters[1]),
all_combinations,
)

# Only use reference system for C interface and unpacked API calls
valid_combinations = filter(
lambda parameters: not (
parameters[2] == AOT_CORSTONE300_RUNNER
and (parameters[0] == "packed" or not parameters[1])
),
valid_combinations,
)

# Skip reference system tests if running in i386 container
marked_combinations = map(
lambda parameters: pytest.param(*parameters, marks=[requires_arm_eabi])
if parameters[2] == AOT_CORSTONE300_RUNNER
else parameters,
valid_combinations,
)

fn = pytest.mark.parametrize(
["interface_api", "use_unpacked_api", "test_runner"],
marked_combinations,
)(test)

return tvm.testing.skip_if_32bit(reason="Reference system unavailable in i386 container")(fn)
File renamed without changes.
Loading