# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import numpy as np from onnx.reference.op_run import OpRun def topk_sorted_implementation(X, k, axis, largest): """See function `_kneighbors_reduce_func `_. """ if isinstance(k, np.ndarray): if k.size != 1: raise RuntimeError(f"k must be an integer not {k!r}.") k = k[0] # This conversion is needed for distribution x86. k = int(k) # Used to tiebreak ind_axis = np.indices(X.shape)[axis] if largest: ind_axis = -ind_axis sorted_indices = np.lexsort((ind_axis, X), axis=axis) sorted_values = np.take_along_axis(X, sorted_indices, axis=axis) if largest: sorted_indices = np.flip(sorted_indices, axis=axis) sorted_values = np.flip(sorted_values, axis=axis) ark = np.arange(k) topk_sorted_indices = np.take(sorted_indices, ark, axis=axis) topk_sorted_values = np.take(sorted_values, ark, axis=axis) return topk_sorted_values, topk_sorted_indices class _CommonTopK(OpRun): def _common_run(self, data, ink, axis, largest=1): """Runtime for operator *TopK*. The implementation is not the most efficient as it sorts everything then extracts the top *k* values. .. warning:: ONNX specifications may be imprecise in case of negative value for axis. The implementation follows what `onnxruntime` does in `top_k.cc `_. """ k = ink[0] axis = axis if axis >= 0 else (axis + len(data.shape)) sort, sorti = topk_sorted_implementation(data, k, axis, largest) return (sort, sorti.astype(np.int64)) class TopK_1(_CommonTopK): def _run(self, data, k=None, axis=None): """Runtime for operator *TopK*. The implementation is not the most efficient as it sorts everything then extracts the top *k* values. .. warning:: ONNX specifications may be imprecise in case of negative value for axis. The implementation follows what `onnxruntime` does in `top_k.cc `_. """ return _CommonTopK._common_run(self, data, [k], axis=axis) class TopK_10(_CommonTopK): def _run(self, data, ink, axis=None): """Runtime for operator *TopK*. The implementation is not the most efficient as it sorts everything then extracts the top *k* values. .. warning:: ONNX specifications may be imprecise in case of negative value for axis. The implementation follows what `onnxruntime` does in `top_k.cc `_. """ return _CommonTopK._common_run(self, data, ink, axis=axis) class TopK_11(_CommonTopK): def _run( self, data, ink, axis=None, largest=None, sorted=None, # noqa: A002 ): """Runtime for operator *TopK*. The implementation is not the most efficient as it sorts everything then extracts the top *k* values. .. warning:: ONNX specifications may be imprecise in case of negative value for axis. The implementation follows what `onnxruntime` does in `top_k.cc `_. """ if sorted not in (True, 1): raise RuntimeError("TopK does not implement anything for sorted=0.") return _CommonTopK._common_run(self, data, ink, axis=axis, largest=largest)