Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 96 additions & 127 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1273,84 +1273,107 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) {
return bytes != bytes_scalar * dtype.lanes();
}

llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers.";

DataType t = op->dtype;
DataType buffer_element_dtype = op->buffer->dtype;
Var buffer_var = op->buffer->data;
PrimExpr buffer_index = op->indices[0];
void CodeGenLLVM::BufferAccessHelper(
Buffer buffer, PrimExpr index, DataType value_dtype,
std::function<llvm::Instruction*(TypedPointer buffer_ptr, int subelement_i, int alignment,
bool is_volatile)>
make_instruction) {
DataType buffer_element_dtype = buffer->dtype;

ICHECK_EQ(value_dtype.lanes(), index.dtype().lanes() * buffer_element_dtype.lanes());

bool is_volatile = volatile_buf_.count(buffer->data.get());

// If the buffer index is a contiguous ramp node, we only need to
// access the first element, then cast to the value type.
if (const RampNode* ramp_index = index.as<RampNode>()) {
if (ramp_index && is_one(ramp_index->stride)) {
index = ramp_index->base;
}
}

bool is_volatile = volatile_buf_.count(buffer_var.get());
// All TVM arrays are densely packed. If the vectorized LLVM type
// contains padding for alignment, we need to index based on the
// size of the scalar type to avoid introducing that padding.
if (index.dtype().lanes() == 1 && HasAlignmentPadding(buffer_element_dtype)) {
index = buffer_element_dtype.lanes() * index;
buffer_element_dtype = buffer_element_dtype.element_of();
}

if (t.lanes() == buffer_element_dtype.lanes()) {
int alignment, native_bits;
GetAlignment(t, buffer_var.get(), buffer_index, &alignment, &native_bits);
int alignment;
if (index.dtype().lanes() == 1) {
// If we are accessing with a single index, then the vectorized
// element being accessed may require more alignment than the
// underlying data type.
int native_bits;
GetAlignment(value_dtype, buffer->data.get(), index, &alignment, &native_bits);
} else {
// Otherwise, alignment is based on the return value's scalar
// type.
ICHECK_GE(value_dtype.bits(), 8);
alignment = value_dtype.bits() / 8;
}

TypedPointer buffer_ptr;
if (HasAlignmentPadding(buffer_element_dtype)) {
buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype.element_of(),
MakeValue(buffer_element_dtype.lanes() * buffer_index), t);
llvm::Value* cached_vector_index = nullptr;
for (int i = 0; i < index.dtype().lanes(); ++i) {
llvm::Value* index_value;
int subelement_i = i;
if (const RampNode* ramp = index.as<RampNode>()) {
PrimExpr offset = ramp->base + (ramp->stride * i);
index_value = MakeValue(offset);
} else if (index.dtype().lanes() > 1) {
if (i == 0) {
cached_vector_index = MakeValue(index);
}
index_value = builder_->CreateExtractElement(cached_vector_index, i);
} else {
buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype,
MakeValue(buffer_index), t);
index_value = MakeValue(index);
subelement_i = -1;
}

TypedPointer buffer_ptr =
CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, index_value,
value_dtype.with_lanes(value_dtype.lanes() / index.dtype().lanes()));
auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile);
AddAliasInfo(instruction, buffer->data.get(), index);
}
}

llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers.";

DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];

std::vector<llvm::Value*> loads;

auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, int alignment,
bool is_volatile) {
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr,
llvm::Align(alignment), is_volatile);
auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr,
llvm::Align(alignment), is_volatile);
#elif TVM_LLVM_VERSION >= 80
llvm::LoadInst* load =
auto load =
builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile);
#else
llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile);
auto load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile);
#endif
AddAliasInfo(load, buffer_var.get(), buffer_index);

loads.push_back(load);
return load;
};

BufferAccessHelper(op->buffer, index, value_dtype, make_load);

if (loads.size() == 1) {
return loads[0];
} else {
// vector load
if (const RampNode* ramp = buffer_index.as<RampNode>()) {
if (is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, buffer_var.get(), ramp->base, &alignment, &native_bits);
ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), t.lanes());
// The index argument is element-based, to create buffer pointer for t's element type.
TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), op->buffer->dtype,
MakeValue(ramp->base), t);
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr,
llvm::Align(alignment), is_volatile);
#elif TVM_LLVM_VERSION >= 80
llvm::LoadInst* load =
builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile);
#else
llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile);
#endif
AddAliasInfo(load, buffer_var.get(), buffer_index);
return load;
}
llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(value_dtype));
for (size_t i = 0; i < loads.size(); i++) {
ret = builder_->CreateInsertElement(ret, loads[i], ConstInt32(i));
}
return ret;
}
// scalarized load.
int basic_align = t.bits() / 8;
llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t));
auto f = [&](int i, llvm::Value* index) {
TypedPointer buffer_ptr =
CreateBufferPtr(MakeValue(op->buffer->data), op->buffer->dtype, index, t.element_of());
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr,
llvm::Align(basic_align), is_volatile);
#elif TVM_LLVM_VERSION >= 80
llvm::LoadInst* load =
builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, basic_align, is_volatile);
#else
llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, basic_align, is_volatile);
#endif
ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
AddAliasInfo(load, buffer_var.get(), PrimExpr());
};
this->Scalarize(buffer_index, f);
return ret;
}

llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
Expand Down Expand Up @@ -1421,80 +1444,26 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers.";

DataType value_dtype = op->value.dtype();
DataType buffer_element_dtype = op->buffer->dtype;
Var buffer_var = op->buffer->data;
PrimExpr buffer_index = op->indices[0];

bool is_volatile = volatile_buf_.count(buffer_var.get());
llvm::Value* buffer = MakeValue(buffer_var);
llvm::Value* value = MakeValue(op->value);

if (value_dtype.lanes() == buffer_element_dtype.lanes()) {
int alignment, native_bits;
GetAlignment(value_dtype, buffer_var.get(), buffer_index, &alignment, &native_bits);

TypedPointer buffer_ptr;
if (HasAlignmentPadding(buffer_element_dtype)) {
buffer_ptr =
CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype.element_of(),
MakeValue(buffer_element_dtype.lanes() * buffer_index), value_dtype);
} else {
buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype,
MakeValue(buffer_index), value_dtype);
auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, int alignment,
bool is_volatile) {
llvm::Value* to_store = value;
if (subelement_i != -1) {
to_store = builder_->CreateExtractElement(value, subelement_i);
}
#if TVM_LLVM_VERSION >= 110
llvm::StoreInst* store =
builder_->CreateAlignedStore(value, buffer_ptr.addr, llvm::Align(alignment), is_volatile);
return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment),
is_volatile);
#else
llvm::StoreInst* store =
builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile);
return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile);
#endif
AddAliasInfo(store, buffer_var.get(), buffer_index);
return;
} else {
// vector store
if (const RampNode* ramp = buffer_index.as<RampNode>()) {
if (is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(value_dtype, buffer_var.get(), ramp->base, &alignment, &native_bits);
ICHECK_EQ(ramp->lanes * buffer_element_dtype.lanes(), value_dtype.lanes());
// The index argument is element-based, to create buffer pointer for t's element type.
TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype,
MakeValue(ramp->base), value_dtype);
unsigned addrspace =
llvm::dyn_cast<llvm::PointerType>(buffer->getType())->getAddressSpace();
buffer_ptr.type = DTypeToLLVMType(value_dtype);
buffer_ptr.addr =
builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace));
#if TVM_LLVM_VERSION >= 110
llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr,
llvm::Align(alignment), is_volatile);
#else
llvm::StoreInst* store =
builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile);
#endif
AddAliasInfo(store, buffer_var.get(), buffer_index);
return;
}
}
}
ICHECK_GE(value_dtype.bits(), 8);
// scalarized store.
int basic_align = value_dtype.bits() / 8;
auto f = [&](int i, llvm::Value* index) {
TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(op->buffer->data), buffer_element_dtype,
index, value_dtype.element_of());
#if TVM_LLVM_VERSION >= 110
llvm::StoreInst* store =
builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), buffer_ptr.addr,
llvm::Align(basic_align), is_volatile);
#else
llvm::StoreInst* store = builder_->CreateAlignedStore(
builder_->CreateExtractElement(value, i), buffer_ptr.addr, basic_align, is_volatile);
#endif
AddAliasInfo(store, buffer_var.get(), PrimExpr());
};
this->Scalarize(buffer_index, f);

BufferAccessHelper(op->buffer, buffer_index, value_dtype, make_store);
}

void CodeGenLLVM::VisitStmt_(const ForNode* op) {
Expand Down
32 changes: 31 additions & 1 deletion src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,37 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
// Scalarize by iterating elements of e.
// f is a callback that takes index and v.
virtual void Scalarize(const PrimExpr& e, std::function<void(int i, llvm::Value* v)> f);
void Scalarize(const PrimExpr& e, std::function<void(int i, llvm::Value* v)> f);

/* \brief Helper function for handling buffer access
*
* \param buffer The buffer being accessed
*
* \param index The index at which the buffer is being accessed.
*
* \param value_dtype The datatype to be read from (BufferLoad) or
* written to (BufferStore) the buffer.
*
* \param make_instruction A callback function that generates that
* actual call.
*
* - buffer_ptr: A typed pointer to the element being accessed
*
* - subelement_i: The index of a vectorized type to be
* stored/loaded. If -1, indicates that the entire type,
* vector or scalar, should be written.
*
* - alignment: The alignment to be used for the read/write.
*
* - is_volatile: Whether the read/write should be volatile.
*
* - Should return the generated expression.
*/
void BufferAccessHelper(
Buffer buffer, PrimExpr index, DataType value_dtype,
std::function<llvm::Instruction*(TypedPointer buffer_ptr, int subelement_i, int alignment,
bool is_volatile)>
make_instruction);
// Initialize target
virtual void InitTarget(llvm::TargetMachine* tm);
// Add module startup function if needed.
Expand Down