// Copyright (c) ONNX Project Contributors /* * SPDX-License-Identifier: Apache-2.0 */ #include "onnx/inliner/inliner.h" #include #include #include #include #include #include #include #include "onnx/common/assertions.h" #include "onnx/common/constants.h" #include "onnx/common/proto_util.h" #include "onnx/common/visitor.h" #include "onnx/defs/parser.h" #include "onnx/shape_inference/attribute_binder.h" #include "onnx/shape_inference/implementation.h" #include "onnx/version_converter/convert.h" namespace ONNX_NAMESPACE { namespace inliner { namespace { // internal/private API using namespace internal; using OpsetMapBase = std::unordered_map; // A representation of the opset versions required by a model or a function. // Used to check for compatibility between a model and a function or between // two functions. struct OpsetMap : public OpsetMapBase { public: // Construct a map representing the opset versions required by a model. explicit OpsetMap(const ModelProto& model) { (void)Add(model.opset_import()); } // Adds the opset versions required by a function to the map. Returns true // iff the function is compatible with the map, i.e., if the function does // not require a different version for any domain already in the map. bool Add(const FunctionProto& function) { return Add(function.opset_import()); } // Returns the set of mismatches in the opset requirements of given // function and the map. OpsetMapBase Mismatches(const FunctionProto& function) const { return Mismatches(function.opset_import()); } private: OpsetMapBase Mismatches(const google::protobuf::RepeatedPtrField& list) const { OpsetMapBase result; for (const auto& pair : list) { auto iter = this->find(NormalizeDomain(pair.domain())); if ((iter != this->end()) && (iter->second != pair.version())) result.insert(*iter); } return result; } bool Add(const google::protobuf::RepeatedPtrField& list) { for (const auto& pair : list) { auto domain = NormalizeDomain(pair.domain()); auto version = pair.version(); auto iter = this->find(domain); if (iter != this->end()) { if (iter->second != version) return false; } else { (*this)[domain] = version; } } return true; } }; using RepeatedNodeProto = google::protobuf::RepeatedPtrField; class NameGenerator : private Visitor { public: explicit NameGenerator(const GraphProto& graph) : index_(0) { NameGenerator::VisitGraph(graph); } explicit NameGenerator(const FunctionProto& function) : index_(0) { NameGenerator::VisitFunction(function); } void ResetFor(const GraphProto& graph) { index_ = 0; existing_names_.clear(); NameGenerator::VisitGraph(graph); } void ResetFor(const FunctionProto& function) { index_ = 0; existing_names_.clear(); NameGenerator::VisitFunction(function); } // Creates a new unique name, based on a suggested name, and adds it to the set // of existing names. Returns the newly created name. std::string CreateNew(const std::string& suggested) { std::string name = suggested; while (existing_names_.count(name) > 0) { name = suggested + "_" + std::to_string(index_++); } existing_names_.insert(name); return name; } void Add(const std::string& name) { // We don't bother to check for empty string names. Ok to add them. existing_names_.insert(name); } bool ProcessGraph(const GraphProto& graph) override { for (const auto& x : graph.input()) Add(x.name()); for (const auto& x : graph.initializer()) Add(x.name()); // Adding graph outputs is redundant for a valid graph, but we do it anyway, // to produce better results for invalid graphs. for (const auto& x : graph.output()) Add(x.name()); return true; } bool ProcessFunction(const FunctionProto& function) override { for (const auto& x : function.input()) Add(x); for (const auto& x : function.output()) Add(x); return true; } bool ProcessNode(const NodeProto& node) override { // We use a single name-space for node names and variable names, to keep name-generation simple. Add(node.name()); for (const std::string& name : node.input()) { Add(name); } for (const std::string& name : node.output()) { Add(name); } return true; } private: unsigned int index_; std::unordered_set existing_names_; }; class InliningRenamer : public MutableVisitor { private: std::string suffix; NameGenerator& generator; protected: std::vector> rename_scopes{}; public: InliningRenamer(std::string suffix_, NameGenerator& generator_) : suffix(std::move(suffix_)), generator(generator_) { // Create an empty mapping for the top-level scope. rename_scopes.emplace_back(); } // We use a two-level renaming scheme to generate names for variables when inlined in the // main graph. First, we add a suffix (specific to the call-site being inlined). // Thus, "temp" in called-function becomes "temp__1" for the first inlined function-call // and "temp__2" for the second inlined function-call. In addition, there is a subsequent // iterative check that ensures that this names does not clash with any pre-existing names, // and tries another counter-based suffix in the case of a clash, stopping when successful. std::string MakeUnique(const std::string& name) { return generator.CreateNew(name + suffix); } /** * @brief Binds a formal parameter name to an actual parameter name. * * @param formal_name The formal parameter name to bind. * @param actual_name The actual parameter name to bind to. */ void BindFormalToActual(const std::string& formal_name, const std::string& actual_name) { auto& current_scope = rename_scopes.back(); current_scope[formal_name] = actual_name; } /** * @brief Creates a unique name for the given name and binds it. * * This method creates a unique name based on the suffix and binds the original * name to the unique name for later reference renaming. * * @param original_name The name to create a unique version of. * @return The unique name that was created and bound. */ std::string BindToUniqueName(const std::string& original_name) { // First create the unique name using MakeUnique std::string unique_name = MakeUnique(original_name); // Then bind the original name to the unique name auto& current_scope = rename_scopes.back(); current_scope[original_name] = unique_name; return unique_name; } template void Bind( google::protobuf::RepeatedPtrField& formals, const google::protobuf::RepeatedPtrField& actuals) { // Every formal parameter name FP should be replace by the corresponding actual parameter name AP. // However, if AP is empty, it is a missing optional parameter. This does not make any difference // for inputs. However, for outputs we use a unique dummy name to handle the case that it // is used in an output-context where it is not optional. ONNX_ASSERTM( actuals.size() <= formals.size(), "Number of actual parameters cannot exceed number of formal parameters") auto& current_scope = rename_scopes.back(); int i = 0; for (; i < actuals.size(); ++i) { std::string& formal = *formals.Mutable(i); std::string rename_as = actuals.Get(i); if (isOutput) if (rename_as.empty()) rename_as = MakeUnique(formal); current_scope[formal] = rename_as; if (!rename_as.empty()) formal = rename_as; } for (; i < formals.size(); ++i) { std::string& formal = *formals.Mutable(i); std::string rename_as = isOutput ? MakeUnique(formal) : std::string(); current_scope[formal] = rename_as; if (!rename_as.empty()) formal = rename_as; } } // Process a node: bool ProcessNode(NodeProto* node) override { if (!node->name().empty()) node->set_name(MakeUnique(node->name())); for (auto& x : *node->mutable_input()) { LookupOrRename(x, false); } for (auto& y : *node->mutable_output()) { LookupOrRename(y, true); } return true; // Process attribute subgraphs in traversal } // Process a sub-graph, contained as an attribute in a control-flow op node. // Since we need both pre-processing and post-processing in the traversal, we // override the VisitGraph method. void VisitGraph(GraphProto* graph) override { rename_scopes.emplace_back(); for (auto& x : *graph->mutable_input()) Rename(*x.mutable_name()); for (auto& init : *graph->mutable_initializer()) Rename(*init.mutable_name()); for (auto& y : *graph->mutable_output()) Rename(*y.mutable_name()); for (auto& n : *graph->mutable_node()) VisitNode(&n); rename_scopes.pop_back(); } private: // Replace given name with a unique version of the name, and cache the // renaming-binding in current scope. void Rename(std::string& name) { auto new_name = MakeUnique(name); auto& current_scope = rename_scopes.back(); current_scope[name] = new_name; name = new_name; } void LookupOrRename(std::string& name, bool is_new_def) { if (name.empty()) return; for (auto i = rename_scopes.size(); i > 0; --i) { const auto& map = rename_scopes[i - 1]; auto iter = map.find(name); if (iter != map.end()) { name = iter->second; return; } } if (is_new_def) { Rename(name); } // Otherwise, it is a reference to an outer-scope variable that should not be renamed. } public: // Renames variables in a FunctionProto for inlining a particular call-site. This does the following: // (i) Rename all intermediate variables in the function to ensure that they are unique (wrt the main graph). // (ii) Rename inputs and outputs using names of actual parameters. static void Rename(const NodeProto& callnode, FunctionProto& callee, std::string unique_suffix, NameGenerator& generator) { InliningRenamer renamer(std::move(unique_suffix), generator); renamer.Bind(*callee.mutable_input(), callnode.input()); renamer.Bind(*callee.mutable_output(), callnode.output()); renamer.VisitFunction(&callee); for (auto& v : *callee.mutable_value_info()) renamer.LookupOrRename(*v.mutable_name(), false); } }; // Identify the set of all "input" variables used by a given node. // This includes the variables listed as node.input, as well as // implicit inputs referred to in any graph-valued-attribute of the node. // In the case of variables referenced in sub-graphs, only non-local variables // are treated as implicit inputs. class ComputeInputs : private Visitor { private: std::vector> namescopes; bool InNestedScope() const { return !namescopes.empty(); } std::unordered_set& CurrentScope() { return namescopes.back(); } bool IsLocalVar(const std::string& name) const { for (auto& scope : namescopes) { if (scope.count(name) > 0) { return true; } } return false; } void VisitGraph(const GraphProto& graph) override { namescopes.emplace_back(); for (auto& x : graph.input()) CurrentScope().insert(x.name()); for (auto& init : graph.initializer()) CurrentScope().insert(init.name()); for (auto& n : graph.node()) VisitNode(n); namescopes.pop_back(); } bool ProcessNode(const NodeProto& node) override { for (auto& var : node.input()) { if (!var.empty() && !IsLocalVar(var)) { result.push_back(var); } } if (InNestedScope()) { for (auto& var : node.output()) { if (!var.empty()) { CurrentScope().insert(var); } } } return true; // process sub-graphs } public: std::vector result; explicit ComputeInputs(const NodeProto& node) { result.reserve(node.input_size()); ComputeInputs::VisitNode(node); } }; std::vector GetUsedVars(const NodeProto& node) { return ComputeInputs(node).result; } using ConstNodeMap = std::unordered_map; ConstNodeMap FindConstantNodes(const GraphProto& graph) { ConstNodeMap result; for (const NodeProto& node : graph.node()) { if (IsOnnxDomain(node.domain()) && (node.op_type() == "Constant")) { result[node.output(0)] = &node; } } return result; } const TypeProto& GetType(const ModelProto& model, const std::string& var) { for (auto& vi : model.graph().value_info()) { if (vi.name() == var) return vi.type(); } for (auto& vi : model.graph().input()) { if (vi.name() == var) return vi.type(); } for (auto& vi : model.graph().output()) { if (vi.name() == var) return vi.type(); } ONNX_ASSERTM(false, "Type unknown for %s", var.c_str()) } void ConvertVersion(ModelProto& model, const NodeProto& call_node, FunctionProto& function, int target_version) { shape_inference::InferShapes(model); ModelProto function_as_model; function_as_model.set_ir_version(model.ir_version()); *function_as_model.mutable_opset_import() = function.opset_import(); GraphProto& graph = *function_as_model.mutable_graph(); // The graph's inputs are all the variables used in the call_node. auto used_vars = GetUsedVars(call_node); auto constant_node_map = FindConstantNodes(model.graph()); RepeatedNodeProto& function_nodes = *function.mutable_node(); RepeatedNodeProto& nodes = *graph.mutable_node(); nodes.Reserve(function_nodes.size() + used_vars.size()); auto* inputs = graph.mutable_input(); for (const auto& var : used_vars) { auto* new_input = inputs->Add(); new_input->set_name(var); *new_input->mutable_type() = GetType(model, var); // Create a copy of constants used by the call_node. // We do not handle initializers-as-constants for now. auto it = constant_node_map.find(var); if (it != constant_node_map.end()) { *nodes.Add() = *(it->second); } } // outputs: from call_node node outputs auto* outputs = graph.mutable_output(); for (const auto& var : call_node.output()) { if (!var.empty()) { auto* new_output = outputs->Add(); new_output->set_name(var); *new_output->mutable_type() = GetType(model, var); } } // TODO: Use std::move when it is fully supported on all protobuf platforms used for (auto& function_node : function_nodes) *nodes.Add() = function_node; function_nodes.Clear(); auto converted = ONNX_NAMESPACE::version_conversion::ConvertVersion(function_as_model, target_version); function_nodes.Swap(converted.mutable_graph()->mutable_node()); // Append new initializers to main graph initializers for (auto& added_initializer : converted.graph().initializer()) *model.mutable_graph()->mutable_initializer()->Add() = added_initializer; for (auto& added_initializer : converted.graph().sparse_initializer()) *model.mutable_graph()->mutable_sparse_initializer()->Add() = added_initializer; } int64_t GetDomainVersion(const ModelProto& model, const std::string& domain) { for (const auto& opset : model.opset_import()) { if (opset.domain() == domain) { return opset.version(); } } return 0; } class VectorSet : public FunctionIdSet { public: VectorSet(FunctionIdVector&& function_ids, bool invert) : function_ids_(std::move(function_ids)), invert_(invert) {} bool Contains(const std::string& function_domain, const std::string& function_name) const override { bool found = std::find(function_ids_.begin(), function_ids_.end(), std::make_pair(function_domain, function_name)) != function_ids_.end(); return invert_ ? !found : found; } private: FunctionIdVector function_ids_; bool invert_; }; constexpr int64_t kNoConversion = -1; using FunctionMap = std::unordered_map>; using NodeList = google::protobuf::RepeatedPtrField; struct InlinerImpl { ModelProto& model; const FunctionIdSet& to_inline; const FunctionMap* function_map; const ISchemaRegistry* schema_registry = nullptr; NameGenerator name_generator; int inline_count = 0; // Construct inliner for inlining call-sites inside main graph of a model. InlinerImpl( ModelProto& model_, const FunctionIdSet& to_inline_, const FunctionMap* function_map_, const ISchemaRegistry* schema_registry_) : model(model_), to_inline(to_inline_), function_map(function_map_), schema_registry(schema_registry_), name_generator(model_.graph()) {} virtual ~InlinerImpl() = default; virtual bool GetCallee(const NodeProto& node, FunctionProto& callee, int64_t& target_version) { const std::string& domain = node.domain(); const std::string& function_name = node.op_type(); if (!to_inline.Contains(domain, function_name)) { return false; } if (function_map != nullptr) { auto iter = this->function_map->find(GetCalleeId(node)); if (iter != this->function_map->end()) { callee = *iter->second.first; target_version = iter->second.second; return true; } } if (schema_registry != nullptr) { int64_t domain_version = GetDomainVersion(model, domain); const auto* op_schema = schema_registry->GetSchema(node.op_type(), domain_version, domain); if (op_schema == nullptr) { // If the schema is not found, we cannot inline the function. return false; } if (op_schema->HasFunction()) { const FunctionProto* function_ptr = op_schema->GetFunction(domain_version, false); if (function_ptr != nullptr) { callee = *function_ptr; target_version = kNoConversion; return true; } } // Check if this node has a schema defined function proto. if (op_schema->HasContextDependentFunction()) { shape_inference::InferShapes(model); // TODO: do shape inference incrementally std::vector input_types; for (const auto& input : node.input()) { input_types.emplace_back(GetType(model, input)); } ONNX_NAMESPACE::FunctionBodyBuildContextImpl function_body_ctx(node, input_types); target_version = kNoConversion; return op_schema->BuildContextDependentFunction(function_body_ctx, callee, domain_version); } } return false; } /** Shared utility function used for inlining into either a GraphProto or a FunctionProto. * @param nodes Mutable list of nodes (of function or graph) * @param value_infos Mutable list of value_infos (of function or graph) */ void Process(NodeList& nodes, ValueInfoList& value_infos) { NodeList original_nodes; // Move all nodes into original_nodes original_nodes.Swap(&nodes); std::function append_node = [&](NodeProto& node) { FunctionProto callee; int64_t target_version = kNoConversion; if (GetCallee(node, callee, target_version)) { // Bind attribute parameters internal::AttributeBinder::BindAttributes(node, callee); // Rename variable names in callee InliningRenamer::Rename(node, callee, "__" + std::to_string(++(this->inline_count)), this->name_generator); if (target_version != kNoConversion) { ConvertVersion(model, node, callee, target_version); } std::unordered_set actual_parameters; for (const auto& x : node.input()) actual_parameters.insert(x); for (const auto& x : node.output()) actual_parameters.insert(x); // Append valueinfos of called function for (auto& callee_vi : callee.value_info()) { if (actual_parameters.count(callee_vi.name()) == 0) { *value_infos.Add() = callee_vi; } } // Append nodes of called function for (auto& callee_node : *callee.mutable_node()) append_node(callee_node); } else { // Append node without inlining. // TODO: use std::move instead of copying. Use of move doesn't seem to work with // protobuf in some platforms/settings. [nodes->Add(std::move(node));] for (auto& attr : *node.mutable_attribute()) { if (attr.has_g()) { ProcessGraph(*attr.mutable_g()); } for (auto& g : *attr.mutable_graphs()) { ProcessGraph(g); } } *nodes.Add() = node; } }; for (auto& node : original_nodes) { append_node(node); } } /** Utility function used for inlining into a GraphProto. * @param graph Mutable graph */ void ProcessGraph(GraphProto& graph) { auto* nodes = graph.mutable_node(); auto* value_infos = graph.mutable_value_info(); Process(*nodes, *value_infos); } /** Utility function used for inlining into a FunctionProto. * @param function Mutable function */ void ProcessFunction(FunctionProto& function) { auto* nodes = function.mutable_node(); auto* value_infos = function.mutable_value_info(); Process(*nodes, *value_infos); } static void InlineLocalFunctions(ModelProto& model, bool convert_version) { FunctionIdVector empty_set; VectorSet all_functions(std::move(empty_set), true); OpsetMap model_imports(model); FunctionMap map; // For every function, we check if there is a mismatch between the opset versions // required for the function and the model. If there is no mismatch, we can inline // this function. If there is a mismatch only for the standard ONNX domain, we // can inline after version-conversion (if the version-conversion is successful). // Otherwise, we cannot inline, since currently version-conversion supports only // standard ONNX domain. for (auto& function : model.functions()) { auto mismatches = model_imports.Mismatches(function); auto iter = mismatches.find(ONNX_DOMAIN); int64_t target_onnx_version = kNoConversion; if (convert_version && (iter != mismatches.end())) { target_onnx_version = iter->second; mismatches.erase(iter); } if (mismatches.empty()) { map[GetFunctionImplId(function)] = std::pair(&function, target_onnx_version); } } InlinerImpl inliner(model, all_functions, &map, nullptr); inliner.ProcessGraph(*model.mutable_graph()); // Remove all model-local functions. We do not remove functions with a mis-matched // opset version. They need to be handled some other way, eg., using a version-adapter. auto* local_functions = model.mutable_functions(); for (auto it = local_functions->begin(); it != local_functions->end();) { if (map.count(GetFunctionImplId(*it)) > 0) it = local_functions->erase(it); else ++it; } } static void InlineSelectedFunctions(ModelProto& model, const FunctionIdSet& to_inline, const ISchemaRegistry* schema_registry) { OpsetMap model_imports(model); FunctionMap map; std::vector non_inlined_functions; // If there is any mismatch between the opset versions required for any of the // functions and the model, the inliner will fail. for (auto& function : *model.mutable_functions()) { if (!model_imports.Add(function)) ONNX_THROW("Model has functions with incompatible opset versions."); if (to_inline.Contains(function.domain(), function.name())) { map[GetFunctionImplId(function)] = std::pair(&function, kNoConversion); } else { non_inlined_functions.push_back(&function); } } InlinerImpl inliner(model, to_inline, &map, schema_registry); inliner.ProcessGraph(*model.mutable_graph()); for (auto* function_ptr : non_inlined_functions) { inliner.ProcessFunction(*function_ptr); } // Remove all inlined model-local functions. auto* local_functions = model.mutable_functions(); for (auto it = local_functions->begin(); it != local_functions->end();) { if (map.count(GetFunctionImplId(*it)) > 0) it = local_functions->erase(it); else ++it; } } static void InlineSelectedLocalFunctions(ModelProto& model, const FunctionIdSet& to_inline) { InlineSelectedFunctions(model, to_inline, nullptr); } }; } // namespace // Public API implementation: std::unique_ptr ONNX_NAMESPACE::inliner::FunctionIdSet::Create( FunctionIdVector&& function_ids, bool invert) { return std::make_unique(std::move(function_ids), invert); } void InlineLocalFunctions(ModelProto& model, bool convert_version) { InlinerImpl::InlineLocalFunctions(model, convert_version); } void InlineSelectedLocalFunctions(ModelProto& model, const FunctionIdSet& to_inline) { InlinerImpl::InlineSelectedLocalFunctions(model, to_inline); } void InlineSelectedFunctions(ModelProto& model, const FunctionIdSet& to_inline) { InlineSelectedLocalFunctions(model, to_inline); } void InlineSelectedFunctions( ModelProto& model, const FunctionIdSet& to_inline, const ISchemaRegistry* schema_registry) { if (schema_registry == nullptr) { schema_registry = OpSchemaRegistry::Instance(); } InlinerImpl::InlineSelectedFunctions(model, to_inline, schema_registry); } // Implementation of the Renamer class using InliningRenamer directly class Renamer::Impl { private: NameGenerator generator_; InliningRenamer renamer_; public: Impl(const std::string& prefix, const GraphProto& graph) : generator_(graph), renamer_("__" + prefix, generator_) {} Impl(const std::string& prefix, const FunctionProto& function) : generator_(function), renamer_("__" + prefix, generator_) {} InliningRenamer& GetRenamer() { return renamer_; } void BindName(const std::string& formal_name, const std::string& actual_name) { renamer_.BindFormalToActual(formal_name, actual_name); } void RenameNode(NodeProto& node) { // Use the InliningRenamer's ProcessNode method which handles graph-value attributes renamer_.ProcessNode(&node); } }; Renamer::Renamer(const std::string& prefix, const GraphProto& graph) : pImpl_(std::make_unique(prefix, graph)) {} Renamer::Renamer(const std::string& prefix, const FunctionProto& function) : pImpl_(std::make_unique(prefix, function)) {} Renamer::~Renamer() = default; void Renamer::BindName(const std::string& formal_name, const std::string& actual_name) { pImpl_->BindName(formal_name, actual_name); } void Renamer::RenameNode(NodeProto& node) { pImpl_->RenameNode(node); } std::string Renamer::BindToUniqueName(const std::string& original_name) { return pImpl_->GetRenamer().BindToUniqueName(original_name); } } // namespace inliner } // namespace ONNX_NAMESPACE