Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions src/runtime/vulkan/vulkan_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "vulkan_context.h"

#include <algorithm>
#include <unordered_map>

#include "vulkan_common.h"
Expand All @@ -29,6 +30,176 @@ namespace tvm {
namespace runtime {
namespace vulkan {

VulkanDeviceProperties::VulkanDeviceProperties(VkInstance instance, VkPhysicalDevice phy_dev,
const std::vector<const char*> instance_extensions,
const std::vector<const char*> device_extensions) {
auto has_instance_extension = [&](const char* query) {
return std::any_of(instance_extensions.begin(), instance_extensions.end(),
[&](const char* extension) { return std::strcmp(query, extension) == 0; });
};

auto has_device_extension = [&](const char* query) {
return std::any_of(device_extensions.begin(), device_extensions.end(),
[&](const char* extension) { return std::strcmp(query, extension) == 0; });
};

///////////////////////////////////////////////////////////////
// Query properties from Vulkan API //
///////////////////////////////////////////////////////////////

// Declare output locations for properties
VkPhysicalDeviceProperties2 properties = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2};
VkPhysicalDeviceDriverProperties driver = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES};
VkPhysicalDeviceSubgroupProperties subgroup = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES};

// Need to do initial query in order to check the apiVersion.
vkGetPhysicalDeviceProperties(phy_dev, &properties.properties);

// Set up linked list for property query
{
void** pp_next = &properties.pNext;
if (has_device_extension("VK_KHR_driver_properties")) {
*pp_next = &driver;
pp_next = &driver.pNext;
}
if (properties.properties.apiVersion >= VK_API_VERSION_1_1) {
*pp_next = &subgroup;
pp_next = &subgroup.pNext;
}
}

// Declare output locations for features
VkPhysicalDeviceFeatures2 features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2};
VkPhysicalDevice8BitStorageFeatures storage_8bit = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES};
VkPhysicalDevice16BitStorageFeatures storage_16bit = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES};
VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES};

// Set up linked list for feature query
{
void** pp_next = &features.pNext;
if (has_device_extension("VK_KHR_8bit_storage")) {
*pp_next = &storage_8bit;
pp_next = &storage_8bit.pNext;
}
if (has_device_extension("VK_KHR_16bit_storage")) {
*pp_next = &storage_16bit;
pp_next = &storage_16bit.pNext;
}
if (has_device_extension("VK_KHR_shader_float16_int8")) {
*pp_next = &float16_int8;
pp_next = &float16_int8.pNext;
}
}

if (has_instance_extension("VK_KHR_get_physical_device_properties2")) {
// Preferred method, call to get all properties that can be queried.
auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL(
vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR"));
vkGetPhysicalDeviceProperties2KHR(phy_dev, &properties);

auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL(
vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR"));
vkGetPhysicalDeviceFeatures2KHR(phy_dev, &features);
} else {
// Fallback, get as many features as we can from the Vulkan1.0
// API. Corresponding vkGetPhysicalDeviceProperties was already done earlier.
vkGetPhysicalDeviceFeatures(phy_dev, &features.features);
}

///////////////////////////////////////////////////////////////
// Fill member variables from Vulkan structures //
///////////////////////////////////////////////////////////////

supports_float16 = float16_int8.shaderFloat16;
supports_float32 = true;
supports_float64 = features.features.shaderFloat64;
supports_int8 = float16_int8.shaderInt8;
supports_int16 = features.features.shaderInt16;
supports_int32 = true;
supports_int64 = features.features.shaderInt64;
supports_8bit_buffer = storage_8bit.storageBuffer8BitAccess;
supports_16bit_buffer = storage_16bit.storageBuffer16BitAccess;
supports_storage_buffer_storage_class =
has_device_extension("VK_KHR_storage_buffer_storage_class");

// Support is available based on these extensions, but allow it to
// be disabled based on an environment variable.
supports_push_descriptor = has_device_extension("VK_KHR_push_descriptor") &&
has_device_extension("VK_KHR_descriptor_update_template");
{
const char* disable = std::getenv("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR");
if (disable && *disable) {
supports_push_descriptor = false;
}
}

// Support is available based on these extensions, but allow it to
// be disabled based on an environment variable.
supports_dedicated_allocation = has_device_extension("VK_KHR_get_memory_requirements2") &&
has_device_extension("VK_KHR_dedicated_allocation");
{
const char* disable = std::getenv("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION");
if (disable && *disable) {
supports_dedicated_allocation = false;
}
}

// The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically
// needed, since it will be set so long at least one queue has
// VK_QUEUE_COMPUTE_BIT. Including it to avoid potential future
// confusion..
supported_subgroup_operations =
(subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0;

max_num_threads = properties.properties.limits.maxComputeWorkGroupInvocations;

// Even if we can't query it, warp size must be at least 1.
thread_warp_size = std::max(subgroup.subgroupSize, 1U);

max_block_size_x = properties.properties.limits.maxComputeWorkGroupSize[0];
max_block_size_y = properties.properties.limits.maxComputeWorkGroupSize[1];
max_block_size_z = properties.properties.limits.maxComputeWorkGroupSize[2];
max_push_constants_size = properties.properties.limits.maxPushConstantsSize;
max_uniform_buffer_range = properties.properties.limits.maxUniformBufferRange;
max_storage_buffer_range = properties.properties.limits.maxStorageBufferRange;
max_per_stage_descriptor_storage_buffer =
properties.properties.limits.maxPerStageDescriptorStorageBuffers;
max_shared_memory_per_block = properties.properties.limits.maxComputeSharedMemorySize;
device_name = properties.properties.deviceName;
driver_version = properties.properties.driverVersion;

// By default, use the maximum API version that the driver allows,
// so that any supported features can be used by TVM shaders.
// However, if we can query the conformance version, then limit to
// only using the api version that passes the vulkan conformance
// tests.
vulkan_api_version = properties.properties.apiVersion;
if (has_device_extension("VK_KHR_driver_properties")) {
auto api_major = VK_VERSION_MAJOR(vulkan_api_version);
auto api_minor = VK_VERSION_MINOR(vulkan_api_version);
if ((api_major > driver.conformanceVersion.major) ||
((api_major == driver.conformanceVersion.major) &&
(api_minor > driver.conformanceVersion.minor))) {
vulkan_api_version =
VK_MAKE_VERSION(driver.conformanceVersion.major, driver.conformanceVersion.minor, 0);
}
}

// From "Versions and Formats" section of Vulkan spec.
max_spirv_version = 0x10000;
if (vulkan_api_version >= VK_API_VERSION_1_2) {
max_spirv_version = 0x10500;
} else if (has_device_extension("VK_KHR_spirv_1_4")) {
max_spirv_version = 0x10400;
} else if (vulkan_api_version >= VK_API_VERSION_1_1) {
max_spirv_version = 0x10300;
}
}

VulkanDescriptorTemplateKHRFunctions::VulkanDescriptorTemplateKHRFunctions(VkDevice device) {
vkCreateDescriptorUpdateTemplateKHR = (PFN_vkCreateDescriptorUpdateTemplateKHR)ICHECK_NOTNULL(
vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR"));
Expand Down
50 changes: 48 additions & 2 deletions src/runtime/vulkan/vulkan_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <tvm/target/target.h>

#include <memory>
#include <string>
#include <vector>

#include "vulkan/vulkan_core.h"
#include "vulkan_buffer.h"
Expand All @@ -47,14 +49,58 @@ struct VulkanGetBufferMemoryRequirements2Functions {
PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr};
};

/*!
* \brief Stores the capabilities/limits queried from the physical device.
*
* The member variables here have a 1-1 mapping to Target parameters,
* if target->kind->device_type==kDLVulkan. A separate struct is used
* to maintain the boundary between the Vulkan runtime in
* libtvm_runtime.so, and the Target object in libtvm.so.
*/
struct VulkanDeviceProperties {
VulkanDeviceProperties() {}
VulkanDeviceProperties(VkInstance instance, VkPhysicalDevice phy_device,
const std::vector<const char*> instance_extensions,
const std::vector<const char*> device_extensions);

bool supports_float16{false};
bool supports_float32{true};
bool supports_float64{false};
bool supports_int8{false};
bool supports_int16{false};
bool supports_int32{true};
bool supports_int64{false};
bool supports_8bit_buffer{false};
bool supports_16bit_buffer{false};
bool supports_storage_buffer_storage_class{false};
bool supports_push_descriptor{false};
bool supports_dedicated_allocation{false};
uint32_t supported_subgroup_operations{0};
uint32_t max_num_threads{1};
uint32_t thread_warp_size{1};
uint32_t max_block_size_x{1};
uint32_t max_block_size_y{1};
uint32_t max_block_size_z{1};
uint32_t max_push_constants_size{128};
uint32_t max_uniform_buffer_range{16384};
uint32_t max_storage_buffer_range{1 << 27};
uint32_t max_per_stage_descriptor_storage_buffer{4};
uint32_t max_shared_memory_per_block{16384};
std::string device_name{"unknown device name"};
uint32_t driver_version{0};
uint32_t vulkan_api_version{VK_API_VERSION_1_0};
uint32_t max_spirv_version{0x10000};
};

struct VulkanContext {
// physical device
VkPhysicalDevice phy_device{nullptr};

// Cached device properties, queried through Vulkan API.
VulkanDeviceProperties device_properties;

// Phyiscal device property
VkPhysicalDeviceProperties phy_device_prop;
// Target that best represents this physical device
Target target;
// Memory type index for staging.
uint32_t staging_mtype_index{0};
// whether staging is coherent
Expand Down
Loading