diff --git a/.github/workflows/regtest.yml b/.github/workflows/regtest.yml index d2306713b..bd2996c97 100644 --- a/.github/workflows/regtest.yml +++ b/.github/workflows/regtest.yml @@ -44,5 +44,5 @@ jobs: run: pip install . - name: Run tests via test.py - run: ./pyro/test.py + run: ./pyro/test.py --nproc 0 diff --git a/pyro/test.py b/pyro/test.py index b9513f006..d403980d0 100755 --- a/pyro/test.py +++ b/pyro/test.py @@ -2,9 +2,13 @@ import argparse +import contextlib import datetime +import io import os import sys +from multiprocessing import Pool +from pathlib import Path import pyro.pyro_sim as pyro from pyro.multigrid.examples import (mg_test_general_inhomogeneous, @@ -23,9 +27,64 @@ def __str__(self): return f"{self.solver}-{self.problem}" +@contextlib.contextmanager +def avoid_interleaved_output(nproc): + """Collect all the printed output and print it all at once to avoid interleaving.""" + if nproc == 1: + # not running in parallel, so we don't have to worry about interleaving + yield + else: + output_buffer = io.StringIO() + try: + with contextlib.redirect_stdout(output_buffer), \ + contextlib.redirect_stderr(output_buffer): + yield + finally: + # a single print call probably won't get interleaved + print(output_buffer.getvalue(), end="", flush=True) + + +def run_test(t, reset_fails, store_all_benchmarks, rtol, nproc): + orig_cwd = Path.cwd() + # run each test in its own directory, since some of the output file names + # overlap between tests, and h5py needs exclusive access when writing + test_dir = orig_cwd / f"test_outputs/{t}" + test_dir.mkdir(parents=True, exist_ok=True) + try: + os.chdir(test_dir) + with avoid_interleaved_output(nproc): + p = pyro.PyroBenchmark(t.solver, comp_bench=True, + reset_bench_on_fail=reset_fails, + make_bench=store_all_benchmarks) + p.initialize_problem(t.problem, t.inputs, t.options) + start_n = p.sim.n + err = p.run_sim(rtol) + finally: + os.chdir(orig_cwd) + if err == 0: + # the test passed; clean up the output files for developer use + basename = p.rp.get_param("io.basename") + (test_dir / f"{basename}{start_n:04d}.h5").unlink() + (test_dir / f"{basename}{p.sim.n:04d}.h5").unlink() + (test_dir / "inputs.auto").unlink() + test_dir.rmdir() + # try removing the top-level output directory + try: + test_dir.parent.rmdir() + except OSError: + pass + + return str(t), err + + +def run_test_star(args): + """multiprocessing doesn't like lambdas, so this needs to be a full function""" + return run_test(*args) + + def do_tests(out_file, reset_fails=False, store_all_benchmarks=False, - single=None, solver=None, rtol=1e-12): + single=None, solver=None, rtol=1e-12, nproc=1): opts = {"driver.verbose": 0, "vis.dovis": 0, "io.do_io": 0} @@ -59,13 +118,16 @@ def do_tests(out_file, else: tests_to_run = tests - for t in tests_to_run: - p = pyro.PyroBenchmark(t.solver, comp_bench=True, - reset_bench_on_fail=reset_fails, make_bench=store_all_benchmarks) - p.initialize_problem(t.problem, t.inputs, t.options) - err = p.run_sim(rtol) - - results[str(t)] = err + if nproc == 0: + nproc = os.cpu_count() + # don't create more processes than needed + nproc = min(nproc, len(tests_to_run)) + with Pool(processes=nproc) as pool: + tasks = ((t, reset_fails, store_all_benchmarks, rtol, nproc) for t in tests_to_run) + imap_it = pool.imap_unordered(run_test_star, tasks) + # collect run results + for name, err in imap_it: + results[name] = err # standalone tests if single is None: @@ -120,9 +182,9 @@ def do_tests(out_file, p = argparse.ArgumentParser() - p.add_argument("-o", - help="name of file to output the report to (otherwise output to the screen", - type=str, nargs=1) + p.add_argument("--outfile", "-o", + help="name of file to output the report to (in addition to the screen)", + type=str, default=None) p.add_argument("--single", help="name of a single test (solver-problem) to run", @@ -142,23 +204,18 @@ def do_tests(out_file, p.add_argument("--rtol", help="relative tolerance to use when comparing data to benchmarks", - type=float, nargs=1) + type=float, default=1.e-12) - args = p.parse_args() - - try: - outfile = args.o[0] - except TypeError: - outfile = None + p.add_argument("--nproc", "-n", + help="maximum number of parallel processes to run, or 0 to use all cores", + type=int, default=1) - try: - rtol = args.rtol[0] - except TypeError: - rtol = 1.e-12 + args = p.parse_args() - failed = do_tests(outfile, + failed = do_tests(args.outfile, reset_fails=args.reset_failures, store_all_benchmarks=args.store_all_benchmarks, - single=args.single, solver=args.solver, rtol=rtol) + single=args.single, solver=args.solver, rtol=args.rtol, + nproc=args.nproc) sys.exit(failed)