This repository was archived by the owner on Feb 24, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 59
Expand file tree
/
Copy pathtarget_detector.py
More file actions
83 lines (65 loc) · 2.59 KB
/
target_detector.py
File metadata and controls
83 lines (65 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import subprocess
from typing import List
from thefuzz import process
from tvm.target import Target
from tvm.target.tag import list_tags
import logging
logger = logging.getLogger(__name__)
TARGET_MISSING_ERROR = (
"TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=<target>`, "
"where <target> is one of the available targets can be found in the output of `tools/get_available_targets.py`."
)
def get_gpu_model_from_nvidia_smi():
"""
Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU.
Returns:
str: The name of the GPU, or None if 'nvidia-smi' command fails.
"""
try:
# Execute nvidia-smi command to get the GPU name
output = subprocess.check_output(
["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"],
encoding="utf-8",
).strip()
except subprocess.CalledProcessError as e:
logger.info("nvidia-smi failed with error: %s", e)
return None
# Return the name of the first GPU if multiple are present
return output.split("\n")[0]
def find_best_match(tags, query):
"""
Finds the best match for a query within a list of tags using fuzzy string matching.
"""
MATCH_THRESHOLD = 25
best_match, score = process.extractOne(query, tags)
def check_target(best, default):
return best if Target(best).arch == Target(default).arch else default
if check_target(best_match, "cuda") == best_match:
return best_match if score >= MATCH_THRESHOLD else "cuda"
else:
logger.warning(TARGET_MISSING_ERROR)
return "cuda"
def get_all_nvidia_targets() -> List[str]:
"""
Returns all available NVIDIA targets.
"""
all_tags = list_tags()
return [tag for tag in all_tags if "nvidia" in tag]
def auto_detect_nvidia_target() -> str:
"""
Automatically detects the NVIDIA GPU architecture to set the appropriate TVM target.
Returns:
str: The detected TVM target architecture.
"""
# Return a predefined target if specified in the environment variable
# if "TVM_TARGET" in os.environ:
# return os.environ["TVM_TARGET"]
# Fetch all available tags and filter for NVIDIA tags
all_tags = list_tags()
nvidia_tags = [tag for tag in all_tags if "nvidia" in tag]
# Get the current GPU model and find the best matching target
gpu_model = get_gpu_model_from_nvidia_smi()
target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda"
return target