402 lines
18 KiB
C
402 lines
18 KiB
C
/*
|
|
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
* and proprietary rights in and to this software, related documentation
|
|
* and any modifications thereto. Any use, reproduction, disclosure or
|
|
* distribution of this software and related documentation without an express
|
|
* license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
*/
|
|
#if !defined(CUSPARSELT_HEADER_)
|
|
#define CUSPARSELT_HEADER_
|
|
|
|
#include "cusparse.h" // cusparseStatus_t
|
|
|
|
#include <cstddef> // size_t
|
|
#include <driver_types.h> // cudaStream_t
|
|
#include <library_types.h> // cudaDataType
|
|
#include <stdint.h> // uint8_t
|
|
|
|
//##############################################################################
|
|
//# CUSPARSELT VERSION INFORMATION
|
|
//##############################################################################
|
|
|
|
#define CUSPARSELT_VER_MAJOR 0
|
|
#define CUSPARSELT_VER_MINOR 7
|
|
#define CUSPARSELT_VER_PATCH 1
|
|
#define CUSPARSELT_VER_BUILD 0
|
|
#define CUSPARSELT_VERSION (CUSPARSELT_VER_MAJOR * 1000 + \
|
|
CUSPARSELT_VER_MINOR * 100 + \
|
|
CUSPARSELT_VER_PATCH)
|
|
|
|
// #############################################################################
|
|
// # MACRO
|
|
// #############################################################################
|
|
|
|
#if !defined(CUSPARSELT_API)
|
|
# if defined(_WIN32)
|
|
# define CUSPARSELT_API __stdcall
|
|
# else
|
|
# define CUSPARSELT_API
|
|
# endif
|
|
#endif
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
#if defined(__cplusplus)
|
|
extern "C" {
|
|
#endif // defined(__cplusplus)
|
|
|
|
//##############################################################################
|
|
//# OPAQUE DATA STRUCTURES
|
|
//##############################################################################
|
|
|
|
typedef struct { uint8_t data[1024]; } cusparseLtHandle_t;
|
|
|
|
typedef struct { uint8_t data[1024]; } cusparseLtMatDescriptor_t;
|
|
|
|
typedef struct { uint8_t data[1024]; } cusparseLtMatmulDescriptor_t;
|
|
|
|
typedef struct { uint8_t data[1024]; } cusparseLtMatmulAlgSelection_t;
|
|
|
|
typedef struct { uint8_t data[1024]; } cusparseLtMatmulPlan_t;
|
|
|
|
const char* CUSPARSELT_API
|
|
cusparseLtGetErrorName(cusparseStatus_t status);
|
|
|
|
const char* CUSPARSELT_API
|
|
cusparseLtGetErrorString(cusparseStatus_t status);
|
|
|
|
//##############################################################################
|
|
//# INITIALIZATION, DESTROY
|
|
//##############################################################################
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtInit(cusparseLtHandle_t* handle);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtDestroy(const cusparseLtHandle_t* handle);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtGetVersion(const cusparseLtHandle_t* handle,
|
|
int* version);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtGetProperty(libraryPropertyType propertyType,
|
|
int* value);
|
|
|
|
//##############################################################################
|
|
//# MATRIX DESCRIPTOR
|
|
//##############################################################################
|
|
// Dense Matrix
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtDenseDescriptorInit(const cusparseLtHandle_t* handle,
|
|
cusparseLtMatDescriptor_t* matDescr,
|
|
int64_t rows,
|
|
int64_t cols,
|
|
int64_t ld,
|
|
uint32_t alignment,
|
|
cudaDataType valueType,
|
|
cusparseOrder_t order);
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Structured Matrix
|
|
|
|
typedef enum {
|
|
CUSPARSELT_SPARSITY_50_PERCENT
|
|
} cusparseLtSparsity_t;
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtStructuredDescriptorInit(const cusparseLtHandle_t* handle,
|
|
cusparseLtMatDescriptor_t* matDescr,
|
|
int64_t rows,
|
|
int64_t cols,
|
|
int64_t ld,
|
|
uint32_t alignment,
|
|
cudaDataType valueType,
|
|
cusparseOrder_t order,
|
|
cusparseLtSparsity_t sparsity);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatDescriptorDestroy(const cusparseLtMatDescriptor_t* matDescr);
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
typedef enum {
|
|
CUSPARSELT_MAT_NUM_BATCHES, // READ/WRITE
|
|
CUSPARSELT_MAT_BATCH_STRIDE // READ/WRITE
|
|
} cusparseLtMatDescAttribute_t;
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatDescSetAttribute(const cusparseLtHandle_t* handle,
|
|
cusparseLtMatDescriptor_t* matmulDescr,
|
|
cusparseLtMatDescAttribute_t matAttribute,
|
|
const void* data,
|
|
size_t dataSize);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatDescGetAttribute(const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatDescriptor_t* matmulDescr,
|
|
cusparseLtMatDescAttribute_t matAttribute,
|
|
void* data,
|
|
size_t dataSize);
|
|
|
|
//##############################################################################
|
|
//# MATMUL DESCRIPTOR
|
|
//##############################################################################
|
|
|
|
typedef enum {
|
|
CUSPARSE_COMPUTE_32I,
|
|
CUSPARSE_COMPUTE_16F,
|
|
CUSPARSE_COMPUTE_32F
|
|
} cusparseComputeType;
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatmulDescriptorInit(const cusparseLtHandle_t* handle,
|
|
cusparseLtMatmulDescriptor_t* matmulDescr,
|
|
cusparseOperation_t opA,
|
|
cusparseOperation_t opB,
|
|
const cusparseLtMatDescriptor_t* matA,
|
|
const cusparseLtMatDescriptor_t* matB,
|
|
const cusparseLtMatDescriptor_t* matC,
|
|
const cusparseLtMatDescriptor_t* matD,
|
|
cusparseComputeType computeType);
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
typedef enum {
|
|
CUSPARSELT_MATMUL_ACTIVATION_RELU, // READ/WRITE
|
|
CUSPARSELT_MATMUL_ACTIVATION_RELU_UPPERBOUND, // READ/WRITE
|
|
CUSPARSELT_MATMUL_ACTIVATION_RELU_THRESHOLD, // READ/WRITE
|
|
CUSPARSELT_MATMUL_ACTIVATION_GELU, // READ/WRITE
|
|
CUSPARSELT_MATMUL_ACTIVATION_GELU_SCALING, // READ/WRITE
|
|
CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING, // READ/WRITE
|
|
CUSPARSELT_MATMUL_BETA_VECTOR_SCALING, // READ/WRITE
|
|
CUSPARSELT_MATMUL_BIAS_STRIDE, // READ/WRITE
|
|
CUSPARSELT_MATMUL_BIAS_POINTER, // READ/WRITE
|
|
CUSPARSELT_MATMUL_SPARSE_MAT_POINTER, // READ/WRITE
|
|
//
|
|
CUSPARSELT_MATMUL_A_SCALE_MODE, // READ/WRITE
|
|
CUSPARSELT_MATMUL_B_SCALE_MODE, // READ/WRITE
|
|
CUSPARSELT_MATMUL_C_SCALE_MODE, // READ/WRITE
|
|
CUSPARSELT_MATMUL_D_SCALE_MODE, // READ/WRITE
|
|
CUSPARSELT_MATMUL_D_OUT_SCALE_MODE, // READ/WRITE
|
|
|
|
CUSPARSELT_MATMUL_A_SCALE_POINTER,
|
|
CUSPARSELT_MATMUL_B_SCALE_POINTER,
|
|
CUSPARSELT_MATMUL_C_SCALE_POINTER,
|
|
CUSPARSELT_MATMUL_D_SCALE_POINTER,
|
|
CUSPARSELT_MATMUL_D_OUT_SCALE_POINTER,
|
|
} cusparseLtMatmulDescAttribute_t;
|
|
|
|
typedef enum {
|
|
CUSPARSELT_MATMUL_SCALE_NONE,
|
|
CUSPARSELT_MATMUL_MATRIX_SCALE_SCALAR_32F,
|
|
CUSPARSELT_MATMUL_MATRIX_SCALE_VEC32_UE4M3,
|
|
CUSPARSELT_MATMUL_MATRIX_SCALE_VEC64_UE8M0,
|
|
} cusparseLtMatmulMatrixScale_t;
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatmulDescSetAttribute(const cusparseLtHandle_t* handle,
|
|
cusparseLtMatmulDescriptor_t* matmulDescr,
|
|
cusparseLtMatmulDescAttribute_t matmulAttribute,
|
|
const void* data,
|
|
size_t dataSize);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatmulDescGetAttribute(
|
|
const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatmulDescriptor_t* matmulDescr,
|
|
cusparseLtMatmulDescAttribute_t matmulAttribute,
|
|
void* data,
|
|
size_t dataSize);
|
|
|
|
//##############################################################################
|
|
//# ALGORITHM SELECTION
|
|
//##############################################################################
|
|
|
|
typedef enum {
|
|
CUSPARSELT_MATMUL_ALG_DEFAULT
|
|
} cusparseLtMatmulAlg_t;
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatmulAlgSelectionInit(
|
|
const cusparseLtHandle_t* handle,
|
|
cusparseLtMatmulAlgSelection_t* algSelection,
|
|
const cusparseLtMatmulDescriptor_t* matmulDescr,
|
|
cusparseLtMatmulAlg_t alg);
|
|
|
|
typedef enum {
|
|
CUSPARSELT_MATMUL_ALG_CONFIG_ID, // READ/WRITE
|
|
CUSPARSELT_MATMUL_ALG_CONFIG_MAX_ID, // READ-ONLY
|
|
CUSPARSELT_MATMUL_SEARCH_ITERATIONS, // READ/WRITE
|
|
CUSPARSELT_MATMUL_SPLIT_K, // READ/WRITE
|
|
CUSPARSELT_MATMUL_SPLIT_K_MODE, // READ/WRITE
|
|
CUSPARSELT_MATMUL_SPLIT_K_BUFFERS // READ/WRITE
|
|
} cusparseLtMatmulAlgAttribute_t;
|
|
|
|
typedef enum {
|
|
CUSPARSELT_INVALID_MODE = 0,
|
|
CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL = 1,
|
|
CUSPARSELT_SPLIT_K_MODE_TWO_KERNELS = 2,
|
|
CUSPARSELT_HEURISTIC,
|
|
CUSPARSELT_DATAPARALLEL,
|
|
CUSPARSELT_SPLITK,
|
|
CUSPARSELT_STREAMK
|
|
} cusparseLtSplitKMode_t;
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatmulAlgSetAttribute(const cusparseLtHandle_t* handle,
|
|
cusparseLtMatmulAlgSelection_t* algSelection,
|
|
cusparseLtMatmulAlgAttribute_t attribute,
|
|
const void* data,
|
|
size_t dataSize);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatmulAlgGetAttribute(
|
|
const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatmulAlgSelection_t* algSelection,
|
|
cusparseLtMatmulAlgAttribute_t attribute,
|
|
void* data,
|
|
size_t dataSize);
|
|
|
|
//##############################################################################
|
|
//# MATMUL PLAN
|
|
//##############################################################################
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatmulGetWorkspace(
|
|
const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatmulPlan_t* plan,
|
|
size_t* workspaceSize);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatmulPlanInit(const cusparseLtHandle_t* handle,
|
|
cusparseLtMatmulPlan_t* plan,
|
|
const cusparseLtMatmulDescriptor_t* matmulDescr,
|
|
const cusparseLtMatmulAlgSelection_t* algSelection);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatmulPlanDestroy(const cusparseLtMatmulPlan_t* plan);
|
|
|
|
//##############################################################################
|
|
//# MATMUL EXECUTION
|
|
//##############################################################################
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatmul(const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatmulPlan_t* plan,
|
|
const void* alpha,
|
|
const void* d_A,
|
|
const void* d_B,
|
|
const void* beta,
|
|
const void* d_C,
|
|
void* d_D,
|
|
void* workspace,
|
|
cudaStream_t* streams,
|
|
int32_t numStreams);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtMatmulSearch(const cusparseLtHandle_t* handle,
|
|
cusparseLtMatmulPlan_t* plan,
|
|
const void* alpha,
|
|
const void* d_A,
|
|
const void* d_B,
|
|
const void* beta,
|
|
const void* d_C,
|
|
void* d_D,
|
|
void* workspace,
|
|
// void* device_buf,
|
|
cudaStream_t* streams,
|
|
int32_t numStreams);
|
|
|
|
//##############################################################################
|
|
//# HELPER ROUTINES
|
|
//##############################################################################
|
|
// PRUNING
|
|
|
|
typedef enum {
|
|
CUSPARSELT_PRUNE_SPMMA_TILE = 0,
|
|
CUSPARSELT_PRUNE_SPMMA_STRIP = 1
|
|
} cusparseLtPruneAlg_t;
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtSpMMAPrune(const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatmulDescriptor_t* matmulDescr,
|
|
const void* d_in,
|
|
void* d_out,
|
|
cusparseLtPruneAlg_t pruneAlg,
|
|
cudaStream_t stream);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtSpMMAPruneCheck(const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatmulDescriptor_t* matmulDescr,
|
|
const void* d_in,
|
|
int* valid,
|
|
cudaStream_t stream);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtSpMMAPrune2(const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatDescriptor_t* sparseMatDescr,
|
|
int isSparseA,
|
|
cusparseOperation_t op,
|
|
const void* d_in,
|
|
void* d_out,
|
|
cusparseLtPruneAlg_t pruneAlg,
|
|
cudaStream_t stream);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtSpMMAPruneCheck2(const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatDescriptor_t* sparseMatDescr,
|
|
int isSparseA,
|
|
cusparseOperation_t op,
|
|
const void* d_in,
|
|
int* d_valid,
|
|
cudaStream_t stream);
|
|
|
|
//------------------------------------------------------------------------------
|
|
// COMPRESSION
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtSpMMACompressedSize(
|
|
const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatmulPlan_t* plan,
|
|
size_t* compressedSize,
|
|
size_t* compressedBufferSize);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtSpMMACompress(const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatmulPlan_t* plan,
|
|
const void* d_dense,
|
|
void* d_compressed,
|
|
void* d_compressed_buffer,
|
|
cudaStream_t stream);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtSpMMACompressedSize2(
|
|
const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatDescriptor_t* sparseMatDescr,
|
|
size_t* compressedSize,
|
|
size_t* compressedBufferSize);
|
|
|
|
cusparseStatus_t CUSPARSELT_API
|
|
cusparseLtSpMMACompress2(const cusparseLtHandle_t* handle,
|
|
const cusparseLtMatDescriptor_t* sparseMatDescr,
|
|
int isSparseA,
|
|
cusparseOperation_t op,
|
|
const void* d_dense,
|
|
void* d_compressed,
|
|
void* d_compressed_buffer,
|
|
cudaStream_t stream);
|
|
|
|
//==============================================================================
|
|
//==============================================================================
|
|
|
|
#if defined(__cplusplus)
|
|
}
|
|
#endif // defined(__cplusplus)
|
|
|
|
#endif // !defined(CUSPARSELT_HEADER_)
|
|
|