diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index ebe91b2504a6..26aadd4ff881 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1273,84 +1273,107 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { return bytes != bytes_scalar * dtype.lanes(); } -llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { - ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; - - DataType t = op->dtype; - DataType buffer_element_dtype = op->buffer->dtype; - Var buffer_var = op->buffer->data; - PrimExpr buffer_index = op->indices[0]; +void CodeGenLLVM::BufferAccessHelper( + Buffer buffer, PrimExpr index, DataType value_dtype, + std::function + make_instruction) { + DataType buffer_element_dtype = buffer->dtype; + + ICHECK_EQ(value_dtype.lanes(), index.dtype().lanes() * buffer_element_dtype.lanes()); + + bool is_volatile = volatile_buf_.count(buffer->data.get()); + + // If the buffer index is a contiguous ramp node, we only need to + // access the first element, then cast to the value type. + if (const RampNode* ramp_index = index.as()) { + if (ramp_index && is_one(ramp_index->stride)) { + index = ramp_index->base; + } + } - bool is_volatile = volatile_buf_.count(buffer_var.get()); + // All TVM arrays are densely packed. If the vectorized LLVM type + // contains padding for alignment, we need to index based on the + // size of the scalar type to avoid introducing that padding. + if (index.dtype().lanes() == 1 && HasAlignmentPadding(buffer_element_dtype)) { + index = buffer_element_dtype.lanes() * index; + buffer_element_dtype = buffer_element_dtype.element_of(); + } - if (t.lanes() == buffer_element_dtype.lanes()) { - int alignment, native_bits; - GetAlignment(t, buffer_var.get(), buffer_index, &alignment, &native_bits); + int alignment; + if (index.dtype().lanes() == 1) { + // If we are accessing with a single index, then the vectorized + // element being accessed may require more alignment than the + // underlying data type. + int native_bits; + GetAlignment(value_dtype, buffer->data.get(), index, &alignment, &native_bits); + } else { + // Otherwise, alignment is based on the return value's scalar + // type. + ICHECK_GE(value_dtype.bits(), 8); + alignment = value_dtype.bits() / 8; + } - TypedPointer buffer_ptr; - if (HasAlignmentPadding(buffer_element_dtype)) { - buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype.element_of(), - MakeValue(buffer_element_dtype.lanes() * buffer_index), t); + llvm::Value* cached_vector_index = nullptr; + for (int i = 0; i < index.dtype().lanes(); ++i) { + llvm::Value* index_value; + int subelement_i = i; + if (const RampNode* ramp = index.as()) { + PrimExpr offset = ramp->base + (ramp->stride * i); + index_value = MakeValue(offset); + } else if (index.dtype().lanes() > 1) { + if (i == 0) { + cached_vector_index = MakeValue(index); + } + index_value = builder_->CreateExtractElement(cached_vector_index, i); } else { - buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, - MakeValue(buffer_index), t); + index_value = MakeValue(index); + subelement_i = -1; } + TypedPointer buffer_ptr = + CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, index_value, + value_dtype.with_lanes(value_dtype.lanes() / index.dtype().lanes())); + auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile); + AddAliasInfo(instruction, buffer->data.get(), index); + } +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; + + DataType value_dtype = op->dtype; + PrimExpr index = op->indices[0]; + + std::vector loads; + + auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, int alignment, + bool is_volatile) { #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, - llvm::Align(alignment), is_volatile); + auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, + llvm::Align(alignment), is_volatile); #elif TVM_LLVM_VERSION >= 80 - llvm::LoadInst* load = + auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); + auto load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif - AddAliasInfo(load, buffer_var.get(), buffer_index); + + loads.push_back(load); return load; + }; + + BufferAccessHelper(op->buffer, index, value_dtype, make_load); + + if (loads.size() == 1) { + return loads[0]; } else { - // vector load - if (const RampNode* ramp = buffer_index.as()) { - if (is_one(ramp->stride)) { - int alignment, native_bits; - GetAlignment(t, buffer_var.get(), ramp->base, &alignment, &native_bits); - ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), t.lanes()); - // The index argument is element-based, to create buffer pointer for t's element type. - TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), op->buffer->dtype, - MakeValue(ramp->base), t); -#if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, - llvm::Align(alignment), is_volatile); -#elif TVM_LLVM_VERSION >= 80 - llvm::LoadInst* load = - builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); -#else - llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); -#endif - AddAliasInfo(load, buffer_var.get(), buffer_index); - return load; - } + llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(value_dtype)); + for (size_t i = 0; i < loads.size(); i++) { + ret = builder_->CreateInsertElement(ret, loads[i], ConstInt32(i)); } + return ret; } - // scalarized load. - int basic_align = t.bits() / 8; - llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t)); - auto f = [&](int i, llvm::Value* index) { - TypedPointer buffer_ptr = - CreateBufferPtr(MakeValue(op->buffer->data), op->buffer->dtype, index, t.element_of()); -#if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, - llvm::Align(basic_align), is_volatile); -#elif TVM_LLVM_VERSION >= 80 - llvm::LoadInst* load = - builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, basic_align, is_volatile); -#else - llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, basic_align, is_volatile); -#endif - ret = builder_->CreateInsertElement(ret, load, ConstInt32(i)); - AddAliasInfo(load, buffer_var.get(), PrimExpr()); - }; - this->Scalarize(buffer_index, f); - return ret; } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { @@ -1421,80 +1444,26 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; DataType value_dtype = op->value.dtype(); - DataType buffer_element_dtype = op->buffer->dtype; Var buffer_var = op->buffer->data; PrimExpr buffer_index = op->indices[0]; - bool is_volatile = volatile_buf_.count(buffer_var.get()); - llvm::Value* buffer = MakeValue(buffer_var); llvm::Value* value = MakeValue(op->value); - if (value_dtype.lanes() == buffer_element_dtype.lanes()) { - int alignment, native_bits; - GetAlignment(value_dtype, buffer_var.get(), buffer_index, &alignment, &native_bits); - - TypedPointer buffer_ptr; - if (HasAlignmentPadding(buffer_element_dtype)) { - buffer_ptr = - CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype.element_of(), - MakeValue(buffer_element_dtype.lanes() * buffer_index), value_dtype); - } else { - buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, - MakeValue(buffer_index), value_dtype); + auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, int alignment, + bool is_volatile) { + llvm::Value* to_store = value; + if (subelement_i != -1) { + to_store = builder_->CreateExtractElement(value, subelement_i); } #if TVM_LLVM_VERSION >= 110 - llvm::StoreInst* store = - builder_->CreateAlignedStore(value, buffer_ptr.addr, llvm::Align(alignment), is_volatile); + return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), + is_volatile); #else - llvm::StoreInst* store = - builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); + return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); #endif - AddAliasInfo(store, buffer_var.get(), buffer_index); - return; - } else { - // vector store - if (const RampNode* ramp = buffer_index.as()) { - if (is_one(ramp->stride)) { - int alignment, native_bits; - GetAlignment(value_dtype, buffer_var.get(), ramp->base, &alignment, &native_bits); - ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), value_dtype.lanes()); - // The index argument is element-based, to create buffer pointer for t's element type. - TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, - MakeValue(ramp->base), value_dtype); - unsigned addrspace = - llvm::dyn_cast(buffer->getType())->getAddressSpace(); - buffer_ptr.type = DTypeToLLVMType(value_dtype); - buffer_ptr.addr = - builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace)); -#if TVM_LLVM_VERSION >= 110 - llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, - llvm::Align(alignment), is_volatile); -#else - llvm::StoreInst* store = - builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); -#endif - AddAliasInfo(store, buffer_var.get(), buffer_index); - return; - } - } - } - ICHECK_GE(value_dtype.bits(), 8); - // scalarized store. - int basic_align = value_dtype.bits() / 8; - auto f = [&](int i, llvm::Value* index) { - TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, - index, value_dtype.element_of()); -#if TVM_LLVM_VERSION >= 110 - llvm::StoreInst* store = - builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), buffer_ptr.addr, - llvm::Align(basic_align), is_volatile); -#else - llvm::StoreInst* store = builder_->CreateAlignedStore( - builder_->CreateExtractElement(value, i), buffer_ptr.addr, basic_align, is_volatile); -#endif - AddAliasInfo(store, buffer_var.get(), PrimExpr()); }; - this->Scalarize(buffer_index, f); + + BufferAccessHelper(op->buffer, buffer_index, value_dtype, make_store); } void CodeGenLLVM::VisitStmt_(const ForNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e8cbe7ae445f..3ec0881d5251 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -259,7 +259,37 @@ class CodeGenLLVM : public ExprFunctor, virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder); // Scalarize by iterating elements of e. // f is a callback that takes index and v. - virtual void Scalarize(const PrimExpr& e, std::function f); + void Scalarize(const PrimExpr& e, std::function f); + + /* \brief Helper function for handling buffer access + * + * \param buffer The buffer being accessed + * + * \param index The index at which the buffer is being accessed. + * + * \param value_dtype The datatype to be read from (BufferLoad) or + * written to (BufferStore) the buffer. + * + * \param make_instruction A callback function that generates that + * actual call. + * + * - buffer_ptr: A typed pointer to the element being accessed + * + * - subelement_i: The index of a vectorized type to be + * stored/loaded. If -1, indicates that the entire type, + * vector or scalar, should be written. + * + * - alignment: The alignment to be used for the read/write. + * + * - is_volatile: Whether the read/write should be volatile. + * + * - Should return the generated expression. + */ + void BufferAccessHelper( + Buffer buffer, PrimExpr index, DataType value_dtype, + std::function + make_instruction); // Initialize target virtual void InitTarget(llvm::TargetMachine* tm); // Add module startup function if needed.