def make_pgd_attacker()

in online_attacks/attacks/pgd.py [0:0]


def make_pgd_attacker(classifier: Module, params: PGDParams = PGDParams()) -> PGDAttack:

    if params.norm == "Linf":
        attacker = LinfPGDAttack(
            classifier,
            eps=params.eps,
            nb_iter=params.nb_iter,
            eps_iter=params.eps_iter,
            rand_init=params.rand_init,
            clip_min=params.clip_min,
            clip_max=params.clip_max,
            targeted=params.targeted,
        )
    elif params.norm == "L2":
        attacker = L2PGDAttack(
            classifier,
            eps=params.eps,
            nb_iter=params.nb_iter,
            eps_iter=params.eps_iter,
            rand_init=params.rand_init,
            clip_min=params.clip_min,
            clip_max=params.clip_max,
            targeted=params.targeted,
        )
    elif params.norm == "L1":
        attacker = L1PGDAttack(
            classifier,
            eps=params.eps,
            nb_iter=params.nb_iter,
            eps_iter=params.eps_iter,
            rand_init=params.rand_init,
            clip_min=params.clip_min,
            clip_max=params.clip_max,
            targeted=params.targeted,
        )

    return attacker