in workload_generator/analysis_pytorch_trace.py [0:0]
def string2comm_type(self, s):
if "all_gather" in s or "_all_gather_base" in s or "_allgather_base" in s:
return CommType.all_gather
if "reduce_scatter" in s or "_reduce_scatter_base" in s:
return CommType.reduce_scatter
if "all_reduce" in s:
return CommType.all_reduce
if "broadcast" in s:
return CommType.broadcast
if "barrier" in s:
return CommType.barrier
if "reduce" in s:
return CommType.reduce
else:
print(f"can not convert {s} to any comm type")
exit(0)