learning/SentencePredictions.ipynb (302 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import train_type as tp\n", "import tensorflow as tf\n", "import numpy as np\n", "DATA_DIR = \"/mnt/big_drive/deeptype/data/\"" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# function to get probabilities given a sentence split\n", "def get_prob(tagger_ins,sentence_splits):\n", " ps = tagger_ins.predict_proba_sentences([sentence_splits])\n", " output = [i for i in ps]\n", " probs = output[0]['type']\n", " return probs[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load all tags\n", "(if you have multiple classifiers, read each classes.txt separetely and append to tags)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['aaa_no_instance_subclass_or_link',\n", " 'aaa_wikidata_prop',\n", " 'aaa_wikimedia_category_page',\n", " 'aaa_wikipedia_disambiguation',\n", " 'aaa_wikipedia_list',\n", " 'aaa_wikipedia_project_page',\n", " 'aaa_wikipedia_template_namespace',\n", " 'aaa_wikipedia_user_language_template',\n", " 'activity',\n", " 'aircraft',\n", " 'airport',\n", " 'algorithm',\n", " 'alphabet',\n", " 'anatomical_structure',\n", " 'astronomical_object',\n", " 'audio_visual_work',\n", " 'award',\n", " 'award_ceremony',\n", " 'battle',\n", " 'book_magazine_article',\n", " 'brand',\n", " 'bridge',\n", " 'character',\n", " 'chemical_compound',\n", " 'clothing',\n", " 'color',\n", " 'concept',\n", " 'country',\n", " 'crime',\n", " 'currency',\n", " 'data_format',\n", " 'date',\n", " 'developmental_biology_period',\n", " 'disease',\n", " 'electromagnetic_wave',\n", " 'event',\n", " 'facility',\n", " 'family',\n", " 'fictional_character',\n", " 'food',\n", " 'gas',\n", " 'gene',\n", " 'genre',\n", " 'geographical_object',\n", " 'geometric_shape',\n", " 'hazard',\n", " 'human',\n", " 'human_female',\n", " 'human_male',\n", " 'international_relations',\n", " 'kinship',\n", " 'lake',\n", " 'language',\n", " 'law',\n", " 'legal_action',\n", " 'legal_case',\n", " 'legislative_term',\n", " 'mathematical_object',\n", " 'mind',\n", " 'molecule',\n", " 'monument',\n", " 'mountain',\n", " 'musical_work',\n", " 'name',\n", " 'natural_phenomenon',\n", " 'number',\n", " 'organization',\n", " 'other_art_work',\n", " 'people',\n", " 'person_role',\n", " 'physical_object',\n", " 'physical_quantity',\n", " 'plant',\n", " 'populated_place',\n", " 'position',\n", " 'postal_code',\n", " 'radio_program',\n", " 'railroad',\n", " 'record_chart',\n", " 'region',\n", " 'religion',\n", " 'research',\n", " 'river',\n", " 'road_vehicle',\n", " 'sea',\n", " 'sex_toy',\n", " 'sexual_orientation',\n", " 'software',\n", " 'song',\n", " 'speech',\n", " 'sport',\n", " 'sport_event',\n", " 'sports_terminology',\n", " 'strategy',\n", " 'taxon',\n", " 'taxonomic_rank',\n", " 'title',\n", " 'train_station',\n", " 'union',\n", " 'unit_of_mass',\n", " 'value',\n", " 'vehicle',\n", " 'vehicle_brand',\n", " 'volcano',\n", " 'war',\n", " 'watercraft',\n", " 'weapon',\n", " 'website',\n", " 'other']" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# load all tags (if you have multiple classifiers read each classes.txt separetely and append to tags)\n", "tags = open('{}/classifications/type_classification/classes.txt'.format(DATA_DIR)).readlines()\n", "tags = [t.replace(\"\\n\",\"\") for t in tags]\n", "tags" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load the LSTM Model \n", "Make sure you have sufficient memory in GPU, if you are using GPU" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from /mnt/big_drive/deeptype/my_great_model_bckp/model.ckpt\n" ] } ], "source": [ "tagger = tp.SequenceTagger('/mnt/big_drive/deeptype/my_great_model_bckp/')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Get probabilities for a sentence \n", "outputs probability vector for each word in a sentence, of size N*tags, \n", "where N is the number of words in the sentence and tags is the number of tags" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(7, 109)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sentence = \"Man saw a jaguar in the jungle\"\n", "sent_splits = sentence.split()\n", "probs = get_prob(tagger,sent_splits)\n", "probs.shape" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Probability for the word Man\n", "[('human_male', 0.14398718), ('mathematical_object', 0.34086773), ('other', 0.12715971)]\n", "\n", "\n", "Probability for the word saw\n", "[('mathematical_object', 0.16467695), ('physical_object', 0.2215165), ('other', 0.248256)]\n", "\n", "\n", "Probability for the word a\n", "[('mathematical_object', 0.13781491), ('physical_object', 0.10018583), ('taxon', 0.27661625), ('other', 0.2516701)]\n", "\n", "\n", "Probability for the word jaguar\n", "[('taxon', 0.6395821), ('other', 0.10057912)]\n", "\n", "\n", "Probability for the word in\n", "[('other', 0.44680375)]\n", "\n", "\n", "Probability for the word the\n", "[('region', 0.12901735), ('other', 0.46021998)]\n", "\n", "\n", "Probability for the word jungle\n", "[('region', 0.6417663), ('other', 0.13294013)]\n", "\n", "\n" ] } ], "source": [ "threshold = 0.1 \n", "for k,p in enumerate(probs):\n", " print(\"Types for the word {}\".format(sent_splits[k]))\n", " print([(tags[k],p[k]) for k in np.where(p>threshold)[0]])\n", " print(\"\\n\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }