scripts/merge_audio_to_metadata.py (45 lines of code) (raw):
import numpy as np
import pandas as pd
from datasets import load_dataset, concatenate_datasets
from multiprocess import set_start_method
import argparse
if __name__ == "__main__":
set_start_method("spawn")
parser = argparse.ArgumentParser()
parser.add_argument("dataset_name", type=str, help="Repo id.")
parser.add_argument("metadata_dataset_name", type=str, help="Repo id.")
parser.add_argument("--configuration", default=None, type=str, help="Dataset configuration to use.")
parser.add_argument("--output_dir", default=None, type=str, help="If specified, save the dasaset on disk.")
parser.add_argument("--repo_id", default=None, type=str, help="If specified, push the model to the hub.")
parser.add_argument("--cpu_num_workers", default=1, type=int, help="Number of CPU workers.")
parser.add_argument("--strategy", default="concatenate", type=str, help="For now only concatenate.")
parser.add_argument("--id_column_name", default="id", type=str, help="For now only concatenate.") # TODO
parser.add_argument("--columns_to_drop", default=None, type=str, help="Column names to drop in the metadataset. If some columns are duplicates. Separated by '+'. ")
args = parser.parse_args()
args = parser.parse_args()
if args.configuration:
dataset = load_dataset(args.dataset_name, args.configuration)
else:
dataset = load_dataset(args.dataset_name)
if args.configuration:
metadata_dataset = load_dataset(args.metadata_dataset_name, args.configuration)
else:
metadata_dataset = load_dataset(args.metadata_dataset_name)
columns_to_drop = None
if args.columns_to_drop is not None:
columns_to_drop = args.columns_to_drop.split("+")
metadata_dataset = metadata_dataset.remove_columns(columns_to_drop)
# TODO: for now suppose that they've kept the same ordering
for split in dataset:
if split in metadata_dataset:
dataset[split] = concatenate_datasets([dataset[split], metadata_dataset[split].rename_column(args.id_column_name, f"metadata_{args.id_column_name}")], axis=1)
else:
raise ValueError(f"Metadataset don't have the same split {split} than dataset")
if len(dataset[split].filter(lambda id1, id2: id1!=id2, input_columns=[args.id_column_name, f"metadata_{args.id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on split {split}")
if args.output_dir:
dataset.save_to_disk(args.output_dir)
if args.repo_id:
if args.configuration:
dataset.push_to_hub(args.repo_id, args.configuration)
else:
dataset.push_to_hub(args.repo_id)