learning/SentencePredictions.ipynb (302 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import train_type as tp\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"DATA_DIR = \"/mnt/big_drive/deeptype/data/\""
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# function to get probabilities given a sentence split\n",
"def get_prob(tagger_ins,sentence_splits):\n",
" ps = tagger_ins.predict_proba_sentences([sentence_splits])\n",
" output = [i for i in ps]\n",
" probs = output[0]['type']\n",
" return probs[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load all tags\n",
"(if you have multiple classifiers, read each classes.txt separetely and append to tags)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['aaa_no_instance_subclass_or_link',\n",
" 'aaa_wikidata_prop',\n",
" 'aaa_wikimedia_category_page',\n",
" 'aaa_wikipedia_disambiguation',\n",
" 'aaa_wikipedia_list',\n",
" 'aaa_wikipedia_project_page',\n",
" 'aaa_wikipedia_template_namespace',\n",
" 'aaa_wikipedia_user_language_template',\n",
" 'activity',\n",
" 'aircraft',\n",
" 'airport',\n",
" 'algorithm',\n",
" 'alphabet',\n",
" 'anatomical_structure',\n",
" 'astronomical_object',\n",
" 'audio_visual_work',\n",
" 'award',\n",
" 'award_ceremony',\n",
" 'battle',\n",
" 'book_magazine_article',\n",
" 'brand',\n",
" 'bridge',\n",
" 'character',\n",
" 'chemical_compound',\n",
" 'clothing',\n",
" 'color',\n",
" 'concept',\n",
" 'country',\n",
" 'crime',\n",
" 'currency',\n",
" 'data_format',\n",
" 'date',\n",
" 'developmental_biology_period',\n",
" 'disease',\n",
" 'electromagnetic_wave',\n",
" 'event',\n",
" 'facility',\n",
" 'family',\n",
" 'fictional_character',\n",
" 'food',\n",
" 'gas',\n",
" 'gene',\n",
" 'genre',\n",
" 'geographical_object',\n",
" 'geometric_shape',\n",
" 'hazard',\n",
" 'human',\n",
" 'human_female',\n",
" 'human_male',\n",
" 'international_relations',\n",
" 'kinship',\n",
" 'lake',\n",
" 'language',\n",
" 'law',\n",
" 'legal_action',\n",
" 'legal_case',\n",
" 'legislative_term',\n",
" 'mathematical_object',\n",
" 'mind',\n",
" 'molecule',\n",
" 'monument',\n",
" 'mountain',\n",
" 'musical_work',\n",
" 'name',\n",
" 'natural_phenomenon',\n",
" 'number',\n",
" 'organization',\n",
" 'other_art_work',\n",
" 'people',\n",
" 'person_role',\n",
" 'physical_object',\n",
" 'physical_quantity',\n",
" 'plant',\n",
" 'populated_place',\n",
" 'position',\n",
" 'postal_code',\n",
" 'radio_program',\n",
" 'railroad',\n",
" 'record_chart',\n",
" 'region',\n",
" 'religion',\n",
" 'research',\n",
" 'river',\n",
" 'road_vehicle',\n",
" 'sea',\n",
" 'sex_toy',\n",
" 'sexual_orientation',\n",
" 'software',\n",
" 'song',\n",
" 'speech',\n",
" 'sport',\n",
" 'sport_event',\n",
" 'sports_terminology',\n",
" 'strategy',\n",
" 'taxon',\n",
" 'taxonomic_rank',\n",
" 'title',\n",
" 'train_station',\n",
" 'union',\n",
" 'unit_of_mass',\n",
" 'value',\n",
" 'vehicle',\n",
" 'vehicle_brand',\n",
" 'volcano',\n",
" 'war',\n",
" 'watercraft',\n",
" 'weapon',\n",
" 'website',\n",
" 'other']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# load all tags (if you have multiple classifiers read each classes.txt separetely and append to tags)\n",
"tags = open('{}/classifications/type_classification/classes.txt'.format(DATA_DIR)).readlines()\n",
"tags = [t.replace(\"\\n\",\"\") for t in tags]\n",
"tags"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load the LSTM Model \n",
"Make sure you have sufficient memory in GPU, if you are using GPU"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /mnt/big_drive/deeptype/my_great_model_bckp/model.ckpt\n"
]
}
],
"source": [
"tagger = tp.SequenceTagger('/mnt/big_drive/deeptype/my_great_model_bckp/')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Get probabilities for a sentence \n",
"outputs probability vector for each word in a sentence, of size N*tags, \n",
"where N is the number of words in the sentence and tags is the number of tags"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(7, 109)"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sentence = \"Man saw a jaguar in the jungle\"\n",
"sent_splits = sentence.split()\n",
"probs = get_prob(tagger,sent_splits)\n",
"probs.shape"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Probability for the word Man\n",
"[('human_male', 0.14398718), ('mathematical_object', 0.34086773), ('other', 0.12715971)]\n",
"\n",
"\n",
"Probability for the word saw\n",
"[('mathematical_object', 0.16467695), ('physical_object', 0.2215165), ('other', 0.248256)]\n",
"\n",
"\n",
"Probability for the word a\n",
"[('mathematical_object', 0.13781491), ('physical_object', 0.10018583), ('taxon', 0.27661625), ('other', 0.2516701)]\n",
"\n",
"\n",
"Probability for the word jaguar\n",
"[('taxon', 0.6395821), ('other', 0.10057912)]\n",
"\n",
"\n",
"Probability for the word in\n",
"[('other', 0.44680375)]\n",
"\n",
"\n",
"Probability for the word the\n",
"[('region', 0.12901735), ('other', 0.46021998)]\n",
"\n",
"\n",
"Probability for the word jungle\n",
"[('region', 0.6417663), ('other', 0.13294013)]\n",
"\n",
"\n"
]
}
],
"source": [
"threshold = 0.1 \n",
"for k,p in enumerate(probs):\n",
" print(\"Types for the word {}\".format(sent_splits[k]))\n",
" print([(tags[k],p[k]) for k in np.where(p>threshold)[0]])\n",
" print(\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}