ludwig/utils/nlp_utils.py (142 lines of code) (raw):

#! /usr/bin/env python # coding=utf-8 # Copyright (c) 2019 Uber Technologies, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import logging import sys logger = logging.getLogger(__name__) nlp_pipelines = { 'en': None, 'it': None, 'es': None, 'de': None, 'fr': None, 'pt': None, 'nl': None, 'el': None, 'nb': None, 'lt': None, 'da': None, 'pl': None, 'ro': None, 'ja': None, 'zh': None, 'xx': None } language_module_registry = { 'en': 'en_core_web_sm', 'it': 'it_core_news_sm', 'es': 'es_core_news_sm', 'de': 'de_core_news_sm', 'fr': 'fr_core_news_sm', 'pt': 'pt_core_news_sm', 'nl': 'nl_core_news_sm', 'el': 'el_core_news_sm', 'nb': 'nb_core_news_sm', 'lt': 'lt_core_news_sm', 'da': 'da_core_news_sm', 'pl': 'pl_core_news_sm', 'ro': 'ro_core_news_sm', 'ja': 'ja_core_news_sm', 'zh': 'zh_core_web_sm', 'xx': 'xx_ent_wiki_sm' } default_characters = [' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '8', '9', '-', ',', ';', '.', '!', '?', ':', '\'', '\'', '/', '\\', '|', '_', '@', '#', '$', '%', '^', '&', '*', '~', '`', '+', '-', '=', '<', '>', '(', ')', '[', ']', '{', '}'] punctuation = {'.', ',', '@', '$', '%', '/', ':', ';', '+', '='} def load_nlp_pipeline(language='xx'): if language not in language_module_registry: logger.error( 'Language {} is not supported.' 'Suported languages are: {}'.format( language, language_module_registry.keys() )) raise ValueError else: spacy_module_name = language_module_registry[language] global nlp_pipelines if nlp_pipelines[language] is None: logger.info('Loading NLP pipeline') try: import spacy except ImportError: logger.error( ' spacy is not installed. ' 'In order to install all text feature dependencies run ' 'pip install ludwig[text]' ) sys.exit(-1) try: nlp_pipelines[language] = spacy.load( spacy_module_name, disable=['parser', 'tagger', 'ner'] ) except OSError: logger.info( ' spaCy {} model is missing, downloading it ' '(this will only happen once)' ) from spacy.cli import download download(spacy_module_name) nlp_pipelines[language] = spacy.load( spacy_module_name, disable=['parser', 'tagger', 'ner'] ) return nlp_pipelines[language] def pass_filters( token, filter_numbers=False, filter_punctuation=False, filter_short_tokens=False, filter_stopwords=False ): passes_filters = True if filter_numbers: passes_filters = not token.like_num if passes_filters and filter_punctuation: passes_filters = not bool(set(token.orth_) & punctuation) if passes_filters and filter_short_tokens: passes_filters = len(token) > 2 if passes_filters and filter_stopwords: passes_filters = not token.is_stop return passes_filters def process_text( text, nlp_pipeline, return_lemma=False, filter_numbers=False, filter_punctuation=False, filter_short_tokens=False, filter_stopwords=False ): doc = nlp_pipeline.tokenizer(text) return [token.lemma_ if return_lemma else token.text for token in doc if pass_filters(token, filter_numbers, filter_punctuation, filter_short_tokens, filter_stopwords) ] if __name__ == '__main__': text = 'Hello John, how are you doing my good old friend? Are you still number 732 in the list? Did you pay $32.43 or 54.21 for the book?' print(process_text(text, load_nlp_pipeline())) print(process_text(text, load_nlp_pipeline(), filter_numbers=True, filter_punctuation=True, filter_short_tokens=True)) print(process_text(text, load_nlp_pipeline(), filter_stopwords=True)) print(process_text(text, load_nlp_pipeline(), return_lemma=True)) print(process_text(text, load_nlp_pipeline(), return_lemma=True, filter_numbers=True, filter_punctuation=True, filter_short_tokens=True)) print(process_text(text, load_nlp_pipeline(), return_lemma=True, filter_stopwords=True))