def input_fn()

in sagemaker/src/hf_train_deploy.py [0:0]


def input_fn(request_body, request_content_type):
    """An input_fn that loads model input data"""
    
    if request_content_type == "application/json":
        data = json.loads(request_body)
        
        if isinstance(data, str):
            data = [data]
        elif isinstance(data, list) and len(data) > 0 and isinstance(data[0], str):
            pass
        else:
            raise ValueError("Unsupported input type. Input type can be a string or an non-empty list. \
                             I got {}".format(data))
                       
        tokenized_inputs = tokenizer.batch_encode_plus(data, max_length=args.max_seq_length, padding='max_length', truncation=True, return_tensors="pt")
        tokenized_inputs.pop("token_type_ids")
        
        return tokenized_inputs
    raise ValueError("Unsupported content type: {}".format(request_content_type))