def load_data()

in simulation/decai/simulation/data/news_data_loader.py [0:0]


    def load_data(self, train_size: int = None, test_size: int = None) -> (tuple, tuple):
        data_folder_path = os.path.join(__file__, '../../../../training_data/news')
        signal_data_path = os.path.join(data_folder_path, 'sample-1M.jsonl')
        if not os.path.exists(signal_data_path):
            raise Exception(f"Could not find the Signal Media dataset at \"{signal_data_path}\"."
                            "\nYou must obtain it from http://research.signalmedia.co/newsir16/signal-dataset.html"
                            f" and follow the instructions to obtain it. Then extract it to \"{signal_data_path}\".")

        sources_path = os.path.join(data_folder_path, 'sources.json')
        if not os.path.exists(sources_path):
            raise Exception(f"Could not find the sources dataset at \"{sources_path}\"."
                            "\nYou must obtain it from https://github.com/OpenSourcesGroup/opensources and put"
                            f" sources.json in \"{data_folder_path}\".")

        with open(sources_path) as f:
            loaded_sources = json.load(f)
        sources = dict()
        for source, info in loaded_sources.items():
            problem_types = (info['type'], info['2nd type'], info['3rd type'])
            sources[source] = set(filter(None, problem_types))
        self._logger.info("Found %d sources with labels.", len(sources))

        # Name: website name in `sources`.
        source_mapping = {}
        not_found_flag = -1
        with open(signal_data_path) as f:
            for index, line in tqdm(enumerate(f),
                                    desc="Filtering news articles",
                                    unit_scale=True, mininterval=2, unit=" articles"
                                    ):
                news = json.loads(line)
                news_id = news['id']
                title = news['title']
                text = news['content']
                source = news['source']
                # media-type is either "News" or "Blog"
                media_type = news['media-type']
                published_date = news['published']
                if media_type not in self._media_types:
                    continue
                source_site = source_mapping.get(source)
                if source_site is None:
                    source_site = self.find_source_site(source, sources)
                    if source_site is not None:
                        source_mapping[source] = source_site
                    else:
                        source_mapping[source] = not_found_flag
                        continue
                elif source_site == not_found_flag:
                    continue
                # TODO Use article and set label.

        with open(os.path.join(data_folder_path, 'source_mapping.json')) as f:
            sorted(source_mapping.items(), key=itemgetter(0))

        self._logger.info("Found %d sources in the articles.", len(source_mapping))

        # TODO Set up output.
        (x_train, y_train), (x_test, y_test) = (None, None), (None, None)
        if train_size is not None:
            x_train, y_train = x_train[:train_size], y_train[:train_size]
        if test_size is not None:
            x_test, y_test = x_test[:test_size], y_test[:test_size]

        self._logger.info("Done loading news data.")
        return (x_train, y_train), (x_test, y_test)