in tcav/tcav.py [0:0]
def __init__(self,
sess,
target,
concepts,
bottlenecks,
activation_generator,
alphas,
random_counterpart=None,
cav_dir=None,
num_random_exp=5,
random_concepts=None):
"""Initialze tcav class.
Args:
sess: tensorflow session.
target: one target class
concepts: A list of names of positive concept sets.
bottlenecks: the name of a bottleneck of interest.
activation_generator: an ActivationGeneratorInterface instance to return
activations.
alphas: list of hyper parameters to run
cav_dir: the path to store CAVs
random_counterpart: the random concept to run against the concepts for
statistical testing. If supplied, only this set will be
used as a positive set for calculating random TCAVs
num_random_exp: number of random experiments to compare against.
random_concepts: A list of names of random concepts for the random
experiments to draw from. Optional, if not provided, the
names will be random500_{i} for i in num_random_exp.
Relative TCAV can be performed by passing in the same
value for both concepts and random_concepts.
"""
self.target = target
self.concepts = concepts
self.bottlenecks = bottlenecks
self.activation_generator = activation_generator
self.cav_dir = cav_dir
self.alphas = alphas
self.mymodel = activation_generator.get_model()
self.model_to_run = self.mymodel.model_name
self.sess = sess
self.random_counterpart = random_counterpart
self.relative_tcav = (random_concepts is not None) and (set(concepts) == set(random_concepts))
if num_random_exp < 2:
tf.compat.v1.logging.error('the number of random concepts has to be at least 2')
if random_concepts:
num_random_exp = len(random_concepts)
# make pairs to test.
self._process_what_to_run_expand(num_random_exp=num_random_exp,
random_concepts=random_concepts)
# parameters
self.params = self.get_params()
tf.compat.v1.logging.info('TCAV will %s params' % len(self.params))