modules/SwissArmyTransformer/sat/ops/scaled_mask_softmax.py (10 lines of code) (raw):

try: from apex.transformer.functional import FusedScaleMaskSoftmax from apex.transformer.enums import AttnMaskType except ModuleNotFoundError: from sat.helpers import print_rank0 print_rank0( "Please install apex to use FusedScaleMaskSoftmax, otherwise the inference efficiency will be greatly reduced" ) FusedScaleMaskSoftmax = None AttnMaskType = None