def validate()

in tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/seq2seq_mia.py [0:0]


  def validate(self):
    """Validates the inputs."""

    if (self.logits_train is None) != (self.logits_test is None):
      raise ValueError(
          'logits_train and logits_test should both be either set or unset')

    if (self.labels_train is None) != (self.labels_test is None):
      raise ValueError(
          'labels_train and labels_test should both be either set or unset')

    if self.logits_train is None or self.labels_train is None:
      raise ValueError(
          'Labels, logits of training, test sets should all be set')

    if (self.vocab_size is None or self.train_size is None or
        self.test_size is None):
      raise ValueError('vocab_size, train_size, test_size should all be set')

    if self.vocab_size is not None and not int:
      raise ValueError('vocab_size should be of integer type')

    if self.train_size is not None and not int:
      raise ValueError('train_size should be of integer type')

    if self.test_size is not None and not int:
      raise ValueError('test_size should be of integer type')

    _is_iterator(self.logits_train, 'logits_train')
    _is_iterator(self.logits_test, 'logits_test')
    _is_iterator(self.labels_train, 'labels_train')
    _is_iterator(self.labels_test, 'labels_test')