def string2comm_type()

in log_analyzer/ds_comm_log_analyzer.py [0:0]


def string2comm_type(s):
    if "all_gather" in s:
        return CommType.all_gather
    if "reduce_scatter" in s:
        return CommType.reduce_scatter
    if "all_reduce" in s:
        return CommType.all_reduce
    if "broadcast" in s:
        return CommType.broadcast
    if "barrier" in s:
        return CommType.barrier
    if "reduce" in s:
        return CommType.reduce
    print(f"WARNING cannot convert {s} to CommType")
    return CommType.epoch_end