in submix.py [0:0]
def __init__(self, device, B, eps, public_model, ensemble,
alpha=float('inf'), gamma=10, lambda_dec_factor=0.93, consumption_multiplier=1.0,
lambda_solver='iteration', temp=1.0):
self.alpha = alpha
self.B = B
self.queries_remaining = B
self.eps_remaining = [eps]*len(ensemble)
self.pairing = self._random_pairing(len(ensemble))
self.STOP = False
self.public_model = public_model
self.ensemble = ensemble
LMs = copy.copy(ensemble)
LMs.insert(0,self.public_model)
self.LMs = LMs
self.temp = temp
self.device = device
self.queries_remaining = B
self.tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
self.target_consumption = consumption_multiplier*eps/B
assert lambda_dec_factor < 1.0, f'lambda_dec_factor should be less than 1'
self.lambda_dec_factor = lambda_dec_factor
self.gamma = gamma
self.lambs = []
self.epsilons = []
self.lambda_solver = lambda_solver