def _init_graph()

in tensorflow_graphics/projects/local_implicit_grid/core/reconstruction.py [0:0]


  def _init_graph(self):
    """Initialize computation graph for tensorflow.
    """
    self.graph = tf.Graph()
    with self.graph.as_default():
      self.point_coords_ph = tf.placeholder(
          tf.float32,
          shape=[1, self.npts, 3])  # placeholder
      self.point_values_ph = tf.placeholder(
          tf.float32,
          shape=[1, self.npts, 1])  # placeholder

      self.point_coords = self.point_coords_ph
      self.point_values = self.point_values_ph
      self.liggrid = lig.LocalImplicitGrid(
          size=self.grid_shape,
          in_features=self.params['codelen'],
          out_features=1,
          num_filters=self.params['refiner_nf'],
          net_type='imnet',
          method='linear' if self.overlap else 'nn',
          x_location_max=(1.0 if self.overlap else 2.0),
          name='lig',
          interp=(not self.indep_pt_loss),
          min_grid_value=self.xmin,
          max_grid_value=self.xmax)

      si, sj, sk = self.grid_shape
      self.occ_idx_flat_ = tf.convert_to_tensor(
          self.occ_idx_flat[:, np.newaxis])
      self.shape_ = tf.constant([si*sj*sk, self.params['codelen']],
                                dtype=tf.int64)
      self.feat_sparse_ = tf.Variable(
          (tf.random.normal(shape=[self.occ_idx.shape[0],
                                   self.params['codelen']]) *
           self.init_std),
          trainable=True,
          name='feat_sparse')
      self.feat_grid = tf.scatter_nd(self.occ_idx_flat_,
                                     self.feat_sparse_,
                                     self.shape_)
      self.feat_grid = tf.reshape(self.feat_grid,
                                  [1, si, sj, sk, self.params['codelen']])
      self.feat_norm = tf.norm(self.feat_sparse_, axis=-1)

      if self.indep_pt_loss:
        self.preds, self.weights = self.liggrid(self.feat_grid,
                                                self.point_coords,
                                                training=True)
        # preds: [b, n, 8, 1], weights: [b, n, 8]
        self.preds_interp = tf.reduce_sum(
            tf.expand_dims(self.weights, axis=-1)*self.preds,
            axis=2)  # [b, n, 1]
        self.preds = tf.concat([self.preds,
                                self.preds_interp[:, :, tf.newaxis, :]],
                               axis=2)  # preds: [b, n, 9, 1]
        self.point_values = tf.broadcast_to(
            self.point_values[:, :, tf.newaxis, :],
            shape=self.preds.shape)  # [b, n, 9, 1]
      else:
        self.preds = self.liggrid(self.feat_grid,
                                  self.point_coords,
                                  training=True)  # [b, n, 1]

      self.labels_01 = (self.point_values+1) / 2  # turn labels to 0, 1 labels
      self.loss_pt = tf.losses.sigmoid_cross_entropy(
          self.labels_01,
          logits=self.preds,
          reduction=tf.losses.Reduction.NONE)
      self.loss_lat = tf.reduce_mean(self.feat_norm) * self.alpha_lat
      self.loss = tf.reduce_mean(self.loss_pt) + self.loss_lat

      # compute accuracy metric
      if self.indep_pt_loss:
        self.pvalue = tf.sign(self.point_values[:, :, -1, 0])
        self.ppred = tf.sign(self.preds[:, :, -1, 0])
      else:
        self.pvalue = tf.sign(self.point_values[..., 0])
        self.ppred = tf.sign(self.preds[:, :, 0])
      self.accu = tf.reduce_sum(tf.cast(
          tf.logical_or(tf.logical_and(self.pvalue > 0, self.ppred > 0),
                        tf.logical_and(self.pvalue < 0, self.ppred < 0)),
          tf.float32)) / float(self.npts)

      # get optimizer
      self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
      self.fgrid_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='feat_sparse')
      self.train_op = self.optimizer.minimize(
          self.loss,
          global_step=tf.train.get_or_create_global_step(),
          var_list=[self.fgrid_vars])

      self.map_dict = self._get_var_mapping(model=self.liggrid,
                                            scope=self.var_prefix)
      self.sess = tf.Session()
      if not self.nows:
        self.saver = tf.train.Saver(self.map_dict)
        self.saver.restore(self.sess, self.ckpt)
      self._initialize_uninitialized(self.sess)