in models/attentive_nas_dynamic_model.py [0:0]
def crossover_and_reset(self, cfg1, cfg2, p=0.5):
def _cross_helper(g1, g2, prob):
assert type(g1) == type(g2)
if isinstance(g1, int):
return g1 if random.random() < prob else g2
elif isinstance(g1, list):
return [v1 if random.random() < prob else v2 for v1, v2 in zip(g1, g2)]
else:
raise NotImplementedError
cfg = {}
cfg['resolution'] = cfg1['resolution'] if random.random() < p else cfg2['resolution']
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
cfg[k] = _cross_helper(cfg1[k], cfg2[k], p)
self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg