DriverTrac/venv/lib/python3.12/site-packages/onnx/reference/ops/op_concat.py
2025-11-28 09:08:33 +05:30

23 lines
682 B
Python

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import numpy as np
from onnx.reference.op_run import OpRun
class Concat(OpRun):
def _preprocess(self, a: np.ndarray, axis: int) -> np.ndarray:
if len(a.shape) == 0:
raise RuntimeError(f"Concat: one input has an empty shape: {a!r}.")
if axis >= len(a.shape):
new_shape = a.shape + (1,) * (axis + 1 - len(a.shape))
return a.reshape(new_shape)
return a
def _run(self, *args, axis=None):
targs = tuple(self._preprocess(a, axis) for a in args)
return (np.concatenate(targs, axis),)