def auto_convert_mixed_precision()

in onnxconverter_common/auto_mixed_precision.py [0:0]


def auto_convert_mixed_precision(model, feed_dict, validate_fn=None, rtol=None, atol=None, keep_io_types=False):
    """
    Automatically converts a model to mixed precision, excluding the minimum number of nodes required to
    ensure valudate_fn returns True and/or results are equal according to rtol/atol
    """
    if rtol is None and atol is not None:
        rtol = 1e-5

    if atol is None and rtol is not None:
        atol = 1e-8

    if rtol is None and validate_fn is None:
        raise ValueError("Argument `validate_fn` and `rtol` cannot both be `None`.")

    def validate(res1, res2):
        if validate_fn is not None and not validate_fn(res1, res2):
            return False
        if rtol is not None:
            for r1, r2 in zip(res1, res2):
                if not np.allclose(r1, r2, rtol, atol):
                    return False
        return True

    model0 = onnx.shape_inference.infer_shapes(model)
    model0 = add_missing_dtypes_using_ort(model0, feed_dict)
    res0 = get_tensor_values_using_ort(model0, feed_dict)
    if not keep_io_types:
        feed_dict = {k: v.astype(np.float16) if v.dtype == np.float32 else v for k, v in feed_dict.items()}
    if not validate(res0, res0):
        raise ValueError("validation failed for original fp32 model")
    node_names = [n.name for n in model0.graph.node if n.op_type not in ["Loop", "If", "Scan"]]

    def run_attempt(node_block_list, return_model=False):
        print(node_block_list)
        model = float16.convert_float_to_float16(copy.deepcopy(model0), node_block_list=node_block_list,
                                                 keep_io_types=keep_io_types, disable_shape_infer=True)
        res1 = get_tensor_values_using_ort(model, feed_dict)
        if return_model:
            return validate(res0, res1), model
        else:
            valid = validate(res0, res1)
            print(valid)
            return valid

    if not run_attempt(node_names):
        raise ValueError("validation failed for model with all nodes in node_block_list")
    print("Sanity checks passed. Starting autoconvert.")
    segments = SegmentList(node_names)
    i = 0
    while segments.get_largest() is not None:
        seg = segments.get_largest()
        nodes_to_try = segments.get_nodes(seg)
        i += 1
        print("Running attempt %d excluding conversion of %s nodes" % (i, len(nodes_to_try)))
        if run_attempt(nodes_to_try):
            seg.good = True
            print("Attempt succeeded.")
        else:
            print("Attempt failed.")
            if seg.size == 1:
                seg.bad = True
            else:
                seg.split()
        print(segments)
    print("Done:", segments.get_nodes())
    valid, model = run_attempt(segments.get_nodes(), return_model=True)
    if not valid:
        raise ValueError("validation failed for final fp16 model")
    print("Final model validated successfully.")
    return model