def main()

in src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_validate.py [0:0]


def main():
    args = parser.parse_args()
    args.gpu_id = 0
    if args.c2_prefix:
        args.c2_init = args.c2_prefix + '.init.pb'
        args.c2_predict = args.c2_prefix + '.predict.pb'

    model = model_helper.ModelHelper(name="validation_net", init_params=False)

    # Bring in the init net from init_net.pb
    init_net_proto = caffe2_pb2.NetDef()
    with open(args.c2_init, "rb") as f:
        init_net_proto.ParseFromString(f.read())
    model.param_init_net = core.Net(init_net_proto)

    # bring in the predict net from predict_net.pb
    predict_net_proto = caffe2_pb2.NetDef()
    with open(args.c2_predict, "rb") as f:
        predict_net_proto.ParseFromString(f.read())
    model.net = core.Net(predict_net_proto)

    data_config = resolve_data_config(None, args)
    loader = create_loader(
        Dataset(args.data, load_bytes=args.tf_preprocessing),
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=False,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=data_config['crop_pct'],
        tensorflow_preprocessing=args.tf_preprocessing)

    # this is so obvious, wonderful interface </sarcasm>
    input_blob = model.net.external_inputs[0]
    output_blob = model.net.external_outputs[0]

    if True:
        device_opts = None
    else:
        # CUDA is crashing, no idea why, awesome error message, give it a try for kicks
        device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id)
        model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
        model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)

    model.param_init_net.GaussianFill(
        [], input_blob.GetUnscopedName(),
        shape=(1,) + data_config['input_size'], mean=0.0, std=1.0)
    workspace.RunNetOnce(model.param_init_net)
    workspace.CreateNet(model.net, overwrite=True)

    batch_time = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()
    for i, (input, target) in enumerate(loader):
        # run the net and return prediction
        caffe2_in = input.data.numpy()
        workspace.FeedBlob(input_blob, caffe2_in, device_opts)
        workspace.RunNet(model.net, num_iter=1)
        output = workspace.FetchBlob(output_blob)

        # measure accuracy and record loss
        prec1, prec5 = accuracy_np(output.data, target.numpy())
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg,
                ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5))

    print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
        top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))