def register()

in python/mxnet/operator.py [0:0]


def register(reg_name):
    """Register a subclass of CustomOpProp to the registry with name reg_name."""
    def do_register(prop_cls):
        """Register a subclass of CustomOpProp to the registry."""
        fb_functype = CFUNCTYPE(c_int, c_int, POINTER(c_void_p), POINTER(c_int),
                                POINTER(c_int), c_int, c_void_p)
        del_functype = CFUNCTYPE(c_int, c_void_p)

        infershape_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int),
                                        POINTER(POINTER(mx_uint)), c_void_p)
        infertype_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
        inferstorage_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
        inferstorage_backward_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), \
                                                   POINTER(c_int), c_void_p)
        list_functype = CFUNCTYPE(c_int, POINTER(POINTER(POINTER(c_char))), c_void_p)
        deps_functype = CFUNCTYPE(c_int, c_int_p, c_int_p, c_int_p,
                                  c_int_p, POINTER(c_int_p), c_void_p)
        createop_functype = CFUNCTYPE(c_int, c_char_p, c_int, POINTER(POINTER(mx_uint)),
                                      POINTER(c_int), POINTER(c_int),
                                      POINTER(MXCallbackList), c_void_p)
        req_enum = ('null', 'write', 'inplace', 'add')

        def creator(op_type, argc, keys, vals, ret):
            """internal function"""
            assert py_str(op_type) == reg_name
            kwargs = dict([(py_str(keys[i]), py_str(vals[i])) for i in range(argc)])
            op_prop = prop_cls(**kwargs)

            def infer_shape_entry(num_tensor, tensor_dims,
                                  tensor_shapes, _):
                """C Callback for ``CustomOpProp::InferShape``."""
                try:
                    n_in = len(op_prop.list_arguments())
                    n_out = len(op_prop.list_outputs())
                    n_aux = len(op_prop.list_auxiliary_states())
                    assert num_tensor == n_in + n_out + n_aux

                    shapes = [[tensor_shapes[i][j] for j in range(tensor_dims[i])]
                              for i in range(n_in)]
                    ret = op_prop.infer_shape(shapes)
                    if len(ret) == 2:
                        ishape, oshape = ret
                        ashape = []
                    elif len(ret) == 3:
                        ishape, oshape, ashape = ret
                    else:
                        raise AssertionError("infer_shape must return 2 or 3 lists")
                    assert len(oshape) == n_out, \
                        "InferShape Error: expecting %d entries in returned output " \
                        "shapes, got %d."%(n_out, len(oshape))
                    assert len(ishape) == n_in, \
                        "InferShape Error: expecting %d entries in returned input " \
                        "shapes, got %d."%(n_in, len(ishape))
                    assert len(ashape) == n_aux, \
                        "InferShape Error: expecting %d entries in returned aux state " \
                        "shapes, got %d."%(n_aux, len(ashape))
                    rshape = list(ishape) + list(oshape) + list(ashape)
                    for i in range(n_in+n_out+n_aux):
                        tensor_shapes[i] = cast(c_array_buf(mx_uint,
                                                            array('I', rshape[i])),
                                                POINTER(mx_uint))
                        tensor_dims[i] = len(rshape[i])

                    infer_shape_entry._ref_holder = [tensor_shapes]
                except Exception:
                    print('Error in %s.infer_shape: %s' % (reg_name, traceback.format_exc()))
                    return False
                return True


            def infer_storage_type_backward_entry(num_tensor, tensor_stypes, tags, _):
                # pylint: disable=C0301
                """C Callback for CustomOpProp::InferStorageTypeBackward"""
                try:
                    tensors = [[] for i in range(5)]
                    for i in range(num_tensor):
                        tensors[tags[i]].append(_STORAGE_TYPE_ID_TO_STR[tensor_stypes[i]])
                    # Ordering of stypes: ograd, input, output, igrad, aux
                    tensors = [tensors[3], tensors[0], tensors[1], tensors[2], tensors[4]]
                    ret = op_prop.infer_storage_type_backward(tensors[0],
                                                              tensors[1],
                                                              tensors[2],
                                                              tensors[3],
                                                              tensors[4])
                    if len(ret) == 4:
                        ret += []
                    elif len(ret) == 5:
                        pass
                    else:
                        raise AssertionError("infer_storage_type_backward must return 4 or 5 lists")
                    assert len(ret[0]) == len(tensors[0]), \
                        "InferStorageTypeBackward Error: expecting == %d " \
                        "entries in returned output gradient " \
                        "stypes, got %d."%(len(tensors[0]), len(ret[0]))
                    assert len(ret[1]) == len(tensors[1]), \
                        "InferStorageTypeBackward Error: expecting == %d " \
                        "entries in returned input stypes, " \
                        "got %d."%(len(tensors[1]), len(ret[1]))
                    assert len(ret[2]) == len(tensors[2]), \
                        "InferStorageTypeBackward Error: expecting == %d " \
                        "entries in returned output stypes, " \
                        "got %d."%(len(tensors[2]), len(ret[2]))
                    assert len(ret[3]) == len(tensors[3]), \
                        "InferStorageTypeBackward Error: expecting == %d " \
                        "entries in returned input gradient stypes, " \
                        "got %d."%(len(tensors[3]), len(ret[3]))
                    assert len(ret[4]) == len(tensors[4]), \
                        "InferStorageTypeBackward Error: expecting == %d " \
                        "entries in returned aux stypes, " \
                        "got %d."%(len(tensors[4]), len(ret[4]))
                    rstype = []
                    for i, ret_list in enumerate(ret):
                        rstype.extend(ret_list)

                    for i, stype in enumerate(rstype):
                        assert stype != _STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_UNDEFINED], \
                            "stype should not be undefined"
                        assert stype in _STORAGE_TYPE_STR_TO_ID, \
                            "Provided stype: %s is not valid " \
                            "valid stypes are %s, %s, %s"%(stype,
                                                           _STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_DEFAULT],
                                                           _STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_ROW_SPARSE],
                                                           _STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_CSR])
                        tensor_stypes[i] = _STORAGE_TYPE_STR_TO_ID[stype]

                    infer_storage_type_backward_entry._ref_holder = [tensor_stypes]
                except Exception:
                    print('Error in %s.infer_type: %s' % (reg_name, traceback.format_exc()))
                    return False
                return True

            def infer_storage_type_entry(num_tensor, tensor_stypes, _):
                """C Callback for CustomOpProp::InferStorageType"""
                try:
                    n_in = len(op_prop.list_arguments())
                    n_out = len(op_prop.list_outputs())
                    n_aux = len(op_prop.list_auxiliary_states())
                    assert num_tensor == n_in + n_out + n_aux

                    stypes = [_STORAGE_TYPE_ID_TO_STR[tensor_stypes[i]] for i in range(n_in)]
                    ret = op_prop.infer_storage_type(stypes)
                    if len(ret) == 2:
                        istype, ostype = ret
                        astype = []
                    elif len(ret) == 3:
                        istype, ostype, astype = ret
                    else:
                        raise AssertionError("infer_storage_type must return 2 or 3 lists")

                    assert len(ostype) == n_out, \
                        "InferStorageType Error: expecting %d entries in returned output " \
                        "stypes, got %d."%(n_out, len(ostype))
                    assert len(istype) == n_in, \
                        "InferStorageType Error: expecting %d entries in returned input " \
                        "stypes, got %d."%(n_in, len(istype))
                    assert len(astype) == n_aux, \
                        "InferStorageType Error: expecting %d entries in returned aux state " \
                        "stypes, got %d."%(n_aux, len(astype))
                    rtype = list(istype) + list(ostype) + list(astype)
                    for i, dtype in enumerate(rtype):
                        tensor_stypes[i] = _STORAGE_TYPE_STR_TO_ID[dtype]
                    infer_storage_type_entry._ref_holder = [tensor_stypes]
                except Exception:
                    print('Error in %s.infer_type: %s' % (reg_name, traceback.format_exc()))
                    return False
                return True

            def infer_type_entry(num_tensor, tensor_types, _):
                """C Callback for CustomOpProp::InferType"""
                try:
                    n_in = len(op_prop.list_arguments())
                    n_out = len(op_prop.list_outputs())
                    n_aux = len(op_prop.list_auxiliary_states())
                    assert num_tensor == n_in + n_out + n_aux

                    types = [_DTYPE_MX_TO_NP[tensor_types[i]] for i in range(n_in)]
                    ret = op_prop.infer_type(types)
                    if len(ret) == 2:
                        itype, otype = ret
                        atype = []
                    elif len(ret) == 3:
                        itype, otype, atype = ret
                    else:
                        raise AssertionError("infer_type must return 2 or 3 lists")
                    assert len(otype) == n_out, \
                        "InferType Error: expecting %d entries in returned output " \
                        "types, got %d."%(n_out, len(otype))
                    assert len(itype) == n_in, \
                        "InferType Error: expecting %d entries in returned input " \
                        "types, got %d."%(n_in, len(itype))
                    assert len(atype) == n_aux, \
                        "InferType Error: expecting %d entries in returned aux state " \
                        "types, got %d."%(n_aux, len(atype))
                    rtype = list(itype) + list(otype) + list(atype)
                    for i, dtype in enumerate(rtype):
                        tensor_types[i] = _DTYPE_NP_TO_MX[dtype]

                    infer_type_entry._ref_holder = [tensor_types]
                except Exception:
                    print('Error in %s.infer_type: %s' % (reg_name, traceback.format_exc()))
                    return False
                return True

            def list_outputs_entry(out, _):
                """C Callback for CustomOpProp::ListOutputs"""
                try:
                    ret = op_prop.list_outputs()
                    ret = [c_str(i) for i in ret] + [c_char_p(0)]
                    ret = c_array(c_char_p, ret)
                    out[0] = cast(ret, POINTER(POINTER(c_char)))

                    list_outputs_entry._ref_holder = [out]
                except Exception:
                    print('Error in %s.list_outputs: %s' % (reg_name, traceback.format_exc()))
                    return False
                return True

            def list_arguments_entry(out, _):
                """C Callback for CustomOpProp::ListArguments"""
                try:
                    ret = op_prop.list_arguments()
                    ret = [c_str(i) for i in ret] + [c_char_p(0)]
                    ret = c_array(c_char_p, ret)
                    out[0] = cast(ret, POINTER(POINTER(c_char)))

                    list_arguments_entry._ref_holder = [out]
                except Exception:
                    print('Error in %s.list_arguments: %s' % (reg_name, traceback.format_exc()))
                    return False
                return True

            def list_auxiliary_states_entry(out, _):
                """C Callback for CustomOpProp::ListAuxiliaryStates"""
                try:
                    ret = op_prop.list_auxiliary_states()
                    ret = [c_str(i) for i in ret] + [c_char_p(0)]
                    ret = c_array(c_char_p, ret)
                    out[0] = cast(ret, POINTER(POINTER(c_char)))

                    list_auxiliary_states_entry._ref_holder = [out]
                except Exception:
                    tb = traceback.format_exc()
                    print('Error in %s.list_auxiliary_states: %s' % (reg_name, tb))
                    return False
                return True

            def declare_backward_dependency_entry(out_grad, in_data, out_data, num_dep, deps, _):
                """C Callback for CustomOpProp::DeclareBacwardDependency"""
                try:
                    out_grad = [out_grad[i] for i in range(len(op_prop.list_outputs()))]
                    in_data = [in_data[i] for i in range(len(op_prop.list_arguments()))]
                    out_data = [out_data[i] for i in range(len(op_prop.list_outputs()))]
                    rdeps = op_prop.declare_backward_dependency(out_grad, in_data, out_data)
                    num_dep[0] = len(rdeps)
                    _registry.result_deps = set()
                    for dep in rdeps:
                        _registry.result_deps.add(dep)
                    rdeps = cast(c_array_buf(c_int, array('i', rdeps)), c_int_p)
                    deps[0] = rdeps

                    declare_backward_dependency_entry._ref_holder = [deps]
                except Exception:
                    tb = traceback.format_exc()
                    print('Error in %s.declare_backward_dependency: %s' % (reg_name, tb))
                    return False
                return True

            def create_operator_entry(ctx, num_inputs, shapes, ndims, dtypes, ret, _):
                """C Callback for CustomOpProp::CreateOperator"""
                try:
                    ctx = py_str(ctx)
                    sep = ctx.find('(')
                    ctx = context.Context(ctx[:sep], int(ctx[sep+1:-1]))
                    ndims = [ndims[i] for i in range(num_inputs)]
                    shapes = [[shapes[i][j] for j in range(ndims[i])] for i in range(num_inputs)]
                    dtypes = [dtypes[i] for i in range(num_inputs)]
                    op = op_prop.create_operator(ctx, shapes, dtypes)

                    def forward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
                        """C Callback for CustomOp::Forward"""
                        try:
                            tensors = [[] for i in range(5)]
                            for i in range(num_ndarray):
                                if tags[i] == 1 or tags[i] == 4:
                                    tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
                                                                              NDArrayHandle),
                                                                         writable=True))
                                else:
                                    tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
                                                                              NDArrayHandle),
                                                                         writable=False))
                            reqs = [req_enum[reqs[i]] for i in range(len(tensors[1]))]
                            with ctx:
                                op.forward(is_train=is_train, req=reqs,
                                           in_data=tensors[0], out_data=tensors[1],
                                           aux=tensors[4])
                        except Exception:
                            print('Error in CustomOp.forward: %s' % traceback.format_exc())
                            return False
                        return True

                    def backward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
                        """C Callback for CustomOp::Backward"""
                        # pylint: disable=W0613
                        try:
                            tensors = [[] for i in range(5)]
                            num_outputs = len(op_prop.list_outputs())
                            num_args = len(op_prop.list_arguments())
                            for i in range(num_ndarray):
                                if i in _registry.result_deps or i >= (num_outputs * 2 + num_args):
                                    # If it is a backward dependency or output or aux:
                                    # Set stype as undefined so that it returns
                                    # ndarray based on existing stype
                                    stype = _STORAGE_TYPE_UNDEFINED
                                else:
                                    # If it is some input, output or out grad ndarray not part of
                                    # backward dependency it is empty and thus the ndarray should
                                    # be set to default
                                    stype = _STORAGE_TYPE_DEFAULT
                                if tags[i] == 2 or tags[i] == 4:
                                    tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
                                                                              NDArrayHandle),
                                                                         writable=True,
                                                                         stype=stype))
                                else:
                                    tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
                                                                              NDArrayHandle),
                                                                         writable=False,
                                                                         stype=stype))
                            reqs = [req_enum[reqs[i]] for i in range(len(tensors[2]))]
                            with ctx:
                                op.backward(req=reqs,
                                            in_data=tensors[0], out_data=tensors[1],
                                            in_grad=tensors[2], out_grad=tensors[3],
                                            aux=tensors[4])
                        except Exception:
                            print('Error in CustomOp.backward: %s' % traceback.format_exc())
                            return False
                        return True

                    cur = _registry.inc()

                    def delete_entry(_):
                        """C Callback for CustomOp::del"""
                        try:
                            del _registry.ref_holder[cur]
                        except Exception:
                            print('Error in CustomOp.delete: %s' % traceback.format_exc())
                            return False
                        return True

                    callbacks = [del_functype(delete_entry),
                                 fb_functype(forward_entry),
                                 fb_functype(backward_entry)]
                    callbacks = [cast(i, CFUNCTYPE(c_int)) for i in callbacks]
                    contexts = [None, None, None]
                    ret[0] = MXCallbackList(c_int(len(callbacks)),
                                            cast(c_array(CFUNCTYPE(c_int), callbacks),
                                                 POINTER(CFUNCTYPE(c_int))),
                                            cast(c_array(c_void_p, contexts),
                                                 POINTER(c_void_p)))
                    op._ref_holder = [ret]
                    _registry.ref_holder[cur] = op
                except Exception:
                    print('Error in %s.create_operator: %s' % (reg_name, traceback.format_exc()))
                    return False
                return True

            cur = _registry.inc()

            def delete_entry(_):
                """C Callback for CustomOpProp::del"""
                try:
                    del _registry.ref_holder[cur]
                except Exception:
                    print('Error in CustomOpProp.delete: %s' % traceback.format_exc())
                    return False
                return True

            callbacks = [del_functype(delete_entry),
                         list_functype(list_arguments_entry),
                         list_functype(list_outputs_entry),
                         list_functype(list_auxiliary_states_entry),
                         infershape_functype(infer_shape_entry),
                         deps_functype(declare_backward_dependency_entry),
                         createop_functype(create_operator_entry),
                         infertype_functype(infer_type_entry),
                         inferstorage_functype(infer_storage_type_entry),
                         inferstorage_backward_functype(infer_storage_type_backward_entry)]
            callbacks = [cast(i, CFUNCTYPE(c_int)) for i in callbacks]
            contexts = [None]*len(callbacks)
            ret[0] = MXCallbackList(c_int(len(callbacks)),
                                    cast(c_array(CFUNCTYPE(c_int), callbacks),
                                         POINTER(CFUNCTYPE(c_int))),
                                    cast(c_array(c_void_p, contexts),
                                         POINTER(c_void_p)))
            op_prop._ref_holder = [ret]
            _registry.ref_holder[cur] = op_prop
            return True

        creator_functype = CFUNCTYPE(c_int, c_char_p, c_int, POINTER(c_char_p),
                                     POINTER(c_char_p), POINTER(MXCallbackList))
        creator_func = creator_functype(creator)
        check_call(_LIB.MXCustomOpRegister(c_str(reg_name), creator_func))
        cur = _registry.inc()
        _registry.ref_holder[cur] = creator_func
        return prop_cls
    return do_register