def collate_csv()

in utils/benchmarking_utils.py [0:0]


def collate_csv(input_files: List[str], output_file: str, is_pixart=False):
    """Collates multiple identically structured CSVs into a single CSV file."""
    fields_copy = copy.deepcopy(BENCHMARK_FIELDS)
    fields = BENCHMARK_FIELDS
    if is_pixart:
        i = BENCHMARK_FIELDS.index("compile_unet")
        fields_copy[i] = "compile_transformer"
        fields = fields_copy
    with open(output_file, mode="w", newline="") as outfile:
        writer = csv.DictWriter(outfile, fieldnames=fields)
        writer.writeheader()

        for file in input_files:
            with open(file, mode="r") as infile:
                reader = csv.DictReader(infile)
                for row in reader:
                    writer.writerow(row)