in MTRF/algorithms/softlearning/policies/gaussian_policy.py [0:0]
def __init__(self,
input_shapes,
output_shape,
*args,
squash=True,
preprocessors=None,
name=None,
**kwargs):
self._Serializable__initialize(locals())
self._input_shapes = input_shapes
self._output_shape = output_shape
self._squash = squash
self._name = name
super(GaussianPolicy, self).__init__(*args, **kwargs)
self.preprocessors = preprocessors
inputs_flat = create_inputs(input_shapes)
preprocessors_flat = (
flatten_input_structure(preprocessors)
if preprocessors is not None
else tuple(None for _ in inputs_flat))
assert len(inputs_flat) == len(preprocessors_flat), (
inputs_flat, preprocessors_flat)
preprocessed_inputs = [
preprocessor(input_) if preprocessor is not None else input_
for preprocessor, input_
in zip(preprocessors_flat, inputs_flat)
]
float_inputs = tf.keras.layers.Lambda(
lambda inputs: training_utils.cast_if_floating_dtype(inputs)
)(preprocessed_inputs)
conditions = tf.keras.layers.Lambda(
lambda inputs: tf.concat(inputs, axis=-1)
)(float_inputs)
self.condition_inputs = inputs_flat
shift_and_log_scale_diag = self._shift_and_log_scale_diag_net(
output_size=output_shape[0] * 2,
)(conditions)
shift, log_scale_diag = tf.keras.layers.Lambda(
lambda shift_and_log_scale_diag: tf.split(
shift_and_log_scale_diag,
num_or_size_splits=2,
axis=-1)
)(shift_and_log_scale_diag)
log_scale_diag = tf.keras.layers.Lambda(
lambda log_scale_diag: tf.clip_by_value(
log_scale_diag, *SCALE_DIAG_MIN_MAX)
)(log_scale_diag)
batch_size = tf.keras.layers.Lambda(
lambda x: tf.shape(input=x)[0])(conditions)
base_distribution = tfp.distributions.MultivariateNormalDiag(
loc=tf.zeros(output_shape),
scale_diag=tf.ones(output_shape))
latents = tf.keras.layers.Lambda(
lambda batch_size: base_distribution.sample(batch_size)
)(batch_size)
self.latents_model = tf.keras.Model(self.condition_inputs, latents)
self.latents_input = tf.keras.layers.Input(
shape=output_shape, name='latents')
def raw_actions_fn(inputs):
shift, log_scale_diag, latents = inputs
bijector = tfp.bijectors.Affine(
shift=shift,
scale_diag=tf.exp(log_scale_diag))
actions = bijector.forward(latents)
return actions
raw_actions = tf.keras.layers.Lambda(
raw_actions_fn
)((shift, log_scale_diag, latents))
raw_actions_for_fixed_latents = tf.keras.layers.Lambda(
raw_actions_fn
)((shift, log_scale_diag, self.latents_input))
squash_bijector = (
SquashBijector()
if self._squash
else tfp.bijectors.Identity())
actions = tf.keras.layers.Lambda(
lambda raw_actions: squash_bijector.forward(raw_actions)
)(raw_actions)
self.actions_model = tf.keras.Model(self.condition_inputs, actions)
actions_for_fixed_latents = tf.keras.layers.Lambda(
lambda raw_actions: squash_bijector.forward(raw_actions)
)(raw_actions_for_fixed_latents)
self.actions_model_for_fixed_latents = tf.keras.Model(
(*self.condition_inputs, self.latents_input),
actions_for_fixed_latents)
deterministic_actions = tf.keras.layers.Lambda(
lambda shift: squash_bijector.forward(shift)
)(shift)
self.deterministic_actions_model = tf.keras.Model(
self.condition_inputs, deterministic_actions)
def log_pis_fn(inputs):
shift, log_scale_diag, actions = inputs
base_distribution = tfp.distributions.MultivariateNormalDiag(
loc=tf.zeros(output_shape),
scale_diag=tf.ones(output_shape))
bijector = tfp.bijectors.Chain((
squash_bijector,
tfp.bijectors.Affine(
shift=shift,
scale_diag=tf.exp(log_scale_diag)),
))
distribution = (
tfp.distributions.ConditionalTransformedDistribution(
distribution=base_distribution,
bijector=bijector))
log_pis = distribution.log_prob(actions)[:, None]
return log_pis
self.actions_input = tf.keras.layers.Input(
shape=output_shape, name='actions')
log_pis = tf.keras.layers.Lambda(
log_pis_fn)([shift, log_scale_diag, actions])
log_pis_for_action_input = tf.keras.layers.Lambda(
log_pis_fn)([shift, log_scale_diag, self.actions_input])
self.log_pis_model = tf.keras.Model(
(*self.condition_inputs, self.actions_input),
log_pis_for_action_input)
self.diagnostics_model = tf.keras.Model(
self.condition_inputs,
(shift, log_scale_diag, log_pis, raw_actions, actions))