pytext/models/representations/lightconv.py [173:231]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def reorder_encoder_out(self, encoder_out: Dict[str, Tensor], new_order: Tensor):
        encoder = encoder_out["encoder_out"]
        encoder = encoder.index_select(1, new_order)
        output_dict = {"encoder_out": encoder}

        output_dict["src_tokens"] = encoder_out["src_tokens"].index_select(0, new_order)
        padding_mask = encoder_out.get("encoder_mask", None)
        if padding_mask is not None:
            padding_mask = padding_mask.index_select(0, new_order)
            output_dict["encoder_mask"] = padding_mask
        return output_dict

    def pos_embed(self, x, src_tokens):
        if self.combine_pos_embed == PostionalEmbedCombine.SUM.value:
            x = self.project_in_dim(x)
            return self._vanilla_transformer(x, src_tokens)
        elif self.combine_pos_embed == PostionalEmbedCombine.CONCAT.value:
            return self._concat_pos_embed(x, src_tokens)
        else:
            raise NotImplementedError("Method not supported")

    def _vanilla_transformer(self, x, src_tokens):
        x += self.embed_positions(src_tokens)
        return x

    def _concat_pos_embed(self, x, src_tokens):
        pos_embed = self.embed_positions(src_tokens)
        return torch.cat([x, pos_embed], dim=2)

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.no_token_positional_embeddings:
            return self.max_source_positions
        return min(self.max_source_positions, self.embed_positions.max_positions())

    def tile_encoder_out(
        self, tile_size: int, encoder_out: Dict[str, Tensor]
    ) -> Dict[str, Tensor]:
        tiled_out = torch.jit.annotate(Dict[str, Tensor], {})

        x = encoder_out["encoder_out"]
        new_x = x.repeat(1, tile_size, 1)
        tiled_out["encoder_out"] = new_x

        if "encoder_mask" in encoder_out:
            new_encoder_mask = encoder_out["encoder_mask"].repeat(tile_size, 1)
            tiled_out["encoder_mask"] = new_encoder_mask
        if "src_tokens" in encoder_out:
            tiled_out["src_tokens"] = encoder_out["src_tokens"].repeat(tile_size, 1)
        if "src_lengths" in encoder_out:
            tiled_out["src_lengths"] = encoder_out["src_lengths"].repeat(tile_size, 1)

        return tiled_out

    def extra_repr(self):
        s = "dropout={}, embed_scale={}, normalize={}".format(
            self.dropout, self.embed_scale, self.normalize
        )
        return s
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



pytext/models/seq_models/conv_encoder.py [317:375]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def reorder_encoder_out(self, encoder_out: Dict[str, Tensor], new_order: Tensor):
        encoder = encoder_out["encoder_out"]
        encoder = encoder.index_select(1, new_order)
        output_dict = {"encoder_out": encoder}

        output_dict["src_tokens"] = encoder_out["src_tokens"].index_select(0, new_order)
        padding_mask = encoder_out.get("encoder_mask", None)
        if padding_mask is not None:
            padding_mask = padding_mask.index_select(0, new_order)
            output_dict["encoder_mask"] = padding_mask
        return output_dict

    def pos_embed(self, x, src_tokens):
        if self.combine_pos_embed == PostionalEmbedCombine.SUM.value:
            x = self.project_in_dim(x)
            return self._vanilla_transformer(x, src_tokens)
        elif self.combine_pos_embed == PostionalEmbedCombine.CONCAT.value:
            return self._concat_pos_embed(x, src_tokens)
        else:
            raise NotImplementedError("Method not supported")

    def _vanilla_transformer(self, x, src_tokens):
        x += self.embed_positions(src_tokens)
        return x

    def _concat_pos_embed(self, x, src_tokens):
        pos_embed = self.embed_positions(src_tokens)
        return torch.cat([x, pos_embed], dim=2)

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.no_token_positional_embeddings:
            return self.max_source_positions
        return min(self.max_source_positions, self.embed_positions.max_positions())

    def tile_encoder_out(
        self, tile_size: int, encoder_out: Dict[str, Tensor]
    ) -> Dict[str, Tensor]:
        tiled_out = torch.jit.annotate(Dict[str, Tensor], {})

        x = encoder_out["encoder_out"]
        new_x = x.repeat(1, tile_size, 1)
        tiled_out["encoder_out"] = new_x

        if "encoder_mask" in encoder_out:
            new_encoder_mask = encoder_out["encoder_mask"].repeat(tile_size, 1)
            tiled_out["encoder_mask"] = new_encoder_mask
        if "src_tokens" in encoder_out:
            tiled_out["src_tokens"] = encoder_out["src_tokens"].repeat(tile_size, 1)
        if "src_lengths" in encoder_out:
            tiled_out["src_lengths"] = encoder_out["src_lengths"].repeat(tile_size, 1)

        return tiled_out

    def extra_repr(self):
        s = "dropout={}, embed_scale={}, normalize={}".format(
            self.dropout, self.embed_scale, self.normalize
        )
        return s
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



