DriverTrac/venv/lib/python3.12/site-packages/onnx/defs/nn/utils.cc

223 lines
9.6 KiB
C++

/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/nn/utils.h"
#include <algorithm>
namespace ONNX_NAMESPACE {
namespace defs {
namespace nn {
namespace utils {
void AttentionPropagateElemTypeFromInputToOutput(InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
int64_t kv_sequence_length = -1;
ONNX_NAMESPACE::TensorShapeProto output_shape;
ONNX_NAMESPACE::TensorShapeProto qk_matmul_shape;
if (hasInputShape(ctx, 0)) {
auto& query_shape = getInputShape(ctx, 0);
auto& query_dims = query_shape.dim();
if ((query_dims.size() != 3) && (query_dims.size() != 4)) {
fail_shape_inference("Inputs 0 (query) shall be 3 or 4 dimensions");
}
if (query_dims.size() == 3) {
auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
if (q_num_heads_attr == nullptr) {
fail_type_inference("3D inputs expected to have q_num_heads attribute.");
}
auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
if (kv_num_heads_attr == nullptr) {
fail_type_inference("3D inputs expected to have q_num_heads attribute.");
}
}
*output_shape.add_dim() = query_dims[0]; // batch_size
*output_shape.add_dim() = query_dims[1]; // num_heads for 4D, sequence_length for 3D
*qk_matmul_shape.add_dim() = query_dims[0]; // batch_size
if (hasInputShape(ctx, 1)) {
auto& key_shape = getInputShape(ctx, 1);
auto& key_dims = key_shape.dim();
if ((key_dims.size() != 3) && (key_dims.size() != 4)) {
fail_shape_inference("Inputs 1 (key) shall be 3 or 4 dimensions");
}
}
if (hasInputShape(ctx, 2)) {
auto& value_shape = getInputShape(ctx, 2);
auto& value_dims = value_shape.dim();
if ((value_dims.size() != 3) && (value_dims.size() != 4)) {
fail_shape_inference("Inputs 2 (value) shall be 3 or 4 dimensions");
}
// Update Output Shape for 4D inputs
// Input 0 (query) has shape (batch_size, q_num_heads, q_sequence_length, head_size)
// Input 1 (key) has shape (batch_size, kv_num_heads, kv_sequence_length, head_size)
// Input 2 (value) has shape (batch_size, kv_num_heads, kv_sequence_length, v_head_size)
// Output 0 has shape (batch_size, q_num_heads, q_sequence_length, v_head_size)
if (value_dims.size() == 4 && query_dims.size() == 4) {
kv_sequence_length = value_dims[2].dim_value();
*output_shape.add_dim() = query_dims[2]; // sequence_length
*output_shape.add_dim() = value_dims[3]; // head_size
updateOutputShape(ctx, 0, output_shape);
// Update qk_matmul_shape
*qk_matmul_shape.add_dim() = query_dims[1]; // q_num_heads
*qk_matmul_shape.add_dim() = query_dims[2]; // q_sequence_length
qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
}
// Update Output Shape for 3D inputs
// Input 0 (query) has shape (batch_size, q_sequence_length, q_hidden_size),
// q_hidden_size = q_num_heads * head_size
// Input 1 (key) has shape (batch_size, kv_sequence_length, k_hidden_size),
// k_hidden_size = kv_num_heads * head_size
// Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size),
// v_hidden_size = kv_num_heads * v_head_size
// Output 0 has shape (batch_size, q_sequence_length, hidden_size),
// hidden_size = q_num_heads * v_head_size
if (value_dims.size() == 3 && query_dims.size() == 3) {
kv_sequence_length = value_dims[1].dim_value();
auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
if (q_num_heads_attr == nullptr) {
fail_type_inference("3D inputs expected to have q_num_heads attribute.");
}
auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
if (kv_num_heads_attr == nullptr) {
fail_type_inference("3D inputs expected to have kv_num_heads attribute.");
}
int64_t q_num_heads = q_num_heads_attr->i();
int64_t kv_num_heads = kv_num_heads_attr->i();
// Calculate v_head_size
int64_t v_head_size = value_dims[2].dim_value() / kv_num_heads;
output_shape.add_dim()->set_dim_value(v_head_size * q_num_heads);
updateOutputShape(ctx, 0, output_shape);
// Update qk_matmul_shape
qk_matmul_shape.add_dim()->set_dim_value(q_num_heads);
*qk_matmul_shape.add_dim() = query_dims[1];
qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
}
}
}
if (ctx.hasOutput(3)) { // has qk_matmul_output
propagateElemTypeFromInputToOutput(ctx, 0, 3);
updateOutputShape(ctx, 3, qk_matmul_shape);
}
if (ctx.hasOutput(1) && ctx.hasOutput(2)) { // has present outputs
if (ctx.hasInput(4) && ctx.hasInput(5)) { // has past_key
// copy the type from query to present key and value
propagateElemTypeFromInputToOutput(ctx, 4, 1);
propagateElemTypeFromInputToOutput(ctx, 5, 2);
if (hasInputShape(ctx, 4) && hasInputShape(ctx, 5)) {
auto& past_key_shape = getInputShape(ctx, 4);
auto& past_key_dims = past_key_shape.dim();
auto& past_value_shape = getInputShape(ctx, 5);
auto& past_value_dims = past_value_shape.dim();
// past key has shape (batch_size, kv_num_heads, past_sequence_length, head_size)
if (past_key_dims.size() != 4) {
fail_shape_inference("The past_key input shall be 4 dimensions");
}
// past value has shape (batch_size, kv_num_heads, past_sequence_length, v_head_size)
if (past_value_dims.size() != 4) {
fail_shape_inference("The past_value input shall be 4 dimensions");
}
if (kv_sequence_length > 0 && past_key_dims[2].has_dim_value()) {
int64_t total_sequence_length = kv_sequence_length + past_key_dims[2].dim_value();
ONNX_NAMESPACE::TensorShapeProto present_key_shape;
for (auto& dim : past_key_dims) {
*present_key_shape.add_dim() = dim;
}
ONNX_NAMESPACE::TensorShapeProto present_value_shape;
for (auto& dim : past_value_dims) {
*present_value_shape.add_dim() = dim;
}
if (ctx.hasOutput(3)) { // has qk_matmul_output with bias
qk_matmul_shape.mutable_dim(3)->set_dim_value(total_sequence_length);
updateOutputShape(ctx, 3, qk_matmul_shape);
}
// shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size)
present_key_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
present_value_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
updateOutputShape(ctx, 1, present_key_shape);
updateOutputShape(ctx, 2, present_value_shape);
}
}
}
}
}
bool AttentionAppendFunctionCausalMask(const FunctionBodyBuildContext& ctx, FunctionBuilder& builder, bool padding) {
builder.Add("NewKVSeqLen = Shape <start = -2, end = -1> (PresentKey)")
.Add("AttnBiasShape = Concat <axis = 0> (QSeqLen, NewKVSeqLen)");
float neg_inf = -std::numeric_limits<float>::infinity();
builder.Const1D("FloatNegInf", neg_inf);
builder.Const1D("ScalarZero", 0.f);
// If attn_mask is provided
if (ctx.hasInput(3)) {
auto* up = ctx.getInputType(3);
if ((up == nullptr) || (!up->has_tensor_type()))
return false;
int64_t U = up->tensor_type().elem_type();
builder.Add(
U == ONNX_NAMESPACE::TensorProto_DataType_BOOL ? "AttnBiasShort = Where(attn_mask, ScalarZero, FloatNegInf)"
: "AttnBiasShort = Identity(attn_mask)");
// If attn_mask has a shorter kv sequence length, we pad it to NewKVSeqLen with FloatNegInf
if (padding) {
builder.Add("MaskKVSeqLen = Shape <start = -1> (attn_mask)")
.Add("PaddingKVSeqLen = Sub(NewKVSeqLen, MaskKVSeqLen)")
.Add("Pads = Concat <axis = 0> (Zero1D, PaddingKVSeqLen)")
.Add("FloatNegInfCast = CastLike(FloatNegInf, AttnBiasShort)")
.Add("AttnBias = Pad(AttnBiasShort, Pads, FloatNegInfCast, NegOne1D)");
} else {
builder.Add("AttnBias = Identity(AttnBiasShort)");
}
} else {
builder.Add("AttnBias = ConstantOfShape(AttnBiasShape)");
}
// If is_causal set to true, the attention masking is a lower triangular matrix when the mask
// is a square matrix. The attention masking has the form of the upper left causal bias due to
// the alignment when the mask is a non-square matrix.
// An error is thrown if both attn_mask and is_causal are set.
auto* is_causal_attr = ctx.getAttribute("is_causal");
int64_t is_causal = (is_causal_attr != nullptr) ? is_causal_attr->i() : 0;
if (is_causal == 1) {
builder.Const1D("Zero", static_cast<int64_t>(0))
.Const1D("One", static_cast<int64_t>(1))
.Add("ZeroNoDim = Squeeze(Zero, Zero)")
.Add("OneNoDim = Squeeze(One, Zero)")
.Add("SequenceLength = Gather(AttnBiasShape, ZeroNoDim)")
.Add("TotalSequenceLength = Gather(AttnBiasShape, OneNoDim)")
.Add("RangeRow = Range(ZeroNoDim, SequenceLength, OneNoDim)")
.Add("RangeRow2D = Unsqueeze(RangeRow, One)")
.Add("RangeCol = Range(ZeroNoDim, TotalSequenceLength, OneNoDim)")
.Add("RangeCol2D = Unsqueeze(RangeCol, Zero)")
.Add("RangeRow2DPast = Add(RangeRow2D, PastKVSeqLen)")
.Add("BoolMaskTri = Less(RangeRow2DPast, RangeCol2D)")
.Add("MaskTri = Where(BoolMaskTri, FloatNegInf, ScalarZero)")
.Add("AttnBiasCausalOrNot = Add(AttnBias, MaskTri)");
} else {
builder.Add("AttnBiasCausalOrNot = Identity(AttnBias)");
}
return true;
}
} // namespace utils
} // namespace nn
} // namespace defs
} // namespace ONNX_NAMESPACE