You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
211 lines
7.4 KiB
211 lines
7.4 KiB
//===- vulkan-runtime-wrappers.cpp - MLIR Vulkan runner wrapper library ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Implements C runtime wrappers around the VulkanRuntime.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <iostream>
|
|
#include <mutex>
|
|
#include <numeric>
|
|
|
|
#include "VulkanRuntime.h"
|
|
|
|
// Explicitly export entry points to the vulkan-runtime-wrapper.
|
|
#define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default")))
|
|
|
|
namespace {
|
|
|
|
class VulkanRuntimeManager {
|
|
public:
|
|
VulkanRuntimeManager() = default;
|
|
VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
|
|
VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
|
|
~VulkanRuntimeManager() = default;
|
|
|
|
void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
|
|
const VulkanHostMemoryBuffer &memBuffer) {
|
|
std::lock_guard<std::mutex> lock(mutex);
|
|
vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
|
|
}
|
|
|
|
void setEntryPoint(const char *entryPoint) {
|
|
std::lock_guard<std::mutex> lock(mutex);
|
|
vulkanRuntime.setEntryPoint(entryPoint);
|
|
}
|
|
|
|
void setNumWorkGroups(NumWorkGroups numWorkGroups) {
|
|
std::lock_guard<std::mutex> lock(mutex);
|
|
vulkanRuntime.setNumWorkGroups(numWorkGroups);
|
|
}
|
|
|
|
void setShaderModule(uint8_t *shader, uint32_t size) {
|
|
std::lock_guard<std::mutex> lock(mutex);
|
|
vulkanRuntime.setShaderModule(shader, size);
|
|
}
|
|
|
|
void runOnVulkan() {
|
|
std::lock_guard<std::mutex> lock(mutex);
|
|
if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
|
|
failed(vulkanRuntime.updateHostMemoryBuffers()) ||
|
|
failed(vulkanRuntime.destroy())) {
|
|
std::cerr << "runOnVulkan failed";
|
|
}
|
|
}
|
|
|
|
private:
|
|
VulkanRuntime vulkanRuntime;
|
|
std::mutex mutex;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
template <typename T, int N> struct MemRefDescriptor {
|
|
T *allocated;
|
|
T *aligned;
|
|
int64_t offset;
|
|
int64_t sizes[N];
|
|
int64_t strides[N];
|
|
};
|
|
|
|
template <typename T, uint32_t S>
|
|
void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex,
|
|
BindingIndex bindIndex, MemRefDescriptor<T, S> *ptr) {
|
|
uint32_t size = sizeof(T);
|
|
for (unsigned i = 0; i < S; i++)
|
|
size *= ptr->sizes[i];
|
|
VulkanHostMemoryBuffer memBuffer{ptr->allocated, size};
|
|
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
|
|
->setResourceData(setIndex, bindIndex, memBuffer);
|
|
}
|
|
|
|
extern "C" {
|
|
/// Initializes `VulkanRuntimeManager` and returns a pointer to it.
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void *initVulkan() {
|
|
return new VulkanRuntimeManager();
|
|
}
|
|
|
|
/// Deinitializes `VulkanRuntimeManager` by the given pointer.
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void deinitVulkan(void *vkRuntimeManager) {
|
|
delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager);
|
|
}
|
|
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void runOnVulkan(void *vkRuntimeManager) {
|
|
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan();
|
|
}
|
|
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void setEntryPoint(void *vkRuntimeManager,
|
|
const char *entryPoint) {
|
|
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
|
|
->setEntryPoint(entryPoint);
|
|
}
|
|
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void
|
|
setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y, uint32_t z) {
|
|
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
|
|
->setNumWorkGroups({x, y, z});
|
|
}
|
|
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void
|
|
setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) {
|
|
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
|
|
->setShaderModule(shader, size);
|
|
}
|
|
|
|
/// Binds the given memref to the given descriptor set and descriptor
|
|
/// index.
|
|
#define DECLARE_BIND_MEMREF(size, type, typeName) \
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void bindMemRef##size##D##typeName( \
|
|
void *vkRuntimeManager, DescriptorSetIndex setIndex, \
|
|
BindingIndex bindIndex, MemRefDescriptor<type, size> *ptr) { \
|
|
bindMemRef<type, size>(vkRuntimeManager, setIndex, bindIndex, ptr); \
|
|
}
|
|
|
|
DECLARE_BIND_MEMREF(1, float, Float)
|
|
DECLARE_BIND_MEMREF(2, float, Float)
|
|
DECLARE_BIND_MEMREF(3, float, Float)
|
|
DECLARE_BIND_MEMREF(1, int32_t, Int32)
|
|
DECLARE_BIND_MEMREF(2, int32_t, Int32)
|
|
DECLARE_BIND_MEMREF(3, int32_t, Int32)
|
|
DECLARE_BIND_MEMREF(1, int16_t, Int16)
|
|
DECLARE_BIND_MEMREF(2, int16_t, Int16)
|
|
DECLARE_BIND_MEMREF(3, int16_t, Int16)
|
|
DECLARE_BIND_MEMREF(1, int8_t, Int8)
|
|
DECLARE_BIND_MEMREF(2, int8_t, Int8)
|
|
DECLARE_BIND_MEMREF(3, int8_t, Int8)
|
|
DECLARE_BIND_MEMREF(1, int16_t, Half)
|
|
DECLARE_BIND_MEMREF(2, int16_t, Half)
|
|
DECLARE_BIND_MEMREF(3, int16_t, Half)
|
|
|
|
/// Fills the given 1D float memref with the given float value.
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void
|
|
_mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
|
|
float value) {
|
|
std::fill_n(ptr->allocated, ptr->sizes[0], value);
|
|
}
|
|
|
|
/// Fills the given 2D float memref with the given float value.
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void
|
|
_mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
|
|
float value) {
|
|
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
|
|
}
|
|
|
|
/// Fills the given 3D float memref with the given float value.
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void
|
|
_mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
|
|
float value) {
|
|
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
|
|
value);
|
|
}
|
|
|
|
/// Fills the given 1D int memref with the given int value.
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void
|
|
_mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
|
|
int32_t value) {
|
|
std::fill_n(ptr->allocated, ptr->sizes[0], value);
|
|
}
|
|
|
|
/// Fills the given 2D int memref with the given int value.
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void
|
|
_mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
|
|
int32_t value) {
|
|
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
|
|
}
|
|
|
|
/// Fills the given 3D int memref with the given int value.
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void
|
|
_mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
|
|
int32_t value) {
|
|
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
|
|
value);
|
|
}
|
|
|
|
/// Fills the given 1D int memref with the given int8 value.
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void
|
|
_mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
|
|
int8_t value) {
|
|
std::fill_n(ptr->allocated, ptr->sizes[0], value);
|
|
}
|
|
|
|
/// Fills the given 2D int memref with the given int8 value.
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void
|
|
_mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
|
|
int8_t value) {
|
|
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
|
|
}
|
|
|
|
/// Fills the given 3D int memref with the given int8 value.
|
|
VULKAN_WRAPPER_SYMBOL_EXPORT void
|
|
_mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
|
|
int8_t value) {
|
|
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
|
|
value);
|
|
}
|
|
}
|