in src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py [0:0]
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
layer_depth = []
for full_name, shape in init_vars:
# logger.info(f"Loading TF weight {name} with shape {shape}")
name = full_name.split("/")
if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
logger.info(f"Skipping non-model layer {full_name}")
continue
if "optimizer" in full_name:
logger.info(f"Skipping optimization layer {full_name}")
continue
if name[0] == "model":
# ignore initial 'model'
name = name[1:]
# figure out how many levels deep the name is
depth = 0
for _name in name:
if _name.startswith("layer_with_weights"):
depth += 1
else:
break
layer_depth.append(depth)
# read data
array = tf.train.load_variable(tf_path, full_name)
names.append("/".join(name))
arrays.append(array)
logger.info(f"Read a total of {len(arrays):,} layers")
# Sanity check
if len(set(layer_depth)) != 1:
raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
layer_depth = list(set(layer_depth))[0]
if layer_depth != 1:
raise ValueError(
"The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP"
" heads."
)
# convert layers
logger.info("Converting weights...")
for full_name, array in zip(names, arrays):
name = full_name.split("/")
pointer = model
trace = []
for i, m_name in enumerate(name):
if m_name == ".ATTRIBUTES":
# variable names end with .ATTRIBUTES/VARIABLE_VALUE
break
if m_name.startswith("layer_with_weights"):
layer_num = int(m_name.split("-")[-1])
if layer_num <= 2:
# embedding layers
# layer_num 0: word_embeddings
# layer_num 1: position_embeddings
# layer_num 2: token_type_embeddings
continue
elif layer_num == 3:
# embedding LayerNorm
trace.extend(["embeddings", "LayerNorm"])
pointer = getattr(pointer, "embeddings")
pointer = getattr(pointer, "LayerNorm")
elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
# encoder layers
trace.extend(["encoder", "layer", str(layer_num - 4)])
pointer = getattr(pointer, "encoder")
pointer = getattr(pointer, "layer")
pointer = pointer[layer_num - 4]
elif layer_num == config.num_hidden_layers + 4:
# pooler layer
trace.extend(["pooler", "dense"])
pointer = getattr(pointer, "pooler")
pointer = getattr(pointer, "dense")
elif m_name == "embeddings":
trace.append("embeddings")
pointer = getattr(pointer, "embeddings")
if layer_num == 0:
trace.append("word_embeddings")
pointer = getattr(pointer, "word_embeddings")
elif layer_num == 1:
trace.append("position_embeddings")
pointer = getattr(pointer, "position_embeddings")
elif layer_num == 2:
trace.append("token_type_embeddings")
pointer = getattr(pointer, "token_type_embeddings")
else:
raise ValueError(f"Unknown embedding layer with name {full_name}")
trace.append("weight")
pointer = getattr(pointer, "weight")
elif m_name == "_attention_layer":
# self-attention layer
trace.extend(["attention", "self"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "self")
elif m_name == "_attention_layer_norm":
# output attention norm
trace.extend(["attention", "output", "LayerNorm"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "LayerNorm")
elif m_name == "_attention_output_dense":
# output attention dense
trace.extend(["attention", "output", "dense"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "dense")
elif m_name == "_output_dense":
# output dense
trace.extend(["output", "dense"])
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "dense")
elif m_name == "_output_layer_norm":
# output dense
trace.extend(["output", "LayerNorm"])
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "LayerNorm")
elif m_name == "_key_dense":
# attention key
trace.append("key")
pointer = getattr(pointer, "key")
elif m_name == "_query_dense":
# attention query
trace.append("query")
pointer = getattr(pointer, "query")
elif m_name == "_value_dense":
# attention value
trace.append("value")
pointer = getattr(pointer, "value")
elif m_name == "_intermediate_dense":
# attention intermediate dense
trace.extend(["intermediate", "dense"])
pointer = getattr(pointer, "intermediate")
pointer = getattr(pointer, "dense")
elif m_name == "_output_layer_norm":
# output layer norm
trace.append("output")
pointer = getattr(pointer, "output")
# weights & biases
elif m_name in ["bias", "beta"]:
trace.append("bias")
pointer = getattr(pointer, "bias")
elif m_name in ["kernel", "gamma"]:
trace.append("weight")
pointer = getattr(pointer, "weight")
else:
logger.warning(f"Ignored {m_name}")
# for certain layers reshape is necessary
trace = ".".join(trace)
if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
r"(\S+)\.attention\.output\.dense\.weight", trace
):
array = array.reshape(pointer.data.shape)
if "kernel" in full_name:
array = array.transpose()
if pointer.shape == array.shape:
pointer.data = torch.from_numpy(array)
else:
raise ValueError(
f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:"
f" {array.shape}"
)
logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
return model