Skip to content
Merged
29 changes: 20 additions & 9 deletions python/tvm/auto_scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,13 @@ class ApplyHistoryBest(DispatchContext):

Parameters
----------
records : str or iterator of (auto_scheduler.measure.MeasureInput,\
auto_scheduler.measure.MeasureResult)
records : str, list of str, or iterator of (auto_scheduler.measure.MeasureInput,\
auto_scheduler.measure.MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair. Otherwise, it is an iterator.
Each row of this file is an encoded record pair. If it is an iterator,
it can either be a set of str filenames which will be applied jointly,
or a set of (input, result) tuples.
n_lines: Optional[int]
if it is not None, only load the first `n_lines` lines of log.
include_compatible: bool
Expand Down Expand Up @@ -196,20 +198,29 @@ def load(self, records, n_lines=None):
n_lines: Optional[int]
if it is not None, only load the first `n_lines` lines of log
"""
if isinstance(records, pathlib.Path):
records = str(records)
joint_records = []
if not isinstance(records, (list, tuple)):
records = [records]

if isinstance(records, str):
records = load_records(records)
for rec in records:
if isinstance(rec, pathlib.Path):
rec = str(rec)

if isinstance(rec, str):
rec = load_records(rec)
joint_records += rec
else:
if rec is not None:
joint_records.append(rec)

if not records:
if not joint_records:
return

best_by_targetkey = self.best_by_targetkey
best_by_model = self.best_by_model

counter = 0
for inp, res in records:
for inp, res in joint_records:
if n_lines is not None and counter >= n_lines:
break
counter += 1
Expand Down
35 changes: 25 additions & 10 deletions python/tvm/autotvm/task/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,12 @@ class ApplyHistoryBest(DispatchContext):

Parameters
----------
records : str or iterator of (autotvm.measure.MeasureInput, autotvm.measure.MeasureResult)
records : str, list of str, or iterator of (autotvm.measure.MeasureInput,\
autotvm.measure.MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair. Otherwise, it is an iterator.
Each row of this file is an encoded record pair. If it is a list, it can either be
a list of paths to log files that will be loaded jointly or an iterator or records.
"""

def __init__(self, records):
Expand All @@ -205,28 +207,41 @@ def load(self, records):

Parameters
----------
records : str or iterator of (autotvm.measure.MeasureInput, autotvm.measure.MeasureResult)
records : str, list of str, or iterator of (autotvm.measure.MeasureInput,\
autotvm.measure.MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair. Otherwise, it is an iterator.
Each row of this file is an encoded record pair. If it is a list
it can either be a list of paths to logs that will loaded jointly or
an iterator of measurement results.
"""
# pylint: disable=import-outside-toplevel
from pathlib import Path
from ..record import load_from_file

if isinstance(records, Path):
records = str(records)
joint_records = []
if not isinstance(records, (list, tuple)):
records = [records]

if isinstance(records, str):
records = load_from_file(records)
if not records:
for rec in records:
if isinstance(rec, Path):
rec = str(rec)

if isinstance(rec, str):
rec = load_from_file(rec)
joint_records += rec
else:
if rec is not None:
joint_records.append(rec)

if not joint_records:
return

best_by_targetkey = self.best_by_targetkey
best_by_model = self.best_by_model

counter = 0
for inp, res in records:
for inp, res in joint_records:
counter += 1
if res.error_no != 0:
continue
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .transform import *
from .recast import recast
from . import fake_quantization_to_integer, mixed_precision
from .flexible_shape import FlexibleShapeDispatch
Loading