def inversion()

in utilities/gan-inversion/02_invert_images.py [0:0]


        def inversion(image_list):

            assert (len(image_list) == batch_size)

            sess.run([latent_node_variable.initializer, label_node_logits.initializer,
                      tf.variables_initializer(opt.variables())])

            target_imgs = np.array(
                [toNetworkSpace(PIL.Image.open(image_file).convert("RGB")) for image_file in image_list])

            best_loss = np.inf
            best_latents = None
            [initial_logits] = sess.run([deep_target_logits_node], {target_node: target_imgs})

            sess.run([tf.assign(label_node_logits, initial_logits)])
            for i in range(random_search_steps):
                latents = np.random.randn(batch_size, latent_size)
                sess.run([tf.assign(latent_node_variable, latents)])
                [loss] = sess.run([total_loss], {target_node: target_imgs, truncation_node: 1.0})
                if loss < best_loss:
                    print(i, loss)
                    best_latents = latents
                    best_loss = loss
            sess.run([tf.assign(latent_node_variable, best_latents)])

            print("Starting gradient descent")
            for i in range(gradient_descent_steps):
                [_] = sess.run([train], {target_node: target_imgs, truncation_node: 1.0})
                if i % 10 == 0:
                    [_, total_loss_val, reg_loss_val, pixel_loss_val, deep_loss_val] = sess.run(
                        [train, total_loss, reg_loss_node, pixel_loss_node, deep_loss_node],
                        {target_node: target_imgs, truncation_node: 1.0})

                    print(i, total_loss_val, pixel_loss_val, deep_loss_val, reg_loss_val)

            # Write out the final images / seeds with appropriate names
            [output_images, latents, labels] = sess.run([output_node, latent_node_variable, label_dist],
                                                        {target_node: target_imgs, truncation_node: 1.0})
            for j in range(len(labels)):
                PIL.Image.fromarray(fromNetworkSpace(output_images[j]), 'RGB').save(os.path.join(
                    inverted_img_dir,
                    "{}.png".format(os.path.basename(image_list[j]).split(".")[0])
                ))

                latent_file = os.path.join(
                    inverted_seed_dir,
                    "{}.json".format(os.path.basename(image_list[j]).split(".")[0])
                )
                with open(latent_file, "w+") as f:
                    json.dump({"latents": latents[j].tolist(), "labels": labels[j].tolist()}, f)

            print("batch done")