in banding_removal/fastmri/spawn_dist.py [0:0]
def run(args=None, ntasks=None):
if args is None:
args = Args().parse_args()
if isinstance(args, dict):
args = Args(**args).parse_args()
# Some automatic ntask settings code
if ntasks is None:
try:
devices = os.environ['CUDA_VISIBLE_DEVICES']
ntasks = len(devices.split(','))
except:
try:
ntasks = int(os.popen("nvidia-smi -L | wc -l").read())
except:
ntasks = 2
args.is_distributed = True
# Temp ignore for bug in pytorch dataloader, it leaks semaphores
os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning,ignore::UserWarning'
# Make this process the head of a process group.
os.setpgrp()
# Most important line in this file. CUDA fails horribly if we use the default fork
multiprocessing.set_start_method('forkserver')
processses = []
for i in range(ntasks):
p = multiprocessing.Process(target=work, args=[(i, ntasks, args)])
p.start()
if args.strace:
# Addtional Monitoring process
subprocess.Popen(["strace", "-tt" ,
"-o", f"{args.exp_dir}/strace_{i}.log",
"-e", "trace=write", "-s256",
"-p", f"{p.pid}"])
processses.append(p)
def terminate(signum, frame):
# Try terminate first
print("terminating child processes")
sys.stdout.flush()
for i, p in enumerate(processses):
if p.is_alive():
p.terminate()
# Wait a little
for i in range(20):
if any(p.is_alive() for p in processses):
sleep(0.1)
## If they are still alive after that kill -9 them
for i, p in enumerate(processses):
if p.is_alive():
print(f"Sending SIGKILL to process {i}")
sys.stdout.flush()
os.kill(p.pid, signal.SIGKILL)
print("exiting")
sys.stdout.flush()
sys.exit(0)
if args.auto_requeue:
def forward_usr1_signal(signum, frame):
print(f"Received USR1 signal in spawn_dist", flush=True)
for i, p in enumerate(processses):
if p.is_alive():
os.kill(p.pid, signal.SIGUSR1)
def forward_term_signal(signum, frame):
print(f"Received SIGTERM signal in spawn_dist", flush=True)
for i, p in enumerate(processses):
if p.is_alive():
os.kill(p.pid, signal.SIGTERM)
# For requeing we need to ignore SIGTERM, and forward USR1
signal.signal(signal.SIGUSR1, forward_usr1_signal)
signal.signal(signal.SIGTERM, forward_term_signal)
signal.signal(signal.SIGINT, terminate)
else:
signal.signal(signal.SIGINT, terminate)
signal.signal(signal.SIGTERM, terminate)
while True:
sleep(0.5)
if any(not p.is_alive() for p in processses):
print("Detected an exited process, so exiting main")
terminate(None, None)
print("DONE")