in tcav/tcav.py [0:0]
def run(self, num_workers=10, run_parallel=False, overwrite=False, return_proto=False):
"""Run TCAV for all parameters (concept and random), write results to html.
Args:
num_workers: number of workers to parallelize
run_parallel: run this parallel.
overwrite: if True, overwrite any saved CAV files.
return_proto: if True, returns results as a tcav.Results object; else,
return as a list of dicts.
Returns:
results: an object (either a Results proto object or a list of
dictionaries) containing metrics for TCAV results.
"""
# for random exp, a machine with cpu = 30, ram = 300G, disk = 10G and
# pool worker 50 seems to work.
tf.compat.v1.logging.info('running %s params' % len(self.params))
results = []
now = time.time()
if run_parallel:
pool = multiprocessing.Pool(num_workers)
for i, res in enumerate(pool.imap(
lambda p: self._run_single_set(
p, overwrite=overwrite, run_parallel=run_parallel),
self.params), 1):
tf.compat.v1.logging.info('Finished running param %s of %s' % (i, len(self.params)))
results.append(res)
pool.close()
else:
for i, param in enumerate(self.params):
tf.compat.v1.logging.info('Running param %s of %s' % (i, len(self.params)))
results.append(self._run_single_set(param, overwrite=overwrite, run_parallel=run_parallel))
tf.compat.v1.logging.info('Done running %s params. Took %s seconds...' % (len(
self.params), time.time() - now))
if return_proto:
return utils.results_to_proto(results)
else:
return results