# Copyright (c) Facebook, Inc. and its affiliates.

import json

DEFAULTS = {
    "network": {
        "arch": "resnet101",
        "activation": "relu",  # supported: relu, leaky_relu, elu, identity
        "activation_param": 0.01,  # slope for leaky_relu, alpha for elu
        "input_3x3": False,
        "bn_mode": "standard",  # supported: standard, inplace, sync
        "classes": 1000,
        "dilation": 1,
        "weight_gain_multiplier": 1,  # note: this is ignored if weight_init == kaiming_*
        "weight_init": "xavier_normal",  # supported: xavier_[normal,uniform], kaiming_[normal,uniform], orthogonal
    },
    "optimizer": {
        "batch_size": 256,
        "type": "SGD",  # supported: SGD, Adam
        "momentum": 0.9,
        "weight_decay": 1e-4,
        "clip": 1.0,
        "learning_rate": 0.1,
        "classifier_lr": -1.0,  # If -1 use same learning rate as the rest of the network
        "nesterov": False,
        "schedule": {
            "type": "constant",  # supported: constant, step, multistep, exponential, linear
            "mode": "epoch",  # supported: epoch, step
            "epochs": 10,
            "params": {},
        },
    },
    "input": {
        "scale_train": -1,  # If -1 do not scale
        "crop_train": 224,
        "color_jitter_train": False,
        "lighting_train": False,
        "scale_val": 256,  # If -1 do not scale
        "crop_val": 224,
        "mean": [0.485, 0.456, 0.406],
        "std": [0.229, 0.224, 0.225],
    },
}


def _merge(src, dst):
    for k, v in src.items():
        if k in dst:
            if isinstance(v, dict):
                _merge(src[k], dst[k])
        else:
            dst[k] = v


def load_config(config_file, defaults=DEFAULTS):
    with open(config_file, "r") as fd:
        config = json.load(fd)
    _merge(defaults, config)
    return config
