in ignite/metrics/gan/fid.py [0:0]
def update(self, output: Sequence[torch.Tensor]) -> None:
train, test = output
train_features = self._extract_features(train)
test_features = self._extract_features(test)
if train_features.shape[0] != test_features.shape[0] or train_features.shape[1] != test_features.shape[1]:
raise ValueError(
f"""
Number of Training Features and Testing Features should be equal ({train_features.shape} != {test_features.shape})
"""
)
# Updates the mean and covariance for the train features
for features in train_features:
self._online_update(features, self._train_total, self._train_sigma)
# Updates the mean and covariance for the test features
for features in test_features:
self._online_update(features, self._test_total, self._test_sigma)
self._num_examples += train_features.shape[0]