DriverTrac/venv/lib/python3.12/site-packages/triton/tools/extra/hip/compile.cpp

67 lines
1.8 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
/* clang-format off */
#include <stdio.h>
#include <stdint.h>
#include <inttypes.h>
#include <string.h>
#include <hip/hip_runtime.h>
// helpers to check for hip errors
#define HIP_CHECK(ans) {{\
gpuAssert((ans), __FILE__, __LINE__);\
}}\
static inline void gpuAssert(hipError_t code, const char *file, int line) {{
if (code != hipSuccess) {{
const char *prefix = "Triton Error [HIP]: ";
const char *str;
hipDrvGetErrorString(code, &str);
char err[1024] = {{0}};
strcat(err, prefix);
strcat(err, str);
printf("%s\\n", err);
exit(code);
}}
}}
// globals
#define HSACO_NAME {kernel_name}_hsaco
hipModule_t {kernel_name}_mod = nullptr;
hipFunction_t {kernel_name}_func = nullptr;
unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }};
void unload_{kernel_name}(void) {{
HIP_CHECK(hipModuleUnload({kernel_name}_mod));
}}
void load_{kernel_name}() {{
int dev = 0;
void *bin = (void *)&HSACO_NAME;
int shared = {shared};
HIP_CHECK(hipModuleLoadData(&{kernel_name}_mod, bin));
HIP_CHECK(hipModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"));
}}
/*
{kernel_docstring}
*/
hipError_t {kernel_name}(hipStream_t stream, {signature}) {{
if ({kernel_name}_func == nullptr)
load_{kernel_name}();
unsigned int gX = {gridX};
unsigned int gY = {gridY};
unsigned int gZ = {gridZ};
hipDeviceptr_t global_scratch = 0;
hipDeviceptr_t profile_scratch = 0;
void *args[{num_args}] = {{ {arg_pointers} }};
// TODO: shared memory
if(gX * gY * gZ > 0)
return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * warpSize, 1, 1, {shared}, stream, args, nullptr);
else
return hipErrorInvalidValue;
}}