From 6fd7da6df98aed5cfdf3ebab501f21ebcbb5cc96 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 8 Mar 2022 15:12:38 -0600 Subject: [PATCH 1/3] [Refactor] Reduced repetition in CodeGenLLVM's buffer access Previously, the majority of the BufferLoad and BufferStore visitors were duplicate logic to handle the indexing. After this commit, the shared logic is extracted out into a helper function. --- src/target/llvm/codegen_llvm.cc | 223 ++++++++++++++------------------ src/target/llvm/codegen_llvm.h | 32 ++++- 2 files changed, 128 insertions(+), 127 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index ebe91b2504a6..ab741cc7a8a6 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1273,84 +1273,109 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { return bytes != bytes_scalar * dtype.lanes(); } -llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { - ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; +void CodeGenLLVM::BufferAccessHelper( + Buffer buffer, PrimExpr index, DataType value_dtype, + std::function + make_instruction) { + DataType buffer_element_dtype = buffer->dtype; + + ICHECK_EQ(value_dtype.lanes(), index.dtype().lanes() * buffer_element_dtype.lanes()); + + bool is_volatile = volatile_buf_.count(buffer->data.get()); + + // If the buffer index is a contiguous ramp node, we only need to + // access the first element, then cast to the value type. + if (const RampNode* ramp_index = index.as()) { + if (ramp_index && is_one(ramp_index->stride)) { + index = ramp_index->base; + } + } - DataType t = op->dtype; - DataType buffer_element_dtype = op->buffer->dtype; - Var buffer_var = op->buffer->data; - PrimExpr buffer_index = op->indices[0]; + // All TVM arrays are densely packed. If the vectorized LLVM type + // contains padding for alignment, we need to index based on the + // size of the scalar type to avoid introducing that padding. + if (index.dtype().lanes() == 1 && HasAlignmentPadding(buffer_element_dtype)) { + index = buffer_element_dtype.lanes() * index; + buffer_element_dtype = buffer_element_dtype.element_of(); + } - bool is_volatile = volatile_buf_.count(buffer_var.get()); + int alignment; + if (index.dtype().lanes() == 1) { + // If we are accessing with a single index, then the vectorized + // element being accessed may require more alignment than the + // underlying data type. + int native_bits; + GetAlignment(value_dtype, buffer->data.get(), index, &alignment, &native_bits); + } else { + // Otherwise, alignment is based on the return value's scalar + // type. + ICHECK_GE(value_dtype.bits(), 8); + alignment = value_dtype.bits() / 8; + } - if (t.lanes() == buffer_element_dtype.lanes()) { - int alignment, native_bits; - GetAlignment(t, buffer_var.get(), buffer_index, &alignment, &native_bits); + std::vector instructions; - TypedPointer buffer_ptr; - if (HasAlignmentPadding(buffer_element_dtype)) { - buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype.element_of(), - MakeValue(buffer_element_dtype.lanes() * buffer_index), t); + llvm::Value* cached_vector_index = nullptr; + for (int i = 0; i < index.dtype().lanes(); ++i) { + llvm::Value* index_value; + int subelement_i = i; + if (const RampNode* ramp = index.as()) { + PrimExpr offset = ramp->base + (ramp->stride * i); + index_value = MakeValue(offset); + } else if (index.dtype().lanes() > 1) { + if (i == 0) { + cached_vector_index = MakeValue(index); + } + index_value = builder_->CreateExtractElement(cached_vector_index, i); } else { - buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, - MakeValue(buffer_index), t); + index_value = MakeValue(index); + subelement_i = -1; } + TypedPointer buffer_ptr = + CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, index_value, + value_dtype.with_lanes(value_dtype.lanes() / index.dtype().lanes())); + auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile); + AddAliasInfo(instruction, buffer->data.get(), index); + } +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { + ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; + + DataType value_dtype = op->dtype; + PrimExpr index = op->indices[0]; + + std::vector loads; + + auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, int alignment, + bool is_volatile) { #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, - llvm::Align(alignment), is_volatile); + auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, + llvm::Align(alignment), is_volatile); #elif TVM_LLVM_VERSION >= 80 - llvm::LoadInst* load = + auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); + auto load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif - AddAliasInfo(load, buffer_var.get(), buffer_index); + + loads.push_back(load); return load; + }; + + BufferAccessHelper(op->buffer, index, value_dtype, make_load); + + if (loads.size() == 1) { + return loads[0]; } else { - // vector load - if (const RampNode* ramp = buffer_index.as()) { - if (is_one(ramp->stride)) { - int alignment, native_bits; - GetAlignment(t, buffer_var.get(), ramp->base, &alignment, &native_bits); - ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), t.lanes()); - // The index argument is element-based, to create buffer pointer for t's element type. - TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), op->buffer->dtype, - MakeValue(ramp->base), t); -#if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, - llvm::Align(alignment), is_volatile); -#elif TVM_LLVM_VERSION >= 80 - llvm::LoadInst* load = - builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); -#else - llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); -#endif - AddAliasInfo(load, buffer_var.get(), buffer_index); - return load; - } + llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(value_dtype)); + for (size_t i = 0; i < loads.size(); i++) { + ret = builder_->CreateInsertElement(ret, loads[i], ConstInt32(i)); } + return ret; } - // scalarized load. - int basic_align = t.bits() / 8; - llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t)); - auto f = [&](int i, llvm::Value* index) { - TypedPointer buffer_ptr = - CreateBufferPtr(MakeValue(op->buffer->data), op->buffer->dtype, index, t.element_of()); -#if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, - llvm::Align(basic_align), is_volatile); -#elif TVM_LLVM_VERSION >= 80 - llvm::LoadInst* load = - builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, basic_align, is_volatile); -#else - llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, basic_align, is_volatile); -#endif - ret = builder_->CreateInsertElement(ret, load, ConstInt32(i)); - AddAliasInfo(load, buffer_var.get(), PrimExpr()); - }; - this->Scalarize(buffer_index, f); - return ret; } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { @@ -1421,80 +1446,26 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; DataType value_dtype = op->value.dtype(); - DataType buffer_element_dtype = op->buffer->dtype; Var buffer_var = op->buffer->data; PrimExpr buffer_index = op->indices[0]; - bool is_volatile = volatile_buf_.count(buffer_var.get()); - llvm::Value* buffer = MakeValue(buffer_var); llvm::Value* value = MakeValue(op->value); - if (value_dtype.lanes() == buffer_element_dtype.lanes()) { - int alignment, native_bits; - GetAlignment(value_dtype, buffer_var.get(), buffer_index, &alignment, &native_bits); - - TypedPointer buffer_ptr; - if (HasAlignmentPadding(buffer_element_dtype)) { - buffer_ptr = - CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype.element_of(), - MakeValue(buffer_element_dtype.lanes() * buffer_index), value_dtype); - } else { - buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, - MakeValue(buffer_index), value_dtype); + auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, int alignment, + bool is_volatile) { + llvm::Value* to_store = value; + if (subelement_i != -1) { + to_store = builder_->CreateExtractElement(value, subelement_i); } #if TVM_LLVM_VERSION >= 110 - llvm::StoreInst* store = - builder_->CreateAlignedStore(value, buffer_ptr.addr, llvm::Align(alignment), is_volatile); + return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), + is_volatile); #else - llvm::StoreInst* store = - builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); + return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); #endif - AddAliasInfo(store, buffer_var.get(), buffer_index); - return; - } else { - // vector store - if (const RampNode* ramp = buffer_index.as()) { - if (is_one(ramp->stride)) { - int alignment, native_bits; - GetAlignment(value_dtype, buffer_var.get(), ramp->base, &alignment, &native_bits); - ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), value_dtype.lanes()); - // The index argument is element-based, to create buffer pointer for t's element type. - TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, - MakeValue(ramp->base), value_dtype); - unsigned addrspace = - llvm::dyn_cast(buffer->getType())->getAddressSpace(); - buffer_ptr.type = DTypeToLLVMType(value_dtype); - buffer_ptr.addr = - builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace)); -#if TVM_LLVM_VERSION >= 110 - llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, - llvm::Align(alignment), is_volatile); -#else - llvm::StoreInst* store = - builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); -#endif - AddAliasInfo(store, buffer_var.get(), buffer_index); - return; - } - } - } - ICHECK_GE(value_dtype.bits(), 8); - // scalarized store. - int basic_align = value_dtype.bits() / 8; - auto f = [&](int i, llvm::Value* index) { - TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype, - index, value_dtype.element_of()); -#if TVM_LLVM_VERSION >= 110 - llvm::StoreInst* store = - builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), buffer_ptr.addr, - llvm::Align(basic_align), is_volatile); -#else - llvm::StoreInst* store = builder_->CreateAlignedStore( - builder_->CreateExtractElement(value, i), buffer_ptr.addr, basic_align, is_volatile); -#endif - AddAliasInfo(store, buffer_var.get(), PrimExpr()); }; - this->Scalarize(buffer_index, f); + + BufferAccessHelper(op->buffer, buffer_index, value_dtype, make_store); } void CodeGenLLVM::VisitStmt_(const ForNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e8cbe7ae445f..3ec0881d5251 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -259,7 +259,37 @@ class CodeGenLLVM : public ExprFunctor, virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder); // Scalarize by iterating elements of e. // f is a callback that takes index and v. - virtual void Scalarize(const PrimExpr& e, std::function f); + void Scalarize(const PrimExpr& e, std::function f); + + /* \brief Helper function for handling buffer access + * + * \param buffer The buffer being accessed + * + * \param index The index at which the buffer is being accessed. + * + * \param value_dtype The datatype to be read from (BufferLoad) or + * written to (BufferStore) the buffer. + * + * \param make_instruction A callback function that generates that + * actual call. + * + * - buffer_ptr: A typed pointer to the element being accessed + * + * - subelement_i: The index of a vectorized type to be + * stored/loaded. If -1, indicates that the entire type, + * vector or scalar, should be written. + * + * - alignment: The alignment to be used for the read/write. + * + * - is_volatile: Whether the read/write should be volatile. + * + * - Should return the generated expression. + */ + void BufferAccessHelper( + Buffer buffer, PrimExpr index, DataType value_dtype, + std::function + make_instruction); // Initialize target virtual void InitTarget(llvm::TargetMachine* tm); // Add module startup function if needed. From 879bd4facc1d5d8a27297aeb49a1c6c25fed2538 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 10 Mar 2022 10:36:12 -0600 Subject: [PATCH 2/3] Fixup, remove declaration of unused variable. --- src/target/llvm/codegen_llvm.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index ab741cc7a8a6..26aadd4ff881 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1314,8 +1314,6 @@ void CodeGenLLVM::BufferAccessHelper( alignment = value_dtype.bits() / 8; } - std::vector instructions; - llvm::Value* cached_vector_index = nullptr; for (int i = 0; i < index.dtype().lanes(); ++i) { llvm::Value* index_value; From 09dba71202ee10cb2c2af875fdde659430ceb8e8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 10 Mar 2022 16:21:58 -0600 Subject: [PATCH 3/3] Bump to CI