//===- vulkan-runtime-wrappers.cpp - MLIR Vulkan runner wrapper library ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Implements C runtime wrappers around the VulkanRuntime. // //===----------------------------------------------------------------------===// #include #include #include #include "VulkanRuntime.h" // Explicitly export entry points to the vulkan-runtime-wrapper. #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default"))) namespace { class VulkanRuntimeManager { public: VulkanRuntimeManager() = default; VulkanRuntimeManager(const VulkanRuntimeManager &) = delete; VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete; ~VulkanRuntimeManager() = default; void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex, const VulkanHostMemoryBuffer &memBuffer) { std::lock_guard lock(mutex); vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer); } void setEntryPoint(const char *entryPoint) { std::lock_guard lock(mutex); vulkanRuntime.setEntryPoint(entryPoint); } void setNumWorkGroups(NumWorkGroups numWorkGroups) { std::lock_guard lock(mutex); vulkanRuntime.setNumWorkGroups(numWorkGroups); } void setShaderModule(uint8_t *shader, uint32_t size) { std::lock_guard lock(mutex); vulkanRuntime.setShaderModule(shader, size); } void runOnVulkan() { std::lock_guard lock(mutex); if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) || failed(vulkanRuntime.updateHostMemoryBuffers()) || failed(vulkanRuntime.destroy())) { std::cerr << "runOnVulkan failed"; } } private: VulkanRuntime vulkanRuntime; std::mutex mutex; }; } // namespace template struct MemRefDescriptor { T *allocated; T *aligned; int64_t offset; int64_t sizes[N]; int64_t strides[N]; }; template void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex, BindingIndex bindIndex, MemRefDescriptor *ptr) { uint32_t size = sizeof(T); for (unsigned i = 0; i < S; i++) size *= ptr->sizes[i]; VulkanHostMemoryBuffer memBuffer{ptr->allocated, size}; reinterpret_cast(vkRuntimeManager) ->setResourceData(setIndex, bindIndex, memBuffer); } extern "C" { /// Initializes `VulkanRuntimeManager` and returns a pointer to it. VULKAN_WRAPPER_SYMBOL_EXPORT void *initVulkan() { return new VulkanRuntimeManager(); } /// Deinitializes `VulkanRuntimeManager` by the given pointer. VULKAN_WRAPPER_SYMBOL_EXPORT void deinitVulkan(void *vkRuntimeManager) { delete reinterpret_cast(vkRuntimeManager); } VULKAN_WRAPPER_SYMBOL_EXPORT void runOnVulkan(void *vkRuntimeManager) { reinterpret_cast(vkRuntimeManager)->runOnVulkan(); } VULKAN_WRAPPER_SYMBOL_EXPORT void setEntryPoint(void *vkRuntimeManager, const char *entryPoint) { reinterpret_cast(vkRuntimeManager) ->setEntryPoint(entryPoint); } VULKAN_WRAPPER_SYMBOL_EXPORT void setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y, uint32_t z) { reinterpret_cast(vkRuntimeManager) ->setNumWorkGroups({x, y, z}); } VULKAN_WRAPPER_SYMBOL_EXPORT void setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) { reinterpret_cast(vkRuntimeManager) ->setShaderModule(shader, size); } /// Binds the given memref to the given descriptor set and descriptor /// index. #define DECLARE_BIND_MEMREF(size, type, typeName) \ VULKAN_WRAPPER_SYMBOL_EXPORT void bindMemRef##size##D##typeName( \ void *vkRuntimeManager, DescriptorSetIndex setIndex, \ BindingIndex bindIndex, MemRefDescriptor *ptr) { \ bindMemRef(vkRuntimeManager, setIndex, bindIndex, ptr); \ } DECLARE_BIND_MEMREF(1, float, Float) DECLARE_BIND_MEMREF(2, float, Float) DECLARE_BIND_MEMREF(3, float, Float) DECLARE_BIND_MEMREF(1, int32_t, Int32) DECLARE_BIND_MEMREF(2, int32_t, Int32) DECLARE_BIND_MEMREF(3, int32_t, Int32) DECLARE_BIND_MEMREF(1, int16_t, Int16) DECLARE_BIND_MEMREF(2, int16_t, Int16) DECLARE_BIND_MEMREF(3, int16_t, Int16) DECLARE_BIND_MEMREF(1, int8_t, Int8) DECLARE_BIND_MEMREF(2, int8_t, Int8) DECLARE_BIND_MEMREF(3, int8_t, Int8) DECLARE_BIND_MEMREF(1, int16_t, Half) DECLARE_BIND_MEMREF(2, int16_t, Half) DECLARE_BIND_MEMREF(3, int16_t, Half) /// Fills the given 1D float memref with the given float value. VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource1DFloat(MemRefDescriptor *ptr, // NOLINT float value) { std::fill_n(ptr->allocated, ptr->sizes[0], value); } /// Fills the given 2D float memref with the given float value. VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource2DFloat(MemRefDescriptor *ptr, // NOLINT float value) { std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); } /// Fills the given 3D float memref with the given float value. VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource3DFloat(MemRefDescriptor *ptr, // NOLINT float value) { std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], value); } /// Fills the given 1D int memref with the given int value. VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource1DInt(MemRefDescriptor *ptr, // NOLINT int32_t value) { std::fill_n(ptr->allocated, ptr->sizes[0], value); } /// Fills the given 2D int memref with the given int value. VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource2DInt(MemRefDescriptor *ptr, // NOLINT int32_t value) { std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); } /// Fills the given 3D int memref with the given int value. VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource3DInt(MemRefDescriptor *ptr, // NOLINT int32_t value) { std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], value); } /// Fills the given 1D int memref with the given int8 value. VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource1DInt8(MemRefDescriptor *ptr, // NOLINT int8_t value) { std::fill_n(ptr->allocated, ptr->sizes[0], value); } /// Fills the given 2D int memref with the given int8 value. VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource2DInt8(MemRefDescriptor *ptr, // NOLINT int8_t value) { std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); } /// Fills the given 3D int memref with the given int8 value. VULKAN_WRAPPER_SYMBOL_EXPORT void _mlir_ciface_fillResource3DInt8(MemRefDescriptor *ptr, // NOLINT int8_t value) { std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], value); } }