def run()

in banding_removal/fastmri/spawn_dist.py [0:0]


def run(args=None, ntasks=None):
    if args is None:
        args = Args().parse_args()
    if isinstance(args, dict):
        args = Args(**args).parse_args()

    # Some automatic ntask settings code
    if ntasks is None:
        try:
            devices = os.environ['CUDA_VISIBLE_DEVICES']
            ntasks = len(devices.split(','))
        except:
            try:
                ntasks = int(os.popen("nvidia-smi -L | wc -l").read())
            except:
                ntasks = 2

    args.is_distributed = True
    # Temp ignore for bug in pytorch dataloader, it leaks semaphores
    os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning,ignore::UserWarning'

    # Make this process the head of a process group.
    os.setpgrp()

    # Most important line in this file. CUDA fails horribly if we use the default fork
    multiprocessing.set_start_method('forkserver')

    processses = []
    for i in range(ntasks):
        p = multiprocessing.Process(target=work, args=[(i, ntasks, args)])
        p.start()

        if args.strace:
            # Addtional Monitoring process
            subprocess.Popen(["strace", "-tt" , 
                "-o", f"{args.exp_dir}/strace_{i}.log", 
                "-e", "trace=write", "-s256", 
                "-p", f"{p.pid}"])

        processses.append(p)

    def terminate(signum, frame):
        # Try terminate first
        print("terminating child processes")
        sys.stdout.flush()
        for i, p in enumerate(processses):
            if p.is_alive():
                p.terminate()

        # Wait a little
        for i in range(20):
            if any(p.is_alive() for p in processses):
                sleep(0.1)

        ## If they are still alive after that kill -9 them
        for i, p in enumerate(processses):
            if p.is_alive():
                print(f"Sending SIGKILL to process {i}")
                sys.stdout.flush()
                os.kill(p.pid, signal.SIGKILL)

        print("exiting")
        sys.stdout.flush()
        sys.exit(0)


    if args.auto_requeue:
        def forward_usr1_signal(signum, frame):
            print(f"Received USR1 signal in spawn_dist", flush=True)
            for i, p in enumerate(processses):
                if p.is_alive():
                    os.kill(p.pid, signal.SIGUSR1)
        
        def forward_term_signal(signum, frame):
            print(f"Received SIGTERM signal in spawn_dist", flush=True)
            for i, p in enumerate(processses):
                if p.is_alive():
                    os.kill(p.pid, signal.SIGTERM)

        # For requeing we need to ignore SIGTERM, and forward USR1
        signal.signal(signal.SIGUSR1, forward_usr1_signal)
        signal.signal(signal.SIGTERM, forward_term_signal)
        signal.signal(signal.SIGINT, terminate)

    else:
        signal.signal(signal.SIGINT, terminate)
        signal.signal(signal.SIGTERM, terminate)

    while True:
        sleep(0.5)
        if any(not p.is_alive() for p in processses):
            print("Detected an exited process, so exiting main")
            terminate(None, None)

    print("DONE")