def train()

in tcav/cav.py [0:0]


  def train(self, acts):
    """Train the CAVs from the activations.

    Args:
      acts: is a dictionary of activations. In particular, acts takes for of
            {'concept1':{'bottleneck name1':[...act array...],
                         'bottleneck name2':[...act array...],...
             'concept2':{'bottleneck name1':[...act array...],
    Raises:
      ValueError: if the model_type in hparam is not compatible.
    """

    tf.compat.v1.logging.info('training with alpha={}'.format(self.hparams['alpha']))
    x, labels, labels2text = CAV._create_cav_training_set(
        self.concepts, self.bottleneck, acts)

    if self.hparams['model_type'] == 'linear':
      lm = linear_model.SGDClassifier(alpha=self.hparams['alpha'], max_iter=self.hparams['max_iter'], tol=self.hparams['tol'])
    elif self.hparams['model_type'] == 'logistic':
      lm = linear_model.LogisticRegression()
    else:
      raise ValueError('Invalid hparams.model_type: {}'.format(
          self.hparams['model_type']))

    self.accuracies = self._train_lm(lm, x, labels, labels2text)
    if len(lm.coef_) == 1:
      # if there were only two labels, the concept is assigned to label 0 by
      # default. So we flip the coef_ to reflect this.
      self.cavs = [-1 * lm.coef_[0], lm.coef_[0]]
    else:
      self.cavs = [c for c in lm.coef_]
    self._save_cavs()