def register()

in python/mxnet/operator.py [0:0]


def register(reg_name):
    """Register a subclass of CustomOpProp to the registry with name reg_name."""
    def do_register(prop_cls):
        """Register a subclass of CustomOpProp to the registry."""

        class MXCallbackList(Structure):
            """Structure that holds Callback information. Passed to CustomOpProp."""
            _fields_ = [
                ('num_callbacks', c_int),
                ('callbacks', POINTER(CFUNCTYPE(c_int))),
                ('contexts', POINTER(c_void_p))
                ]

        fb_functype = CFUNCTYPE(c_int, c_int, POINTER(c_void_p), POINTER(c_int),
                                POINTER(c_int), c_int, c_void_p)
        del_functype = CFUNCTYPE(c_int, c_void_p)

        infershape_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int),
                                        POINTER(POINTER(mx_uint)), c_void_p)
        infertype_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
        list_functype = CFUNCTYPE(c_int, POINTER(POINTER(POINTER(c_char))), c_void_p)
        deps_functype = CFUNCTYPE(c_int, c_int_p, c_int_p, c_int_p,
                                  c_int_p, POINTER(c_int_p), c_void_p)
        createop_functype = CFUNCTYPE(c_int, c_char_p, c_int, POINTER(POINTER(mx_uint)),
                                      POINTER(c_int), POINTER(c_int),
                                      POINTER(MXCallbackList), c_void_p)
        req_enum = ('null', 'write', 'inplace', 'add')

        def creator(op_type, argc, keys, vals, ret):
            """internal function"""
            assert py_str(op_type) == reg_name
            kwargs = dict([(py_str(keys[i]), py_str(vals[i])) for i in range(argc)])
            op_prop = prop_cls(**kwargs)

            def infer_shape_entry(num_tensor, tensor_dims,
                                  tensor_shapes, _):
                """C Callback for ``CustomOpProp::InferShape``."""
                try:
                    n_in = len(op_prop.list_arguments())
                    n_out = len(op_prop.list_outputs())
                    n_aux = len(op_prop.list_auxiliary_states())
                    assert num_tensor == n_in + n_out + n_aux

                    shapes = [[tensor_shapes[i][j] for j in range(tensor_dims[i])]
                              for i in range(n_in)]
                    ret = op_prop.infer_shape(shapes)
                    if len(ret) == 2:
                        ishape, oshape = ret
                        ashape = []
                    elif len(ret) == 3:
                        ishape, oshape, ashape = ret
                    else:
                        raise AssertionError("infer_shape must return 2 or 3 lists")
                    assert len(oshape) == n_out
                    assert len(ishape) == n_in
                    assert len(ashape) == n_aux
                    rshape = list(ishape) + list(oshape) + list(ashape)
                    for i in range(n_in+n_out+n_aux):
                        tensor_shapes[i] = cast(c_array(mx_uint, rshape[i]), POINTER(mx_uint))
                        tensor_dims[i] = len(rshape[i])

                    infer_shape_entry._ref_holder = [tensor_shapes]
                except Exception:
                    print('Error in %s.infer_shape: %s' % (reg_name, traceback.format_exc()))
                    return False
                return True

            def infer_type_entry(num_tensor, tensor_types, _):
                """C Callback for CustomOpProp::InferType"""
                try:
                    n_in = len(op_prop.list_arguments())
                    n_out = len(op_prop.list_outputs())
                    n_aux = len(op_prop.list_auxiliary_states())
                    assert num_tensor == n_in + n_out + n_aux

                    types = [_DTYPE_MX_TO_NP[tensor_types[i]] for i in range(n_in)]
                    ret = op_prop.infer_type(types)
                    if len(ret) == 2:
                        itype, otype = ret
                        atype = []
                    elif len(ret) == 3:
                        itype, otype, atype = ret
                    else:
                        raise AssertionError("infer_type must return 2 or 3 lists")
                    assert len(otype) == n_out
                    assert len(itype) == n_in
                    assert len(atype) == n_aux
                    rtype = list(itype) + list(otype) + list(atype)
                    for i, dtype in enumerate(rtype):
                        tensor_types[i] = _DTYPE_NP_TO_MX[dtype]

                    infer_type_entry._ref_holder = [tensor_types]
                except Exception:
                    print('Error in %s.infer_type: %s' % (reg_name, traceback.format_exc()))
                    return False
                return True

            def list_outputs_entry(out, _):
                """C Callback for CustomOpProp::ListOutputs"""
                try:
                    ret = op_prop.list_outputs()
                    ret = [c_str(i) for i in ret] + [c_char_p(0)]
                    ret = c_array(c_char_p, ret)
                    out[0] = cast(ret, POINTER(POINTER(c_char)))

                    list_outputs_entry._ref_holder = [out]
                except Exception:
                    print('Error in %s.list_outputs: %s' % (reg_name, traceback.format_exc()))
                    return False
                return True

            def list_arguments_entry(out, _):
                """C Callback for CustomOpProp::ListArguments"""
                try:
                    ret = op_prop.list_arguments()
                    ret = [c_str(i) for i in ret] + [c_char_p(0)]
                    ret = c_array(c_char_p, ret)
                    out[0] = cast(ret, POINTER(POINTER(c_char)))

                    list_arguments_entry._ref_holder = [out]
                except Exception:
                    print('Error in %s.list_arguments: %s' % (reg_name, traceback.format_exc()))
                    return False
                return True

            def list_auxiliary_states_entry(out, _):
                """C Callback for CustomOpProp::ListAuxiliaryStates"""
                try:
                    ret = op_prop.list_auxiliary_states()
                    ret = [c_str(i) for i in ret] + [c_char_p(0)]
                    ret = c_array(c_char_p, ret)
                    out[0] = cast(ret, POINTER(POINTER(c_char)))

                    list_auxiliary_states_entry._ref_holder = [out]
                except Exception:
                    tb = traceback.format_exc()
                    print('Error in %s.list_auxiliary_states: %s' % (reg_name, tb))
                    return False
                return True

            def declare_backward_dependency_entry(out_grad, in_data, out_data, num_dep, deps, _):
                """C Callback for CustomOpProp::DeclareBacwardDependency"""
                try:
                    out_grad = [out_grad[i] for i in range(len(op_prop.list_outputs()))]
                    in_data = [in_data[i] for i in range(len(op_prop.list_arguments()))]
                    out_data = [out_data[i] for i in range(len(op_prop.list_outputs()))]
                    rdeps = op_prop.declare_backward_dependency(out_grad, in_data, out_data)
                    num_dep[0] = len(rdeps)
                    rdeps = cast(c_array(c_int, rdeps), c_int_p)
                    deps[0] = rdeps

                    declare_backward_dependency_entry._ref_holder = [deps]
                except Exception:
                    tb = traceback.format_exc()
                    print('Error in %s.declare_backward_dependency: %s' % (reg_name, tb))
                    return False
                return True

            def create_operator_entry(ctx, num_inputs, shapes, ndims, dtypes, ret, _):
                """C Callback for CustomOpProp::CreateOperator"""
                try:
                    ndims = [ndims[i] for i in range(num_inputs)]
                    shapes = [[shapes[i][j] for j in range(ndims[i])] for i in range(num_inputs)]
                    dtypes = [dtypes[i] for i in range(num_inputs)]
                    op = op_prop.create_operator(ctx, shapes, dtypes)

                    def forward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
                        """C Callback for CustomOp::Forward"""
                        try:
                            tensors = [[] for i in range(5)]
                            for i in range(num_ndarray):
                                if tags[i] == 1 or tags[i] == 4:
                                    tensors[tags[i]].append(NDArray(cast(ndarraies[i],
                                                                         NDArrayHandle),
                                                                    writable=True))
                                else:
                                    tensors[tags[i]].append(NDArray(cast(ndarraies[i],
                                                                         NDArrayHandle),
                                                                    writable=False))
                            reqs = [req_enum[reqs[i]] for i in range(len(tensors[1]))]
                            op.forward(is_train=is_train, req=reqs,
                                       in_data=tensors[0], out_data=tensors[1],
                                       aux=tensors[4])
                        except Exception:
                            print('Error in CustomOp.forward: %s' % traceback.format_exc())
                            return False
                        return True

                    def backward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
                        """C Callback for CustomOp::Backward"""
                        # pylint: disable=W0613
                        try:
                            tensors = [[] for i in range(5)]
                            for i in range(num_ndarray):
                                if tags[i] == 2 or tags[i] == 4:
                                    tensors[tags[i]].append(NDArray(cast(ndarraies[i],
                                                                         NDArrayHandle),
                                                                    writable=True))
                                else:
                                    tensors[tags[i]].append(NDArray(cast(ndarraies[i],
                                                                         NDArrayHandle),
                                                                    writable=False))
                            reqs = [req_enum[reqs[i]] for i in range(len(tensors[2]))]
                            op.backward(req=reqs,
                                        in_data=tensors[0], out_data=tensors[1],
                                        in_grad=tensors[2], out_grad=tensors[3],
                                        aux=tensors[4])
                        except Exception:
                            print('Error in CustomOp.backward: %s' % traceback.format_exc())
                            return False
                        return True

                    cur = _registry.inc()

                    def delete_entry(_):
                        """C Callback for CustomOp::del"""
                        try:
                            del _registry.ref_holder[cur]
                        except Exception:
                            print('Error in CustomOp.delete: %s' % traceback.format_exc())
                            return False
                        return True

                    callbacks = [del_functype(delete_entry),
                                 fb_functype(forward_entry),
                                 fb_functype(backward_entry)]
                    callbacks = [cast(i, CFUNCTYPE(c_int)) for i in callbacks]
                    contexts = [None, None, None]
                    ret[0] = MXCallbackList(c_int(len(callbacks)),
                                            cast(c_array(CFUNCTYPE(c_int), callbacks),
                                                 POINTER(CFUNCTYPE(c_int))),
                                            cast(c_array(c_void_p, contexts),
                                                 POINTER(c_void_p)))
                    op._ref_holder = [ret]
                    _registry.ref_holder[cur] = op
                except Exception:
                    print('Error in %s.create_operator: %s' % (reg_name, traceback.format_exc()))
                    return False
                return True

            cur = _registry.inc()

            def delete_entry(_):
                """C Callback for CustomOpProp::del"""
                try:
                    del _registry.ref_holder[cur]
                except Exception:
                    print('Error in CustomOpProp.delete: %s' % traceback.format_exc())
                    return False
                return True

            callbacks = [del_functype(delete_entry),
                         list_functype(list_arguments_entry),
                         list_functype(list_outputs_entry),
                         list_functype(list_auxiliary_states_entry),
                         infershape_functype(infer_shape_entry),
                         deps_functype(declare_backward_dependency_entry),
                         createop_functype(create_operator_entry),
                         infertype_functype(infer_type_entry)]
            callbacks = [cast(i, CFUNCTYPE(c_int)) for i in callbacks]
            contexts = [None]*len(callbacks)
            ret[0] = MXCallbackList(c_int(len(callbacks)),
                                    cast(c_array(CFUNCTYPE(c_int), callbacks),
                                         POINTER(CFUNCTYPE(c_int))),
                                    cast(c_array(c_void_p, contexts),
                                         POINTER(c_void_p)))
            op_prop._ref_holder = [ret]
            _registry.ref_holder[cur] = op_prop
            return True

        creator_functype = CFUNCTYPE(c_int, c_char_p, c_int, POINTER(c_char_p),
                                     POINTER(c_char_p), POINTER(MXCallbackList))
        creator_func = creator_functype(creator)
        check_call(_LIB.MXCustomOpRegister(c_str(reg_name), creator_func))
        cur = _registry.inc()
        _registry.ref_holder[cur] = creator_func
        return prop_cls
    return do_register