ma_policy/variable_schema.py (36 lines of code) (raw):
import numpy as np
import tensorflow as tf
BATCH = "batch"
TIMESTEPS = "timesteps"
class VariableSchema(object):
def __init__(self, shape, dtype):
"""Creates a schema for a variable used in policy.
Allows for symbolic definition of shape. Shape can consist of integers, as well as
strings BATCH and TIMESTEPS. This is taken advantage of in the optimizers, to
create placeholders or variables that asynchronously prefetch the inputs.
Parameters
----------
shape: [int, np.int64, np.int32, or str]
shape of the variable, e.g. [12, 4], [BATCH, 12], [BATCH, 'timestep']
dtype:
tensorflow type of the variable, e.g. tf.float32, tf.int32
"""
assert all(isinstance(s, (int, np.int64, np.int32)) or s in [BATCH, TIMESTEPS] for s in shape), 'Bad shape %s' % shape
self.shape = shape
self.dtype = tf.as_dtype(dtype)
def _substituted_shape(self, batch=None, timesteps=None):
feeds = dict(batch=batch, timesteps=timesteps)
return [feeds.get(v, v) for v in self.shape]
def substitute(self, *, batch=BATCH, timesteps=TIMESTEPS):
"""Make a new VariableSchema with batch or timesteps optionally filled in."""
# Coerse None to default value.
batch = batch or BATCH
timesteps = timesteps or TIMESTEPS
shape = self._substituted_shape(batch, timesteps)
return VariableSchema(shape=shape, dtype=self.dtype)
def placeholder(self, *, batch=None, timesteps=None, name=None):
real_shape = self._substituted_shape(batch, timesteps)
return tf.placeholder(self.dtype, real_shape, name=name)
def variable(self, *, name, batch=None, timesteps=None, **kwargs):
real_shape = self._substituted_shape(batch, timesteps)
assert None not in real_shape
return tf.get_variable(name, real_shape, self.dtype, **kwargs)
def np_zeros(self, *, batch=None, timesteps=None, **kwargs):
real_shape = self._substituted_shape(batch, timesteps)
np_dtype = self.dtype.as_numpy_dtype
return np.zeros(shape=real_shape, dtype=np_dtype, **kwargs)
def match_shape(self, shape, *, batch=None, timesteps=None):
expected = self._substituted_shape(batch, timesteps)
if len(expected) != len(shape):
return False
for expected, actual in zip(expected, shape):
if expected is not None and expected != actual:
return False
return True