in training/flax/convert_train_state_to_hf.py [0:0]
def main():
# 1. Parse input arguments
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser(
(
ModelArguments,
Seq2SeqTrainingArguments,
)
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, training_args = parser.parse_args_into_dataclasses()
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name,
token=training_args.hub_token,
)
else:
repo_name = training_args.hub_model_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(
training_args.output_dir,
clone_from=repo_name,
token=training_args.hub_token,
)
# 5. Load pretrained config, model and processor
config = AutoConfig.from_pretrained(
(model_args.config_name if model_args.config_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
config=config,
dtype=getattr(jnp, model_args.dtype),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
_do_init=False,
use_scan=model_args.load_with_scan_weights,
)
# enable scan / gradient checkpointing if necessary in the student model
if model_args.use_scan:
student_model.enable_scan() # to enable scan in the nn.Module
student_params = student_model.convert_unroll_to_scan(student_params) # to convert the unrolled params to scan
# Initialize our student state
rng = jax.random.PRNGKey(training_args.seed)
rng, dropout_rng = jax.random.split(rng)
total_train_steps = int(training_args.max_steps)
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
total_train_steps,
training_args.lr_scheduler_type,
training_args.warmup_steps,
training_args.learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = [
"layer_norm",
"self_attn_layer_norm",
"final_layer_norm",
"encoder_attn_layer_norm",
]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: path[-1] != "bias" and path[-2:] not in layer_norm_named_params for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
student_state = TrainState.create(
apply_fn=student_model.__call__,
params=student_params,
tx=adamw,
dropout_rng=dropout_rng,
max_grad_norm=training_args.max_grad_norm,
)
if training_args.resume_from_checkpoint is not None:
if os.path.isfile(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")):
logger.info(
f"Checkpoint detected, resuming training at {training_args.resume_from_checkpoint}. To avoid "
"this behavior, omit the resume_from_checkpoint argument."
)
with Path(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")).open("rb") as f:
student_state = from_bytes(student_state, f.read())
else:
logger.warning(
f"Checkpoint {training_args.resume_from_checkpoint} not detected, training from scratch. Ensure "
f"you pass the path to a folder with a valid checkpoint for your model."
)
cur_step = int(jax.device_get(student_state.step))
# save weights in HF Transformers format
if jax.process_index() == 0:
student_model.disable_scan()
student_state_params = student_model.convert_scan_to_unroll(student_state.params)
student_params = jax.device_get(student_state_params)
student_model.save_pretrained(
os.path.join(training_args.output_dir, f"checkpoint-{cur_step}"), params=student_params
)
if training_args.push_to_hub:
repo.push_to_hub(
commit_message=f"Saving weights of step {cur_step}",
blocking=False,
)