This repository was archived by the owner on Feb 24, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 58
Expand file tree
/
Copy pathgemv_dequantize.py
More file actions
370 lines (301 loc) · 15.7 KB
/
gemv_dequantize.py
File metadata and controls
370 lines (301 loc) · 15.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""A rule for GEMV and DecodeGEMV."""
from functools import reduce
from typing import List, Dict
from tvm.target import Target
from tvm.tir.function import PrimFunc
from tvm import DataType, tir
import logging
from ..base import (
normalize_prim_func,
get_output_blocks,
get_block,
)
from .base import GPUScheduleRule
from .matmul_analysis import auto_inline_producers, auto_inline_consumers
logger = logging.getLogger(__name__)
class GEMVWithDequantizeInfo(GPUScheduleRule):
"""A rule for Dequantized GEMV."""
def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
self,
func: tir.PrimFunc,
target: Target,
_: bool,
):
sch = tir.Schedule(func)
from .intrin import get_lop3_intrin_group
dequantize_info = func.attrs["dequantize_info"]
def check_dequantize_info(dequantize_info):
conditions = []
# currently only support weight only dequantization
conditions.append(len(dequantize_info) == 1)
# TODO(@lei) check if the dequantize value name is weight
return all(conditions)
if not check_dequantize_info(dequantize_info):
logger.debug("Dequantize info is not valid")
return None
(weight_decode_info,) = list(dequantize_info.values())
def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["float16", "int8"])
return all(conditions)
if not check_weight_decode_info(weight_decode_info):
logger.debug("Weight Dequantize info is not valid")
return None
block_infos = normalize_prim_func(sch)
if block_infos is None:
return None
reduction_block: tir.schedule.BlockRV = None
for block in block_infos:
s_loops: List[tir.schedule.LoopRV] = []
r_loops: List[tir.schedule.LoopRV] = []
o_loops: List[tir.schedule.LoopRV] = []
dom_kind = block.dom_kind()
block = block.block_rv
if (any([
sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block)
]) or len(sch.get_loops(block)) == 0):
continue
for loop, iter_type in zip(sch.get_loops(block), dom_kind):
{"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop)
if not s_loops:
s_loops.append(sch.add_unit_loop(block))
if len(r_loops) > 0:
reduction_block = block
def prod(iterable):
return reduce(lambda x, y: x * y, iterable, 1)
def get_vectorize_factor(target_format):
# coalesced access requires the vectorize factor to be the same as the transaction size
return 128 // DataType(target_format).bits
vec = get_vectorize_factor(weight_decode_info["target_format"])
num_warps = 1
warp_size = 32
block_b = reduction_block
output_blocks = get_output_blocks(sch, block_infos) # noqa: F841
B_decode_block = get_block(sch, block_infos, weight_decode_info["decode_block"])
block_decode_B = sch.cache_read(block_b, 1, "local")
sch.compute_inline(B_decode_block)
j, k = sch.get_loops(block_b)[-2:]
if len(sch.get_loops(block_b)) == 3:
i = sch.get_loops(block_b)[0]
sch.bind(i, "blockIdx.z")
elif len(sch.get_loops(block_b)) == 4:
# splitk case
sk, i = sch.get_loops(block_b)[:2]
sch.bind(sk, "blockIdx.y")
sch.bind(i, "blockIdx.z")
# get target dequantize buffer's idx
def get_idx(weight_decode_info: Dict):
# for LUT dequantize, the expr is LUT(w), the idx is 1
# maybe we can use a more general and structural based way
# to analysis the idx
if weight_decode_info["source_format"]["format"] == "nf":
return 1
return 0
block_shared_local_A = sch.cache_read(block_b, 0, "local")
block_shared_local_B = sch.cache_read(block_decode_B, get_idx(weight_decode_info), "local")
block_local_C = sch.cache_write(block_b, 0, "local")
auto_inline_producers(sch, block_shared_local_B)
auto_inline_consumers(sch, block_local_C)
bx, j = sch.split(j, factors=[None, num_warps])
k, tx, vk = sch.split(k, factors=[None, warp_size, vec])
# for dp4a/hfma2
inst_factor = 2 if weight_decode_info["target_format"] == "float16" else 4
_, vk = sch.split(vk, factors=[None, inst_factor])
sch.reorder(bx, j, k, tx)
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
sch.bind(j, "threadIdx.y")
self.block_size = [sch.get(tx).extent, sch.get(j).extent, 1]
self.grid_size = [sch.get(bx).extent, 1, 1]
sch.compute_at(block_decode_B, tx, preserve_unit_loops=True)
sch.compute_at(block_shared_local_A, tx, preserve_unit_loops=True)
sch.compute_at(block_shared_local_B, tx, preserve_unit_loops=True)
sch.reverse_compute_at(block_local_C, j, preserve_unit_loops=True)
block_local_a_v = sch.get_loops(block_shared_local_A)[-1]
sch.vectorize(block_local_a_v)
block_local_b_v = sch.get_loops(block_shared_local_B)[-1]
sch.vectorize(block_local_b_v)
skip_blocks = [block_shared_local_B]
if "zeros_mode" in weight_decode_info and weight_decode_info["zeros_mode"] == "quantized":
if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]:
block_local_scales = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 1, "local")
sch.compute_at(block_local_scales, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)
skip_blocks.append(block_local_scales)
if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]:
block_local_zeros = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 2, "local")
sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
skip_blocks.append(block_local_zeros)
auto_inline_producers(sch, block_decode_B, skip_blocks)
if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]):
source_bit = weight_decode_info["source_format"]["bits"]
out_dtype = weight_decode_info["target_format"]
intrin_info = get_lop3_intrin_group(
out_dtype=out_dtype,
storage_dtype=weight_decode_info["storage_dtype"],
source_format=weight_decode_info["source_format"]["format"],
source_bit=source_bit,
with_scaling=weight_decode_info["with_scaling"],
with_zeros=weight_decode_info["with_zeros"],
zeros_mode=weight_decode_info["zeros_mode"],
)
sch.tensorize(sch.get_loops(block_decode_B)[-1], intrin_info["compute"])
sch.annotate(block_b, ann_key="pragma_import_c", ann_val=intrin_info["c_source"])
return sch
def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
self,
func: tir.PrimFunc,
config,
):
sch = tir.Schedule(func)
from .intrin import get_lop3_intrin_group
dequantize_info = func.attrs["dequantize_info"]
def check_dequantize_info(dequantize_info):
conditions = []
# currently only support weight only dequantization
conditions.append(len(dequantize_info) == 1)
# TODO(@lei) check if the dequantize value name is weight
return all(conditions)
if not check_dequantize_info(dequantize_info):
logger.debug("Dequantize info is not valid")
return None
(weight_decode_info,) = list(dequantize_info.values())
def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["float16", "int8"])
return all(conditions)
if not check_weight_decode_info(weight_decode_info):
logger.debug("Weight Dequantize info is not valid")
return None
block_infos = normalize_prim_func(sch)
if block_infos is None:
return None
reduction_block: tir.schedule.BlockRV = None
for block in block_infos:
s_loops: List[tir.schedule.LoopRV] = []
r_loops: List[tir.schedule.LoopRV] = []
o_loops: List[tir.schedule.LoopRV] = []
dom_kind = block.dom_kind()
block = block.block_rv
if (any([
sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block)
]) or len(sch.get_loops(block)) == 0):
continue
for loop, iter_type in zip(sch.get_loops(block), dom_kind):
{"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop)
if not s_loops:
s_loops.append(sch.add_unit_loop(block))
if len(r_loops) > 0:
reduction_block = block
def prod(iterable):
return reduce(lambda x, y: x * y, iterable, 1)
def get_vectorize_factor(target_format):
# coalesced access requires the vectorize factor to be the same as the transaction size
return config.arch.transaction_size[-1] // DataType(target_format).bits
vec = get_vectorize_factor(weight_decode_info["target_format"])
num_warps = int(prod(config.thread))
warp_size = int(prod(config.reduce_thread))
block_b = reduction_block
output_blocks = get_output_blocks(sch, block_infos) # noqa: F841
B_decode_block = get_block(sch, block_infos, weight_decode_info["decode_block"])
block_decode_B = sch.cache_read(block_b, 1, "local")
sch.compute_inline(B_decode_block)
j, k = sch.get_loops(block_b)[-2:]
if len(sch.get_loops(block_b)) == 3:
i = sch.get_loops(block_b)[0]
sch.bind(i, "blockIdx.z")
elif len(sch.get_loops(block_b)) == 4:
# splitk case
sk, i = sch.get_loops(block_b)[:2]
sch.bind(sk, "blockIdx.y")
sch.bind(i, "blockIdx.z")
assert len(config.thread) == 2, "SplitK only support 2D thread config"
num_warps = int(num_warps // config.thread[0])
# get target dequantize buffer's idx
def get_idx(weight_decode_info: Dict):
# for LUT dequantize, the expr is LUT(w), the idx is 1
# maybe we can use a more general and structural based way
# to analysis the idx
if weight_decode_info["source_format"]["format"] == "nf":
return 1
return 0
block_shared_local_A = sch.cache_read(block_b, 0, "local")
block_shared_local_B = sch.cache_read(block_decode_B, get_idx(weight_decode_info), "local")
block_local_C = sch.cache_write(block_b, 0, "local")
auto_inline_producers(sch, block_shared_local_B)
auto_inline_consumers(sch, block_local_C)
bx, j = sch.split(j, factors=[None, num_warps])
k, tx, vk = sch.split(k, factors=[None, warp_size, vec])
# for dp4a/hfma2
inst_factor = 2 if weight_decode_info["target_format"] == "float16" else 4
_, vk = sch.split(vk, factors=[None, inst_factor])
sch.reorder(bx, j, k, tx)
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
sch.bind(j, "threadIdx.y")
self.block_size = [sch.get(tx).extent, sch.get(j).extent, 1]
self.grid_size = [sch.get(bx).extent, 1, 1]
sch.compute_at(block_decode_B, tx, preserve_unit_loops=True)
sch.compute_at(block_shared_local_A, tx, preserve_unit_loops=True)
sch.compute_at(block_shared_local_B, tx, preserve_unit_loops=True)
sch.reverse_compute_at(block_local_C, j, preserve_unit_loops=True)
block_local_a_v = sch.get_loops(block_shared_local_A)[-1]
sch.vectorize(block_local_a_v)
block_local_b_v = sch.get_loops(block_shared_local_B)[-1]
sch.vectorize(block_local_b_v)
skip_blocks = [block_shared_local_B]
if "zeros_mode" in weight_decode_info and weight_decode_info["zeros_mode"] == "quantized":
if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]:
block_local_scales = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 1, "local")
sch.compute_at(block_local_scales, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)
skip_blocks.append(block_local_scales)
if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]:
block_local_zeros = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 2, "local")
sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
skip_blocks.append(block_local_zeros)
auto_inline_producers(sch, block_decode_B, skip_blocks)
if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]):
source_bit = weight_decode_info["source_format"]["bits"]
out_dtype = weight_decode_info["target_format"]
intrin_info = get_lop3_intrin_group(
out_dtype=out_dtype,
storage_dtype=weight_decode_info["storage_dtype"],
source_format=weight_decode_info["source_format"]["format"],
source_bit=source_bit,
with_scaling=weight_decode_info["with_scaling"],
with_zeros=weight_decode_info["with_zeros"],
zeros_mode=weight_decode_info["zeros_mode"],
)
sch.tensorize(sch.get_loops(block_decode_B)[-1], intrin_info["compute"])
sch.annotate(block_b, ann_key="pragma_import_c", ann_val=intrin_info["c_source"])
return sch
def apply_config(self, func: PrimFunc, config):
if any([t > 1 for t in config.reduce_thread]):
return self.sch_inner_reduction_with_config(func, config)
else:
return None