in src/speech_reps/models/bertphone.py [0:0]
def __init__(self, model_name, prefix=None, params=None, **kwargs):
super(BertPhone, self).__init__(prefix=prefix, params=params)
predefined_args = bert_hparams[model_name]
mutable_args = ['use_residual', 'dropout', 'embed_dropout', 'word_embed']
mutable_args = frozenset(mutable_args)
assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \
'Cannot override predefined model settings.'
predefined_args.update(kwargs)
# encoder
self.encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'],
num_layers=predefined_args['num_layers'],
units=predefined_args['units'],
hidden_size=predefined_args['hidden_size'],
max_length=predefined_args['max_length'],
num_heads=predefined_args['num_heads'],
scaled=predefined_args['scaled'],
dropout=predefined_args['dropout'],
output_attention=False,
output_all_encodings=False,
use_residual=predefined_args['use_residual'])
self.embed = nn.Dense(predefined_args['embed_size'], flatten=False)