def transform_fn()

in notebooks/classify_mxnet.py [0:0]


def transform_fn(net, data, input_content_type, output_content_type):
    """
    Transform a request using the Gluon model. Called once per request.

    :param net: The Gluon model.
    :param data: The request payload.
    :param input_content_type: The request content type.
    :param output_content_type: The (desired) response content type.
    :return: response payload and content type.
    """
    normalize = gluon.data.vision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    test_augs = gluon.data.vision.transforms.Compose([
        gluon.data.vision.transforms.Resize(256),
        gluon.data.vision.transforms.CenterCrop(224),
        gluon.data.vision.transforms.ToTensor(),
        normalize])
    
    if input_content_type == JPEG_CONTENT_TYPE: 
        nda = mx.img.imdecode(data)
    if input_content_type == PNG_CONTENT_TYPE: 
        nda = mx.img.imdecode(data)
    if input_content_type == JSON_CONTENT_TYPE:
        parsed = json.loads(data)
        nda = mx.nd.array(parsed)
       
    img = test_augs(nda)
    img = img.expand_dims(axis=0)                
    
    output = net(img)
    prediction = mx.nd.argmax(output, axis=1)
    response_body = json.dumps(prediction.asnumpy().tolist()[0])
    return response_body, output_content_type