def __init__()

in src/nv_wavenet_generator.py [0:0]


    def __init__(self, model, sample_count, batch_size, implementation, cond_repeat=800):
        self.model = model
        self.cond_repeat = cond_repeat

        fname = Path(__file__)
        if "CUDA_VISIBLE_DEVICES" in os.environ:
            postfix = os.environ["CUDA_VISIBLE_DEVICES"].replace(',', '_')
        else:
            postfix = ""

        try:
            import nv_wavenet_ext
            self.wavenet_cu = nv_wavenet_ext
        except ImportError as e:
            print("Failed loading nv_wavenet_ext, building dynamically")
            build_name = "wavenet_cu_" + next(tempfile._get_candidate_names()) + postfix
            if not os.path.exists(".build_dir"):
                os.mkdir(".build_dir")

            if not os.path.exists('.build_dir/' + build_name):
                os.mkdir('.build_dir/' + build_name)

            self.wavenet_cu = load(name=build_name,
                                   sources=[fname.parent / "nv-wavenet/wavenet_infer.cu",
                                            fname.parent / "nv-wavenet/wavenet_infer_wrapper.cpp",
                                            fname.parent / "nv-wavenet/matrix.cpp"],
                                   build_directory='.build_dir/' + build_name,
                                   verbose=False,
                                   extra_cuda_cflags=["-arch=sm_70", "-std=c++14", "--use_fast_math",
                                                      "-maxrregcount 128", "--ptxas-options=-v",
                                                      "--expt-relaxed-constexpr", "-D__GNUC__=6"])

        embedding_prev, embedding_curr = model.export_embed_weights()
        conv_init_weight, conv_init_bias, conv_out_weight, \
        conv_out_bias, conv_end_weight, conv_end_bias = model.export_final_weights()
        dilate_weights, dilate_biases, res_weights, \
        res_biases, skip_weights, skip_biases = model.export_layer_weights()
        use_embed_tanh = False
        layers = len(model.layers)
        blocks = model.blocks
        max_dilation = 2 ** (layers // blocks - 1)

        self.R = self.wavenet_cu.num_res_channels()
        self.S = self.wavenet_cu.num_skip_channels()
        self.A = self.wavenet_cu.num_out_channels()

        self.max_dilation = max_dilation
        self.use_embed_tanh = use_embed_tanh
        assert embedding_prev.size() == (self.A, self.R), \
            ("embedding_prev: {} doesn't match compiled"
             " nv-wavenet size: {}").format(embedding_prev.size(),
                                            (self.A, self.R))
        self.embedding_prev = column_major(torch.t(embedding_prev))

        assert embedding_curr.size() == (self.A, self.R), \
            ("embedding_curr: {} doesn't match compiled"
             " nv-wavenet size: {}").format(embedding_curr.size(),
                                            (self.A, self.R))
        self.embedding_curr = column_major(torch.t(embedding_curr))

        assert conv_init_weight.size()[:2] == (self.S, self.R), \
            ("conv_init_weight: {} doesn't match compiled"
             " nv-wavenet size: {}").format(conv_init_weight.size()[:2],
                                            (self.S, self.R))
        self.conv_init = column_major(conv_init_weight)
        self.conv_init_bias = column_major(conv_init_bias)

        assert conv_out_weight.size()[:2] == (self.S, self.S), \
            ("conv_out_weight: {} doesn't match compiled"
             " nv-wavenet size: {}").format(conv_out_weight.size()[:2],
                                            (self.S, self.S))
        self.conv_out = column_major(conv_out_weight)
        self.conv_out_bias = column_major(conv_out_bias)

        assert conv_end_weight.size()[:2] == (self.A, self.S), \
            ("conv_end_weight: {} doesn't match compiled"
             " nv-wavenet size: {}").format(conv_end_weight.size()[:2],
                                            (self.A, self.S))
        self.conv_end = column_major(conv_end_weight)
        self.conv_end_bias = column_major(conv_end_bias)

        self.dilate_weights_prev = []
        self.dilate_weights_curr = []
        for weight in dilate_weights:
            assert weight.size(2) == 2, \
                "nv-wavenet only supports kernel_size 2"
            assert weight.size()[:2] == (2 * self.R, self.R), \
                ("dilated weight: {} doesn't match compiled"
                 " nv-wavenet size: {}").format(weight.size()[:2],
                                                (2 * self.R, self.R))
            Wprev = column_major(weight[:, :, 0])
            Wcurr = column_major(weight[:, :, 1])
            self.dilate_weights_prev.append(Wprev)
            self.dilate_weights_curr.append(Wcurr)

        for bias in dilate_biases:
            assert (bias.size(0) == 2 * self.R)
        for weight in res_weights:
            assert weight.size()[:2] == (self.R, self.R), \
                ("residual weight: {} doesn't match compiled"
                 " nv-wavenet size: {}").format(weight.size()[:2],
                                                (self.R, self.R))
        for bias in res_biases:
            assert (bias.size(0) == self.R), \
                ("residual bias: {} doesn't match compiled"
                 " nv-wavenet size: {}").format(bias.size(0), self.R)
        for weight in skip_weights:
            assert weight.size()[:2] == (self.S, self.R), \
                ("skip weight: {} doesn't match compiled"
                 " nv-wavenet size: {}").format(weight.size()[:2],
                                                (self.S, self.R))
        for bias in skip_biases:
            assert (bias.size(0) == self.S), \
                ("skip bias: {} doesn't match compiled"
                 " nv-wavenet size: {}").format(bias.size(0), self.S)

        self.dilate_biases = [column_major(bias) for bias in dilate_biases]
        self.res_weights = [column_major(weight) for weight in res_weights]
        self.res_biases = [column_major(bias) for bias in res_biases]
        self.skip_weights = [column_major(weight) for weight in skip_weights]
        self.skip_biases = [column_major(bias) for bias in skip_biases]

        # There's an extra residual layer that's not used
        # self.res_weights.append(torch.zeros(self.R,self.R))
        # self.res_biases.append(torch.zeros(self.R))

        assert (len(self.res_biases) == len(self.skip_biases) and
                len(self.res_biases) == len(self.dilate_biases) and
                len(self.res_weights) == len(self.skip_weights) and
                len(self.res_weights) == len(dilate_weights)), \