in extra_scripts/convert_vissl_to_classy_vision.py [0:0]
def main():
parser = argparse.ArgumentParser(
description="Convert SSL framework RN50 models to Classy Vision models"
)
parser.add_argument(
"--depth", type=int, default=50, help="Depth of the ResNet model to convert"
)
parser.add_argument(
"--num_fc_layers",
type=int,
default=1,
help="Number of linear layers in the head",
)
parser.add_argument(
"--include_heads",
default=False,
action="store_true",
help="Whether to convert the head as well or not",
)
parser.add_argument(
"--use_bn_head",
default=False,
action="store_true",
help="Whether BN is in the head or not",
)
parser.add_argument(
"--use_bias_head_fc",
default=False,
action="store_true",
help="Whether FC layers in head have bias param or not",
)
parser.add_argument(
"--output_head_prefix",
type=str,
default="embedding",
help="The [output_prefix]_fc.weight for the heads",
)
parser.add_argument(
"--input_model_file",
type=str,
default=None,
help="Path to input model weights to be converted",
)
parser.add_argument(
"--output_model",
type=str,
default=None,
help="Path to save converted RN-50 model",
)
parser.add_argument(
"--state_dict_key_name",
type=str,
default="classy_state_dict",
help="Pytorch model state_dict key name",
)
args = parser.parse_args()
# load the input model weights
logger.info("Loading weights...")
state_dict = torch.load(args.input_model_file)
assert (
args.state_dict_key_name in state_dict
), f"{args.state_dict_key_name} not found"
state_dict = state_dict[args.state_dict_key_name]
if args.state_dict_key_name == "classy_state_dict":
state_dict_head = state_dict["base_model"]["model"]["heads"]
state_dict = state_dict["base_model"]["model"]["trunk"]
else:
assert not args.include_heads, "Can't convert heads"
converted_trunk = convert_trunk_to_classy_model(state_dict, args.depth)
if args.include_heads:
converted_heads = convert_heads_to_classy_model(
state_dict_head,
args.output_head_prefix,
args.num_fc_layers,
args.use_bn_head,
args.use_bias_head_fc,
)
output_state_dict = {
"classy_state_dict": {
"base_model": {
"model": {"trunk": converted_trunk, "heads": converted_heads}
}
}
}
else:
output_state_dict = {
"classy_state_dict": {"base_model": {"model": {"trunk": converted_trunk}}}
}
logger.info("Saving converted weights to: {}".format(args.output_model))
torch.save(output_state_dict, args.output_model)
logger.info("Done!!")