From 8675477bce0f405ccfa62c2bbeadf83b1d50b920 Mon Sep 17 00:00:00 2001 From: driazati Date: Fri, 2 Sep 2022 11:46:37 -0700 Subject: [PATCH] rebase --- ci/scripts/check_pr.py | 17 ++++++--------- tests/python/ci/test_ci.py | 43 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 11 deletions(-) mode change 100644 => 100755 ci/scripts/check_pr.py diff --git a/ci/scripts/check_pr.py b/ci/scripts/check_pr.py old mode 100644 new mode 100755 index 45d502c6a72e..9af5ec5580a3 --- a/ci/scripts/check_pr.py +++ b/ci/scripts/check_pr.py @@ -18,6 +18,7 @@ import argparse import re import os +import json import textwrap from dataclasses import dataclass from typing import Any, List, Callable @@ -108,10 +109,7 @@ def run_checks(checks: List[Check], s: str, name: str) -> bool: parser.add_argument("--pr", required=True) parser.add_argument("--remote", default="origin", help="ssh remote to parse") parser.add_argument( - "--pr-body", help="(testing) PR body to use instead of fetching from GitHub" - ) - parser.add_argument( - "--pr-title", help="(testing) PR title to use instead of fetching from GitHub" + "--pr-data", help="(testing) PR data to use instead of fetching from GitHub" ) args = parser.parse_args() @@ -121,20 +119,17 @@ def run_checks(checks: List[Check], s: str, name: str) -> bool: print(f"PR was not a number: {args.pr}") exit(0) - if args.pr_body: - body = args.pr_body - title = args.pr_title + if args.pr_data: + pr = json.loads(args.pr_data) else: remote = git(["config", "--get", f"remote.{args.remote}.url"]) user, repo = parse_remote(remote) github = GitHubRepo(token=os.environ["GITHUB_TOKEN"], user=user, repo=repo) pr = github.get(f"pulls/{args.pr}") - body = pr["body"] - title = pr["title"] - body = body.strip() - title = title.strip() + body = "" if pr["body"] is None else pr["body"].strip() + title = "" if pr["title"] is None else pr["title"].strip() title_passed = run_checks(checks=title_checks, s=title, name="PR title") print("") diff --git a/tests/python/ci/test_ci.py b/tests/python/ci/test_ci.py index f2e686d1e582..79c72ce988c3 100644 --- a/tests/python/ci/test_ci.py +++ b/tests/python/ci/test_ci.py @@ -1144,5 +1144,48 @@ def test_should_rebuild_docker(tmpdir_factory, changed_files, name, check, expec assert proc.returncode == expected_code +@parameterize_named( + passing=dict( + title="[something] a change", + body="something", + expected="All checks passed", + expected_code=0, + ), + period=dict( + title="[something] a change.", + body="something", + expected="trailing_period: FAILED", + expected_code=1, + ), + empty_body=dict( + title="[something] a change", + body=None, + expected="non_empty: FAILED", + expected_code=1, + ), +) +def test_pr_linter(title, body, expected, expected_code): + """ + Test the PR linter + """ + tag_script = REPO_ROOT / "ci" / "scripts" / "check_pr.py" + pr_data = { + "title": title, + "body": body, + } + proc = run_script( + [ + tag_script, + "--pr", + 1234, + "--pr-data", + json.dumps(pr_data), + ], + check=False, + ) + assert proc.returncode == expected_code + assert_in(expected, proc.stdout) + + if __name__ == "__main__": tvm.testing.main()