def load_model()

in tensorflow_similarity/models/contrastive_model.py [0:0]


def load_model(filepath):
    spath = Path(filepath)
    backbone_path = spath / "backbone"
    proj_path = spath / "projector"
    pred_path = spath / "predictor"
    optw_path = spath / "opt_weights.npy"
    config_path = spath / "config.json"

    backbone = tf.keras.models.load_model(backbone_path)
    projector = tf.keras.models.load_model(proj_path)

    if os.path.isdir(pred_path):
        predictor = tf.keras.models.load_model(pred_path)
    else:
        predictor = None

    with open(config_path, "r") as f:
        config = json.load(f)
        metadata = json.loads(config)

    optimizer = tf.keras.optimizers.deserialize(metadata["optimizer"])
    opt_weights = np.load(optw_path, allow_pickle=True)
    loss = tf.keras.losses.deserialize(metadata["loss"])
    metrics = [tf.keras.metrics.deserialize(m) for m in metadata["metrics"]]
    algorithm = metadata["algorithm"]

    model = ContrastiveModel(
        backbone=backbone,
        projector=projector,
        predictor=predictor,
        algorithm=algorithm,
    )

    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    # collect train variables from both the backbone, projector, and projector
    tvars = model.backbone.trainable_variables
    tvars = tvars + model.projector.trainable_variables
    if model.predictor is not None:
        tvars = tvars + model.predictor.trainable_variables

    # dummy zero gradients
    zero_grads = [tf.zeros_like(w) for w in tvars]
    # save current state of variables
    saved_vars = [tf.identity(w) for w in tvars]

    optimizer.apply_gradients(zip(zero_grads, tvars))

    # Reload variables
    [x.assign(y) for x, y in zip(tvars, saved_vars)]

    optimizer.set_weights(opt_weights)

    return model