def get_or_train_cav()

in tcav/cav.py [0:0]


def get_or_train_cav(concepts,
                     bottleneck,
                     acts,
                     cav_dir=None,
                     cav_hparams=None,
                     overwrite=False):
  """Gets, creating and training if necessary, the specified CAV.

  Assumes the activations already exists.

  Args:
    concepts: set of concepts used for CAV
            Note: if there are two concepts, provide the positive concept
                  first, then negative concept (e.g., ['striped', 'random500_1']
    bottleneck: the bottleneck used for CAV
    acts: dictionary contains activations of concepts in each bottlenecks
          e.g., acts[concept][bottleneck]
    cav_dir: a directory to store the results.
    cav_hparams: a parameter used to learn CAV
    overwrite: if set to True overwrite any saved CAV files.

  Returns:
    returns a CAV instance
  """

  if cav_hparams is None:
    cav_hparams = CAV.default_hparams()

  cav_path = None
  if cav_dir is not None:
    utils.make_dir_if_not_exists(cav_dir)
    cav_path = os.path.join(
        cav_dir,
        CAV.cav_key(concepts, bottleneck, cav_hparams['model_type'],
                    cav_hparams['alpha']).replace('/', '.') + '.pkl')

    if not overwrite and tf.io.gfile.exists(cav_path):
      tf.compat.v1.logging.info('CAV already exists: {}'.format(cav_path))
      cav_instance = CAV.load_cav(cav_path)
      tf.compat.v1.logging.info('CAV accuracies: {}'.format(cav_instance.accuracies))
      return cav_instance

  tf.compat.v1.logging.info('Training CAV {} - {} alpha {}'.format(
      concepts, bottleneck, cav_hparams['alpha']))
  cav_instance = CAV(concepts, bottleneck, cav_hparams, cav_path)
  cav_instance.train({c: acts[c] for c in concepts})
  tf.compat.v1.logging.info('CAV accuracies: {}'.format(cav_instance.accuracies))
  return cav_instance