in utils/benchmarking_utils.py [0:0]
def collate_csv(input_files: List[str], output_file: str, is_pixart=False):
"""Collates multiple identically structured CSVs into a single CSV file."""
fields_copy = copy.deepcopy(BENCHMARK_FIELDS)
fields = BENCHMARK_FIELDS
if is_pixart:
i = BENCHMARK_FIELDS.index("compile_unet")
fields_copy[i] = "compile_transformer"
fields = fields_copy
with open(output_file, mode="w", newline="") as outfile:
writer = csv.DictWriter(outfile, fieldnames=fields)
writer.writeheader()
for file in input_files:
with open(file, mode="r") as infile:
reader = csv.DictReader(infile)
for row in reader:
writer.writerow(row)