in tt_embeddings_ops.py [0:0]
def reset_parameters(self, weight_dist: str) -> None: # noqa C901
assert weight_dist in [
"uniform",
"naive-uniform",
"normal",
"approx-uniform",
"approx-normal",
]
if weight_dist == "uniform":
lamb = 2.0 / (self.num_embeddings + self.embedding_dim)
stddev = np.sqrt(lamb)
tt_ranks = np.array(self.tt_ranks)
cr_exponent = -1.0 / (2 * self.tt_ndim)
var = np.prod(tt_ranks ** cr_exponent)
core_stddev = stddev ** (1.0 / self.tt_ndim) * var
for i in range(self.tt_ndim):
torch.nn.init.uniform_(self.tt_cores[i], 0.0, core_stddev)
elif weight_dist == "naive-uniform":
for i in range(self.tt_ndim):
torch.nn.init.uniform_(
self.tt_cores[i], 0.0, 1 / np.sqrt(self.num_embeddings)
)
elif weight_dist == "normal":
mu = 0.0
sigma = 1.0 / np.sqrt(self.num_embeddings)
scale = 1.0 / self.tt_ranks[0]
for i in range(self.tt_ndim):
torch.nn.init.normal_(self.tt_cores[i], mu, sigma)
self.tt_cores[i].data *= scale
elif weight_dist == "approx-normal":
mu = 0.0
sigma = 1.0
scale = np.power(1 / np.sqrt(3 * self.num_embeddings), 1 / 3)
for i in range(self.tt_ndim):
W = np.random.normal(
loc=mu, scale=sigma, size=np.asarray(self.tt_cores[i].shape)
).astype(np.float32)
core_shape = self.tt_cores[i].shape
W = W.flatten()
for ele in range(W.shape[0]):
while np.abs(W[ele]) < 2:
W[ele] = np.random.normal(loc=mu, scale=sigma, size=[1]).astype(
np.float32
)
W = np.reshape(W, core_shape)
W *= scale
self.tt_cores[i].data = torch.tensor(W, requires_grad=True)
elif weight_dist == "approx-uniform":
def _flat_saw_tooth(nb_gridpts: int, width: float, nb_samples: int = 1):
"""
This is a "flat saw tooth" distribution
that is, the density function is a sum of
j*delta + uniform(-width/2, width/2), width < delta/2 in general
a finite train of flat tooth with space in between
The idea is that when this density function convolved
with a very narrow gaussian-like distribution
the space will be filled up and the result looks like a uniform distribiution
"""
N = nb_gridpts
delta = 1.0 / N
j = np.random.randint(-(N - 1), N, nb_samples)
x = -width / 2.0 + width * np.random.rand(nb_samples)
return j * delta + x
def _gen_block(
dist: str, dim: List[int], center: float, param: float
) -> np.ndarray:
nb_samples = (np.array(dim)).prod()
if dist == "gaussian":
B = center + np.random.randn(nb_samples) * param
elif dist == "uniform":
B = center - (param / 2.0) + param * np.random.rand(nb_samples)
else:
assert 0, f"Does not support {dist} distribution"
# pyre-fixme[16]
B = B.reshape(dim)
return B
def _gen_head(dim: List[int], sigma: float = 0.01) -> np.ndarray:
# expect dim = (1, m1, n1, r1) where r1 is the tensor train rank
scale = 1.0 / np.sqrt(dim[-1])
size = (np.array(dim)).prod()
B = _gen_block("gaussian", size, scale, sigma)
B = B.reshape(dim)
return B
def _gen_tail(
dim: List[int],
sigma: float = 0.01,
nb_gridpts: int = 15,
width: float = 0.7 / 30.0,
):
"""
expect dim = (r3, m3, n3, 1); r3 is the tensor train rank
in our scheme here, all the elements are small, N(0,sigma^2)
except on each possible m, n there is one random odd r
such that (r, m, n, 1) follows a saw tooth distribution
"""
# first generate all the backgrounds as one big block
B = _gen_block("gaussian", dim, 0.0, sigma)
# generate the needed saw tooth distribution
r3 = dim[0]
B = B.reshape(r3, -1)
nb_samples = B.shape[1]
values = _flat_saw_tooth(nb_gridpts, width, nb_samples=nb_samples)
for ell in range(nb_samples):
p = random.randrange(1, r3, 2)
B[p, ell] = values[ell]
B = B.reshape(dim)
return B
def _gen_mid(
dim: List[int],
sigma: float = 0.01,
nb_gridpts: int = 15,
width: float = 0.7 / 30.0,
):
"""
expect dim = (r2, m2, n2, r3)
in our scheme, all the elements are in general close to 1/sqrt(r2)
so that the product with the head yield
values close to 1
but for each specific value of (m,n) in the range of (m2,n2)
we pick a random even index k in range of r3 such that we
make the vector (:,m,n,k) to be small except
for one random j in range of r2 so that the value (j,m,n,k)
is drawn for a saw tooth distribution
so the total number of needed saw tooth samples is m2 x n2
"""
r2, m2, n2, r3 = dim
scale = 1.0 / np.sqrt(r2)
B = _gen_block("gaussian", dim, scale, sigma)
B = B.reshape(r2, m2 * n2, r3)
values = _flat_saw_tooth(nb_gridpts, width, nb_samples=m2 * n2) / scale
for ell in range(m2 * n2):
p = random.randrange(0, r3, 2)
v = np.random.randn(r2) * (sigma * sigma / scale)
B[:, ell, p] = v
j = random.randrange(r2)
B[j, ell, p] = values[ell]
B = B.reshape(dim)
return B
assert self.tt_ndim == 3
assert (
self.num_tables == 1
), "approx_uniform only supported for num_tables == 1"
scale = 1.0 / (np.sqrt(self.num_embeddings) ** (1.0 / 3.0))
shapes = []
for i in range(self.tt_ndim):
core_shape = [
self.tt_ranks[i],
self.tt_p_shapes[i],
self.tt_q_shapes[i],
self.tt_ranks[i + 1],
]
shapes.append(core_shape)
W0 = _gen_head(shapes[0], sigma=0.01)
W0 = W0 * scale
W0 = W0.transpose([1, 0, 2, 3]).reshape(
(self.num_tables, self.tt_p_shapes[0], -1)
)
W0 = W0.astype(np.float32)
W1 = _gen_mid(shapes[1], sigma=0.01)
W1 = W1 * scale
W1 = W1.astype(np.float32)
W1 = W1.transpose([1, 0, 2, 3]).reshape(
(self.num_tables, self.tt_p_shapes[1], -1)
)
W2 = _gen_tail(shapes[2], sigma=0.01)
W2 = W2 * scale
W2 = W2.astype(np.float32)
W2 = W2.transpose([1, 0, 2, 3]).reshape(
(self.num_tables, self.tt_p_shapes[2], -1)
)
self.tt_cores[0].data = torch.tensor(W0, requires_grad=True)
self.tt_cores[1].data = torch.tensor(W1, requires_grad=True)
self.tt_cores[2].data = torch.tensor(W2, requires_grad=True)