172 lines
4.6 KiB
C++
172 lines
4.6 KiB
C++
/*
|
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD-style license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <cstdint>
|
|
#include "./FbgemmBuild.h" // @manual
|
|
#include "./UtilsAvx2.h" // @manual
|
|
|
|
/// @defgroup fbgemm-quant-utils-avx2 Quantization Utilities (AVX2)
|
|
///
|
|
|
|
namespace fbgemm {
|
|
|
|
/// Struct from <a href="https://github.com/google/gemmlowp">`gemmlowp`</a>
|
|
///
|
|
/// A structure to hold quantization parameters `scale` and `zero_point`.
|
|
/// The meaning of these values is as the constants in the quantization equation
|
|
///
|
|
/// `real_value = scale * (quantized_value - zero_point)`
|
|
///
|
|
/// In other words, 'zero_point' is the quantized value that corresponds
|
|
/// to the real value 0, and 'scale' is the difference of real values
|
|
/// corresponding to consecutive quantized values.
|
|
struct FBGEMM_API TensorQuantizationParams {
|
|
float scale;
|
|
std::int32_t zero_point;
|
|
int precision;
|
|
float Min() const;
|
|
float Max() const;
|
|
};
|
|
|
|
/// Parameters when we scale from int32 intermediate matrix multiplication
|
|
/// results to 8-bit integers
|
|
struct FBGEMM_API RequantizationParams {
|
|
/// For floating-point requantization
|
|
float real_multiplier;
|
|
|
|
/// For fixed-point requantization
|
|
std::int32_t multiplier;
|
|
int right_shift;
|
|
|
|
TensorQuantizationParams target_qparams;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
// Utility functions
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T = std::uint8_t, bool LEGACY = true>
|
|
void QuantizeAvx2(
|
|
const float* src,
|
|
T* dst,
|
|
int64_t len,
|
|
const TensorQuantizationParams& qparams);
|
|
|
|
template <typename T = std::uint8_t>
|
|
void FusedQuantizeDequantizeAvx2(
|
|
const float* src,
|
|
float* dst,
|
|
int len,
|
|
const TensorQuantizationParams& qparams,
|
|
float noise_ratio = 0.0f);
|
|
|
|
/// @ingroup fbgemm-quant-utils-avx2
|
|
///
|
|
/// Random number generator in [0, 9] based on
|
|
/// <a href="https://www.jstatsoft.org/v08/i14/paper">this paper</a>.
|
|
uint32_t FBGEMM_API Xor128();
|
|
|
|
/// @ingroup fbgemm-quant-utils-avx2
|
|
///
|
|
/// @brief Find the min and max value in a float matrix.
|
|
void FBGEMM_API FindMinMax(const float* m, float* min, float* max, int64_t len);
|
|
|
|
void RequantizeFixedPointAvx2(
|
|
const std::int32_t* src,
|
|
std::uint8_t* dst,
|
|
int len,
|
|
const RequantizationParams& params);
|
|
|
|
void RequantizeAvx2(
|
|
const std::int32_t* src,
|
|
std::uint8_t* dst,
|
|
int len,
|
|
const RequantizationParams& params);
|
|
|
|
/// @ingroup fbgemm-quant-utils-avx2
|
|
///
|
|
/// Requantize with avx2 and bias is fused.
|
|
template <
|
|
bool A_SYMMETRIC,
|
|
bool B_SYMMETRIC,
|
|
QuantizationGranularity Q_GRAN,
|
|
bool HAS_BIAS,
|
|
bool FUSE_RELU,
|
|
typename BIAS_TYPE = std::int32_t,
|
|
bool DIRECT = false>
|
|
FBGEMM_API void requantizeOutputProcessingAvx2(
|
|
std::uint8_t* out,
|
|
const std::int32_t* inp,
|
|
const block_type_t& block,
|
|
int ld_out,
|
|
int ld_in,
|
|
const requantizationParams_t<BIAS_TYPE>& r);
|
|
|
|
template <
|
|
bool A_SYMMETRIC,
|
|
bool B_SYMMETRIC,
|
|
QuantizationGranularity Q_GRAN,
|
|
bool HAS_BIAS,
|
|
bool FUSE_RELU,
|
|
int C_PER_G,
|
|
typename BIAS_TYPE = std::int32_t>
|
|
FBGEMM_API void requantizeOutputProcessingGConvAvx2(
|
|
std::uint8_t* out,
|
|
const std::int32_t* inp,
|
|
const block_type_t& block,
|
|
int ld_out,
|
|
int ld_in,
|
|
const requantizationParams_t<BIAS_TYPE>& r);
|
|
|
|
template <
|
|
bool A_SYMMETRIC,
|
|
bool B_SYMMETRIC,
|
|
QuantizationGranularity Q_GRAN,
|
|
bool HAS_BIAS,
|
|
bool FUSE_RELU>
|
|
FBGEMM_API void requantizeForFloatAvx2(
|
|
float* out,
|
|
const std::int32_t* inp,
|
|
const block_type_t& block,
|
|
int ld_out,
|
|
int ld_in,
|
|
const requantizationForFloatParams_t& r);
|
|
|
|
template <typename InputType, int BIT_RATE>
|
|
void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
|
|
const InputType* input,
|
|
size_t input_rows,
|
|
int input_columns,
|
|
std::uint8_t* output);
|
|
|
|
template <typename InputType>
|
|
void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(
|
|
const InputType* input,
|
|
size_t input_rows,
|
|
int input_columns,
|
|
std::uint8_t* output,
|
|
const InputType* rowwise_min_max = nullptr);
|
|
|
|
template <typename OutputType, int BIT_RATE>
|
|
void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(
|
|
const std::uint8_t* input,
|
|
size_t input_rows,
|
|
int input_columns,
|
|
OutputType* output);
|
|
|
|
template <typename OutputType>
|
|
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(
|
|
const std::uint8_t* input,
|
|
size_t input_rows,
|
|
int input_columns,
|
|
OutputType* output);
|
|
|
|
} // namespace fbgemm
|