def build_hrnet()

in models/vision/detection/awsdet/models/backbones/hrnet.py [0:0]


def build_hrnet(model_name, include_top=True):
    # CONFIG W32C
    model_w32c = dict(type='HRNet',
                 num_stages=4,
                 stem=dict(
                     channels=64,
                     kernel_size=3,
                     stride=2,
                     padding='same',
                     use_bias=False,
                     act_cfg=dict(type='relu', ),
                     norm_cfg=dict(
                         type='BN',
                         axis=-1,
                         momentum=0.9,
                         eps=1e-5,
                     ),
                     weight_decay=5e-5,
                 ),
                 stage1=dict(
                     name='s1',
                     num_modules=1,
                     num_branches=1,
                     num_blocks=(4, ),
                     num_channels=(64, ),
                     expansion = 4,
                     act_cfg=dict(type='relu', ),
                     norm_cfg=dict(
                         type='BN',
                         axis=-1,
                         momentum=0.9,
                         eps=1e-5,
                     ),
                     weight_decay=5e-5,
                 ),
                 stage2=dict(
                     name='s2',
                     num_modules=1,
                     num_branches=2,
                     num_blocks=(4, 4),
                     num_channels=(32, 64),
                     expansion = 1,
                     act_cfg=dict(type='relu', ),
                     norm_cfg=dict(
                         type='BN',
                         axis=-1,
                         momentum=0.9,
                         eps=1e-5,
                     ),
                     weight_decay=5e-5,
                 ),
                 stage3=dict(
                     name='s3',
                     num_modules=4,
                     num_branches=3,
                     num_blocks=(4, 4, 4),
                     num_channels=(32, 64, 128),
                     expansion = 1,
                     act_cfg=dict(type='relu', ),
                     norm_cfg=dict(
                         type='BN',
                         axis=-1,
                         momentum=0.9,
                         eps=1e-5,
                     ),
                     weight_decay=5e-5,
                 ),
                 stage4=dict(
                     name='s4',
                     num_modules=3,
                     num_branches=4,
                     num_blocks=(4, 4, 4, 4),
                     num_channels=(32, 64, 128, 256),
                     expansion = 1,
                     act_cfg=dict(type='relu', ),
                     norm_cfg=dict(
                         type='BN',
                         axis=-1,
                         momentum=0.9,
                         eps=1e-5,
                     ),
                     weight_decay=5e-5,
                 ),
                 head=dict(
                     name='cls_head',
                     channels=(32, 64, 128, 256),
                     expansion = 4,
                     act_cfg=dict(type='relu', ),
                     norm_cfg=dict(
                         type='BN',
                         axis=-1,
                         momentum=0.9,
                         eps=1e-5,
                     ),
                     weight_decay=5e-5,
                 ))

    # CONFIG W18C
    model_w18c = dict(type='HRNet',
                 num_stages=4,
                 stem=dict(
                     channels=64,
                     kernel_size=3,
                     stride=2,
                     padding='same',
                     use_bias=False,
                     act_cfg=dict(type='relu', ),
                     norm_cfg=dict(
                         type='BN',
                         axis=-1,
                         momentum=0.9,
                         eps=1e-5,
                     ),
                     weight_decay=5e-5,
                 ),
                 stage1=dict(
                     name='s1',
                     num_modules=1,
                     num_branches=1,
                     num_blocks=(4, ),
                     num_channels=(64, ),
                     expansion = 4,
                     act_cfg=dict(type='relu', ),
                     norm_cfg=dict(
                         type='BN',
                         axis=-1,
                         momentum=0.9,
                         eps=1e-5,
                     ),
                     weight_decay=5e-5,
                 ),
                 stage2=dict(
                     name='s2',
                     num_modules=1,
                     num_branches=2,
                     num_blocks=(4, 4),
                     num_channels=(18, 36),
                     expansion = 1,
                     act_cfg=dict(type='relu', ),
                     norm_cfg=dict(
                         type='BN',
                         axis=-1,
                         momentum=0.9,
                         eps=1e-5,
                     ),
                     weight_decay=5e-5,
                 ),
                 stage3=dict(
                     name='s3',
                     num_modules=4,
                     num_branches=3,
                     num_blocks=(4, 4, 4),
                     num_channels=(18, 36, 72),
                     expansion = 1,
                     act_cfg=dict(type='relu', ),
                     norm_cfg=dict(
                         type='BN',
                         axis=-1,
                         momentum=0.9,
                         eps=1e-5,
                     ),
                     weight_decay=5e-5,
                 ),
                 stage4=dict(
                     name='s4',
                     num_modules=3,
                     num_branches=4,
                     num_blocks=(4, 4, 4, 4),
                     num_channels=(18, 36, 72, 144),
                     expansion = 1,
                     act_cfg=dict(type='relu', ),
                     norm_cfg=dict(
                         type='BN',
                         axis=-1,
                         momentum=0.9,
                         eps=1e-5,
                     ),
                     weight_decay=5e-5,
                 ),
                 head=dict(
                     name='cls_head',
                     channels=(32, 64, 128, 256),
                     expansion = 4,
                     act_cfg=dict(type='relu', ),
                     norm_cfg=dict(
                         type='BN',
                         axis=-1,
                         momentum=0.9,
                         eps=1e-5,
                     ),
                     weight_decay=5e-5,
                 ))



    train_cfg = dict(weight_decay=5e-5, )
    dataset_type = 'imagenet'
    dataset_mean = ()
    dataset_std = ()
    data_root = '/data/imagenet'
    data = dict(
        imgs_per_gpu=128,
        train=dict(
            type=dataset_type,
            train=True,
            dataset_dir=data_root,
            tf_record_pattern='train-*',
            resize_dim=256,
            crop_dim=224,
            augment=True,
            mean=(),
            std=(),
        ),
        val=dict(
            type=dataset_type,
            train=False,
            dataset_dir=data_root,
            tf_record_pattern='val-*',
            resize_dim=256,
            crop_dim=224,
            augment=False,
            mean=(),
            std=(),
        ),
    )
    evaluation = dict(interval=1)
    # optimizer
    optimizer = dict(
        type='SGD',
        learning_rate=1e-2,
        momentum=0.9,
        nesterov=True,
    )
    # extra options related to optimizers
    optimizer_config = dict(amp_enabled=True, )
    # learning policy
    lr_config = dict(policy='step',
                     warmup='linear',
                     warmup_epochs=5,
                     warmup_ratio=1.0 / 3,
                     step=[30, 60, 90])


    checkpoint_config = dict(interval=1, outdir='checkpoints')
    log_config = dict(interval=50, )
    total_epochs = 100,
    log_level = 'INFO'
    work_dir = './work_dirs/hrnet_w32_cls'
    resume_from = None

    if model_name == 'hrnet_w18c':
        hrnet = HRNet(model_w18c, include_top)
    elif model_name == 'hrnet_w32c':
        hrnet = HRNet(model_w32c, include_top)
    # set input layer
    inputs = layers.Input(shape=(None, None, 3))
    outputs = hrnet(inputs, training=False)
    # hrnet.summary()
    return hrnet