def unzip_to_init_apply_subjaxprs()

in spinoffs/oryx/oryx/core/interpreters/unzip.py [0:0]


def unzip_to_init_apply_subjaxprs(master, settings, keys, pvals):
  """Function transformation that returns init/apply jaxprs."""
  trace = UnzipTrace(master, jax_core.cur_sublevel())
  # Setting up input UnzipTracer objects
  in_tracers = safe_map(lambda a: trace.new_arg(a[0], a[1]), zip(pvals, keys))
  key_tracers = [t for t in in_tracers if t.key]
  abstract_tracers = [t for t in in_tracers if not t.key]
  # Passing input tracers into function
  # to get output tracers
  context = UnzipContext(settings)
  with trace_util.new_dynamic_context(master, context):
    ans = yield in_tracers, {}
  out_tracers = safe_map(trace.full_raise, safe_map(jax_core.full_lower, ans))
  out_pvals = [t.pval for t in out_tracers]

  all_tracers = jax_util.toposort(out_tracers)
  variable_tracers = [t for t in all_tracers if t.variable_recipe]
  if not settings.block:
    try:
      # This try/catch tests whether or not the variables define a cut of the
      # computation graph. `pe.tracers_to_jaxpr` throws an AssertionError
      # if that is the case.
      old_recipes = [t.recipe for t in variable_tracers]
      for t in variable_tracers:
        t.recipe = pe.LambdaBinding()
      _tracers_to_jaxpr(variable_tracers + abstract_tracers, out_tracers)
    except VariableError:
      success = False
    else:
      success = True
    finally:
      # Restore the old recipes if it fails
      for t, old_recipe in safe_zip(variable_tracers, old_recipes):
        t.recipe = old_recipe
  else:
    success = False
  if not success:
    jaxpr, consts, env = _tracers_to_jaxpr(in_tracers, out_tracers)
    out_keys = [t.is_key() for t in out_tracers]
    yield success, (jaxpr, (out_pvals, out_keys, consts, env))
    return

  variable_recipes = {}
  for t in all_tracers:
    if t.variable_recipe:
      name = t.variable_recipe.name
      if (name in variable_recipes and
          variable_recipes[name] is not t.variable_recipe):
        raise ValueError('Cannot use duplicate variable name: {}'.format(name))
      variable_recipes[name] = t.variable_recipe

  variables = {
      name: (recipe.in_tracers, recipe.out_tracers)
      for name, recipe in variable_recipes.items()
  }
  variable_names, variable_tracers = jax_util.unzip2(variables.items())
  var_in_tracers, var_out_tracers = jax_util.unzip2(variable_tracers)
  flat_var_in_tracers, variable_tree = tree_util.tree_flatten(var_in_tracers)
  var_pvals = [t.pval for t in flat_var_in_tracers]
  flat_var_out_tracers, _ = tree_util.tree_flatten(var_out_tracers)
  init_jaxpr, init_consts, init_env = _tracers_to_jaxpr(key_tracers,
                                                        flat_var_in_tracers)
  for t in flat_var_out_tracers:
    t.recipe = pe.LambdaBinding()
  apply_jaxpr, apply_consts, apply_env = _tracers_to_jaxpr(
      flat_var_out_tracers + abstract_tracers, out_tracers)
  if None in variable_names:
    raise ValueError('Must provide name for variable.')
  out_keys = [t.is_key() for t in out_tracers]
  yield success, ((init_jaxpr, init_consts,
                   init_env), (apply_jaxpr, apply_consts, apply_env),
                  (var_pvals, out_pvals), (variable_names, variable_tree,
                                           out_keys))