in tensorflow_graphics/projects/local_implicit_grid/core/reconstruction.py [0:0]
def _init_graph(self):
"""Initialize computation graph for tensorflow.
"""
self.graph = tf.Graph()
with self.graph.as_default():
self.point_coords_ph = tf.placeholder(
tf.float32,
shape=[1, self.npts, 3]) # placeholder
self.point_values_ph = tf.placeholder(
tf.float32,
shape=[1, self.npts, 1]) # placeholder
self.point_coords = self.point_coords_ph
self.point_values = self.point_values_ph
self.liggrid = lig.LocalImplicitGrid(
size=self.grid_shape,
in_features=self.params['codelen'],
out_features=1,
num_filters=self.params['refiner_nf'],
net_type='imnet',
method='linear' if self.overlap else 'nn',
x_location_max=(1.0 if self.overlap else 2.0),
name='lig',
interp=(not self.indep_pt_loss),
min_grid_value=self.xmin,
max_grid_value=self.xmax)
si, sj, sk = self.grid_shape
self.occ_idx_flat_ = tf.convert_to_tensor(
self.occ_idx_flat[:, np.newaxis])
self.shape_ = tf.constant([si*sj*sk, self.params['codelen']],
dtype=tf.int64)
self.feat_sparse_ = tf.Variable(
(tf.random.normal(shape=[self.occ_idx.shape[0],
self.params['codelen']]) *
self.init_std),
trainable=True,
name='feat_sparse')
self.feat_grid = tf.scatter_nd(self.occ_idx_flat_,
self.feat_sparse_,
self.shape_)
self.feat_grid = tf.reshape(self.feat_grid,
[1, si, sj, sk, self.params['codelen']])
self.feat_norm = tf.norm(self.feat_sparse_, axis=-1)
if self.indep_pt_loss:
self.preds, self.weights = self.liggrid(self.feat_grid,
self.point_coords,
training=True)
# preds: [b, n, 8, 1], weights: [b, n, 8]
self.preds_interp = tf.reduce_sum(
tf.expand_dims(self.weights, axis=-1)*self.preds,
axis=2) # [b, n, 1]
self.preds = tf.concat([self.preds,
self.preds_interp[:, :, tf.newaxis, :]],
axis=2) # preds: [b, n, 9, 1]
self.point_values = tf.broadcast_to(
self.point_values[:, :, tf.newaxis, :],
shape=self.preds.shape) # [b, n, 9, 1]
else:
self.preds = self.liggrid(self.feat_grid,
self.point_coords,
training=True) # [b, n, 1]
self.labels_01 = (self.point_values+1) / 2 # turn labels to 0, 1 labels
self.loss_pt = tf.losses.sigmoid_cross_entropy(
self.labels_01,
logits=self.preds,
reduction=tf.losses.Reduction.NONE)
self.loss_lat = tf.reduce_mean(self.feat_norm) * self.alpha_lat
self.loss = tf.reduce_mean(self.loss_pt) + self.loss_lat
# compute accuracy metric
if self.indep_pt_loss:
self.pvalue = tf.sign(self.point_values[:, :, -1, 0])
self.ppred = tf.sign(self.preds[:, :, -1, 0])
else:
self.pvalue = tf.sign(self.point_values[..., 0])
self.ppred = tf.sign(self.preds[:, :, 0])
self.accu = tf.reduce_sum(tf.cast(
tf.logical_or(tf.logical_and(self.pvalue > 0, self.ppred > 0),
tf.logical_and(self.pvalue < 0, self.ppred < 0)),
tf.float32)) / float(self.npts)
# get optimizer
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
self.fgrid_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='feat_sparse')
self.train_op = self.optimizer.minimize(
self.loss,
global_step=tf.train.get_or_create_global_step(),
var_list=[self.fgrid_vars])
self.map_dict = self._get_var_mapping(model=self.liggrid,
scope=self.var_prefix)
self.sess = tf.Session()
if not self.nows:
self.saver = tf.train.Saver(self.map_dict)
self.saver.restore(self.sess, self.ckpt)
self._initialize_uninitialized(self.sess)