DriverTrac/venv/lib/python3.12/site-packages/onnx/reference/ops/op_matmul.py
2025-11-28 09:08:33 +05:30

26 lines
644 B
Python

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import numpy as np
from onnx.reference.ops._op import OpRunBinaryNum
def numpy_matmul(a, b):
"""Implements a matmul product. See :func:`np.matmul`.
Handles sparse matrices.
"""
try:
if len(a.shape) <= 2 and len(b.shape) <= 2:
return np.dot(a, b)
return np.matmul(a, b)
except ValueError as e:
raise ValueError(f"Unable to multiply shapes {a.shape!r}, {b.shape!r}.") from e
class MatMul(OpRunBinaryNum):
def _run(self, a, b):
return (numpy_matmul(a, b),)