From dc71ca535b469781727e600ac02eaa14f5b27cf4 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 14 Nov 2022 22:30:07 -0800 Subject: [PATCH 01/13] Unify output function. --- .../testing/validate_database.py | 170 ++++++++++++------ 1 file changed, 117 insertions(+), 53 deletions(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 5e48bfb6b04e..7c5ad3a01432 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -15,22 +15,25 @@ # specific language governing permissions and limitations # under the License. """JSON Database validation script""" -from typing import Union, Callable, List -from distutils.util import strtobool import argparse import logging import warnings +from distutils.util import strtobool +from typing import Callable, List, Tuple, Union + import numpy as np # type: ignore import tvm -from tvm.target import Target -from tvm.ir import IRModule -from tvm.tir import Schedule from tvm import meta_schedule as ms +from tvm._ffi import get_global_func, register_func +from tvm.ir import IRModule from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.tune_utils import create_calculator, generate_input_data -from tvm._ffi import get_global_func, register_func from tvm.support import describe +from tvm.target import Target +from tvm.tir import Schedule +from tvm.tir.schedule import Trace +from tvm.tir.tensor_intrin import cuda, x86 # type: ignore # pylint: disable=unused-import DELIMITOR = "\n" + "-" * 30 + "\n" @@ -133,6 +136,64 @@ def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bo return True +def is_failed_record(record: ms.database.TuningRecord) -> bool: + """Check if a tuning record is failed.""" + return len(record.run_secs) == 1 and record.run_secs[0] == 1e9 + + +def print_validation_result( + idx: int, + total: int, + result: str, + time: float, + *, + original_mod: IRModule = None, + scheduled_mod: IRModule = None, + inputs: List[np.ndarray] = None, + original_res: List[np.ndarray] = None, + scheduled_res: List[np.ndarray] = None, + exception: Exception = None, + trace: Trace = None, +) -> None: + """Print the validation result.""" + output = [ + f"Progress {idx: 6d} / {total: 6d} checked, used {float(time): 3.3f} sec. Result: {result}" + ] + if result not in ["pass", "skip"]: + output.extend( + [ + "Original IRModule:" + DELIMITOR + original_mod.script(), + "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(), + "Trace" + DELIMITOR + str(trace), + "Input:" + DELIMITOR + str(inputs), + ] + ) + if result == "wrong answer": + output.extend( + [ + "Original Result:" + DELIMITOR + str(original_res), + "Scheduled Result:" + DELIMITOR + str(scheduled_res), + "Max Diff:" + + DELIMITOR + + str( + [ + np.max(np.abs(original_res[i] - scheduled_res[i])) + for i in range(len(original_res)) + ] + ), + ] + ) + elif result == "exception": + output.extend( + [ + "Exception:" + DELIMITOR + str(exception), + ] + ) + else: + raise ValueError(f"Unknown result: {result}") + print("\n\n".join(output)) + + def validate_correctness( original_mod: IRModule, # compiled for "baseline_target" scheduled_mod: IRModule, # compiled for "target" @@ -147,7 +208,7 @@ def validate_correctness( f_check_metric: Union[ str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool] ] = default_check_metric, -) -> bool: +) -> Tuple[bool, List[np.ndarray], List[np.ndarray], List[np.ndarray]]: """Function to validate the correctness of a scheduled module. Parameters @@ -185,7 +246,7 @@ def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]: assert a is not None, "Empty result cannot be converted to TVM NDArray" return [tvm.nd.array(x) for x in a] - def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: + def build_and_run(mod: IRModule, target: Target, dev_type: str) -> List[tvm.nd.NDArray]: """Build and run the module on the target device.""" rt_mod = tvm.build(mod, target=target) return run_module_via_rpc( @@ -197,32 +258,26 @@ def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray: backend="tir", ) - # fetch functions & prepare inputs + # fetch input function & prepare inputs if isinstance(f_input_generator, str): f_input_generator = get_global_func(f_input_generator) - if isinstance(f_check_metric, str): - f_check_metric = get_global_func(f_check_metric) inputs = to_numpy(f_input_generator(original_mod)) # type: ignore + # build & run original result original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu")) scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type)) + + # fetch comparison function + if isinstance(f_check_metric, str): + f_check_metric = get_global_func(f_check_metric) + # check metric - if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)): # type: ignore - return True - else: - print( - ("\n\n").join( - [ - "Validation failed!", - "Original Result:" + DELIMITOR + str(original_res), - "Scheduled Result:" + DELIMITOR + str(scheduled_res), - "Input:" + DELIMITOR + str(inputs), - "Original IRModule:" + DELIMITOR + original_mod.script(), - "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(), - ] - ) - ) - return False + return ( + f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)), # type: ignore + inputs, + original_res, + scheduled_res, + ) def main(): @@ -240,14 +295,18 @@ def main(): with ms.Profiler() as profiler: for i, record in enumerate(records): scope_name = f"validate #{i}" - with profiler.timeit(scope_name): - original_mod = record.workload.mod - sch = Schedule(original_mod) - record.trace.apply_to_schedule(sch=sch, remove_postproc=False) - scheduled_mod = sch.mod - is_success = False - try: - is_success = validate_correctness( + if is_failed_record(record): + print_validation_result(i + 1, total=len(records), result="skip", time=0.0) + continue + try: + with profiler.timeit(scope_name): + original_mod = record.workload.mod + sch = Schedule(original_mod) + record.trace.apply_to_schedule(sch=sch, remove_postproc=False) + scheduled_mod = sch.mod + passed = False + + passed, inputs, original_res, scheduled_res = validate_correctness( original_mod=original_mod, scheduled_mod=scheduled_mod, target=target, @@ -255,26 +314,31 @@ def main(): dev_type=dev_type, rpc_config=ARGS.rpc_config, ) - except Exception as e: # pylint: disable=broad-except, invalid-name - print( - ("\n\n").join( - [ - "Validation failed!", - "Original IRModule:" + DELIMITOR + original_mod.script(), - "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(), - "Exception" + DELIMITOR + str(e), - ] - ) - ) - if is_success: - print( - f"Progress {i+1: 6d} / {len(records): 6d} checked," - f" used {float(profiler.get()[scope_name]): 3.3f} sec." + print_validation_result( + i + 1, + total=len(records), + result="pass" if passed else "wrong answer", + time=profiler.get()[scope_name], + original_mod=original_mod, + scheduled_mod=scheduled_mod, + trace=record.trace, + inputs=inputs, + original_res=original_res, + scheduled_res=scheduled_res, + ) + except Exception as e: # pylint: disable=broad-except, invalid-name + print_validation_result( + i + 1, + total=len(records), + result="exception", + time=profiler.get()[scope_name], + original_mod=original_mod, + scheduled_mod=scheduled_mod, + trace=record.trace, + exception=e, ) - else: - return - print("Validation passed!") + print("Validation finished!") print(f"Total time spent: {float(profiler.get()['Total']): 3.3f} sec.") From b4b1ba1ce7779fbde0deb1c2f1de49ebb9c9927d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 14 Nov 2022 23:29:39 -0800 Subject: [PATCH 02/13] Support reuse input & local mod results. --- .../testing/validate_database.py | 255 ++++++++++++------ 1 file changed, 173 insertions(+), 82 deletions(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 7c5ad3a01432..25c851787045 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -19,7 +19,7 @@ import logging import warnings from distutils.util import strtobool -from typing import Callable, List, Tuple, Union +from typing import Callable, Tuple, Union, List, Any import numpy as np # type: ignore @@ -88,6 +88,12 @@ def _parse_args(): type=int, default=100, ) + args.add_argument( + "--top-k", + type=int, + default=10**9, + help="The number of top-k tuning records to validate for each unique original workload.", + ) args.add_argument( "--cpu-flush", type=lambda x: bool(strtobool(x)), @@ -136,13 +142,25 @@ def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bo return True +def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]: + """Convert a list of TVM NDArray to a list of numpy array""" + assert a is not None, "Empty result cannot be converted to numpy" + return [x.numpy() for x in a] + + +def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]: + """Convert a list of numpy array to a list of TVM NDArray""" + assert a is not None, "Empty result cannot be converted to TVM NDArray" + return [tvm.nd.array(x) for x in a] + + def is_failed_record(record: ms.database.TuningRecord) -> bool: """Check if a tuning record is failed.""" return len(record.run_secs) == 1 and record.run_secs[0] == 1e9 -def print_validation_result( - idx: int, +def print_result( + counter: int, total: int, result: str, time: float, @@ -157,7 +175,8 @@ def print_validation_result( ) -> None: """Print the validation result.""" output = [ - f"Progress {idx: 6d} / {total: 6d} checked, used {float(time): 3.3f} sec. Result: {result}" + f"Progress {counter: 6d} / {total: 6d} checked, " + f"used {float(time): 3.3f} sec. Result: {result}" ] if result not in ["pass", "skip"]: output.extend( @@ -194,6 +213,34 @@ def print_validation_result( print("\n\n".join(output)) +def check_and_run(func: Union[str, Callable], *args, **kwargs) -> Any: + """Check if the function is a string or a callable, and run it.""" + if isinstance(func, str): + func = get_global_func(func) + return func(*args, **kwargs) # type: ignore + + +def build_and_run( + mod: IRModule, + target: Target, + rpc_config: ms.runner.RPCConfig, + dev_type: str, + inputs: List[np.ndarray], +) -> List[np.ndarray]: + """Build and run the module on the target device.""" + rt_mod = tvm.build(mod, target=target) + return to_numpy( + run_module_via_rpc( + rpc_config=rpc_config, + lib=rt_mod, + dev_type=dev_type, + args={i: v for i, v in enumerate(inputs)}, # pylint: disable=unnecessary-comprehension + continuation=create_calculator(backend="tir"), + backend="tir", + ) + ) + + def validate_correctness( original_mod: IRModule, # compiled for "baseline_target" scheduled_mod: IRModule, # compiled for "target" @@ -202,6 +249,8 @@ def validate_correctness( target: Target, dev_type: str, rpc_config: ms.runner.RPCConfig, + inputs: List[np.ndarray] = None, # for input reuse + original_res: List[np.ndarray] = None, # for original mod results reuse f_input_generator: Union[ str, Callable[[IRModule], List[tvm.nd.NDArray]] ] = default_input_generator, @@ -225,6 +274,10 @@ def validate_correctness( The device type to run the module via rpc. rpc_config : RPCConfig The RPCConfig to run the scheduled module. + inputs : List[np.ndarray] + The input data to be reused, if None, generate new inputs. + original_res : List[np.ndarray] + The original module results to be reused, if None, run the original module. f_input_generator : Union[str, Callable] The function to generate the input data. f_check_metric : Union[str, Callable] @@ -232,48 +285,39 @@ def validate_correctness( Returns ------- - result : bool - The result of the validation. + passed: bool + Whether the validation passed. + inputs: List[np.ndarray] + The input data used for validation in numpy array. + original_res: List[np.ndarray] + The original module results in numpy array. + scheduled_res: List[np.ndarray] + The scheduled module results in numpy array. """ - def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]: - """Convert a list of TVM NDArray to a list of numpy array""" - assert a is not None, "Empty result cannot be converted to numpy" - return [x.numpy() for x in a] - - def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]: - """Convert a list of numpy array to a list of TVM NDArray""" - assert a is not None, "Empty result cannot be converted to TVM NDArray" - return [tvm.nd.array(x) for x in a] - - def build_and_run(mod: IRModule, target: Target, dev_type: str) -> List[tvm.nd.NDArray]: - """Build and run the module on the target device.""" - rt_mod = tvm.build(mod, target=target) - return run_module_via_rpc( - rpc_config=rpc_config, - lib=rt_mod, - dev_type=dev_type, - args={i: v for i, v in enumerate(inputs)}, # pylint: disable=unnecessary-comprehension - continuation=create_calculator(backend="tir"), - backend="tir", - ) - # fetch input function & prepare inputs - if isinstance(f_input_generator, str): - f_input_generator = get_global_func(f_input_generator) - inputs = to_numpy(f_input_generator(original_mod)) # type: ignore + if inputs is None: + inputs = to_numpy(check_and_run(f_input_generator, original_mod)) # build & run original result - original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu")) - scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type)) + if original_res is None: + original_res = build_and_run( + original_mod, + target=baseline_target, + rpc_config=rpc_config, + dev_type="cpu", + inputs=inputs, + ) + scheduled_res = build_and_run( + scheduled_mod, target=target, rpc_config=rpc_config, dev_type=dev_type, inputs=inputs + ) # fetch comparison function - if isinstance(f_check_metric, str): - f_check_metric = get_global_func(f_check_metric) + validation_res = check_and_run(f_check_metric, original_res, scheduled_res) # check metric return ( - f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)), # type: ignore + validation_res, inputs, original_res, scheduled_res, @@ -283,63 +327,110 @@ def build_and_run(mod: IRModule, target: Target, dev_type: str) -> List[tvm.nd.N def main(): """Main function""" describe() - database = ms.database.create(work_dir=ARGS.work_dir) target = ARGS.target + database = ms.database.create(work_dir=ARGS.work_dir) + + # determine target kind if target.kind.name == "llvm": dev_type = "cpu" elif target.kind.name == "cuda": dev_type = "cuda" else: raise RuntimeError(f"Unsupported target kind: {target.kind.name}") - records = database.get_all_tuning_records() + + # start profiling with ms.Profiler() as profiler: - for i, record in enumerate(records): - scope_name = f"validate #{i}" - if is_failed_record(record): - print_validation_result(i + 1, total=len(records), result="skip", time=0.0) - continue - try: - with profiler.timeit(scope_name): - original_mod = record.workload.mod - sch = Schedule(original_mod) - record.trace.apply_to_schedule(sch=sch, remove_postproc=False) - scheduled_mod = sch.mod - passed = False - - passed, inputs, original_res, scheduled_res = validate_correctness( + # collect records + with profiler.timeit("collect records"): + records = database.get_all_tuning_records() + print( + f"Total {len(records)} records to be validated. " + f"Collected in {float(profiler.get()['collect records']): 3.3f} sec." + ) + + # collect unique original TIR + with profiler.timeit("deduplicate records"): + workloads = dict() + for record in records: + mod = record.workload.mod + s_hash = tvm.ir.structural_hash(mod) + if s_hash not in workloads: + workloads[s_hash] = [mod] + else: + duplicate = False + for previous_mod in workloads[hash]: + if tvm.ir.structural_equal(mod, previous_mod): + duplicate = True + break + if not duplicate: + workloads[hash].append(mod) + # put the workload into a list + unique_workloads = [] + for _, mods in workloads.items(): + unique_workloads.extend(mods) + print( + f"Total {len(unique_workloads)} unique original TIR to be validated. " + f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec." + ) + + # validate correctness + counter = 0 + for original_mod in unique_workloads: + records = database.get_top_k(workload=database.commit_workload(mod), top_k=10**9) + inputs = None + original_res = None + for record in records: + counter += 1 + scope_name = f"validate #{counter}" + if is_failed_record(record): + # skip failed records where run_secs is 1e9 + # these records are only negative samples for cost model + print_result(counter + 1, total=len(records), result="skip", time=0.0) + continue + try: + with profiler.timeit(scope_name): + # prepare scheduled module + sch = Schedule(original_mod) + record.trace.apply_to_schedule(sch=sch, remove_postproc=False) + scheduled_mod = sch.mod + # validate correctness + passed, inputs, original_res, scheduled_res = validate_correctness( + original_mod=original_mod, + scheduled_mod=scheduled_mod, + target=target, + baseline_target=ARGS.baseline_target, + dev_type=dev_type, + rpc_config=ARGS.rpc_config, + inputs=inputs, + original_res=original_res, + ) + # validation finished + print_result( + counter + 1, + total=len(records), + result="pass" if passed else "wrong answer", + time=profiler.get()[scope_name], + original_mod=original_mod, + scheduled_mod=scheduled_mod, + trace=record.trace, + inputs=inputs, + original_res=original_res, + scheduled_res=scheduled_res, + ) + except Exception as e: # pylint: disable=broad-except, invalid-name + # validation failed with exception + print_result( + counter + 1, + total=len(records), + result="exception", + time=profiler.get()[scope_name], original_mod=original_mod, scheduled_mod=scheduled_mod, - target=target, - baseline_target=ARGS.baseline_target, - dev_type=dev_type, - rpc_config=ARGS.rpc_config, + trace=record.trace, + exception=e, ) - print_validation_result( - i + 1, - total=len(records), - result="pass" if passed else "wrong answer", - time=profiler.get()[scope_name], - original_mod=original_mod, - scheduled_mod=scheduled_mod, - trace=record.trace, - inputs=inputs, - original_res=original_res, - scheduled_res=scheduled_res, - ) - except Exception as e: # pylint: disable=broad-except, invalid-name - print_validation_result( - i + 1, - total=len(records), - result="exception", - time=profiler.get()[scope_name], - original_mod=original_mod, - scheduled_mod=scheduled_mod, - trace=record.trace, - exception=e, - ) - - print("Validation finished!") - print(f"Total time spent: {float(profiler.get()['Total']): 3.3f} sec.") + print("Validation finished!") + print(f"Total time spent: {float(profiler.get()['Total']): 3.3f} sec.") if __name__ == "__main__": From 0b5014e8a6f951976f6ed7280f33eee571458a19 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 15 Nov 2022 00:02:28 -0800 Subject: [PATCH 03/13] Fix issues. --- .../testing/validate_database.py | 75 ++++++++++--------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 25c851787045..e1014597a38f 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -123,6 +123,19 @@ def _parse_args(): ARGS = _parse_args() +class OriginalModule: + """Original module class""" + + def __init__(self, mod: IRModule): + self.mod = mod + + def __eq__(self, __o: "OriginalModule") -> bool: # type: ignore + return tvm.ir.structural_equal(self.mod, __o.mod) + + def __hash__(self) -> int: + return tvm.ir.structural_hash(self.mod) + + @register_func("tvm.meta_schedule.testing.default_input_generator") def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]: args_info = ms.arg_info.TensorInfo.from_prim_func(mod["main"]) @@ -313,7 +326,9 @@ def validate_correctness( ) # fetch comparison function - validation_res = check_and_run(f_check_metric, original_res, scheduled_res) + validation_res = check_and_run( + f_check_metric, to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res) + ) # check metric return ( @@ -343,40 +358,32 @@ def main(): # collect records with profiler.timeit("collect records"): records = database.get_all_tuning_records() - print( - f"Total {len(records)} records to be validated. " - f"Collected in {float(profiler.get()['collect records']): 3.3f} sec." - ) + total = len(records) + print( + f"Total {total} records to be validated. " + f"Collected in {float(profiler.get()['collect records']): 3.3f} sec." + ) # collect unique original TIR with profiler.timeit("deduplicate records"): - workloads = dict() + workloads = set() for record in records: - mod = record.workload.mod - s_hash = tvm.ir.structural_hash(mod) - if s_hash not in workloads: - workloads[s_hash] = [mod] - else: - duplicate = False - for previous_mod in workloads[hash]: - if tvm.ir.structural_equal(mod, previous_mod): - duplicate = True - break - if not duplicate: - workloads[hash].append(mod) - # put the workload into a list - unique_workloads = [] - for _, mods in workloads.items(): - unique_workloads.extend(mods) - print( - f"Total {len(unique_workloads)} unique original TIR to be validated. " - f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec." - ) + workloads.add(OriginalModule(record.workload.mod)) + print( + f"Total {len(workloads)} unique original TIR to validate. " + f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec." + ) + if ARGS.top_k < 10**9: + print(f"Top {ARGS.top_k} records for each original TIR will be validated.") + total = len(workloads) * ARGS.top_k # validate correctness counter = 0 - for original_mod in unique_workloads: - records = database.get_top_k(workload=database.commit_workload(mod), top_k=10**9) + for item in workloads: + original_mod = item.mod + records = database.get_top_k( + workload=database.commit_workload(original_mod), top_k=ARGS.top_k + ) inputs = None original_res = None for record in records: @@ -406,8 +413,8 @@ def main(): ) # validation finished print_result( - counter + 1, - total=len(records), + counter, + total=total, result="pass" if passed else "wrong answer", time=profiler.get()[scope_name], original_mod=original_mod, @@ -420,8 +427,8 @@ def main(): except Exception as e: # pylint: disable=broad-except, invalid-name # validation failed with exception print_result( - counter + 1, - total=len(records), + counter, + total=total, result="exception", time=profiler.get()[scope_name], original_mod=original_mod, @@ -429,8 +436,8 @@ def main(): trace=record.trace, exception=e, ) - print("Validation finished!") - print(f"Total time spent: {float(profiler.get()['Total']): 3.3f} sec.") + print("Validation finished!") + print(f"Total time spent: {float(profiler.get()['Total']): 3.3f} sec.") if __name__ == "__main__": From af3523879cf2c5581dc5c9c80ce26a2b075a102e Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 15 Nov 2022 09:43:14 -0800 Subject: [PATCH 04/13] Remove exception inputs in results. --- python/tvm/meta_schedule/testing/validate_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index e1014597a38f..6a7eb3f6c11d 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -197,12 +197,12 @@ def print_result( "Original IRModule:" + DELIMITOR + original_mod.script(), "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(), "Trace" + DELIMITOR + str(trace), - "Input:" + DELIMITOR + str(inputs), ] ) if result == "wrong answer": output.extend( [ + "Input:" + DELIMITOR + str(inputs), "Original Result:" + DELIMITOR + str(original_res), "Scheduled Result:" + DELIMITOR + str(scheduled_res), "Max Diff:" From 81b57612ddb4cc5686532471fe1ebbf2970c8d0e Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 15 Nov 2022 17:05:24 -0800 Subject: [PATCH 05/13] Check point. --- .../testing/validate_database.py | 220 +++++++++++++++--- 1 file changed, 182 insertions(+), 38 deletions(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 6a7eb3f6c11d..3ee43a2d6124 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -18,6 +18,8 @@ import argparse import logging import warnings +import itertools +from statistics import mean from distutils.util import strtobool from typing import Callable, Tuple, Union, List, Any @@ -27,8 +29,7 @@ from tvm import meta_schedule as ms from tvm._ffi import get_global_func, register_func from tvm.ir import IRModule -from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc -from tvm.meta_schedule.testing.tune_utils import create_calculator, generate_input_data +from tvm.meta_schedule.testing.tune_utils import generate_input_data from tvm.support import describe from tvm.target import Target from tvm.tir import Schedule @@ -58,6 +59,13 @@ def _parse_args(): required=False, help="The baseline target to compile the original module.", ) + args.add_argument( + "--top-k", + type=int, + default=10**9, + required=False, + help="The number of top-k tuning records to validate for each unique original workload.", + ) args.add_argument( "--rpc-host", type=str, @@ -88,12 +96,6 @@ def _parse_args(): type=int, default=100, ) - args.add_argument( - "--top-k", - type=int, - default=10**9, - help="The number of top-k tuning records to validate for each unique original workload.", - ) args.add_argument( "--cpu-flush", type=lambda x: bool(strtobool(x)), @@ -113,14 +115,21 @@ def _parse_args(): return parsed +# arg parser +ARGS = _parse_args() + # logging logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) -# arg parser -ARGS = _parse_args() + +def check_and_run(func: Union[str, Callable], *args, **kwargs) -> Any: + """Check if the function is a string or a callable, and run it.""" + if isinstance(func, str): + func = get_global_func(func) + return func(*args, **kwargs) # type: ignore class OriginalModule: @@ -183,14 +192,24 @@ def print_result( inputs: List[np.ndarray] = None, original_res: List[np.ndarray] = None, scheduled_res: List[np.ndarray] = None, + original_run_secs: List[float] = None, + scheduled_run_secs: List[float] = None, exception: Exception = None, trace: Trace = None, ) -> None: """Print the validation result.""" - output = [ + status = ( f"Progress {counter: 6d} / {total: 6d} checked, " f"used {float(time): 3.3f} sec. Result: {result}" - ] + ) + + if result in ["pass", "wrong answer"]: + status += ( + f"original: {mean(original_run_secs): 3.3f} sec, " + f"scheduled: {mean(scheduled_run_secs): 3.3f} sec" + ) + + output = [status] if result not in ["pass", "skip"]: output.extend( [ @@ -216,21 +235,84 @@ def print_result( ] ) elif result == "exception": - output.extend( - [ - "Exception:" + DELIMITOR + str(exception), - ] - ) + output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"]) else: raise ValueError(f"Unknown result: {result}") print("\n\n".join(output)) -def check_and_run(func: Union[str, Callable], *args, **kwargs) -> Any: - """Check if the function is a string or a callable, and run it.""" - if isinstance(func, str): - func = get_global_func(func) - return func(*args, **kwargs) # type: ignore +def make_alloc_arg_and_check( + args: List[np.ndarray], results: List[List[np.ndarray]] +) -> Tuple[Callable, Callable]: + """Make alloc_arg and check functions for the given inputs and collect results.""" + + def f_with_args_alloc_argument( + # pylint: disable=unused-argument + session: tvm.rpc.RPCSession, + device: tvm.runtime.Device, + args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST, + alloc_repeat: int, + # pylint: enable=unused-argument + ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]: + return [[tvm.nd.array(arg, device=device) for arg in args] for _ in range(alloc_repeat)] + + def run_evaluator_with_args( + rt_mod: tvm.runtime.Module, + device: tvm.runtime.Device, + evaluator_config: ms.runner.EvaluatorConfig, + repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST], + ) -> List[float]: + """With args function to run the evaluator + + Parameters + ---------- + rt_mod: Module + The runtime module + device: Device + The device to run the evaluator + evaluator_config: EvaluatorConfig + The evaluator config + repeated_args: List[T_ARGUMENT_LIST] + The repeated arguments + + Returns + ------- + costs: List[float] + The evaluator results + """ + evaluator = rt_mod.time_evaluator( + func_name=rt_mod.entry_name, + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + + results.append([[arg.numpy() for arg in args] for args in repeated_args]) # type: ignore + repeated_costs: List[List[float]] = [] + for args in repeated_args: + device.sync() + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + return costs + + def f_with_args_run_evaluator( + session: tvm.rpc.RPCSession, # pylint: disable=unused-argument + rt_mod: tvm.runtime.Module, + device: tvm.runtime.Device, + evaluator_config: ms.runner.EvaluatorConfig, + repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST], + ) -> List[float]: + # run remote module + # pull remote args back using `arg.numpy() for arg in remote_args` + # check the results + return run_evaluator_with_args(rt_mod, device, evaluator_config, repeated_args) + + return f_with_args_alloc_argument, f_with_args_run_evaluator def build_and_run( @@ -239,19 +321,49 @@ def build_and_run( rpc_config: ms.runner.RPCConfig, dev_type: str, inputs: List[np.ndarray], -) -> List[np.ndarray]: + builder: ms.builder.Builder, +) -> Tuple[List[np.ndarray], List[float]]: """Build and run the module on the target device.""" - rt_mod = tvm.build(mod, target=target) - return to_numpy( - run_module_via_rpc( - rpc_config=rpc_config, - lib=rt_mod, - dev_type=dev_type, - args={i: v for i, v in enumerate(inputs)}, # pylint: disable=unnecessary-comprehension - continuation=create_calculator(backend="tir"), - backend="tir", - ) + builder_results = builder.build([ms.builder.BuilderInput(mod, target)]) + assert ( + len(builder_results) == 1 + ), f"Unexpected number of build results, expected 1 got {len(builder_results)}" + (builder_result,) = builder_results # pylint: disable=unbalanced-tuple-unpacking + assert builder_result.error_msg is None, "Builder failed: " + str( + builder_result.error_msg if builder_result.error_msg else "Empty error message" + ) + + results: List[List[np.ndarray]] = [] + + f_with_args_alloc_argument, f_with_args_run_evaluator = make_alloc_arg_and_check( + inputs, results + ) + runner = ms.runner.RPCRunner( + rpc_config=rpc_config, + evaluator_config=ms.runner.EvaluatorConfig( + number=ARGS.number, + repeat=ARGS.repeat, + min_repeat_ms=ARGS.min_repeat_ms, + enable_cpu_cache_flush=ARGS.cpu_flush, + ), + alloc_repeat=1, + f_alloc_argument=f_with_args_alloc_argument, + f_run_evaluator=f_with_args_run_evaluator, + ) + runner_futures = runner.run( + # arginfo is not used in this case so we can pass an empty list + [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])] + ) + assert ( + len(runner_futures) == 1 + ), f"Unexpected number of runner futures, expected 1 got {len(runner_futures)}" + (runner_future,) = runner_futures # pylint: disable=unbalanced-tuple-unpacking + runner_res = runner_future.result() + assert runner_res.error_msg is None, "Runner failed: " + ( + runner_res.error_msg if runner_res.error_msg else "Empty error message" ) + assert len(results) == 1, f"Unexpected number of repeat results, expected 1 got {len(results)}" + return results[0], runner_res.run_secs def validate_correctness( @@ -262,15 +374,17 @@ def validate_correctness( target: Target, dev_type: str, rpc_config: ms.runner.RPCConfig, + builder: ms.builder.Builder, inputs: List[np.ndarray] = None, # for input reuse original_res: List[np.ndarray] = None, # for original mod results reuse + original_run_secs: List[float] = None, # for original mod run secs reuse f_input_generator: Union[ str, Callable[[IRModule], List[tvm.nd.NDArray]] ] = default_input_generator, f_check_metric: Union[ str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool] ] = default_check_metric, -) -> Tuple[bool, List[np.ndarray], List[np.ndarray], List[np.ndarray]]: +) -> Tuple[bool, List[np.ndarray], List[np.ndarray], List[np.ndarray], List[float], List[float]]: """Function to validate the correctness of a scheduled module. Parameters @@ -287,10 +401,14 @@ def validate_correctness( The device type to run the module via rpc. rpc_config : RPCConfig The RPCConfig to run the scheduled module. + builder : Builder + The builder to build the original and scheduled modules. inputs : List[np.ndarray] The input data to be reused, if None, generate new inputs. original_res : List[np.ndarray] The original module results to be reused, if None, run the original module. + original_run_secs : List[float] + The original module run secs to be reused, if None, run the original module. f_input_generator : Union[str, Callable] The function to generate the input data. f_check_metric : Union[str, Callable] @@ -306,6 +424,10 @@ def validate_correctness( The original module results in numpy array. scheduled_res: List[np.ndarray] The scheduled module results in numpy array. + original_run_secs: List[float] + The running time of the original module via rpc runner. + scheduled_run_secs: List[float] + The running time of the scheduled module via rpc runner. """ # fetch input function & prepare inputs @@ -314,15 +436,21 @@ def validate_correctness( # build & run original result if original_res is None: - original_res = build_and_run( + original_res, original_run_secs = build_and_run( original_mod, + builder=builder, target=baseline_target, rpc_config=rpc_config, dev_type="cpu", inputs=inputs, ) - scheduled_res = build_and_run( - scheduled_mod, target=target, rpc_config=rpc_config, dev_type=dev_type, inputs=inputs + scheduled_res, scheduled_run_secs = build_and_run( + scheduled_mod, + builder=builder, + target=target, + rpc_config=rpc_config, + dev_type=dev_type, + inputs=inputs, ) # fetch comparison function @@ -336,6 +464,8 @@ def validate_correctness( inputs, original_res, scheduled_res, + original_run_secs, + scheduled_run_secs, ) @@ -343,6 +473,7 @@ def main(): """Main function""" describe() target = ARGS.target + builder = ms.builder.LocalBuilder() database = ms.database.create(work_dir=ARGS.work_dir) # determine target kind @@ -386,6 +517,7 @@ def main(): ) inputs = None original_res = None + original_run_secs = None for record in records: counter += 1 scope_name = f"validate #{counter}" @@ -401,15 +533,24 @@ def main(): record.trace.apply_to_schedule(sch=sch, remove_postproc=False) scheduled_mod = sch.mod # validate correctness - passed, inputs, original_res, scheduled_res = validate_correctness( + ( + passed, + inputs, + original_res, + scheduled_res, + original_run_secs, + scheduled_run_secs, + ) = validate_correctness( original_mod=original_mod, scheduled_mod=scheduled_mod, target=target, baseline_target=ARGS.baseline_target, dev_type=dev_type, rpc_config=ARGS.rpc_config, + builder=builder, # type: ignore inputs=inputs, original_res=original_res, + original_run_secs=original_run_secs, ) # validation finished print_result( @@ -423,8 +564,11 @@ def main(): inputs=inputs, original_res=original_res, scheduled_res=scheduled_res, + original_run_secs=original_run_secs, + scheduled_run_secs=scheduled_run_secs, ) except Exception as e: # pylint: disable=broad-except, invalid-name + raise e # todo remove this line # validation failed with exception print_result( counter, From 5dcec8a914fe96cd8feecc2f95fab0d7d435e760 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 18 Nov 2022 11:09:47 -0800 Subject: [PATCH 06/13] Check point. --- .../testing/validate_database.py | 427 ++++++++---------- 1 file changed, 183 insertions(+), 244 deletions(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 3ee43a2d6124..bf970ac4b25c 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -22,19 +22,20 @@ from statistics import mean from distutils.util import strtobool from typing import Callable, Tuple, Union, List, Any - import numpy as np # type: ignore import tvm from tvm import meta_schedule as ms from tvm._ffi import get_global_func, register_func from tvm.ir import IRModule -from tvm.meta_schedule.testing.tune_utils import generate_input_data from tvm.support import describe from tvm.target import Target from tvm.tir import Schedule from tvm.tir.schedule import Trace -from tvm.tir.tensor_intrin import cuda, x86 # type: ignore # pylint: disable=unused-import +from tvm.meta_schedule.testing.tune_utils import generate_input_data + +# todo add tensor intrinsics +# from tvm.tir.tensor_intrin import cuda, x86 # type: ignore # pylint: disable=unused-import DELIMITOR = "\n" + "-" * 30 + "\n" @@ -102,6 +103,16 @@ def _parse_args(): help="example: True / False", required=True, ) + args.add_argument( + "--input-generator-func", + type=str, + default="tvm.meta_schedule.testing.default_input_generator", + ) + args.add_argument( + "--check-metric-func", + type=str, + default="tvm.meta_schedule.testing.default_check_metric", + ) parsed = args.parse_args() parsed.target = tvm.target.Target(parsed.target) parsed.rpc_config = ms.runner.RPCConfig( @@ -125,6 +136,25 @@ def _parse_args(): logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +def get_device_type(target: Target) -> str: + """Get the device type string from a target.""" + if target.kind.name == "llvm": + return "cpu" + elif target.kind.name == "cuda": + return "cuda" + else: + raise RuntimeError(f"Unsupported target kind for device type: {target.kind.name}") + + +def get_runtime_device(target: Target) -> tvm.runtime.Device: + if target.kind.name == "llvm": + return tvm.cpu() + elif target.kind.name == "cuda": + return tvm.cuda() + else: + raise RuntimeError(f"Unsupported target kind for runtime device: {target.kind.name}") + + def check_and_run(func: Union[str, Callable], *args, **kwargs) -> Any: """Check if the function is a string or a callable, and run it.""" if isinstance(func, str): @@ -181,68 +211,73 @@ def is_failed_record(record: ms.database.TuningRecord) -> bool: return len(record.run_secs) == 1 and record.run_secs[0] == 1e9 -def print_result( - counter: int, - total: int, - result: str, - time: float, - *, - original_mod: IRModule = None, - scheduled_mod: IRModule = None, - inputs: List[np.ndarray] = None, - original_res: List[np.ndarray] = None, - scheduled_res: List[np.ndarray] = None, - original_run_secs: List[float] = None, - scheduled_run_secs: List[float] = None, - exception: Exception = None, - trace: Trace = None, -) -> None: - """Print the validation result.""" - status = ( - f"Progress {counter: 6d} / {total: 6d} checked, " - f"used {float(time): 3.3f} sec. Result: {result}" - ) - - if result in ["pass", "wrong answer"]: - status += ( - f"original: {mean(original_run_secs): 3.3f} sec, " - f"scheduled: {mean(scheduled_run_secs): 3.3f} sec" - ) +def print_with_counter_func(counter: int, total: int) -> Callable: + """Print with counter""" + + def print_result( + result: str, + *, + original_mod: IRModule = None, + scheduled_mod: IRModule = None, + inputs: List[np.ndarray] = None, + original_res: List[np.ndarray] = None, + scheduled_res: List[np.ndarray] = None, + original_run_secs: List[float] = None, + scheduled_run_secs: List[float] = None, + exception: Exception = None, + trace: Trace = None, + ) -> None: + """Print the validation result.""" + status = f"Progress {counter: 6d} / {total: 6d} checked, result: {result}" + + if result in ["pass", "wrong answer"]: + status += ( + f"original: {mean(original_run_secs): 3.3f} sec, " + f"scheduled: {mean(scheduled_run_secs): 3.3f} sec" + ) - output = [status] - if result not in ["pass", "skip"]: - output.extend( - [ - "Original IRModule:" + DELIMITOR + original_mod.script(), - "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(), - "Trace" + DELIMITOR + str(trace), - ] - ) - if result == "wrong answer": + output = [status] + if result not in ["pass", "skip"]: output.extend( [ - "Input:" + DELIMITOR + str(inputs), - "Original Result:" + DELIMITOR + str(original_res), - "Scheduled Result:" + DELIMITOR + str(scheduled_res), - "Max Diff:" - + DELIMITOR - + str( - [ - np.max(np.abs(original_res[i] - scheduled_res[i])) - for i in range(len(original_res)) - ] - ), + "Original IRModule:" + DELIMITOR + original_mod.script(), + "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(), + "Trace" + DELIMITOR + str(trace), ] ) - elif result == "exception": - output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"]) - else: - raise ValueError(f"Unknown result: {result}") - print("\n\n".join(output)) + if result == "wrong answer": + output.extend( + [ + "Input:" + DELIMITOR + str(inputs), + "Original Result:" + DELIMITOR + str(original_res), + "Scheduled Result:" + DELIMITOR + str(scheduled_res), + "Max Diff:" + + DELIMITOR + + str( + [ + np.max(np.abs(original_res[i] - scheduled_res[i])) + for i in range(len(original_res)) + ] + ), + ] + ) + elif result == "exception": + output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"]) + else: + raise ValueError(f"Unknown result: {result}") + print("\n\n".join(output)) + + return print_result def make_alloc_arg_and_check( - args: List[np.ndarray], results: List[List[np.ndarray]] + inputs: List[np.ndarray], + original_mod: IRModule, + scheduled_mod: IRModule, + trace: Trace, + original_res: List[np.ndarray], + original_run_secs: List[float], + print_result: Callable, ) -> Tuple[Callable, Callable]: """Make alloc_arg and check functions for the given inputs and collect results.""" @@ -254,7 +289,7 @@ def f_with_args_alloc_argument( alloc_repeat: int, # pylint: enable=unused-argument ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]: - return [[tvm.nd.array(arg, device=device) for arg in args] for _ in range(alloc_repeat)] + return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)] def run_evaluator_with_args( rt_mod: tvm.runtime.Module, @@ -291,13 +326,34 @@ def run_evaluator_with_args( else "", ) - results.append([[arg.numpy() for arg in args] for args in repeated_args]) # type: ignore repeated_costs: List[List[float]] = [] for args in repeated_args: device.sync() profile_result = evaluator(*args) repeated_costs.append(profile_result.results) costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + + assert len(repeated_args) == 1, "Only support one set of arguments" + scheduled_res = [arg.numpy() for arg in repeated_args[0]] # type: ignore + # fetch comparison function + passed = check_and_run( + ARGS.check_metric_func, + to_tvm_ndarray(original_res), + to_tvm_ndarray(scheduled_res), + ) + + print_result( + result="pass" if passed else "wrong answer", + original_mod=original_mod, + scheduled_mod=scheduled_mod, + trace=trace, + inputs=inputs, + original_res=original_res, + scheduled_res=scheduled_res, + original_run_secs=original_run_secs, + scheduled_run_secs=mean(costs), + ) + return costs def f_with_args_run_evaluator( @@ -315,15 +371,26 @@ def f_with_args_run_evaluator( return f_with_args_alloc_argument, f_with_args_run_evaluator -def build_and_run( +def local_build_and_run( mod: IRModule, target: Target, - rpc_config: ms.runner.RPCConfig, - dev_type: str, + device: tvm.runtime.Device, inputs: List[np.ndarray], - builder: ms.builder.Builder, ) -> Tuple[List[np.ndarray], List[float]]: - """Build and run the module on the target device.""" + """Build and run the module locally.""" + # potential memory leak https://github.com/apache/tvm/issues/11096 + lib = tvm.build(mod, target=target) + tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs] + device.sync() + func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat) + benchmark_res = func(*tvm_inputs) + device.sync() + return [arg.numpy() for arg in tvm_inputs], benchmark_res + + +def _build_single_mod( + mod: IRModule, builder: ms.builder.Builder, target: Target +) -> ms.builder.BuilderResult: builder_results = builder.build([ms.builder.BuilderInput(mod, target)]) assert ( len(builder_results) == 1 @@ -332,24 +399,14 @@ def build_and_run( assert builder_result.error_msg is None, "Builder failed: " + str( builder_result.error_msg if builder_result.error_msg else "Empty error message" ) + return builder_results[0] - results: List[List[np.ndarray]] = [] - f_with_args_alloc_argument, f_with_args_run_evaluator = make_alloc_arg_and_check( - inputs, results - ) - runner = ms.runner.RPCRunner( - rpc_config=rpc_config, - evaluator_config=ms.runner.EvaluatorConfig( - number=ARGS.number, - repeat=ARGS.repeat, - min_repeat_ms=ARGS.min_repeat_ms, - enable_cpu_cache_flush=ARGS.cpu_flush, - ), - alloc_repeat=1, - f_alloc_argument=f_with_args_alloc_argument, - f_run_evaluator=f_with_args_run_evaluator, - ) +def _run_single_mod( + builder_result: ms.builder.BuilderResult, + runner: ms.runner.Runner, + dev_type: str, +) -> None: runner_futures = runner.run( # arginfo is not used in this case so we can pass an empty list [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])] @@ -362,130 +419,18 @@ def build_and_run( assert runner_res.error_msg is None, "Runner failed: " + ( runner_res.error_msg if runner_res.error_msg else "Empty error message" ) - assert len(results) == 1, f"Unexpected number of repeat results, expected 1 got {len(results)}" - return results[0], runner_res.run_secs - - -def validate_correctness( - original_mod: IRModule, # compiled for "baseline_target" - scheduled_mod: IRModule, # compiled for "target" - *, - baseline_target: Target, - target: Target, - dev_type: str, - rpc_config: ms.runner.RPCConfig, - builder: ms.builder.Builder, - inputs: List[np.ndarray] = None, # for input reuse - original_res: List[np.ndarray] = None, # for original mod results reuse - original_run_secs: List[float] = None, # for original mod run secs reuse - f_input_generator: Union[ - str, Callable[[IRModule], List[tvm.nd.NDArray]] - ] = default_input_generator, - f_check_metric: Union[ - str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool] - ] = default_check_metric, -) -> Tuple[bool, List[np.ndarray], List[np.ndarray], List[np.ndarray], List[float], List[float]]: - """Function to validate the correctness of a scheduled module. - - Parameters - ---------- - original_mod : IRModule - The original module to be compiled. - scheduled_mod : IRModule - The scheduled module to be compiled. - baseline_target : Target - The baseline target to compile the original module. - target : Target - The target to compile the scheduled module. - dev_type : str - The device type to run the module via rpc. - rpc_config : RPCConfig - The RPCConfig to run the scheduled module. - builder : Builder - The builder to build the original and scheduled modules. - inputs : List[np.ndarray] - The input data to be reused, if None, generate new inputs. - original_res : List[np.ndarray] - The original module results to be reused, if None, run the original module. - original_run_secs : List[float] - The original module run secs to be reused, if None, run the original module. - f_input_generator : Union[str, Callable] - The function to generate the input data. - f_check_metric : Union[str, Callable] - The function to check the metric. - - Returns - ------- - passed: bool - Whether the validation passed. - inputs: List[np.ndarray] - The input data used for validation in numpy array. - original_res: List[np.ndarray] - The original module results in numpy array. - scheduled_res: List[np.ndarray] - The scheduled module results in numpy array. - original_run_secs: List[float] - The running time of the original module via rpc runner. - scheduled_run_secs: List[float] - The running time of the scheduled module via rpc runner. - """ - - # fetch input function & prepare inputs - if inputs is None: - inputs = to_numpy(check_and_run(f_input_generator, original_mod)) - - # build & run original result - if original_res is None: - original_res, original_run_secs = build_and_run( - original_mod, - builder=builder, - target=baseline_target, - rpc_config=rpc_config, - dev_type="cpu", - inputs=inputs, - ) - scheduled_res, scheduled_run_secs = build_and_run( - scheduled_mod, - builder=builder, - target=target, - rpc_config=rpc_config, - dev_type=dev_type, - inputs=inputs, - ) - - # fetch comparison function - validation_res = check_and_run( - f_check_metric, to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res) - ) - - # check metric - return ( - validation_res, - inputs, - original_res, - scheduled_res, - original_run_secs, - scheduled_run_secs, - ) def main(): """Main function""" describe() - target = ARGS.target - builder = ms.builder.LocalBuilder() - database = ms.database.create(work_dir=ARGS.work_dir) - - # determine target kind - if target.kind.name == "llvm": - dev_type = "cpu" - elif target.kind.name == "cuda": - dev_type = "cuda" - else: - raise RuntimeError(f"Unsupported target kind: {target.kind.name}") - - # start profiling with ms.Profiler() as profiler: + # initialize + target = ARGS.target + dev_type = get_device_type(target) + builder = ms.builder.LocalBuilder() + database = ms.database.create(work_dir=ARGS.work_dir) + # collect records with profiler.timeit("collect records"): records = database.get_all_tuning_records() @@ -515,66 +460,60 @@ def main(): records = database.get_top_k( workload=database.commit_workload(original_mod), top_k=ARGS.top_k ) - inputs = None - original_res = None - original_run_secs = None + inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod)) + original_res, original_run_secs = local_build_and_run( + original_mod, + target=ARGS.baseline_target, + inputs=inputs, + device=get_runtime_device(ARGS.baseline_target), + ) for record in records: counter += 1 - scope_name = f"validate #{counter}" + print_result = print_with_counter_func(counter=counter, total=len(records)) if is_failed_record(record): # skip failed records where run_secs is 1e9 # these records are only negative samples for cost model - print_result(counter + 1, total=len(records), result="skip", time=0.0) + print_result(result="skip") continue try: - with profiler.timeit(scope_name): - # prepare scheduled module - sch = Schedule(original_mod) - record.trace.apply_to_schedule(sch=sch, remove_postproc=False) - scheduled_mod = sch.mod - # validate correctness - ( - passed, - inputs, - original_res, - scheduled_res, - original_run_secs, - scheduled_run_secs, - ) = validate_correctness( - original_mod=original_mod, - scheduled_mod=scheduled_mod, - target=target, - baseline_target=ARGS.baseline_target, - dev_type=dev_type, - rpc_config=ARGS.rpc_config, - builder=builder, # type: ignore - inputs=inputs, - original_res=original_res, - original_run_secs=original_run_secs, - ) - # validation finished - print_result( - counter, - total=total, - result="pass" if passed else "wrong answer", - time=profiler.get()[scope_name], - original_mod=original_mod, - scheduled_mod=scheduled_mod, - trace=record.trace, - inputs=inputs, + # prepare scheduled module + sch = Schedule(original_mod) + record.trace.apply_to_schedule(sch=sch, remove_postproc=False) + scheduled_mod = sch.mod + # build the scheduled module locally + builder_result = _build_single_mod( + scheduled_mod, builder, target # type: ignore + ) + ( + f_with_args_alloc_argument, + f_with_args_run_evaluator, + ) = make_alloc_arg_and_check( + inputs, + original_mod, + scheduled_mod, + record.trace, original_res=original_res, - scheduled_res=scheduled_res, original_run_secs=original_run_secs, - scheduled_run_secs=scheduled_run_secs, + print_result=print_result, + ) + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, + evaluator_config=ms.runner.EvaluatorConfig( + number=ARGS.number, + repeat=ARGS.repeat, + min_repeat_ms=ARGS.min_repeat_ms, + enable_cpu_cache_flush=ARGS.cpu_flush, + ), + alloc_repeat=1, + f_alloc_argument=f_with_args_alloc_argument, + f_run_evaluator=f_with_args_run_evaluator, ) + _run_single_mod(builder_result, runner, dev_type) # type: ignore except Exception as e: # pylint: disable=broad-except, invalid-name - raise e # todo remove this line + raise e # validation failed with exception print_result( - counter, - total=total, result="exception", - time=profiler.get()[scope_name], original_mod=original_mod, scheduled_mod=scheduled_mod, trace=record.trace, From 0ad19f59369427a0a979295c0b2ab41a8749bdfc Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 21 Nov 2022 15:00:07 -0800 Subject: [PATCH 07/13] Check point. --- .../testing/validate_database.py | 64 +++++++++++-------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index bf970ac4b25c..a691e635b97a 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -175,23 +175,28 @@ def __hash__(self) -> int: return tvm.ir.structural_hash(self.mod) -@register_func("tvm.meta_schedule.testing.default_input_generator") -def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]: - args_info = ms.arg_info.TensorInfo.from_prim_func(mod["main"]) - inputs = [ - tvm.nd.array(generate_input_data(input_shape=arg_info.shape, input_dtype=arg_info.dtype)) - for arg_info in args_info - ] - return inputs - +def initializer(): + """Initializer function to register the functions""" + + @register_func("tvm.meta_schedule.testing.default_input_generator") + def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]: + args_info = ms.arg_info.TensorInfo.from_prim_func(mod["main"]) + inputs = [ + tvm.nd.array( + generate_input_data(input_shape=arg_info.shape, input_dtype=arg_info.dtype) + ) + for arg_info in args_info + ] + return inputs -@register_func("tvm.meta_schedule.testing.default_check_metric") -def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool: - assert len(a) == len(b), "Different number of outputs from two modules" - for i, _ in enumerate(a): - if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3): - return False - return True + @register_func("tvm.meta_schedule.testing.default_check_metric") + def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool: + raise Exception("Not implemented") + assert len(a) == len(b), "Different number of outputs from two modules" + for i, _ in enumerate(a): + if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3): + return False + return True def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]: @@ -225,15 +230,15 @@ def print_result( original_run_secs: List[float] = None, scheduled_run_secs: List[float] = None, exception: Exception = None, - trace: Trace = None, + trace: str = None, ) -> None: """Print the validation result.""" - status = f"Progress {counter: 6d} / {total: 6d} checked, result: {result}" + status = f"Progress {counter: 6d} / {total: 6d} checked, result: {result}, " if result in ["pass", "wrong answer"]: status += ( - f"original: {mean(original_run_secs): 3.3f} sec, " - f"scheduled: {mean(scheduled_run_secs): 3.3f} sec" + f"original: {mean(original_run_secs) * 1e3: 10.3f} ms, " + f"scheduled: {mean(scheduled_run_secs) * 1e3: 10.3f} ms" ) output = [status] @@ -258,7 +263,8 @@ def print_result( np.max(np.abs(original_res[i] - scheduled_res[i])) for i in range(len(original_res)) ] - ), + ) + + "\n", ] ) elif result == "exception": @@ -274,7 +280,7 @@ def make_alloc_arg_and_check( inputs: List[np.ndarray], original_mod: IRModule, scheduled_mod: IRModule, - trace: Trace, + trace: str, original_res: List[np.ndarray], original_run_secs: List[float], print_result: Callable, @@ -351,7 +357,7 @@ def run_evaluator_with_args( original_res=original_res, scheduled_res=scheduled_res, original_run_secs=original_run_secs, - scheduled_run_secs=mean(costs), + scheduled_run_secs=costs, ) return costs @@ -385,7 +391,7 @@ def local_build_and_run( func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat) benchmark_res = func(*tvm_inputs) device.sync() - return [arg.numpy() for arg in tvm_inputs], benchmark_res + return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results) def _build_single_mod( @@ -424,6 +430,7 @@ def _run_single_mod( def main(): """Main function""" describe() + initializer() with ms.Profiler() as profiler: # initialize target = ARGS.target @@ -452,6 +459,7 @@ def main(): if ARGS.top_k < 10**9: print(f"Top {ARGS.top_k} records for each original TIR will be validated.") total = len(workloads) * ARGS.top_k + print() # validate correctness counter = 0 @@ -469,7 +477,7 @@ def main(): ) for record in records: counter += 1 - print_result = print_with_counter_func(counter=counter, total=len(records)) + print_result = print_with_counter_func(counter=counter, total=total) if is_failed_record(record): # skip failed records where run_secs is 1e9 # these records are only negative samples for cost model @@ -491,7 +499,7 @@ def main(): inputs, original_mod, scheduled_mod, - record.trace, + str(record.trace), original_res=original_res, original_run_secs=original_run_secs, print_result=print_result, @@ -507,16 +515,16 @@ def main(): alloc_repeat=1, f_alloc_argument=f_with_args_alloc_argument, f_run_evaluator=f_with_args_run_evaluator, + initializer=initializer, ) _run_single_mod(builder_result, runner, dev_type) # type: ignore except Exception as e: # pylint: disable=broad-except, invalid-name - raise e # validation failed with exception print_result( result="exception", original_mod=original_mod, scheduled_mod=scheduled_mod, - trace=record.trace, + trace=str(record.trace), exception=e, ) print("Validation finished!") From e3814d75f5ded9a6df5c2357b6aae2418c9ec30a Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 21 Nov 2022 15:53:47 -0800 Subject: [PATCH 08/13] Resolve issues. --- .../testing/validate_database.py | 316 +++++++++++++++--- 1 file changed, 266 insertions(+), 50 deletions(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index a691e635b97a..5d05e9edd748 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -32,10 +32,9 @@ from tvm.target import Target from tvm.tir import Schedule from tvm.tir.schedule import Trace +from tvm.meta_schedule.utils import remove_build_dir from tvm.meta_schedule.testing.tune_utils import generate_input_data - -# todo add tensor intrinsics -# from tvm.tir.tensor_intrin import cuda, x86 # type: ignore # pylint: disable=unused-import +from tvm.tir.tensor_intrin import * # type: ignore # pylint: disable=unused-import DELIMITOR = "\n" + "-" * 30 + "\n" @@ -134,10 +133,22 @@ def _parse_args(): format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +logging.getLogger("tvm.meta_schedule.runner.rpc_runner").setLevel(logging.WARN) def get_device_type(target: Target) -> str: - """Get the device type string from a target.""" + """Get the device type string from a target. + + Parameters + ---------- + target : Target + The target to get the device type from. + + Returns + ------- + device_type : str + The device type string. + """ if target.kind.name == "llvm": return "cpu" elif target.kind.name == "cuda": @@ -147,6 +158,18 @@ def get_device_type(target: Target) -> str: def get_runtime_device(target: Target) -> tvm.runtime.Device: + """Get the runtime device from a target. + + Parameters + ---------- + target : Target + The target to get the runtime device from. + + Returns + ------- + device : tvm.runtime.Device + The runtime device. + """ if target.kind.name == "llvm": return tvm.cpu() elif target.kind.name == "cuda": @@ -163,7 +186,7 @@ def check_and_run(func: Union[str, Callable], *args, **kwargs) -> Any: class OriginalModule: - """Original module class""" + """Original module class for deduplication.""" def __init__(self, mod: IRModule): self.mod = mod @@ -175,11 +198,26 @@ def __hash__(self) -> int: return tvm.ir.structural_hash(self.mod) -def initializer(): - """Initializer function to register the functions""" +def initializer() -> None: + """Initializer function to register the functions on PopenWorker and locally.""" @register_func("tvm.meta_schedule.testing.default_input_generator") - def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]: + def default_input_generator( # pylint: disable=unused-variable + mod: IRModule, + ) -> List[tvm.nd.NDArray]: + """Default input generator function + + Parameters + ---------- + mod : IRModule + The IRModule to generate the input data for. + + Returns + ------- + inputs : List[tvm.nd.NDArray] + The generated input data. + """ + args_info = ms.arg_info.TensorInfo.from_prim_func(mod["main"]) inputs = [ tvm.nd.array( @@ -190,8 +228,24 @@ def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]: return inputs @register_func("tvm.meta_schedule.testing.default_check_metric") - def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool: - raise Exception("Not implemented") + def default_check_metric( # pylint: disable=unused-variable,unreachable-code + a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray] + ) -> bool: + """Check if the outputs are equal + + Parameters + ---------- + a : List[tvm.nd.NDArray] + The first list of NDArrays to compare. + + b : List[tvm.nd.NDArray] + The second list of NDArrays to compare. + + Returns + ------- + is_equal : bool + Whether the two lists of NDArrays are equal. + """ assert len(a) == len(b), "Different number of outputs from two modules" for i, _ in enumerate(a): if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3): @@ -200,24 +254,69 @@ def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bo def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]: - """Convert a list of TVM NDArray to a list of numpy array""" + """Convert a list of TVM NDArray to a list of numpy array + + Parameters + ---------- + a : List[tvm.nd.NDArray] + The list of TVM NDArray to be converted + + Returns + ------- + b : List[np.ndarray] + The list of numpy array + """ assert a is not None, "Empty result cannot be converted to numpy" return [x.numpy() for x in a] def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]: - """Convert a list of numpy array to a list of TVM NDArray""" + """Convert a list of numpy array to a list of TVM NDArray + + Parameters + ---------- + a : List[np.ndarray] + The list of numpy array to be converted. + + Returns + ------- + b : List[tvm.nd.NDArray] + The list of TVM NDArray. + """ assert a is not None, "Empty result cannot be converted to TVM NDArray" return [tvm.nd.array(x) for x in a] def is_failed_record(record: ms.database.TuningRecord) -> bool: - """Check if a tuning record is failed.""" + """Check if a tuning record is failed. + + Parameters + ---------- + record : TuningRecord + The tuning record to check. + + Returns + ------- + is_failed : bool + """ return len(record.run_secs) == 1 and record.run_secs[0] == 1e9 def print_with_counter_func(counter: int, total: int) -> Callable: - """Print with counter""" + """Print with counter + + Parameters + ---------- + counter : int + The counter to print with. + total : int + The total number of items to print with. + + Returns + ------- + print_result : Callable + The print result function. + """ def print_result( result: str, @@ -233,7 +332,7 @@ def print_result( trace: str = None, ) -> None: """Print the validation result.""" - status = f"Progress {counter: 6d} / {total: 6d} checked, result: {result}, " + status = f"Progress {counter: 6d} / {total: 6d} (estimated) checked, result: {result:>10}, " if result in ["pass", "wrong answer"]: status += ( @@ -285,7 +384,33 @@ def make_alloc_arg_and_check( original_run_secs: List[float], print_result: Callable, ) -> Tuple[Callable, Callable]: - """Make alloc_arg and check functions for the given inputs and collect results.""" + """Make alloc_arg and check functions for the given inputs and collect results. + + Parameters + ---------- + inputs : List[np.ndarray] + The inputs to the two modules. + original_mod : IRModule + The original IRModule. + scheduled_mod : IRModule + The scheduled IRModule. + trace : str + The trace of the scheduled IRModule. + original_res : List[np.ndarray] + The original results. + original_run_secs : List[float] + The original run times. + print_result : Callable + The print result function. + + Returns + ------- + f_with_args_alloc_argument : Callable + The function to allocate arguments. + + f_with_args_run_evaluator : Callable + The function to run evaluator. + """ def f_with_args_alloc_argument( # pylint: disable=unused-argument @@ -295,9 +420,28 @@ def f_with_args_alloc_argument( alloc_repeat: int, # pylint: enable=unused-argument ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]: + """Allocate arguments using the given inputs. + + Parameters + ---------- + session : RPCSession + The RPC session. + device : Device + The device. + args_info : T_ARG_INFO_JSON_OBJ_LIST + argument information. + alloc_repeat : int + The number of times to repeat the allocation. + + Returns + ------- + args_list : List[T_ARGUMENT_LIST] + The list of argument lists. + """ return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)] - def run_evaluator_with_args( + def f_with_args_run_evaluator( + session: tvm.rpc.RPCSession, # pylint: disable=unused-argument rt_mod: tvm.runtime.Module, device: tvm.runtime.Device, evaluator_config: ms.runner.EvaluatorConfig, @@ -307,6 +451,8 @@ def run_evaluator_with_args( Parameters ---------- + session : tvm.rpc.RPCSession + The RPC session rt_mod: Module The runtime module device: Device @@ -362,18 +508,6 @@ def run_evaluator_with_args( return costs - def f_with_args_run_evaluator( - session: tvm.rpc.RPCSession, # pylint: disable=unused-argument - rt_mod: tvm.runtime.Module, - device: tvm.runtime.Device, - evaluator_config: ms.runner.EvaluatorConfig, - repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST], - ) -> List[float]: - # run remote module - # pull remote args back using `arg.numpy() for arg in remote_args` - # check the results - return run_evaluator_with_args(rt_mod, device, evaluator_config, repeated_args) - return f_with_args_alloc_argument, f_with_args_run_evaluator @@ -383,7 +517,26 @@ def local_build_and_run( device: tvm.runtime.Device, inputs: List[np.ndarray], ) -> Tuple[List[np.ndarray], List[float]]: - """Build and run the module locally.""" + """Build and run the module locally. + + Parameters + ---------- + mod: IRModule + The module to build and run + target: Target + The target to build the module + device: Device + The device to run the module + inputs: List[np.ndarray] + The inputs to run the module + + Returns + ------- + res: List[np.ndarray] + The results of running the module + run_secs: List[float] + The running time of running the module + """ # potential memory leak https://github.com/apache/tvm/issues/11096 lib = tvm.build(mod, target=target) tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs] @@ -394,18 +547,63 @@ def local_build_and_run( return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results) -def _build_single_mod( - mod: IRModule, builder: ms.builder.Builder, target: Target -) -> ms.builder.BuilderResult: - builder_results = builder.build([ms.builder.BuilderInput(mod, target)]) - assert ( - len(builder_results) == 1 - ), f"Unexpected number of build results, expected 1 got {len(builder_results)}" - (builder_result,) = builder_results # pylint: disable=unbalanced-tuple-unpacking +def _check_builder_result(builder_result: ms.builder.BuilderResult) -> None: + """Check if the builder result is defined. + + Parameters + ---------- + builder_result: BuilderResult + The builder result + """ assert builder_result.error_msg is None, "Builder failed: " + str( builder_result.error_msg if builder_result.error_msg else "Empty error message" ) - return builder_results[0] + + +def _apply_trace(mod: IRModule, trace: Trace) -> IRModule: + """Apply the trace to the module. + + Parameters + ---------- + mod: IRModule + The module to apply the trace to + trace: Trace + The trace to apply + + Returns + ------- + mod: IRModule + The module with the trace applied + """ + sch = Schedule(mod) + trace.apply_to_schedule(sch, remove_postproc=False) + return sch.mod + + +def _build_all_mods( + mods: List[IRModule], builder: ms.builder.Builder, target: Target +) -> List[ms.builder.BuilderResult]: + """Build all the modules. + + Parameters + ---------- + mods: List[IRModule] + The modules to build + builder: Builder + The builder to build the modules + target: Target + The target to build the modules + + Returns + ------- + builder_results: List[BuilderResult] + The builder results + """ + builder_results = builder.build([ms.builder.BuilderInput(mod, target) for mod in mods]) + assert len(builder_results) == len( + mods + ), f"Unexpected number of build results, expected {len(mods)} got {len(builder_results)}" + return builder_results def _run_single_mod( @@ -413,6 +611,17 @@ def _run_single_mod( runner: ms.runner.Runner, dev_type: str, ) -> None: + """Run a single module. + + Parameters + ---------- + builder_result: BuilderResult + The builder result + runner: Runner + The runner to run the module + dev_type: str + The device type + """ runner_futures = runner.run( # arginfo is not used in this case so we can pass an empty list [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])] @@ -430,7 +639,7 @@ def _run_single_mod( def main(): """Main function""" describe() - initializer() + initializer() # for local input generation with ms.Profiler() as profiler: # initialize target = ARGS.target @@ -468,6 +677,8 @@ def main(): records = database.get_top_k( workload=database.commit_workload(original_mod), top_k=ARGS.top_k ) + if len(records) < ARGS.top_k: + total -= ARGS.top_k - len(records) inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod)) original_res, original_run_secs = local_build_and_run( original_mod, @@ -475,7 +686,11 @@ def main(): inputs=inputs, device=get_runtime_device(ARGS.baseline_target), ) - for record in records: + scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records] + builder_results = _build_all_mods( + [scheduled_mod for scheduled_mod in scheduled_mods], builder, target # type: ignore + ) + for i, record in enumerate(records): counter += 1 print_result = print_with_counter_func(counter=counter, total=total) if is_failed_record(record): @@ -485,13 +700,11 @@ def main(): continue try: # prepare scheduled module - sch = Schedule(original_mod) - record.trace.apply_to_schedule(sch=sch, remove_postproc=False) - scheduled_mod = sch.mod - # build the scheduled module locally - builder_result = _build_single_mod( - scheduled_mod, builder, target # type: ignore - ) + scheduled_mod = scheduled_mods[i] + # check build result + builder_result = builder_results[i] + _check_builder_result(builder_result) + # fetch functions ( f_with_args_alloc_argument, f_with_args_run_evaluator, @@ -504,6 +717,7 @@ def main(): original_run_secs=original_run_secs, print_result=print_result, ) + # create rpc runner runner = ms.runner.RPCRunner( rpc_config=ARGS.rpc_config, evaluator_config=ms.runner.EvaluatorConfig( @@ -517,6 +731,7 @@ def main(): f_run_evaluator=f_with_args_run_evaluator, initializer=initializer, ) + # run and validate _run_single_mod(builder_result, runner, dev_type) # type: ignore except Exception as e: # pylint: disable=broad-except, invalid-name # validation failed with exception @@ -527,8 +742,9 @@ def main(): trace=str(record.trace), exception=e, ) - print("Validation finished!") - print(f"Total time spent: {float(profiler.get()['Total']): 3.3f} sec.") + # clean up + remove_build_dir(builder_result.artifact_path) + print(f"Validation finished! Total time spent: {float(profiler.get()['Total']): 3.3f} sec.") if __name__ == "__main__": From 674a9e1f77a63d64a1692191e7bc58ef071c0b22 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 21 Nov 2022 15:57:28 -0800 Subject: [PATCH 09/13] Avoid nullptr. --- src/meta_schedule/measure_callback/remove_build_artifact.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/meta_schedule/measure_callback/remove_build_artifact.cc b/src/meta_schedule/measure_callback/remove_build_artifact.cc index 0abbebf3b484..41e52adbae99 100644 --- a/src/meta_schedule/measure_callback/remove_build_artifact.cc +++ b/src/meta_schedule/measure_callback/remove_build_artifact.cc @@ -28,7 +28,7 @@ class RemoveBuildArtifactNode : public MeasureCallbackNode { const Array& builder_results, const Array& runner_results) final { static const PackedFunc* f_rm = runtime::Registry::Get("meta_schedule.remove_build_dir"); - ICHECK(*f_rm != nullptr) << "The `remove_build_dir` func is not in tvm registry."; + ICHECK(f_rm != nullptr) << "The `remove_build_dir` func is not in tvm registry."; auto _ = Profiler::TimedScope("MeasureCallback/RemoveBuildArtifact"); for (const BuilderResult& build_result : builder_results) { if (Optional path = build_result->artifact_path) { From eb3702606245f486c4ee0f08a2809053671b33b2 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 21 Nov 2022 16:07:50 -0800 Subject: [PATCH 10/13] Linting. --- python/tvm/meta_schedule/testing/validate_database.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 5d05e9edd748..71cb71406f06 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -34,7 +34,7 @@ from tvm.tir.schedule import Trace from tvm.meta_schedule.utils import remove_build_dir from tvm.meta_schedule.testing.tune_utils import generate_input_data -from tvm.tir.tensor_intrin import * # type: ignore # pylint: disable=unused-import +from tvm.tir.tensor_intrin import * # type: ignore # pylint: disable=unused-import,wildcard-import DELIMITOR = "\n" + "-" * 30 + "\n" @@ -687,9 +687,7 @@ def main(): device=get_runtime_device(ARGS.baseline_target), ) scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records] - builder_results = _build_all_mods( - [scheduled_mod for scheduled_mod in scheduled_mods], builder, target # type: ignore - ) + builder_results = _build_all_mods(scheduled_mods, builder, target) # type: ignore for i, record in enumerate(records): counter += 1 print_result = print_with_counter_func(counter=counter, total=total) From 3bcafb0e62aaa3f3f9ee1a1af88dbc483fce48c1 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 21 Nov 2022 16:16:13 -0800 Subject: [PATCH 11/13] Linting. --- python/tvm/meta_schedule/testing/validate_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index 71cb71406f06..ccb6b0d99dca 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -34,7 +34,7 @@ from tvm.tir.schedule import Trace from tvm.meta_schedule.utils import remove_build_dir from tvm.meta_schedule.testing.tune_utils import generate_input_data -from tvm.tir.tensor_intrin import * # type: ignore # pylint: disable=unused-import,wildcard-import +from tvm.tir.tensor_intrin import * # type: ignore # pylint: disable=wildcard-import,unused-wildcard-import DELIMITOR = "\n" + "-" * 30 + "\n" From 6e59a36ee9123b1a5fe78818bd6ef2010c7d46c2 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 21 Nov 2022 16:33:03 -0800 Subject: [PATCH 12/13] Move function out of initializer. --- .../testing/validate_database.py | 66 +++++++++---------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index ccb6b0d99dca..c763f4fb514d 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -199,46 +199,20 @@ def __hash__(self) -> int: def initializer() -> None: - """Initializer function to register the functions on PopenWorker and locally.""" - - @register_func("tvm.meta_schedule.testing.default_input_generator") - def default_input_generator( # pylint: disable=unused-variable - mod: IRModule, - ) -> List[tvm.nd.NDArray]: - """Default input generator function - - Parameters - ---------- - mod : IRModule - The IRModule to generate the input data for. - - Returns - ------- - inputs : List[tvm.nd.NDArray] - The generated input data. - """ - - args_info = ms.arg_info.TensorInfo.from_prim_func(mod["main"]) - inputs = [ - tvm.nd.array( - generate_input_data(input_shape=arg_info.shape, input_dtype=arg_info.dtype) - ) - for arg_info in args_info - ] - return inputs + """Initializer function to register the functions on PopenWorker.""" @register_func("tvm.meta_schedule.testing.default_check_metric") def default_check_metric( # pylint: disable=unused-variable,unreachable-code - a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray] + lhs: List[tvm.nd.NDArray], rhs: List[tvm.nd.NDArray] ) -> bool: """Check if the outputs are equal Parameters ---------- - a : List[tvm.nd.NDArray] + lhs : List[tvm.nd.NDArray] The first list of NDArrays to compare. - b : List[tvm.nd.NDArray] + rhs : List[tvm.nd.NDArray] The second list of NDArrays to compare. Returns @@ -246,13 +220,38 @@ def default_check_metric( # pylint: disable=unused-variable,unreachable-code is_equal : bool Whether the two lists of NDArrays are equal. """ - assert len(a) == len(b), "Different number of outputs from two modules" - for i, _ in enumerate(a): - if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3): + assert len(lhs) == len(rhs), "Different number of outputs from two modules" + for i in range(len(lhs)): # pylint: disable=consider-using-enumerate + if not np.allclose(lhs[i].numpy(), rhs[i].numpy(), rtol=1e-3, atol=2e-3): return False return True +@register_func("tvm.meta_schedule.testing.default_input_generator") +def default_input_generator( # pylint: disable=unused-variable + mod: IRModule, +) -> List[tvm.nd.NDArray]: + """Default input generator function + + Parameters + ---------- + mod : IRModule + The IRModule to generate the input data for. + + Returns + ------- + inputs : List[tvm.nd.NDArray] + The generated input data. + """ + + args_info = ms.arg_info.TensorInfo.from_prim_func(mod["main"]) + inputs = [ + tvm.nd.array(generate_input_data(input_shape=arg_info.shape, input_dtype=arg_info.dtype)) + for arg_info in args_info + ] + return inputs + + def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]: """Convert a list of TVM NDArray to a list of numpy array @@ -639,7 +638,6 @@ def _run_single_mod( def main(): """Main function""" describe() - initializer() # for local input generation with ms.Profiler() as profiler: # initialize target = ARGS.target From e29a4df69fdfe508054c5e75039e6bd9d83b022c Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 21 Nov 2022 18:19:26 -0800 Subject: [PATCH 13/13] Support local runner. --- .../testing/validate_database.py | 90 ++++++++++++------- 1 file changed, 59 insertions(+), 31 deletions(-) diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index c763f4fb514d..a5981a78d645 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -69,17 +69,14 @@ def _parse_args(): args.add_argument( "--rpc-host", type=str, - required=True, ) args.add_argument( "--rpc-port", type=int, - required=True, ) args.add_argument( "--rpc-key", type=str, - required=True, ) args.add_argument( "--number", @@ -114,12 +111,16 @@ def _parse_args(): ) parsed = args.parse_args() parsed.target = tvm.target.Target(parsed.target) - parsed.rpc_config = ms.runner.RPCConfig( - tracker_host=parsed.rpc_host, - tracker_port=parsed.rpc_port, - tracker_key=parsed.rpc_key, - session_timeout_sec=600, - ) + if parsed.rpc_host is not None and parsed.rpc_port is not None and parsed.rpc_key is not None: + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=600, + ) + else: + parsed.rpc_config = None + warnings.warn("RPC config is not provided, will use local runner.") if parsed.cpu_flush and parsed.target.kind.name != "llvm": warnings.warn("cpu_flush is only supported on llvm target") return parsed @@ -133,7 +134,7 @@ def _parse_args(): format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) -logging.getLogger("tvm.meta_schedule.runner.rpc_runner").setLevel(logging.WARN) +logging.getLogger("tvm.meta_schedule.runner").setLevel(logging.WARN) def get_device_type(target: Target) -> str: @@ -411,13 +412,10 @@ def make_alloc_arg_and_check( The function to run evaluator. """ - def f_with_args_alloc_argument( - # pylint: disable=unused-argument - session: tvm.rpc.RPCSession, + def f_with_args_alloc_argument_common( device: tvm.runtime.Device, - args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST, + args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST, # pylint: disable=unused-argument alloc_repeat: int, - # pylint: enable=unused-argument ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]: """Allocate arguments using the given inputs. @@ -439,8 +437,7 @@ def f_with_args_alloc_argument( """ return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)] - def f_with_args_run_evaluator( - session: tvm.rpc.RPCSession, # pylint: disable=unused-argument + def f_with_args_run_evaluator_common( rt_mod: tvm.runtime.Module, device: tvm.runtime.Device, evaluator_config: ms.runner.EvaluatorConfig, @@ -507,7 +504,27 @@ def f_with_args_run_evaluator( return costs - return f_with_args_alloc_argument, f_with_args_run_evaluator + def f_with_args_alloc_argument_rpc( + rpc_session: ms.runner.rpc_runner.RPCSession, # pylint: disable=unused-argument + device: tvm.runtime.Device, + args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST, + alloc_repeat: int, + ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]: + return f_with_args_alloc_argument_common(device, args_info, alloc_repeat) + + def f_with_args_run_evaluator_rpc( + rpc_session: ms.runner.rpc_runner.RPCSession, # pylint: disable=unused-argument + rt_mod: tvm.runtime.Module, + device: tvm.runtime.Device, + evaluator_config: ms.runner.EvaluatorConfig, + repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST], + ) -> List[float]: + return f_with_args_run_evaluator_common(rt_mod, device, evaluator_config, repeated_args) + + if ARGS.rpc_config is None: + return f_with_args_alloc_argument_common, f_with_args_run_evaluator_common + else: + return f_with_args_alloc_argument_rpc, f_with_args_run_evaluator_rpc def local_build_and_run( @@ -713,20 +730,31 @@ def main(): original_run_secs=original_run_secs, print_result=print_result, ) - # create rpc runner - runner = ms.runner.RPCRunner( - rpc_config=ARGS.rpc_config, - evaluator_config=ms.runner.EvaluatorConfig( - number=ARGS.number, - repeat=ARGS.repeat, - min_repeat_ms=ARGS.min_repeat_ms, - enable_cpu_cache_flush=ARGS.cpu_flush, - ), - alloc_repeat=1, - f_alloc_argument=f_with_args_alloc_argument, - f_run_evaluator=f_with_args_run_evaluator, - initializer=initializer, + # create runner + evaluator_config = ms.runner.EvaluatorConfig( + number=ARGS.number, + repeat=ARGS.repeat, + min_repeat_ms=ARGS.min_repeat_ms, + enable_cpu_cache_flush=ARGS.cpu_flush, ) + if ARGS.rpc_config is not None: + runner: ms.Runner = ms.runner.RPCRunner( # type: ignore + ARGS.rpc_config, + evaluator_config=evaluator_config, + alloc_repeat=1, + f_alloc_argument=f_with_args_alloc_argument, + f_run_evaluator=f_with_args_run_evaluator, + initializer=initializer, + ) + else: + runner: ms.Runner = ms.runner.LocalRunner( # type: ignore + evaluator_config=evaluator_config, + alloc_repeat=1, + f_alloc_argument=f_with_args_alloc_argument, + f_run_evaluator=f_with_args_run_evaluator, + initializer=initializer, + ) + # run and validate _run_single_mod(builder_result, runner, dev_type) # type: ignore except Exception as e: # pylint: disable=broad-except, invalid-name