DriverTrac/venv/lib/python3.12/site-packages/onnx/reference/ops/op_mod.py

19 lines
482 B
Python

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import numpy as np
from onnx.reference.op_run import OpRun
class Mod(OpRun):
def _run(self, a, b, fmod=None):
fmod = fmod or self.fmod
if fmod == 1:
return (np.fmod(a, b),)
if a.dtype in (np.float16, np.float32, np.float64):
return (np.nan_to_num(np.fmod(a, b)),)
return (np.nan_to_num(np.mod(a, b)),)