def reset_parameters()

in tt_embeddings_ops.py [0:0]


    def reset_parameters(self, weight_dist: str) -> None:  # noqa C901
        assert weight_dist in [
            "uniform",
            "naive-uniform",
            "normal",
            "approx-uniform",
            "approx-normal",
        ]
        if weight_dist == "uniform":
            lamb = 2.0 / (self.num_embeddings + self.embedding_dim)
            stddev = np.sqrt(lamb)
            tt_ranks = np.array(self.tt_ranks)
            cr_exponent = -1.0 / (2 * self.tt_ndim)
            var = np.prod(tt_ranks ** cr_exponent)
            core_stddev = stddev ** (1.0 / self.tt_ndim) * var
            for i in range(self.tt_ndim):
                torch.nn.init.uniform_(self.tt_cores[i], 0.0, core_stddev)
        elif weight_dist == "naive-uniform":
            for i in range(self.tt_ndim):
                torch.nn.init.uniform_(
                    self.tt_cores[i], 0.0, 1 / np.sqrt(self.num_embeddings)
                )
        elif weight_dist == "normal":
            mu = 0.0
            sigma = 1.0 / np.sqrt(self.num_embeddings)
            scale = 1.0 / self.tt_ranks[0]
            for i in range(self.tt_ndim):
                torch.nn.init.normal_(self.tt_cores[i], mu, sigma)
                self.tt_cores[i].data *= scale
        elif weight_dist == "approx-normal":
            mu = 0.0
            sigma = 1.0
            scale = np.power(1 / np.sqrt(3 * self.num_embeddings), 1 / 3)
            for i in range(self.tt_ndim):
                W = np.random.normal(
                    loc=mu, scale=sigma, size=np.asarray(self.tt_cores[i].shape)
                ).astype(np.float32)
                core_shape = self.tt_cores[i].shape
                W = W.flatten()
                for ele in range(W.shape[0]):
                    while np.abs(W[ele]) < 2:
                        W[ele] = np.random.normal(loc=mu, scale=sigma, size=[1]).astype(
                            np.float32
                        )
                W = np.reshape(W, core_shape)
                W *= scale
                self.tt_cores[i].data = torch.tensor(W, requires_grad=True)
        elif weight_dist == "approx-uniform":

            def _flat_saw_tooth(nb_gridpts: int, width: float, nb_samples: int = 1):
                """
                This is a "flat saw tooth" distribution
                that is, the density function is a sum of
                j*delta + uniform(-width/2, width/2), width < delta/2 in general
                a finite train of flat tooth with space in between
                The idea is that when this density function convolved
                with a very narrow gaussian-like distribution
                the space will be filled up and the result looks like a uniform distribiution
                """

                N = nb_gridpts
                delta = 1.0 / N
                j = np.random.randint(-(N - 1), N, nb_samples)
                x = -width / 2.0 + width * np.random.rand(nb_samples)
                return j * delta + x

            def _gen_block(
                dist: str, dim: List[int], center: float, param: float
            ) -> np.ndarray:
                nb_samples = (np.array(dim)).prod()
                if dist == "gaussian":
                    B = center + np.random.randn(nb_samples) * param
                elif dist == "uniform":
                    B = center - (param / 2.0) + param * np.random.rand(nb_samples)
                else:
                    assert 0, f"Does not support {dist} distribution"
                # pyre-fixme[16]
                B = B.reshape(dim)
                return B

            def _gen_head(dim: List[int], sigma: float = 0.01) -> np.ndarray:
                # expect dim = (1, m1, n1, r1) where r1 is the tensor train rank
                scale = 1.0 / np.sqrt(dim[-1])
                size = (np.array(dim)).prod()
                B = _gen_block("gaussian", size, scale, sigma)
                B = B.reshape(dim)
                return B

            def _gen_tail(
                dim: List[int],
                sigma: float = 0.01,
                nb_gridpts: int = 15,
                width: float = 0.7 / 30.0,
            ):
                """
                expect dim = (r3, m3, n3, 1); r3 is the tensor train rank
                in our scheme here, all the elements are small, N(0,sigma^2)
                except on each possible m, n  there is one random odd r
                such that (r, m, n, 1) follows a saw tooth distribution
                """
                # first generate all the backgrounds as one big block
                B = _gen_block("gaussian", dim, 0.0, sigma)
                # generate the needed saw tooth distribution
                r3 = dim[0]
                B = B.reshape(r3, -1)
                nb_samples = B.shape[1]
                values = _flat_saw_tooth(nb_gridpts, width, nb_samples=nb_samples)
                for ell in range(nb_samples):
                    p = random.randrange(1, r3, 2)
                    B[p, ell] = values[ell]
                B = B.reshape(dim)
                return B

            def _gen_mid(
                dim: List[int],
                sigma: float = 0.01,
                nb_gridpts: int = 15,
                width: float = 0.7 / 30.0,
            ):
                """
                expect dim = (r2, m2, n2, r3)
                in our scheme, all the elements are in general close to 1/sqrt(r2)
                so that the product with the head yield
                values close to 1
                but for each specific value of (m,n) in the range of (m2,n2)
                we pick a random even index k in range of r3 such that we
                make the vector (:,m,n,k) to be small except
                for one random j in range of r2 so that the value (j,m,n,k)
                is drawn for a saw tooth distribution
                so the total number of needed saw tooth samples is m2 x n2
                """
                r2, m2, n2, r3 = dim
                scale = 1.0 / np.sqrt(r2)
                B = _gen_block("gaussian", dim, scale, sigma)
                B = B.reshape(r2, m2 * n2, r3)
                values = _flat_saw_tooth(nb_gridpts, width, nb_samples=m2 * n2) / scale
                for ell in range(m2 * n2):
                    p = random.randrange(0, r3, 2)
                    v = np.random.randn(r2) * (sigma * sigma / scale)
                    B[:, ell, p] = v
                    j = random.randrange(r2)
                    B[j, ell, p] = values[ell]
                B = B.reshape(dim)
                return B

            assert self.tt_ndim == 3
            assert (
                self.num_tables == 1
            ), "approx_uniform only supported for num_tables == 1"
            scale = 1.0 / (np.sqrt(self.num_embeddings) ** (1.0 / 3.0))
            shapes = []
            for i in range(self.tt_ndim):
                core_shape = [
                    self.tt_ranks[i],
                    self.tt_p_shapes[i],
                    self.tt_q_shapes[i],
                    self.tt_ranks[i + 1],
                ]
                shapes.append(core_shape)
            W0 = _gen_head(shapes[0], sigma=0.01)
            W0 = W0 * scale
            W0 = W0.transpose([1, 0, 2, 3]).reshape(
                (self.num_tables, self.tt_p_shapes[0], -1)
            )
            W0 = W0.astype(np.float32)
            W1 = _gen_mid(shapes[1], sigma=0.01)
            W1 = W1 * scale
            W1 = W1.astype(np.float32)
            W1 = W1.transpose([1, 0, 2, 3]).reshape(
                (self.num_tables, self.tt_p_shapes[1], -1)
            )
            W2 = _gen_tail(shapes[2], sigma=0.01)
            W2 = W2 * scale
            W2 = W2.astype(np.float32)
            W2 = W2.transpose([1, 0, 2, 3]).reshape(
                (self.num_tables, self.tt_p_shapes[2], -1)
            )
            self.tt_cores[0].data = torch.tensor(W0, requires_grad=True)
            self.tt_cores[1].data = torch.tensor(W1, requires_grad=True)
            self.tt_cores[2].data = torch.tensor(W2, requires_grad=True)