in lib/models/lfb_helper.py [0:0]
def NLCore(
model, in_blob1, in_blob2, in_dim1, in_dim2, latent_dim,
num_feat1, num_feat2, prefix, test_mode):
"""Core logic of non-local blocks."""
theta = model.ConvNd(
in_blob1, prefix + '_theta',
in_dim1,
latent_dim,
[1, 1, 1],
strides=[1, 1, 1],
pads=[0, 0, 0] * 2,
**init_params1)
phi = model.ConvNd(
in_blob2,
prefix + '_phi',
in_dim2,
latent_dim,
[1, 1, 1],
strides=[1, 1, 1],
pads=[0, 0, 0] * 2,
**init_params1)
g = model.ConvNd(
in_blob2,
prefix + '_g',
in_dim2,
latent_dim,
[1, 1, 1],
strides=[1, 1, 1],
pads=[0, 0, 0] * 2,
**init_params1)
theta, theta_shape_5d = model.Reshape(
theta,
[theta + '_re' if not cfg.MODEL.ALLOW_INPLACE_RESHAPE else theta,
theta + '_shape5d'],
shape=(-1, latent_dim, num_feat1))
phi, phi_shape_5d = model.Reshape(
phi,
[phi + '_re' if not cfg.MODEL.ALLOW_INPLACE_RESHAPE else phi,
phi + '_shape5d'],
shape=(-1, latent_dim, num_feat2))
g, g_shape_5d = model.Reshape(
g,
[g + '_re',
g + '_shape5d'],
shape=(-1, latent_dim, num_feat2))
# (N, C, num_feat1), (N, C, num_feat2) -> (N, num_feat1, num_feat2)
theta_phi = model.net.BatchMatMul(
[theta, phi], prefix + '_affinity', trans_a=1)
if cfg.FBO_NL.SCALE:
theta_phi = model.Scale(
theta_phi, theta_phi, scale=latent_dim**-.5)
p = model.Softmax(
theta_phi, theta_phi + '_prob', engine='CUDNN', axis=2)
# (N, C, num_feat2), (N, num_feat1, num_feat2) -> (B, C, num_feat1)
t = model.net.BatchMatMul([g, p], prefix + '_y', trans_b=1)
blob_out, t_shape = model.Reshape(
[t, theta_shape_5d],
[t + '_re' if not cfg.MODEL.ALLOW_INPLACE_RESHAPE else t,
t + '_shape3d'])
if cfg.FBO_NL.PRE_ACT:
blob_out = pre_act(model, blob_out)
blob_out = model.ConvNd(
blob_out, prefix + '_out',
latent_dim,
in_dim1,
[1, 1, 1],
strides=[1, 1, 1],
pads=[0, 0, 0] * 2,
**init_params2)
if not cfg.FBO_NL.PRE_ACT:
blob_out = model.LayerNorm(
blob_out,
[prefix + "_ln", prefix + "_ln_mean", prefix + "_ln_std"])[0]
if cfg.FBO_NL.LFB_DROPOUT_ON and not test_mode:
blob_out = model.Dropout(
blob_out, blob_out + '_drop',
ratio=cfg.FBO_NL.DROPOUT_RATE, is_test=False)
return blob_out