def get_tokenizer()

in modules/SwissArmyTransformer/sat/tokenization/__init__.py [0:0]


def get_tokenizer(args=None, *, tokenizer_type=None, outer_tokenizer=None):
    '''
        If you're using outer_tokenizer, call `get_tokenizer(args, outer_tokenizer)`
        before `training_main`.
    '''
    if outer_tokenizer is not None: # set 1
        get_tokenizer.tokenizer = outer_tokenizer
        get_tokenizer.tokenizer_type = 'outer_tokenizer'
        print_rank0('> Set tokenizer as an outer_tokenizer! Now you can get_tokenizer() everywhere.')
        return outer_tokenizer
    if tokenizer_type is None:
        if args is None:
            assert hasattr(get_tokenizer, 'tokenizer'), 'Never set tokenizer.'
            return get_tokenizer.tokenizer
        tokenizer_type = args.tokenizer_type

    # find the tokenizer via tokenizer_type!
    if hasattr(get_tokenizer, 'tokenizer_type') and \
        tokenizer_type == get_tokenizer.tokenizer_type:  # the same as last
        return get_tokenizer.tokenizer

    get_tokenizer.tokenizer_type = tokenizer_type
    # load the tokenizer according to tokenizer_type
    if tokenizer_type.startswith('cogview'): # or cogview_ICE
        from .cogview import UnifiedTokenizer
        get_tokenizer.tokenizer = UnifiedTokenizer(
            args.img_tokenizer_path,
            txt_tokenizer_type='cogview',
            device=torch.cuda.current_device()
        )
    elif tokenizer_type.startswith('glm'):
        kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask,
                    "add_decoder_mask": args.block_mask_prob > 0.0}
        if tokenizer_type == "glm_GPT2BPETokenizer":
            from .glm import GPT2BPETokenizer
            get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
        elif tokenizer_type == "glm_ChineseSPTokenizer":
            from .glm import ChineseSPTokenizer
            get_tokenizer.tokenizer = ChineseSPTokenizer(args.tokenizer_model_type, **kwargs)
        elif tokenizer_type == "glm_BertWordPieceTokenizer":
            from .glm import BertWordPieceTokenizer
            get_tokenizer.tokenizer = BertWordPieceTokenizer(args.tokenizer_model_type, **kwargs)
    elif tokenizer_type == 'icetk':
        from icetk import icetk
        get_tokenizer.tokenizer = icetk
    elif tokenizer_type == 'icetk-glm-130B':
        from .icetk_glm_130B import _IceTokenizer
        get_tokenizer.tokenizer = _IceTokenizer()
    # elif tokenizer_type.startswith('hf'):
    #     from .hf_tokenizer import HFT5Tokenizer
    #     if tokenizer_type == "hf_T5Tokenizer":
    #         get_tokenizer.tokenizer = HFT5Tokenizer(args.tokenizer_model_type, cache_dir=args.cache_dir)
    else:
        print_rank0('Try to load tokenizer from Huggingface transformers...')
        os.environ['TOKENIZERS_PARALLELISM'] = 'true'
        from transformers import AutoTokenizer
        try:
            get_tokenizer.tokenizer = AutoTokenizer.from_pretrained(tokenizer_type, trust_remote_code=True)
        except OSError as e:
            print_rank0(f'Cannot find {tokenizer_type} from Huggingface or sat. Creating a fake tokenizer...')
            assert args.vocab_size > 0
            get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
            return get_tokenizer.tokenizer
    print_rank0(f'> Set tokenizer as a {tokenizer_type} tokenizer! Now you can get_tokenizer() everywhere.')