Skip to content

[BUG] topi.sparse.csrmv only accepts float32, but not other data types #8406

Description

@learning-chip

Problem description

topi.sparse.csrmv has "float32" hard-coded inside ir_builder and te.extern, making it only accept float32, but not float64 and other data types:

with irb.for_range(0, num_rows, kind="parallel", name="row") as row:
dot = irb.allocate("float32", (1,), name="dot", scope="local")
out_ptr[row] = 0.0

matmul = te.extern(
oshape,
[data, indices, indptr, weight],
lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
tag="csrmv",
dtype="float32",
name="csrmv",
)

Same problem for topi.sparse.csrmm.

Steps to reproduce

Build TVM 0.8dev from the latest master branch, and then run:

# extracted from tests/python/topi/python/test_topi_sparse.py
from tvm import te
from tvm import topi
import tvm.contrib.sparse as tvmsp

dtype = "float64"  # "float32" works fine
nr, nc = (3, 5)
nnz = 6

A = tvmsp.placeholder(shape=(nr, nc), nonzeros=nnz, dtype=dtype, name="A")
B = te.placeholder((nc, 1), name="B")
out = topi.sparse.csrmv(A, B)  # TVMError: Cannot match type float64 vs float32

Full error message:

Details
---------------------------------------------------------------------------
TVMError                                  Traceback (most recent call last)
<ipython-input-1-6daa8fd2fb08> in <module>
      9 A = tvmsp.placeholder(shape=(nr, nc), nonzeros=nnz, dtype=dtype, name="A")
     10 B = te.placeholder((nc, 1), name="B")
---> 11 out = topi.sparse.csrmv(A, B)  # TVMError: Cannot match type float64 vs float32

/tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in csrmv(a, x, y)
    111         2-D dense matrix with shape [m, 1]
    112     """
--> 113     return csrmv_default(a.data, a.indices, a.indptr, x, y)

/tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in csrmv_default(data, indices, indptr, weight, bias)
     78
     79     oshape = (batch, 1)
---> 80     matmul = te.extern(
     81         oshape,
     82         [data, indices, indptr, weight],

/tvm_install/tvm/python/tvm/te/operation.py in extern(shape, inputs, fcompute, name, dtype, in_buffers, out_buffers, tag, attrs)
    315         for shp, dt in zip(shape, dtype):
    316             output_placeholders.append(tvm.tir.decl_buffer(shp, dt, name))
--> 317     body = fcompute(input_placeholders, output_placeholders)
    318     if isinstance(body, tvm.tir.PrimExpr):
    319         body = tvm.tir.Evaluate(body)

/tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in <lambda>(ins, outs)
     81         oshape,
     82         [data, indices, indptr, weight],
---> 83         lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
     84         tag="csrmv",
     85         dtype="float32",

/tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in csrmv_default_ir(data, indices, indptr, weight, out)
     73             with irb.for_range(0, row_elems, name="elemidx") as elemidx:
     74                 elem = row_start + elemidx
---> 75                 dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]]
     76             out_ptr[row] += dot[0]
     77         return irb.get()

/tvm_install/tvm/python/tvm/tir/expr.py in __mul__(self, other)
     75
     76     def __mul__(self, other):
---> 77         return _generic.multiply(self, other)
     78
     79     def __rmul__(self, other):

/tvm_install/tvm/python/tvm/topi/generic_op_impl.py in _tensor_bop_impl(lhs, rhs)
     81         """
     82         if not isinstance(lhs, te.tensor.Tensor) and not isinstance(rhs, te.tensor.Tensor):
---> 83             return orig_bop(lhs, rhs)
     84         return broadcast_bop(lhs, rhs)
     85

/tvm_install/tvm/python/tvm/tir/generic.py in multiply(lhs, rhs, span)
     84         The result Expr of multiply operaton.
     85     """
---> 86     return _ffi_api._OpMul(lhs, rhs, span)
     87
     88

/tvm_install/tvm/python/tvm/_ffi/_ctypes/packed_func.py in __call__(self, *args)
    235             != 0
    236         ):
--> 237             raise get_last_ffi_error()
    238         _ = temp_args
    239         _ = args

TVMError: Traceback (most recent call last):
  3: TVMFuncCall
  2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::PrimExpr (tvm::PrimExpr, tvm::PrimExpr, tvm::Span)>::AssignTypedLambda<tvm::{lambda(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)#5}>(tvm::{lambda(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)#5}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  1: tvm::mul(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
  0: tvm::BinaryOpMatchTypes(tvm::PrimExpr&, tvm::PrimExpr&, tvm::Span)
  File "/tvm_install/tvm/src/tir/op/op.cc", line 144
TVMError: Cannot match type float64 vs float32

Desired fix

  • topi.sparse.{csrmv, csrmm} should be independent of data type.
  • Add unit tests to tests/python/topi/python/test_topi_sparse.py to make sure multiple data types work

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions