diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index eceeba38e081..98566f863650 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -25,6 +25,7 @@ import logging import pathlib +from collections.abc import Iterable import numpy as np @@ -199,7 +200,7 @@ def load(self, records, n_lines=None): if it is not None, only load the first `n_lines` lines of log """ joint_records = [] - if not isinstance(records, (list, tuple)): + if not isinstance(records, Iterable) or isinstance(records, str): records = [records] for rec in records: diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index ffff50b9dc0b..6c072dc1fa17 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -31,6 +31,7 @@ from __future__ import absolute_import as _abs import logging +from collections.abc import Iterable import numpy as np @@ -212,7 +213,7 @@ def load(self, records): Collection of tuning records. If is str, then it should be the filename of a records log file. Each row of this file is an encoded record pair. If it is a list - it can either be a list of paths to logs that will loaded jointly or + it can either be a list of paths to logs that will be loaded jointly or an iterator of measurement results. """ # pylint: disable=import-outside-toplevel @@ -220,7 +221,7 @@ def load(self, records): from ..record import load_from_file joint_records = [] - if not isinstance(records, (list, tuple)): + if not isinstance(records, Iterable) or isinstance(records, str): records = [records] for rec in records: diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index c9ce5b59ff09..735486ef27c6 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -62,6 +62,13 @@ def tune_network(network, target): best, auto_scheduler.dispatcher.ApplyHistoryBest ), "Unable to load multiple log files jointly." + # Confirm iterables can be directly loaded. + loaded_recs = auto_scheduler.dispatcher.load_records(log_file) + with auto_scheduler.ApplyHistoryBest(iter(loaded_recs)) as best: + assert isinstance( + best, auto_scheduler.dispatcher.ApplyHistoryBest + ), "Unable to ingest logs from an interator." + # Sample a schedule when missing with auto_scheduler.ApplyHistoryBestOrSample(None, num_measure=2): with tvm.transform.PassContext( diff --git a/tests/python/unittest/test_autotvm_record.py b/tests/python/unittest/test_autotvm_record.py index 2ee75cf18c0e..147122ff10d6 100644 --- a/tests/python/unittest/test_autotvm_record.py +++ b/tests/python/unittest/test_autotvm_record.py @@ -91,6 +91,11 @@ def test_apply_history_best(): x = hist_best.query(target, tsk.workload) assert str(x) == str(tsk.config_space.get(2)) + # Confirm same functionality for iterators. + hist_best = ApplyHistoryBest(iter(records)) + x = hist_best.query(target, tsk.workload) + assert str(x) == str(tsk.config_space.get(2)) + if __name__ == "__main__": test_load_dump()