in benchmarking/nursery/lstm_wikitext2/lstm_wikitext2.py [0:0]
def objective(config):
# print(args)
model_type = "rnn"
emsize = 200
nhid = emsize
nlayers = 2
eval_batch_size = 10
bptt = 35
tied = True
seed = np.random.randint(10000)
# log_interval = 200
# save = "./model.pt"
nhead = 2
dropout = config['dropout']
batch_size = config['batch_size']
clip = config['clip']
lr_factor = config['lr_factor']
report_current_best = parse_bool(config['report_current_best'])
trial_id = config.get('trial_id')
debug_log = trial_id is not None
if debug_log:
print("Trial {}: Starting evaluation".format(trial_id), flush=True)
torch.manual_seed(seed)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
#######################################################################
# Load data
#######################################################################
path = config['dataset_path']
os.makedirs(path, exist_ok=True)
# Lock protection is needed for backends which run multiple worker
# processes on the same instance
lock_path = os.path.join(path, 'lock')
lock = SoftFileLock(lock_path)
try:
with lock.acquire(timeout=120, poll_intervall=1):
corpus = Corpus(config['dataset_path'])
except Timeout:
print(
"WARNING: Could not obtain lock for dataset files. Trying anyway...",
flush=True)
corpus = Corpus(config['dataset_path'])
# Do not want to count the time to download the dataset, which can be
# substantial the first time
ts_start = time.time()
report = Reporter()
def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).t().contiguous()
return data.to(device)
train_data = batchify(corpus.train, batch_size)
val_data = batchify(corpus.valid, eval_batch_size)
# test_data = batchify(corpus.test, eval_batch_size)
#######################################################################
# Build the model
#######################################################################
ntokens = len(corpus.dictionary)
if model_type == "transformer":
model = TransformerModel(
ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
else:
model = RNNModel(
'LSTM', ntokens, emsize, nhid, nlayers, dropout,
tied).to(device)
criterion = nn.CrossEntropyLoss()
def get_batch(source, i):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i + seq_len]
target = source[i + 1:i + 1 + seq_len].view(-1)
return data, target
def evaluate(model, corpus, criterion, data_source):
# Turn on evaluation mode which disables dropout.
model.eval()
total_loss = 0.
ntokens = len(corpus.dictionary)
if model_type != 'transformer':
hidden = model.init_hidden(eval_batch_size)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, bptt):
data, targets = get_batch(data_source, i)
if model_type == 'transformer':
output = model(data)
else:
output, hidden = model(data, hidden)
hidden = repackage_hidden(hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
return total_loss / (len(data_source) - 1)
def train(model, corpus, criterion, train_data, lr, batch_size, clip):
# Turn on training mode which enables dropout.
model.train()
# total_loss = 0.
# start_time = time.time()
ntokens = len(corpus.dictionary)
if model_type != 'transformer':
hidden = model.init_hidden(batch_size)
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i)
# Starting each batch, we detach the hidden state from how it was previously produced.
# If we didn't, the model would try backpropagating all the way to start of the dataset.
model.zero_grad()
if model_type == 'transformer':
output = model(data)
else:
hidden = repackage_hidden(hidden)
output, hidden = model(data, hidden)
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
for p in model.parameters():
p.data.add_(p.grad.data, alpha=-lr)
# total_loss += loss.item()
# if batch % log_interval == 0 and batch > 0:
# cur_loss = total_loss / log_interval
# elapsed = time.time() - start_time
# print('| {:5d}/{:5d} batches | lr {:02.5f} | ms/batch {:5.2f} | '
# 'loss {:5.4f} | ppl {:8.2f}'.format(
# batch, len(train_data) // bptt, lr,
# elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
# total_loss = 0
# start_time = time.time()
# Checkpointing
# Note that `lr` and `best_val_loss` are also part of the state to be
# checkpointed. In order for things to work out, we keep them in a
# dict (otherwise, they'd not be mutable in `load_model_fn`,
# `save_model_fn`.
mutable_state = {
'lr': config['lr'],
'best_val_loss': None}
load_model_fn, save_model_fn = pytorch_load_save_functions(
{'model': model}, mutable_state)
# Resume from checkpoint (optional)
resume_from = resume_from_checkpointed_model(config, load_model_fn)
# Loop over epochs.
for epoch in range(resume_from + 1, config['epochs'] + 1):
train(model, corpus, criterion, train_data, mutable_state['lr'],
batch_size, clip)
val_loss = evaluate(model, corpus, criterion, val_data)
val_loss = np.clip(val_loss, 1e-10, 10)
# print('-' * 89)
# print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
# 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
# val_loss, math.exp(val_loss)))
# print('-' * 89)
elapsed_time = time.time() - ts_start
if not np.isfinite(val_loss):
val_loss = 7
best_val_loss = mutable_state['best_val_loss']
if not best_val_loss or val_loss < best_val_loss:
best_val_loss = val_loss
mutable_state['best_val_loss'] = val_loss
else:
# Anneal the learning rate if no improvement has been seen in the validation dataset.
mutable_state['lr'] /= lr_factor
# Feed the score back back to Tune.
_loss = best_val_loss if report_current_best else val_loss
objective = -math.exp(_loss)
report(**{RESOURCE_ATTR: epoch,
METRIC_NAME: objective,
ELAPSED_TIME_ATTR: elapsed_time})
# Write checkpoint (optional)
checkpoint_model_at_rung_level(config, save_model_fn, epoch)
if debug_log:
print("Trial {}: epoch = {}, objective = {:.3f}, elapsed_time = {:.2f}".format(
trial_id, epoch, objective, elapsed_time), flush=True)