def compute_vae_loss()

in src/model_def.py [0:0]


def compute_vae_loss(encoder_mean, encoder_lgvar, vae, x):
    """Compute the loss function of Variational Autoencoders
        
    PARAMERTERS
    -----------
    input: encoder_mean - model part to output means in the hidden layer 
           encoder_lgvar - model part to output vars in the hidden layer
           vae - Variational Autoencoders
           x - input data
    
    RETURNS
    ------
    Variational Autoencoders loss
            = Reconstruction Loss + KL loss for each data in minibatch
    """
    z_mean = encoder_mean(x)
    z_lgvar = encoder_lgvar(x)
    x_pred = vae(x)
    
    #E(log P(X/z))
    cross_ent = K.binary_crossentropy(x, x_pred)
    recon = tf.reduce_sum(cross_ent, axis=[1,2,3]) #consolidate at each instance

    #KL divergence 
    kl = 0.5 * K.sum(K.exp(z_lgvar) + K.square(z_mean) - 1. - z_lgvar, axis=1)
    
    return recon, recon + kl