mlsh_code/observation_network.py (20 lines of code) (raw):

import rl_algs.common.tf_util as U import tensorflow as tf import numpy as np import gym class Features(object): def __init__(self, name, ob): with tf.variable_scope(name): self.scope = tf.get_variable_scope().name with tf.variable_scope("obfilter"): self.ob_rms = RunningMeanStd(shape=(ob.get_shape()[1],)) obz = tf.clip_by_value((ob - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0) x = tf.nn.relu(U.conv2d(obz, 16, "l1", [8, 8], [4, 4], pad="VALID")) x = tf.nn.relu(U.conv2d(x, 16, "l2", [4, 4], [2, 2], pad="VALID")) x = U.flattenallbut0(x) x = tf.nn.relu(U.dense(x, 64, 'lin', U.normc_initializer(1.0))) self.ob = x def get_variables(self): return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope) def get_trainable_variables(self): return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)