def neural_voxel_renderer_plus()

in tensorflow_graphics/projects/neural_voxel_renderer/models.py [0:0]


def neural_voxel_renderer_plus(voxels,
                               rerendering,
                               light_pos,
                               size=4,
                               norm2d='batchnorm',
                               norm3d='batchnorm'):
  """Neural Voxel Renderer + keras model."""
  with tf.name_scope('Network/'):

    voxels = layers.Input(tensor=voxels)
    rerendering = layers.Input(tensor=rerendering)
    light_pos = layers.Input(tensor=light_pos)

    nf_2d = 512

    with tf.name_scope('VoxelProcessing'):
      vol0_a = layer_utils.conv_block_3d(voxels,
                                         nfilters=16,
                                         size=size,
                                         strides=2,
                                         normalization=norm3d)  # 64x64x64x16
      vol0_b = layer_utils.conv_block_3d(vol0_a,
                                         nfilters=16,
                                         size=size,
                                         strides=1,
                                         normalization=norm3d)  # 64x64x64x16
      vol1_a = layer_utils.conv_block_3d(vol0_b,
                                         nfilters=16,
                                         size=size,
                                         strides=2,
                                         normalization=norm3d)  # 32x32x32x16
      vol1_b = layer_utils.conv_block_3d(vol1_a,
                                         nfilters=32,
                                         size=size,
                                         strides=1,
                                         normalization=norm3d)  # 32x32x32x32
      vol1_c = layer_utils.conv_block_3d(vol1_b,
                                         nfilters=32,
                                         size=size,
                                         strides=1,
                                         normalization=norm3d)  # 32x32x32x32
      shortcut = vol1_c
      vol_a1 = layer_utils.residual_block_3d(vol1_c,
                                             32,
                                             strides=(1, 1, 1),
                                             normalization=norm3d)  # 32x
      vol_a2 = layer_utils.residual_block_3d(vol_a1,
                                             32,
                                             strides=(1, 1, 1),
                                             normalization=norm3d)  # 32x
      vol_a3 = layer_utils.residual_block_3d(vol_a2,
                                             32,
                                             strides=(1, 1, 1),
                                             normalization=norm3d)  # 32x
      vol_a4 = layer_utils.residual_block_3d(vol_a3,
                                             32,
                                             strides=(1, 1, 1),
                                             normalization=norm3d)  # 32x
      vol_a5 = layer_utils.residual_block_3d(vol_a4,
                                             32,
                                             strides=(1, 1, 1),
                                             normalization=norm3d)  # 32x
      encoded_vol = layers.add([shortcut, vol_a5])
      encoded_vol = layers.Reshape([32, 32, 32*32])(encoded_vol)
      encoded_vol = layers.Conv2D(nf_2d,
                                  kernel_size=1,
                                  strides=(1, 1),
                                  padding='same',
                                  kernel_initializer=initializer)(encoded_vol)
      latent_projection = layers.LeakyReLU()(encoded_vol)  # 32x32x512

    with tf.name_scope('ProjectionProcessing'):
      shortcut = latent_projection  # 32x32xnf_2d
      e1 = layer_utils.residual_block_2d(latent_projection,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      e2 = layer_utils.residual_block_2d(e1,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      e3 = layer_utils.residual_block_2d(e2,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      e4 = layer_utils.residual_block_2d(e3,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      e5 = layer_utils.residual_block_2d(e4,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      encoded_proj = layers.add([shortcut, e5])  # 32x32xnf_2d

    with tf.name_scope('LightProcessing'):
      fc_light = layers.Dense(64, kernel_initializer=initializer)(light_pos)
      light_code = layers.Dense(64, kernel_initializer=initializer)(fc_light)
      light_code = \
        layers.Lambda(lambda v: tf.tile(v[0], [1, 32*32]))([light_code])
      light_code = layers.Reshape((32, 32, 64))(light_code)  # 32x32x64

    with tf.name_scope('Merger'):
      latent_code_final = layers.concatenate([encoded_proj, light_code])
      latent_code_final = layer_utils.conv_block_2d(latent_code_final,
                                                    nfilters=nf_2d,
                                                    size=size,
                                                    strides=1,
                                                    normalization=norm3d)
      shortcut = latent_code_final
      m1 = layer_utils.residual_block_2d(latent_code_final,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      m2 = layer_utils.residual_block_2d(m1,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      m3 = layer_utils.residual_block_2d(m2,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      m4 = layer_utils.residual_block_2d(m3,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      m5 = layer_utils.residual_block_2d(m4,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d

      latent_code_final2 = layers.add([shortcut, m5])  # 32x32xnf_2d

    with tf.name_scope('Decoder'):
      d7 = layer_utils.conv_t_block_2d(latent_code_final2,
                                       nfilters=128,
                                       size=size,
                                       strides=2,
                                       normalization=norm2d)  # 64x64x128
      d7 = layer_utils.conv_block_2d(d7,
                                     nfilters=128,
                                     size=size,
                                     strides=1,
                                     normalization=norm2d)  # 64x64x128
      d8 = layer_utils.conv_t_block_2d(d7,
                                       nfilters=64,
                                       size=size,
                                       strides=2,
                                       normalization=norm2d)  # 128x128x64
      d8 = layer_utils.conv_block_2d(d8,
                                     nfilters=64,
                                     size=size,
                                     strides=1,
                                     normalization=norm2d)  # 128x128x64
      d9 = layer_utils.conv_t_block_2d(d8,
                                       nfilters=32,
                                       size=size,
                                       strides=2,
                                       normalization=norm2d)  # 256x256x32
      d9 = layer_utils.conv_block_2d(d9,
                                     nfilters=32,
                                     size=size,
                                     strides=1,
                                     normalization=norm2d)  # 256x256x32
      rendered_image = layers.Conv2D(32,
                                     size,
                                     strides=1,
                                     padding='same',
                                     kernel_initializer=initializer,
                                     use_bias=False)(d9)  # 256x256x3

    with tf.name_scope('ImageProcessingNetwork'):
      ec1 = layer_utils.conv_block_2d(rerendering,
                                      nfilters=32,
                                      size=size,
                                      strides=1,
                                      normalization=norm2d)  # 256x
      ec2 = layer_utils.conv_block_2d(ec1,
                                      nfilters=32,
                                      size=size,
                                      strides=1,
                                      normalization=norm2d)  # 256x

    with tf.name_scope('NeuralRerenderingNetwork'):
      latent_img = layers.add([rendered_image, ec2])
      target_code = unet_3x_with_res_in_mid(latent_img, 32, norm2d=norm2d)
      out0 = layer_utils.conv_block_2d(target_code,
                                       nfilters=32,
                                       size=size,
                                       strides=1,
                                       normalization=norm2d)  # 256x
      predicted_image = layers.Conv2D(3,
                                      size,
                                      strides=1,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False)(out0)  # 256x256x3

    return tf.keras.Model(inputs=[voxels, rerendering, light_pos],
                          outputs=[predicted_image])