in modules/adaptive_span.py [0:0]
def add_args(parser):
parser.add_argument(
"--adapt-span",
action="store_true",
default=False,
help="enable adaptive attention span",
)
parser.add_argument(
"--adapt-span-loss", type=float, default=0, help="loss coeff on attention span"
)
parser.add_argument(
"--adapt-span-len", type=float, default=32, help="ramp length of adaptive span"
)
parser.add_argument(
"--adapt-span-init", type=float, default=0, help="initial attention span ratio"
)
parser.add_argument(
"--adapt-span-cache",
action="store_true",
default=False,
help="adapt cache size to reduce memory usage",
)
parser.add_argument(
"--adapt-span-trim-step", type=int, default=64, help="trim step"
)
parser.add_argument(
"--adapt-span-layer",
action="store_true",
default=False,
help="constrain all heads in a layer to have same span",
)