def string2comm_type()

in workload_generator/analysis_pytorch_trace.py [0:0]


    def string2comm_type(self, s):
        if "all_gather" in s or "_all_gather_base" in s or "_allgather_base" in s:
            return CommType.all_gather
        if "reduce_scatter" in s or "_reduce_scatter_base" in s:
            return CommType.reduce_scatter
        if "all_reduce" in s:
            return CommType.all_reduce
        if "broadcast" in s:
            return CommType.broadcast
        if "barrier" in s:
            return CommType.barrier
        if "reduce" in s:
            return CommType.reduce
        else:
            print(f"can not convert {s} to any comm type")
            exit(0)