def __init__()

in tcav/tcav.py [0:0]


  def __init__(self,
               sess,
               target,
               concepts,
               bottlenecks,
               activation_generator,
               alphas,
               random_counterpart=None,
               cav_dir=None,
               num_random_exp=5,
               random_concepts=None):
    """Initialze tcav class.

    Args:
      sess: tensorflow session.
      target: one target class
      concepts: A list of names of positive concept sets.
      bottlenecks: the name of a bottleneck of interest.
      activation_generator: an ActivationGeneratorInterface instance to return
                            activations.
      alphas: list of hyper parameters to run
      cav_dir: the path to store CAVs
      random_counterpart: the random concept to run against the concepts for
                  statistical testing. If supplied, only this set will be
                  used as a positive set for calculating random TCAVs
      num_random_exp: number of random experiments to compare against.
      random_concepts: A list of names of random concepts for the random
                       experiments to draw from. Optional, if not provided, the
                       names will be random500_{i} for i in num_random_exp.
                       Relative TCAV can be performed by passing in the same
                       value for both concepts and random_concepts.
    """
    self.target = target
    self.concepts = concepts
    self.bottlenecks = bottlenecks
    self.activation_generator = activation_generator
    self.cav_dir = cav_dir
    self.alphas = alphas
    self.mymodel = activation_generator.get_model()
    self.model_to_run = self.mymodel.model_name
    self.sess = sess
    self.random_counterpart = random_counterpart
    self.relative_tcav = (random_concepts is not None) and (set(concepts) == set(random_concepts))

    if num_random_exp < 2:
        tf.compat.v1.logging.error('the number of random concepts has to be at least 2')
    if random_concepts:
      num_random_exp = len(random_concepts)

    # make pairs to test.
    self._process_what_to_run_expand(num_random_exp=num_random_exp,
                                     random_concepts=random_concepts)
    # parameters
    self.params = self.get_params()
    tf.compat.v1.logging.info('TCAV will %s params' % len(self.params))