-
Notifications
You must be signed in to change notification settings - Fork 60
Matrix multiplication is outputting unexpected values if W is transposed through PyTorch's t() function #33
Description
Greetings!
I've been trying to multiply small matrices layout="nt" tells me that W should be transposed.
So far, initializing the weights with the values already transposed (W.t() or torch.transpose(W, 0, 1), the output is no longer correct.
config = bitblas.MatmulConfig(layout="nt", ...)
matmul = bitblas.Matmul(config) # int8 input/output, with int32 accumulation
a = torch.Tensor(...).cuda()
w = torch.Tensor(...).cuda() # W
wt = torch.Tensor(...).cuda() # W with values already transposed
# This prints the correct answer
c = matmul(matmul.transform_input(a.to(torch.int8)),
matmul.transform_weight(wt.to(torch.int8)))
print(c)
# This prints a different answer
c = matmul(matmul.transform_input(a.to(torch.int8)),
matmul.transform_weight(w.t().to(torch.int8)))
print(c)Is my understanding of how I should be using the library correct?
System Specs
- Ubuntu 22.04 LTS
- Python 3.10.12
- BitBLAS from PyPI, version 0.0.1.dev3
- CUDA 12.3
- RTX 3070 TI (laptop)
Code Sample
Matrix multiplication with int8 values and int32 accumulation.
import bitblas
import torch
matmul_config = bitblas.MatmulConfig(
M=1, # M dimension
N=3, # N dimension
K=2, # K dimension
A_dtype="int8", # activation A dtype
W_dtype="int8", # weight W dtype
accum_dtype="int32", # accumulation dtype
out_dtype="int8", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
)
matmul = bitblas.Matmul(config=matmul_config)PyTorch's matmul gives the expected answer:
a = torch.Tensor([[2, 3]]).cuda()
wt = torch.Tensor([[4, 2], [2, 1], [3, 2]]).cuda()
print(torch.matmul(a, wt.t()))tensor([[14., 7., 12.]], device='cuda:0')
Likewise, using BitBLAS with int8 gives the correct answer:
c = matmul(matmul.transform_input(a.to(torch.int8)),
matmul.transform_weight(wt.to(torch.int8)))
print(c)tensor([[14, 7, 12]], device='cuda:0', dtype=torch.int8)
However, if I initialize w with the values in their "natural" order then transpose afterwards, the output is no longer the same:
w = torch.Tensor([[4, 2, 3], [2, 1, 2]]).cuda()
print(torch.matmul(a, w))
c = matmul(matmul.transform_input(a.to(torch.int8)),
matmul.transform_weight(w.t().to(torch.int8)))
print(c)tensor([[14., 7., 12.]], device='cuda:0')
tensor([[14, 12, 8]], device='cuda:0', dtype=torch.int8)
w.t() and wt should be the same, unless there are some memory shenanigans I'm not aware of:
print(wt == w.t())tensor([[True, True],
[True, True],
[True, True]], device='cuda:0')
Is it a known issue? Or am I missing something?
Thanks in advance!