Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions ci/scripts/check_pr.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import argparse
import re
import os
import json
import textwrap
from dataclasses import dataclass
from typing import Any, List, Callable
Expand Down Expand Up @@ -108,10 +109,7 @@ def run_checks(checks: List[Check], s: str, name: str) -> bool:
parser.add_argument("--pr", required=True)
parser.add_argument("--remote", default="origin", help="ssh remote to parse")
parser.add_argument(
"--pr-body", help="(testing) PR body to use instead of fetching from GitHub"
)
parser.add_argument(
"--pr-title", help="(testing) PR title to use instead of fetching from GitHub"
"--pr-data", help="(testing) PR data to use instead of fetching from GitHub"
)
args = parser.parse_args()

Expand All @@ -121,20 +119,17 @@ def run_checks(checks: List[Check], s: str, name: str) -> bool:
print(f"PR was not a number: {args.pr}")
exit(0)

if args.pr_body:
body = args.pr_body
title = args.pr_title
if args.pr_data:
pr = json.loads(args.pr_data)
else:
remote = git(["config", "--get", f"remote.{args.remote}.url"])
user, repo = parse_remote(remote)

github = GitHubRepo(token=os.environ["GITHUB_TOKEN"], user=user, repo=repo)
pr = github.get(f"pulls/{args.pr}")
body = pr["body"]
title = pr["title"]

body = body.strip()
title = title.strip()
body = "" if pr["body"] is None else pr["body"].strip()
title = "" if pr["title"] is None else pr["title"].strip()

title_passed = run_checks(checks=title_checks, s=title, name="PR title")
print("")
Expand Down
43 changes: 43 additions & 0 deletions tests/python/ci/test_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,5 +1144,48 @@ def test_should_rebuild_docker(tmpdir_factory, changed_files, name, check, expec
assert proc.returncode == expected_code


@parameterize_named(
passing=dict(
title="[something] a change",
body="something",
expected="All checks passed",
expected_code=0,
),
period=dict(
title="[something] a change.",
body="something",
expected="trailing_period: FAILED",
expected_code=1,
),
empty_body=dict(
title="[something] a change",
body=None,
expected="non_empty: FAILED",
expected_code=1,
),
)
def test_pr_linter(title, body, expected, expected_code):
"""
Test the PR linter
"""
tag_script = REPO_ROOT / "ci" / "scripts" / "check_pr.py"
pr_data = {
"title": title,
"body": body,
}
proc = run_script(
[
tag_script,
"--pr",
1234,
"--pr-data",
json.dumps(pr_data),
],
check=False,
)
assert proc.returncode == expected_code
assert_in(expected, proc.stdout)


if __name__ == "__main__":
tvm.testing.main()