in simulation/decai/simulation/data/news_data_loader.py [0:0]
def load_data(self, train_size: int = None, test_size: int = None) -> (tuple, tuple):
data_folder_path = os.path.join(__file__, '../../../../training_data/news')
signal_data_path = os.path.join(data_folder_path, 'sample-1M.jsonl')
if not os.path.exists(signal_data_path):
raise Exception(f"Could not find the Signal Media dataset at \"{signal_data_path}\"."
"\nYou must obtain it from http://research.signalmedia.co/newsir16/signal-dataset.html"
f" and follow the instructions to obtain it. Then extract it to \"{signal_data_path}\".")
sources_path = os.path.join(data_folder_path, 'sources.json')
if not os.path.exists(sources_path):
raise Exception(f"Could not find the sources dataset at \"{sources_path}\"."
"\nYou must obtain it from https://github.com/OpenSourcesGroup/opensources and put"
f" sources.json in \"{data_folder_path}\".")
with open(sources_path) as f:
loaded_sources = json.load(f)
sources = dict()
for source, info in loaded_sources.items():
problem_types = (info['type'], info['2nd type'], info['3rd type'])
sources[source] = set(filter(None, problem_types))
self._logger.info("Found %d sources with labels.", len(sources))
# Name: website name in `sources`.
source_mapping = {}
not_found_flag = -1
with open(signal_data_path) as f:
for index, line in tqdm(enumerate(f),
desc="Filtering news articles",
unit_scale=True, mininterval=2, unit=" articles"
):
news = json.loads(line)
news_id = news['id']
title = news['title']
text = news['content']
source = news['source']
# media-type is either "News" or "Blog"
media_type = news['media-type']
published_date = news['published']
if media_type not in self._media_types:
continue
source_site = source_mapping.get(source)
if source_site is None:
source_site = self.find_source_site(source, sources)
if source_site is not None:
source_mapping[source] = source_site
else:
source_mapping[source] = not_found_flag
continue
elif source_site == not_found_flag:
continue
# TODO Use article and set label.
with open(os.path.join(data_folder_path, 'source_mapping.json')) as f:
sorted(source_mapping.items(), key=itemgetter(0))
self._logger.info("Found %d sources in the articles.", len(source_mapping))
# TODO Set up output.
(x_train, y_train), (x_test, y_test) = (None, None), (None, None)
if train_size is not None:
x_train, y_train = x_train[:train_size], y_train[:train_size]
if test_size is not None:
x_test, y_test = x_test[:test_size], y_test[:test_size]
self._logger.info("Done loading news data.")
return (x_train, y_train), (x_test, y_test)