in src/toolkits/style_transfer/style_transfer_model_definition.cpp [38:970]
void define_resnet(model_spec& nn_spec, size_t num_styles, bool initialize=false, int random_seed=0) {
std::mt19937 random_engine;
std::seed_seq seed_seq{random_seed};
random_engine = std::mt19937(seed_seq);
weight_initializer initializer;
// This is to make sure that when the uniform initialization is not needed extra work is avoided
if (initialize) {
initializer = uniform_weight_initializer(LOWER_BOUND, UPPER_BOUND, &random_engine);
} else {
initializer = zero_weight_initializer();
}
nn_spec.add_padding(
/* name */ "transformer_pad0",
/* input */ "image",
/* padding_top */ 4,
/* padding_bottom */ 4,
/* padding_left */ 4,
/* padding_right */ 4);
nn_spec.add_convolution(
/* name */ "transformer_encode_1_conv",
/* input */ "transformer_pad0",
/* num_output_channels */ 32,
/* num_kernel_channels */ 3,
/* kernel_height */ 9,
/* kernel_width */ 9,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_encode_1_inst_gamma",
/* input */ "index",
/* num_output_channels */ 32,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_encode_1_inst_beta",
/* input */ "index",
/* num_output_channels */ 32,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_instancenorm0__fwd_bn_",
/* input */ "transformer_encode_1_conv",
/* num_channels */ 32,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_instancenorm0__fwd_mult_gamma",
/* inputs */ {"transformer_instancenorm0__fwd_bn_",
"transformer_encode_1_inst_gamma"});
nn_spec.add_addition(
/* name */ "transformer_instancenorm0__fwd",
/* inputs */ {"transformer_instancenorm0__fwd_mult_gamma",
"transformer_encode_1_inst_beta"});
nn_spec.add_relu(
/* name */ "transformer_activation0",
/* input */ "transformer_instancenorm0__fwd");
nn_spec.add_padding(
/* name */ "transformer_pad1",
/* input */ "transformer_activation0",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_encode_2_conv",
/* input */ "transformer_pad1",
/* num_output_channels */ 64,
/* num_kernel_channels */ 32,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 2,
/* stride_width */ 2,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_encode_2_inst_gamma",
/* input */ "index",
/* num_output_channels */ 64,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_encode_2_inst_beta",
/* input */ "index",
/* num_output_channels */ 64,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_instancenorm1__fwd_bn_",
/* input */ "transformer_encode_2_conv",
/* num_channels */ 64,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_instancenorm1__fwd_mult_gamma",
/* inputs */ {"transformer_instancenorm1__fwd_bn_",
"transformer_encode_2_inst_gamma"});
nn_spec.add_addition(
/* name */ "transformer_instancenorm1__fwd",
/* inputs */ {"transformer_instancenorm1__fwd_mult_gamma",
"transformer_encode_2_inst_beta"});
nn_spec.add_relu(
/* name */ "transformer_activation1",
/* input */ "transformer_instancenorm1__fwd");
nn_spec.add_padding(
/* name */ "transformer_pad2",
/* input */ "transformer_activation1",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_encode_3_conv",
/* input */ "transformer_pad2",
/* num_output_channels */ 128,
/* num_kernel_channels */ 64,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 2,
/* stride_width */ 2,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_encode_3_inst_gamma",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_encode_3_inst_beta",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_instancenorm2__fwd_bn_",
/* input */ "transformer_encode_3_conv",
/* num_channels */ 128,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_instancenorm2__fwd_mult_gamma",
/* inputs */ {"transformer_instancenorm2__fwd_bn_",
"transformer_encode_3_inst_gamma"});
nn_spec.add_addition(
/* name */ "transformer_instancenorm2__fwd",
/* inputs */ {"transformer_instancenorm2__fwd_mult_gamma",
"transformer_encode_3_inst_beta"});
nn_spec.add_relu(
/* name */ "transformer_activation2",
/* input */ "transformer_instancenorm2__fwd");
nn_spec.add_padding(
/* name */ "transformer_residualblock0_pad0",
/* input */ "transformer_activation2",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_residual_1_conv_1",
/* input */ "transformer_residualblock0_pad0",
/* num_output_channels */ 128,
/* num_kernel_channels */ 128,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_residual_1_inst_1_gamma",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_residual_1_inst_1_beta",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_residualblock0_instancenorm0__fwd_bn_",
/* input */ "transformer_residual_1_conv_1",
/* num_channels */ 128,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_residualblock0_instancenorm0__fwd_mult_gamma",
/* inputs */ {"transformer_residualblock0_instancenorm0__fwd_bn_",
"transformer_residual_1_inst_1_gamma"});
nn_spec.add_addition(
/* name */ "transformer_residualblock0_instancenorm0__fwd",
/* inputs */ {"transformer_residualblock0_instancenorm0__fwd_mult_gamma",
"transformer_residual_1_inst_1_beta"});
nn_spec.add_relu(
/* name */ "transformer_residualblock0_activation0",
/* input */ "transformer_residualblock0_instancenorm0__fwd");
nn_spec.add_padding(
/* name */ "transformer_residualblock0_pad1",
/* input */ "transformer_residualblock0_activation0",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_residual_1_conv_2",
/* input */ "transformer_residualblock0_pad1",
/* num_output_channels */ 128,
/* num_kernel_channels */ 128,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_residual_1_inst_2_gamma",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_residual_1_inst_2_beta",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_residualblock0_instancenorm1__fwd_bn_",
/* input */ "transformer_residual_1_conv_2",
/* num_channels */ 128,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_residualblock0_instancenorm1__fwd_mult_gamma",
/* inputs */ {"transformer_residualblock0_instancenorm1__fwd_bn_",
"transformer_residual_1_inst_2_gamma"});
nn_spec.add_addition(
/* name */ "transformer_residualblock0_instancenorm1__fwd",
/* inputs */ {"transformer_residualblock0_instancenorm1__fwd_mult_gamma",
"transformer_residual_1_inst_2_beta"});
nn_spec.add_addition(
/* name */ "transformer_residualblock0__plus0",
/* inputs */ {"transformer_activation2",
"transformer_residualblock0_instancenorm1__fwd"});
nn_spec.add_padding(
/* name */ "transformer_residualblock1_pad0",
/* input */ "transformer_residualblock0__plus0",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_residual_2_conv_1",
/* input */ "transformer_residualblock1_pad0",
/* num_output_channels */ 128,
/* num_kernel_channels */ 128,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_residual_2_inst_1_gamma",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_residual_2_inst_1_beta",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_residualblock1_instancenorm0__fwd_bn_",
/* input */ "transformer_residual_2_conv_1",
/* num_channels */ 128,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_residualblock1_instancenorm0__fwd_mult_gamma",
/* inputs */ {"transformer_residualblock1_instancenorm0__fwd_bn_",
"transformer_residual_2_inst_1_gamma"});
nn_spec.add_addition(
/* name */ "transformer_residualblock1_instancenorm0__fwd",
/* inputs */ {"transformer_residualblock1_instancenorm0__fwd_mult_gamma",
"transformer_residual_2_inst_1_beta"});
nn_spec.add_relu(
/* name */ "transformer_residualblock1_activation0",
/* input */ "transformer_residualblock1_instancenorm0__fwd");
nn_spec.add_padding(
/* name */ "transformer_residualblock1_pad1",
/* input */ "transformer_residualblock1_activation0",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_residual_2_conv_2",
/* input */ "transformer_residualblock1_pad1",
/* num_output_channels */ 128,
/* num_kernel_channels */ 128,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_residual_2_inst_2_gamma",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_residual_2_inst_2_beta",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_residualblock1_instancenorm1__fwd_bn_",
/* input */ "transformer_residual_2_conv_2",
/* num_channels */ 128,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_residualblock1_instancenorm1__fwd_mult_gamma",
/* inputs */ {"transformer_residualblock1_instancenorm1__fwd_bn_",
"transformer_residual_2_inst_2_gamma"});
nn_spec.add_addition(
/* name */ "transformer_residualblock1_instancenorm1__fwd",
/* inputs */ {"transformer_residualblock1_instancenorm1__fwd_mult_gamma",
"transformer_residual_2_inst_2_beta"});
nn_spec.add_addition(
/* name */ "transformer_residualblock1__plus0",
/* inputs */ {"transformer_residualblock0__plus0",
"transformer_residualblock1_instancenorm1__fwd"});
nn_spec.add_padding(
/* name */ "transformer_residualblock2_pad0",
/* input */ "transformer_residualblock1__plus0",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_residual_3_conv_1",
/* input */ "transformer_residualblock2_pad0",
/* num_output_channels */ 128,
/* num_kernel_channels */ 128,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_residual_3_inst_1_gamma",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_residual_3_inst_1_beta",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_residualblock2_instancenorm0__fwd_bn_",
/* input */ "transformer_residual_3_conv_1",
/* num_channels */ 128,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_residualblock2_instancenorm0__fwd_mult_gamma",
/* inputs */ {"transformer_residualblock2_instancenorm0__fwd_bn_",
"transformer_residual_3_inst_1_gamma"});
nn_spec.add_addition(
/* name */ "transformer_residualblock2_instancenorm0__fwd",
/* inputs */ {"transformer_residualblock2_instancenorm0__fwd_mult_gamma",
"transformer_residual_3_inst_1_beta"});
nn_spec.add_relu(
/* name */ "transformer_residualblock2_activation0",
/* input */ "transformer_residualblock2_instancenorm0__fwd");
nn_spec.add_padding(
/* name */ "transformer_residualblock2_pad1",
/* input */ "transformer_residualblock2_activation0",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_residual_3_conv_2",
/* input */ "transformer_residualblock2_pad1",
/* num_output_channels */ 128,
/* num_kernel_channels */ 128,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_residual_3_inst_2_gamma",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_residual_3_inst_2_beta",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_residualblock2_instancenorm1__fwd_bn_",
/* input */ "transformer_residual_3_conv_2",
/* num_channels */ 128,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_residualblock2_instancenorm1__fwd_mult_gamma",
/* inputs */ {"transformer_residualblock2_instancenorm1__fwd_bn_",
"transformer_residual_3_inst_2_gamma"});
nn_spec.add_addition(
/* name */ "transformer_residualblock2_instancenorm1__fwd",
/* inputs */ {"transformer_residualblock2_instancenorm1__fwd_mult_gamma",
"transformer_residual_3_inst_2_beta"});
nn_spec.add_addition(
/* name */ "transformer_residualblock2__plus0",
/* inputs */ {"transformer_residualblock1__plus0",
"transformer_residualblock2_instancenorm1__fwd"});
nn_spec.add_padding(
/* name */ "transformer_residualblock3_pad0",
/* input */ "transformer_residualblock2__plus0",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_residual_4_conv_1",
/* input */ "transformer_residualblock3_pad0",
/* num_output_channels */ 128,
/* num_kernel_channels */ 128,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_residual_4_inst_1_gamma",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_residual_4_inst_1_beta",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_residualblock3_instancenorm0__fwd_bn_",
/* input */ "transformer_residual_4_conv_1",
/* num_channels */ 128,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_residualblock3_instancenorm0__fwd_mult_gamma",
/* inputs */ {"transformer_residualblock3_instancenorm0__fwd_bn_",
"transformer_residual_4_inst_1_gamma"});
nn_spec.add_addition(
/* name */ "transformer_residualblock3_instancenorm0__fwd",
/* inputs */ {"transformer_residualblock3_instancenorm0__fwd_mult_gamma",
"transformer_residual_4_inst_1_beta"});
nn_spec.add_relu(
/* name */ "transformer_residualblock3_activation0",
/* input */ "transformer_residualblock3_instancenorm0__fwd");
nn_spec.add_padding(
/* name */ "transformer_residualblock3_pad1",
/* input */ "transformer_residualblock3_activation0",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_residual_4_conv_2",
/* input */ "transformer_residualblock3_pad1",
/* num_output_channels */ 128,
/* num_kernel_channels */ 128,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_residual_4_inst_2_gamma",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_residual_4_inst_2_beta",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_residualblock3_instancenorm1__fwd_bn_",
/* input */ "transformer_residual_4_conv_2",
/* num_channels */ 128,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_residualblock3_instancenorm1__fwd_mult_gamma",
/* inputs */ {"transformer_residualblock3_instancenorm1__fwd_bn_",
"transformer_residual_4_inst_2_gamma"});
nn_spec.add_addition(
/* name */ "transformer_residualblock3_instancenorm1__fwd",
/* inputs */ {"transformer_residualblock3_instancenorm1__fwd_mult_gamma",
"transformer_residual_4_inst_2_beta"});
nn_spec.add_addition(
/* name */ "transformer_residualblock3__plus0",
/* inputs */ {"transformer_residualblock2__plus0",
"transformer_residualblock3_instancenorm1__fwd"});
nn_spec.add_padding(
/* name */ "transformer_residualblock4_pad0",
/* input */ "transformer_residualblock3__plus0",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_residual_5_conv_1",
/* input */ "transformer_residualblock4_pad0",
/* num_output_channels */ 128,
/* num_kernel_channels */ 128,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_residual_5_inst_1_gamma",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_residual_5_inst_1_beta",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_residualblock4_instancenorm0__fwd_bn_",
/* input */ "transformer_residual_5_conv_1",
/* num_channels */ 128,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_residualblock4_instancenorm0__fwd_mult_gamma",
/* inputs */ {"transformer_residualblock4_instancenorm0__fwd_bn_",
"transformer_residual_5_inst_1_gamma"});
nn_spec.add_addition(
/* name */ "transformer_residualblock4_instancenorm0__fwd",
/* inputs */ {"transformer_residualblock4_instancenorm0__fwd_mult_gamma",
"transformer_residual_5_inst_1_beta"});
nn_spec.add_relu(
/* name */ "transformer_residualblock4_activation0",
/* input */ "transformer_residualblock4_instancenorm0__fwd");
nn_spec.add_padding(
/* name */ "transformer_residualblock4_pad1",
/* input */ "transformer_residualblock4_activation0",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_residual_5_conv_2",
/* input */ "transformer_residualblock4_pad1",
/* num_output_channels */ 128,
/* num_kernel_channels */ 128,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_residual_5_inst_2_gamma",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_residual_5_inst_2_beta",
/* input */ "index",
/* num_output_channels */ 128,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_residualblock4_instancenorm1__fwd_bn_",
/* input */ "transformer_residual_5_conv_2",
/* num_channels */ 128,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_residualblock4_instancenorm1__fwd_mult_gamma",
/* inputs */ {"transformer_residualblock4_instancenorm1__fwd_bn_",
"transformer_residual_5_inst_2_gamma"});
nn_spec.add_addition(
/* name */ "transformer_residualblock4_instancenorm1__fwd",
/* inputs */ {"transformer_residualblock4_instancenorm1__fwd_mult_gamma",
"transformer_residual_5_inst_2_beta"});
nn_spec.add_addition(
/* name */ "transformer_residualblock4__plus0",
/* inputs */ {"transformer_residualblock3__plus0",
"transformer_residualblock4_instancenorm1__fwd"});
nn_spec.add_upsampling(
/* name */ "transformer_upsampling0",
/* input */ "transformer_residualblock4__plus0",
/* scaling_x */ 2,
/* scaling_y */ 2);
nn_spec.add_padding(
/* name */ "transformer_pad3",
/* input */ "transformer_upsampling0",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_decoding_1_conv",
/* input */ "transformer_pad3",
/* num_output_channels */ 64,
/* num_kernel_channels */ 128,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_decoding_1_inst_gamma",
/* input */ "index",
/* num_output_channels */ 64,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_decoding_1_inst_beta",
/* input */ "index",
/* num_output_channels */ 64,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_instancenorm3__fwd_bn_",
/* input */ "transformer_decoding_1_conv",
/* num_channels */ 64,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_instancenorm3__fwd_mult_gamma",
/* inputs */ {"transformer_instancenorm3__fwd_bn_",
"transformer_decoding_1_inst_gamma"});
nn_spec.add_addition(
/* name */ "transformer_instancenorm3__fwd",
/* inputs */ {"transformer_instancenorm3__fwd_mult_gamma",
"transformer_decoding_1_inst_beta"});
nn_spec.add_relu(
/* name */ "transformer_activation3",
/* input */ "transformer_instancenorm3__fwd");
nn_spec.add_upsampling(
/* name */ "transformer_upsampling1",
/* input */ "transformer_activation3",
/* scaling_x */ 2,
/* scaling_y */ 2);
nn_spec.add_padding(
/* name */ "transformer_pad4",
/* input */ "transformer_upsampling1",
/* padding_top */ 1,
/* padding_bottom */ 1,
/* padding_left */ 1,
/* padding_right */ 1);
nn_spec.add_convolution(
/* name */ "transformer_decoding_2_conv",
/* input */ "transformer_pad4",
/* num_output_channels */ 32,
/* num_kernel_channels */ 64,
/* kernel_height */ 3,
/* kernel_width */ 3,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_decoding_2_inst_gamma",
/* input */ "index",
/* num_output_channels */ 32,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_decoding_2_inst_beta",
/* input */ "index",
/* num_output_channels */ 32,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_instancenorm4__fwd_bn_",
/* input */ "transformer_decoding_2_conv",
/* num_channels */ 32,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_instancenorm4__fwd_mult_gamma",
/* inputs */ {"transformer_instancenorm4__fwd_bn_",
"transformer_decoding_2_inst_gamma"});
nn_spec.add_addition(
/* name */ "transformer_instancenorm4__fwd",
/* inputs */ {"transformer_instancenorm4__fwd_mult_gamma",
"transformer_decoding_2_inst_beta"});
nn_spec.add_relu(
/* name */ "transformer_activation4",
/* input */ "transformer_instancenorm4__fwd");
nn_spec.add_padding(
/* name */ "transformer_pad5",
/* input */ "transformer_activation4",
/* padding_top */ 4,
/* padding_bottom */ 4,
/* padding_left */ 4,
/* padding_right */ 4);
nn_spec.add_convolution(
/* name */ "transformer_conv5",
/* input */ "transformer_pad5",
/* num_output_channels */ 3,
/* num_kernel_channels */ 32,
/* kernel_height */ 9,
/* kernel_width */ 9,
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ initializer);
nn_spec.add_inner_product(
/* name */ "transformer_instancenorm5_gamma",
/* input */ "index",
/* num_output_channels */ 3,
/* num_input_channels */ num_styles,
/* weight_init_fn */ scalar_weight_initializer(1.0f),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_inner_product(
/* name */ "transformer_instancenorm5_beta",
/* input */ "index",
/* num_output_channels */ 3,
/* num_input_channels */ num_styles,
/* weight_init_fn */ zero_weight_initializer(),
/* bias_init_fn */ zero_weight_initializer());
nn_spec.add_instancenorm(
/* name */ "transformer_instancenorm5__fwd_bn_",
/* input */ "transformer_conv5",
/* num_channels */ 3,
/* epsilon */ 1e-5);
nn_spec.add_multiplication(
/* name */ "transformer_instancenorm5__fwd_mult_gamma",
/* inputs */ {"transformer_instancenorm5__fwd_bn_",
"transformer_instancenorm5_gamma"});
nn_spec.add_addition(
/* name */ "transformer_instancenorm5__fwd",
/* inputs */ {"transformer_instancenorm5__fwd_mult_gamma",
"transformer_instancenorm5_beta"});
nn_spec.add_sigmoid(
/* name */ "transformer_activation5",
/* input */ "transformer_instancenorm5__fwd");
nn_spec.add_scale(
/* name */ "stylizedImage",
/* input */ "transformer_activation5",
/* shape_c */ {1},
/* weight_init_fn */ scalar_weight_initializer(255.0));
}