def add_args()

in modules/adaptive_span.py [0:0]


def add_args(parser):
    parser.add_argument(
        "--adapt-span",
        action="store_true",
        default=False,
        help="enable adaptive attention span",
    )
    parser.add_argument(
        "--adapt-span-loss", type=float, default=0, help="loss coeff on attention span"
    )
    parser.add_argument(
        "--adapt-span-len", type=float, default=32, help="ramp length of adaptive span"
    )
    parser.add_argument(
        "--adapt-span-init", type=float, default=0, help="initial attention span ratio"
    )
    parser.add_argument(
        "--adapt-span-cache",
        action="store_true",
        default=False,
        help="adapt cache size to reduce memory usage",
    )
    parser.add_argument(
        "--adapt-span-trim-step", type=int, default=64, help="trim step"
    )
    parser.add_argument(
        "--adapt-span-layer",
        action="store_true",
        default=False,
        help="constrain all heads in a layer to have same span",
    )