in tcav/cav.py [0:0]
def get_or_train_cav(concepts,
bottleneck,
acts,
cav_dir=None,
cav_hparams=None,
overwrite=False):
"""Gets, creating and training if necessary, the specified CAV.
Assumes the activations already exists.
Args:
concepts: set of concepts used for CAV
Note: if there are two concepts, provide the positive concept
first, then negative concept (e.g., ['striped', 'random500_1']
bottleneck: the bottleneck used for CAV
acts: dictionary contains activations of concepts in each bottlenecks
e.g., acts[concept][bottleneck]
cav_dir: a directory to store the results.
cav_hparams: a parameter used to learn CAV
overwrite: if set to True overwrite any saved CAV files.
Returns:
returns a CAV instance
"""
if cav_hparams is None:
cav_hparams = CAV.default_hparams()
cav_path = None
if cav_dir is not None:
utils.make_dir_if_not_exists(cav_dir)
cav_path = os.path.join(
cav_dir,
CAV.cav_key(concepts, bottleneck, cav_hparams['model_type'],
cav_hparams['alpha']).replace('/', '.') + '.pkl')
if not overwrite and tf.io.gfile.exists(cav_path):
tf.compat.v1.logging.info('CAV already exists: {}'.format(cav_path))
cav_instance = CAV.load_cav(cav_path)
tf.compat.v1.logging.info('CAV accuracies: {}'.format(cav_instance.accuracies))
return cav_instance
tf.compat.v1.logging.info('Training CAV {} - {} alpha {}'.format(
concepts, bottleneck, cav_hparams['alpha']))
cav_instance = CAV(concepts, bottleneck, cav_hparams, cav_path)
cav_instance.train({c: acts[c] for c in concepts})
tf.compat.v1.logging.info('CAV accuracies: {}'.format(cav_instance.accuracies))
return cav_instance