DriverTrac/venv/lib/python3.12/site-packages/onnx/reference/ops/op_split.py
2025-11-28 09:08:33 +05:30

52 lines
1.5 KiB
Python

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from onnx.reference.op_run import OpRun
class CommonSplit(OpRun):
def __init__(self, onnx_node, run_params):
OpRun.__init__(self, onnx_node, run_params)
self.n_outputs = len(onnx_node.output)
def common_run(self, mat, split, axis, num_outputs):
n_outputs = num_outputs or self.n_outputs
if split is None:
if mat.shape[axis] % n_outputs == 0:
div = mat.shape[axis] // n_outputs
split = [div] * n_outputs
else:
div = mat.shape[axis] // n_outputs + 1
split = [div] * n_outputs
split[-1] += mat.shape[axis] - sum(split)
sli = [slice(0, s) for s in mat.shape]
res = []
pos = 0
for spl in split:
sli[axis] = slice(pos, pos + spl)
pos += spl
res.append(mat[tuple(sli)])
return tuple(res)
class Split_2(CommonSplit):
def _run(self, mat, axis=None, split=None):
return self.common_run(mat, split, axis=axis, num_outputs=None)
class Split_11(Split_2):
pass
class Split_13(CommonSplit):
def _run(self, mat, split=None, axis=None):
return self.common_run(mat, split, axis=axis, num_outputs=None)
class Split_18(CommonSplit):
def _run(self, mat, split=None, axis=None, num_outputs=None):
return self.common_run(mat, split, axis=axis, num_outputs=num_outputs)