torchbenchmark/util/env_check.py (35 lines of code) (raw):
"""
PyTorch benchmark env check utils.
This file may be loaded without torch packages installed, e.g., in OnDemand CI.
"""
import importlib
from typing import List, Dict, Tuple
MAIN_RANDOM_SEED = 1337
def set_random_seed():
import torch
import random
import numpy
torch.manual_seed(MAIN_RANDOM_SEED)
random.seed(MAIN_RANDOM_SEED)
numpy.random.seed(MAIN_RANDOM_SEED)
def get_pkg_versions(packages: List[str]) -> Dict[str, str]:
versions = {}
for module in packages:
module = importlib.import_module(module)
versions[module] = module.__version__
return versions
def has_native_amp() -> bool:
import torch
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
return True
except AttributeError:
pass
return False
def correctness_check(eager_output: Tuple['torch.Tensor'], output: Tuple['torch.Tensor']) -> float:
import torch
# sanity checks
assert len(eager_output) == len(output), "Correctness check requires two inputs have the same length"
result = 1.0
for i in range(len(eager_output)):
t1 = eager_output[i]
t2 = output[i]
cos = torch.nn.CosineSimilarity(dim=0, eps=1e-4)
# need to call float() because fp16 tensor may overflow when calculating cosine similarity
result *= cos(t1.flatten().float(), t2.flatten().float())
assert list(result.size())==[], "The result of cosine similarity must be a scalar."
return float(result)