in torchtext/vocab/vectors.py [0:0]
def cache(self, name, cache, url=None, max_vectors=None):
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
if os.path.isfile(name):
path = name
if max_vectors:
file_suffix = '_{}.pt'.format(max_vectors)
else:
file_suffix = '.pt'
path_pt = os.path.join(cache, os.path.basename(name)) + file_suffix
else:
path = os.path.join(cache, name)
if max_vectors:
file_suffix = '_{}.pt'.format(max_vectors)
else:
file_suffix = '.pt'
path_pt = path + file_suffix
if not os.path.isfile(path_pt):
if not os.path.isfile(path) and url:
logger.info('Downloading vectors from {}'.format(url))
if not os.path.exists(cache):
os.makedirs(cache)
dest = os.path.join(cache, os.path.basename(url))
if not os.path.isfile(dest):
with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t:
try:
urlretrieve(url, dest, reporthook=reporthook(t))
except KeyboardInterrupt as e: # remove the partial zip file
os.remove(dest)
raise e
logger.info('Extracting vectors into {}'.format(cache))
ext = os.path.splitext(dest)[1][1:]
if ext == 'zip':
with zipfile.ZipFile(dest, "r") as zf:
zf.extractall(cache)
elif ext == 'gz':
if dest.endswith('.tar.gz'):
with tarfile.open(dest, 'r:gz') as tar:
tar.extractall(path=cache)
if not os.path.isfile(path):
raise RuntimeError('no vectors found at {}'.format(path))
logger.info("Loading vectors from {}".format(path))
ext = os.path.splitext(path)[1][1:]
if ext == 'gz':
open_file = gzip.open
else:
open_file = open
vectors_loaded = 0
with open_file(path, 'rb') as f:
num_lines, dim = _infer_shape(f)
if not max_vectors or max_vectors > num_lines:
max_vectors = num_lines
itos, vectors, dim = [], torch.zeros((max_vectors, dim)), None
for line in tqdm(f, total=max_vectors):
# Explicitly splitting on " " is important, so we don't
# get rid of Unicode non-breaking spaces in the vectors.
entries = line.rstrip().split(b" ")
word, entries = entries[0], entries[1:]
if dim is None and len(entries) > 1:
dim = len(entries)
elif len(entries) == 1:
logger.warning("Skipping token {} with 1-dimensional "
"vector {}; likely a header".format(word, entries))
continue
elif dim != len(entries):
raise RuntimeError(
"Vector for token {} has {} dimensions, but previously "
"read vectors have {} dimensions. All vectors must have "
"the same number of dimensions.".format(word, len(entries),
dim))
try:
if isinstance(word, bytes):
word = word.decode('utf-8')
except UnicodeDecodeError:
logger.info("Skipping non-UTF8 token {}".format(repr(word)))
continue
vectors[vectors_loaded] = torch.tensor([float(x) for x in entries])
vectors_loaded += 1
itos.append(word)
if vectors_loaded == max_vectors:
break
self.itos = itos
self.stoi = {word: i for i, word in enumerate(itos)}
self.vectors = torch.Tensor(vectors).view(-1, dim)
self.dim = dim
logger.info('Saving vectors to {}'.format(path_pt))
if not os.path.exists(cache):
os.makedirs(cache)
torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt)
else:
logger.info('Loading vectors from {}'.format(path_pt))
self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt)