# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from jax._src import core as _core_src from jax._src.interpreters import partial_eval as _pe_src from jax._src.interpreters.partial_eval import ( DynamicJaxprTracer as DynamicJaxprTracer, JaxprTracer as JaxprTracer, PartialVal as PartialVal, Val as Val, custom_partial_eval_rules as custom_partial_eval_rules, dce_jaxpr as dce_jaxpr, dce_jaxpr_call_rule as dce_jaxpr_call_rule, dce_jaxpr_closed_call_rule as dce_jaxpr_closed_call_rule, dce_jaxpr_consts as dce_jaxpr_consts, dce_rules as dce_rules, partial_eval_jaxpr_custom_rules as partial_eval_jaxpr_custom_rules, trace_to_jaxpr_dynamic as trace_to_jaxpr_dynamic, trace_to_jaxpr_nounits as trace_to_jaxpr_nounits, ) _deprecations = { # Deprecated for JAX v0.7.1; finalize in JAX v0.9.0. "AbstractedAxesSpec": ( "jax.interpreters.partial_eval.AbstractedAxesSpec is deprecated.", _pe_src.AbstractedAxesSpec, ), "AbstractedAxisName": ( "jax.interpreters.partial_eval.AbstractedAxisName is deprecated.", _pe_src.AbstractedAxisName, ), "BoundedAxisSize": ( "jax.interpreters.partial_eval.BoundedAxisSize is deprecated.", _pe_src.BoundedAxisSize, ), "Const": ( "jax.interpreters.partial_eval.Const is deprecated.", _pe_src.Const, ), "ConstFoldRule": ( "jax.interpreters.partial_eval.ConstFoldRule is deprecated.", _pe_src.ConstFoldRule, ), "ConstVar": ( "jax.interpreters.partial_eval.ConstVar is deprecated.", _pe_src.ConstVar, ), "DCERule": ( "jax.interpreters.partial_eval.DCERule is deprecated.", _pe_src.DCERule, ), "DynamicJaxprTrace": ( "jax.interpreters.partial_eval.DynamicJaxprTrace is deprecated.", _pe_src.DynamicJaxprTrace, ), "ForwardingRule": ( "jax.interpreters.partial_eval.ForwardingRule is deprecated.", _pe_src.ForwardingRule, ), "FreeVar": ( "jax.interpreters.partial_eval.FreeVar is deprecated.", _pe_src.FreeVar, ), "Jaxpr": ( ( "jax.interpreters.partial_eval.Jaxpr is deprecated. Use" " jax.extend.core.Jaxpr, and please note that you must" " `import jax.extend` explicitly." ), _core_src.Jaxpr, ), "JaxprEqnRecipe": ( "jax.interpreters.partial_eval.JaxprEqnRecipe is deprecated.", _pe_src.JaxprEqnRecipe, ), "JaxprStackFrame": ( "jax.interpreters.partial_eval.JaxprStackFrame is deprecated.", _pe_src.JaxprStackFrame, ), "JaxprTrace": ( "jax.interpreters.partial_eval.JaxprTrace is deprecated.", _pe_src.JaxprTrace, ), "JaxprTracerRecipe": ( "jax.interpreters.partial_eval.JaxprTracerRecipe is deprecated.", _pe_src.JaxprTracerRecipe, ), "LambdaBinding": ( "jax.interpreters.partial_eval.LambdaBinding is deprecated.", _pe_src.LambdaBinding, ), "ParamsUpdater": ( "jax.interpreters.partial_eval.ParamsUpdater is deprecated.", _pe_src.ParamsUpdater, ), "PartialEvalCustomResult": ( "jax.interpreters.partial_eval.PartialEvalCustomResult is deprecated.", _pe_src.PartialEvalCustomResult, ), "PartialEvalCustomRule": ( "jax.interpreters.partial_eval.PartialEvalCustomRule is deprecated.", _pe_src.PartialEvalCustomRule, ), "ResAvalUpdater": ( "jax.interpreters.partial_eval.ResAvalUpdater is deprecated.", _pe_src.ResAvalUpdater, ), "TracerAsName": ( "jax.interpreters.partial_eval.TracerAsName is deprecated.", _pe_src.TracerAsName, ), "TracerId": ( "jax.interpreters.partial_eval.TracerId is deprecated.", _pe_src.TracerId, ), "abstract_eval_fun": ( "jax.interpreters.partial_eval.abstract_eval_fun is deprecated.", _pe_src.abstract_eval_fun, ), "call_padding_rule": ( "jax.interpreters.partial_eval.call_padding_rule is deprecated.", _pe_src.call_padding_rule, ), "call_param_updaters": ( "jax.interpreters.partial_eval.call_param_updaters is deprecated.", _pe_src.call_param_updaters, ), "call_partial_eval_custom_rule": ( "jax.interpreters.partial_eval.call_partial_eval_custom_rule is deprecated.", _pe_src.call_partial_eval_custom_rule, ), "call_partial_eval_rules": ( "jax.interpreters.partial_eval.call_partial_eval_rules is deprecated.", _pe_src.call_partial_eval_rules, ), "close_jaxpr": ( "jax.interpreters.partial_eval.close_jaxpr is deprecated.", _pe_src.close_jaxpr, ), "closed_call_partial_eval_custom_rule": ( "jax.interpreters.partial_eval.closed_call_partial_eval_custom_rule is deprecated.", _pe_src.closed_call_partial_eval_custom_rule, ), "config": ( "jax.interpreters.partial_eval.config is deprecated; use jax.config directly.", _pe_src.config, ), "const_fold_rules": ( "jax.interpreters.partial_eval.const_fold_rules is deprecated.", _pe_src.const_fold_rules, ), "convert_constvars_jaxpr": ( "jax.interpreters.partial_eval.convert_constvars_jaxpr is deprecated.", _pe_src.convert_constvars_jaxpr, ), "convert_envvars_to_constvars": ( "jax.interpreters.partial_eval.convert_envvars_to_constvars is deprecated.", _pe_src.convert_envvars_to_constvars, ), "convert_invars_to_constvars": ( "jax.interpreters.partial_eval.convert_invars_to_constvars is deprecated.", _pe_src.convert_invars_to_constvars, ), "custom_staging_rules": ( "jax.interpreters.partial_eval.custom_staging_rules is deprecated.", _pe_src.custom_staging_rules, ), "def_trivial_padding": ( "jax.interpreters.partial_eval.def_trivial_padding is deprecated.", _pe_src.def_trivial_padding, ), "forwarding_rules": ( "jax.interpreters.partial_eval.forwarding_rules is deprecated.", _pe_src.forwarding_rules, ), "has_effects": ( "jax.interpreters.partial_eval.has_effects is deprecated.", _pe_src.has_effects, ), "infer_lambda_input_type": ( "jax.interpreters.partial_eval.infer_lambda_input_type is deprecated.", _pe_src.infer_lambda_input_type, ), "instantiate_const_at": ( "jax.interpreters.partial_eval.instantiate_const_at is deprecated.", _pe_src.instantiate_const_at, ), "make_jaxpr_effects": ( "jax.interpreters.partial_eval.make_jaxpr_effects is deprecated.", _pe_src.make_jaxpr_effects, ), "move_binders_to_back": ( "jax.interpreters.partial_eval.move_binders_to_back is deprecated.", _pe_src.move_binders_to_back, ), "move_binders_to_front": ( "jax.interpreters.partial_eval.move_binders_to_front is deprecated.", _pe_src.move_binders_to_front, ), "new_eqn_recipe": ( "jax.interpreters.partial_eval.new_eqn_recipe is deprecated.", _pe_src.new_eqn_recipe, ), "pad_jaxpr": ( "jax.interpreters.partial_eval.pad_jaxpr is deprecated.", _pe_src.pad_jaxpr, ), "padding_rules": ( "jax.interpreters.partial_eval.padding_rules is deprecated.", _pe_src.padding_rules, ), "partial_eval_jaxpr_custom": ( "jax.interpreters.partial_eval.partial_eval_jaxpr_custom is deprecated.", _pe_src.partial_eval_jaxpr_custom, ), "partial_eval_jaxpr_custom_rule_not_implemented": ( "jax.interpreters.partial_eval.partial_eval_jaxpr_custom_rule_not_implemented is deprecated.", _pe_src.partial_eval_jaxpr_custom_rule_not_implemented, ), "partial_eval_jaxpr_nounits": ( "jax.interpreters.partial_eval.partial_eval_jaxpr_nounits is deprecated.", _pe_src.partial_eval_jaxpr_nounits, ), "partial_eval_wrapper_nounits": ( "jax.interpreters.partial_eval.partial_eval_wrapper_nounits is deprecated.", _pe_src.partial_eval_wrapper_nounits, ), "partition_pvals": ( "jax.interpreters.partial_eval.partition_pvals is deprecated.", _pe_src.partition_pvals, ), "recipe_to_eqn": ( "jax.interpreters.partial_eval.recipe_to_eqn is deprecated.", _pe_src.recipe_to_eqn, ), "trace_to_jaxpr_dynamic2": ( "jax.interpreters.partial_eval.trace_to_jaxpr_dynamic2 is deprecated.", _pe_src.trace_to_jaxpr_dynamic2, ), "trace_to_subjaxpr_nounits": ( "jax.interpreters.partial_eval.trace_to_subjaxpr_nounits is deprecated.", _pe_src.trace_to_subjaxpr_nounits, ), "trace_to_subjaxpr_nounits_fwd": ( "jax.interpreters.partial_eval.trace_to_subjaxpr_nounits_fwd is deprecated.", _pe_src.trace_to_subjaxpr_nounits_fwd, ), "tracers_to_jaxpr": ( "jax.interpreters.partial_eval.tracers_to_jaxpr is deprecated.", _pe_src.tracers_to_jaxpr, ), } import typing if typing.TYPE_CHECKING: AbstractedAxesSpec = _pe_src.AbstractedAxesSpec AbstractedAxisName = _pe_src.AbstractedAxisName BoundedAxisSize = _pe_src.BoundedAxisSize Const = _pe_src.Const ConstFoldRule = _pe_src.ConstFoldRule ConstVar = _pe_src.ConstVar DCERule = _pe_src.DCERule DynamicJaxprTrace = _pe_src.DynamicJaxprTrace ForwardingRule = _pe_src.ForwardingRule FreeVar = _pe_src.FreeVar Jaxpr = _core_src.Jaxpr JaxprEqnRecipe = _pe_src.JaxprEqnRecipe JaxprStackFrame = _pe_src.JaxprStackFrame JaxprTrace = _pe_src.JaxprTrace JaxprTracerRecipe = _pe_src.JaxprTracerRecipe LambdaBinding = _pe_src.LambdaBinding ParamsUpdater = _pe_src.ParamsUpdater PartialEvalCustomResult = _pe_src.PartialEvalCustomResult PartialEvalCustomRule = _pe_src.PartialEvalCustomRule ResAvalUpdater = _pe_src.ResAvalUpdater TracerAsName = _pe_src.TracerAsName TracerId = _pe_src.TracerId abstract_eval_fun = _pe_src.abstract_eval_fun call_padding_rule = _pe_src.call_padding_rule call_param_updaters = _pe_src.call_param_updaters call_partial_eval_custom_rule = _pe_src.call_partial_eval_custom_rule call_partial_eval_rules = _pe_src.call_partial_eval_rules close_jaxpr = _pe_src.close_jaxpr closed_call_partial_eval_custom_rule = _pe_src.closed_call_partial_eval_custom_rule config = _pe_src.config const_fold_rules = _pe_src.const_fold_rules convert_constvars_jaxpr = _pe_src.convert_constvars_jaxpr convert_envvars_to_constvars = _pe_src.convert_envvars_to_constvars convert_invars_to_constvars = _pe_src.convert_invars_to_constvars custom_staging_rules = _pe_src.custom_staging_rules def_trivial_padding = _pe_src.def_trivial_padding forwarding_rules = _pe_src.forwarding_rules has_effects = _pe_src.has_effects infer_lambda_input_type = _pe_src.infer_lambda_input_type instantiate_const_at = _pe_src.instantiate_const_at make_jaxpr_effects = _pe_src.make_jaxpr_effects move_binders_to_back = _pe_src.move_binders_to_back move_binders_to_front = _pe_src.move_binders_to_front new_eqn_recipe = _pe_src.new_eqn_recipe pad_jaxpr = _pe_src.pad_jaxpr padding_rules = _pe_src.padding_rules partial_eval_jaxpr_custom = _pe_src.partial_eval_jaxpr_custom partial_eval_jaxpr_custom_rule_not_implemented = _pe_src.partial_eval_jaxpr_custom_rule_not_implemented partial_eval_jaxpr_nounits = _pe_src.partial_eval_jaxpr_nounits partial_eval_wrapper_nounits = _pe_src.partial_eval_wrapper_nounits partition_pvals = _pe_src.partition_pvals recipe_to_eqn = _pe_src.recipe_to_eqn trace_to_jaxpr_dynamic2 = _pe_src.trace_to_jaxpr_dynamic2 trace_to_subjaxpr_nounits = _pe_src.trace_to_subjaxpr_nounits trace_to_subjaxpr_nounits_fwd = _pe_src.trace_to_subjaxpr_nounits_fwd tracers_to_jaxpr = _pe_src.tracers_to_jaxpr else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del typing, _pe_src, _core_src