sockeye/data_io.py [1640:1704]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        super().__init__(buckets=[bucket],
                         batch_size=batch_size,
                         bucket_batch_sizes=[BucketBatchSize(bucket, batch_size, None)],
                         num_source_factors=num_source_factors,
                         num_target_factors=num_target_factors,
                         permute=False,
                         dtype=dtype)
        self.data_loader = data_loader
        self.sources_sentences, self.targets_sentences = create_sequence_readers(sources, targets,
                                                                                 source_vocabs, target_vocabs)
        self.sources_iters = [iter(s) for s in self.sources_sentences]
        self.targets_iters = [iter(s) for s in self.targets_sentences]
        self.max_len_source, self.max_len_target = max_lens
        self.next_batch = None  # type: Optional[Batch]
        self.sentno = 1

    def reset(self):
        raise Exception('Not supported!')

    def iter_next(self) -> bool:
        """
        True if the iterator can return another batch.
        """

        # Read batch_size lines from the source stream
        sources_sentences = [[] for _ in self.sources_sentences]  # type: List[List[str]]
        targets_sentences = [[] for _ in self.targets_sentences]  # type: List[List[str]]
        num_read = 0
        for num_read, (sources, targets) in enumerate(
                parallel_iterate(self.sources_iters, self.targets_iters, skip_blanks=False), 1):
            source_len = 0 if sources[0] is None else len(sources[0])
            target_len = 0 if targets[0] is None else len(targets[0])
            if source_len > self.max_len_source:
                logger.debug("Trimming source sentence {} ({} -> {})".format(self.sentno + num_read,
                                                                            source_len,
                                                                            self.max_len_source))
                sources = [source[0: self.max_len_source] for source in sources]
            if target_len > self.max_len_target:
                logger.debug("Trimming target sentence {} ({} -> {})".format(self.sentno + num_read,
                                                                            target_len,
                                                                            self.max_len_target))
                targets = [target[0: self.max_len_target] for target in targets]

            for i, source in enumerate(sources):
                sources_sentences[i].append(source)
            for i, target in enumerate(targets):
                targets_sentences[i].append(target)
            if num_read == self.batch_size:
                break

        aux = int(self.sentno / 1_000_000)
        self.sentno += num_read
        if int(self.sentno / 1_000_000) != aux:
            logger.info("Processed {} lines".format(self.sentno))

        if num_read == 0:
            self.next_batch = None
            return False

        dataset = self.data_loader.load(sources_sentences, targets_sentences, [num_read])

        source = dataset.source[0]
        target, label = create_target_and_shifted_label_sequences(dataset.target[0])
        self.next_batch = create_batch_from_parallel_sample(source, target, label)
        return True
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



sockeye/data_io_pt.py [1607:1671]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        super().__init__(buckets=[bucket],
                         batch_size=batch_size,
                         bucket_batch_sizes=[BucketBatchSize(bucket, batch_size, None)],
                         num_source_factors=num_source_factors,
                         num_target_factors=num_target_factors,
                         permute=False,
                         dtype=dtype)
        self.data_loader = data_loader
        self.sources_sentences, self.targets_sentences = create_sequence_readers(sources, targets,
                                                                                 source_vocabs, target_vocabs)
        self.sources_iters = [iter(s) for s in self.sources_sentences]
        self.targets_iters = [iter(s) for s in self.targets_sentences]
        self.max_len_source, self.max_len_target = max_lens
        self.next_batch = None  # type: Optional[Batch]
        self.sentno = 1

    def reset(self):
        raise Exception('Not supported!')

    def iter_next(self) -> bool:
        """
        True if the iterator can return another batch.
        """

        # Read batch_size lines from the source stream
        sources_sentences = [[] for _ in self.sources_sentences]  # type: List[List[str]]
        targets_sentences = [[] for _ in self.targets_sentences]  # type: List[List[str]]
        num_read = 0
        for num_read, (sources, targets) in enumerate(
                parallel_iterate(self.sources_iters, self.targets_iters, skip_blanks=False), 1):
            source_len = 0 if sources[0] is None else len(sources[0])
            target_len = 0 if targets[0] is None else len(targets[0])
            if source_len > self.max_len_source:
                logger.debug("Trimming source sentence {} ({} -> {})".format(self.sentno + num_read,
                                                                            source_len,
                                                                            self.max_len_source))
                sources = [source[0: self.max_len_source] for source in sources]
            if target_len > self.max_len_target:
                logger.debug("Trimming target sentence {} ({} -> {})".format(self.sentno + num_read,
                                                                            target_len,
                                                                            self.max_len_target))
                targets = [target[0: self.max_len_target] for target in targets]

            for i, source in enumerate(sources):
                sources_sentences[i].append(source)
            for i, target in enumerate(targets):
                targets_sentences[i].append(target)
            if num_read == self.batch_size:
                break

        aux = int(self.sentno / 1_000_000)
        self.sentno += num_read
        if int(self.sentno / 1_000_000) != aux:
            logger.info("Processed {} lines".format(self.sentno))

        if num_read == 0:
            self.next_batch = None
            return False

        dataset = self.data_loader.load(sources_sentences, targets_sentences, [num_read])

        source = dataset.source[0]
        target, label = create_target_and_shifted_label_sequences(dataset.target[0])
        self.next_batch = create_batch_from_parallel_sample(source, target, label)
        return True
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



