in spinoffs/oryx/oryx/core/interpreters/unzip.py [0:0]
def unzip_to_init_apply_subjaxprs(master, settings, keys, pvals):
"""Function transformation that returns init/apply jaxprs."""
trace = UnzipTrace(master, jax_core.cur_sublevel())
# Setting up input UnzipTracer objects
in_tracers = safe_map(lambda a: trace.new_arg(a[0], a[1]), zip(pvals, keys))
key_tracers = [t for t in in_tracers if t.key]
abstract_tracers = [t for t in in_tracers if not t.key]
# Passing input tracers into function
# to get output tracers
context = UnzipContext(settings)
with trace_util.new_dynamic_context(master, context):
ans = yield in_tracers, {}
out_tracers = safe_map(trace.full_raise, safe_map(jax_core.full_lower, ans))
out_pvals = [t.pval for t in out_tracers]
all_tracers = jax_util.toposort(out_tracers)
variable_tracers = [t for t in all_tracers if t.variable_recipe]
if not settings.block:
try:
# This try/catch tests whether or not the variables define a cut of the
# computation graph. `pe.tracers_to_jaxpr` throws an AssertionError
# if that is the case.
old_recipes = [t.recipe for t in variable_tracers]
for t in variable_tracers:
t.recipe = pe.LambdaBinding()
_tracers_to_jaxpr(variable_tracers + abstract_tracers, out_tracers)
except VariableError:
success = False
else:
success = True
finally:
# Restore the old recipes if it fails
for t, old_recipe in safe_zip(variable_tracers, old_recipes):
t.recipe = old_recipe
else:
success = False
if not success:
jaxpr, consts, env = _tracers_to_jaxpr(in_tracers, out_tracers)
out_keys = [t.is_key() for t in out_tracers]
yield success, (jaxpr, (out_pvals, out_keys, consts, env))
return
variable_recipes = {}
for t in all_tracers:
if t.variable_recipe:
name = t.variable_recipe.name
if (name in variable_recipes and
variable_recipes[name] is not t.variable_recipe):
raise ValueError('Cannot use duplicate variable name: {}'.format(name))
variable_recipes[name] = t.variable_recipe
variables = {
name: (recipe.in_tracers, recipe.out_tracers)
for name, recipe in variable_recipes.items()
}
variable_names, variable_tracers = jax_util.unzip2(variables.items())
var_in_tracers, var_out_tracers = jax_util.unzip2(variable_tracers)
flat_var_in_tracers, variable_tree = tree_util.tree_flatten(var_in_tracers)
var_pvals = [t.pval for t in flat_var_in_tracers]
flat_var_out_tracers, _ = tree_util.tree_flatten(var_out_tracers)
init_jaxpr, init_consts, init_env = _tracers_to_jaxpr(key_tracers,
flat_var_in_tracers)
for t in flat_var_out_tracers:
t.recipe = pe.LambdaBinding()
apply_jaxpr, apply_consts, apply_env = _tracers_to_jaxpr(
flat_var_out_tracers + abstract_tracers, out_tracers)
if None in variable_names:
raise ValueError('Must provide name for variable.')
out_keys = [t.is_key() for t in out_tracers]
yield success, ((init_jaxpr, init_consts,
init_env), (apply_jaxpr, apply_consts, apply_env),
(var_pvals, out_pvals), (variable_names, variable_tree,
out_keys))