examples/rsa/semantic_parsing.py (241 lines of code) (raw):
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
"""
Combining models of RSA pragmatics and CCG-based compositional semantics.
Taken from: http://dippl.org/examples/zSemanticPragmaticMashup.html
"""
import torch
import argparse
import collections
import pyro
import pyro.distributions as dist
from search_inference import HashingMarginal, BestFirstSearch, memoize
torch.set_default_dtype(torch.float64)
def Marginal(fn=None, **kwargs):
if fn is None:
return lambda _fn: Marginal(_fn, **kwargs)
return memoize(lambda *args: HashingMarginal(BestFirstSearch(fn, **kwargs).run(*args)))
###################################################################
# Lexical semantics
###################################################################
def flip(name, p):
return pyro.sample(name, dist.Bernoulli(p)).item() == 1
# hashable state
obj = collections.namedtuple("Obj", ["name", "blond", "nice", "tall"])
def Obj(name):
return obj(name=name,
blond=flip(name + "_blond", 0.5),
nice=flip(name + "_nice", 0.5),
tall=flip(name + "_tall", 0.5))
class Meaning:
def sem(self, world):
raise NotImplementedError
__call__ = sem
def syn(self):
raise NotImplementedError
class UndefinedMeaning(Meaning):
def sem(self, world):
return None
def syn(self):
return ""
class BlondMeaning(Meaning):
def sem(self, world):
return lambda obj: obj.blond
def syn(self):
return {"dir": "L", "int": "NP", "out": "S"}
class NiceMeaning(Meaning):
def sem(self, world):
return lambda obj: obj.nice
def syn(self):
return {"dir": "L", "int": "NP", "out": "S"}
class TallMeaning(Meaning):
def sem(self, world):
return lambda obj: obj.tall
def syn(self):
return {"dir": "L", "int": "NP", "out": "S"}
class BobMeaning(Meaning):
def sem(self, world):
return list(filter(lambda obj: obj.name == "Bob", world))[0]
def syn(self):
return "NP"
class SomeMeaning(Meaning):
def sem(self, world):
def f1(P):
def f2(Q):
return len(list(filter(Q, filter(P, world)))) > 0
return f2
return f1
def syn(self):
return {
"dir": "R",
"int": {"dir": "L", "int": "NP", "out": "S"},
"out": {
"dir": "R",
"int": {"dir": "L", "int": "NP", "out": "S"},
"out": "S"
}
}
class AllMeaning(Meaning):
def sem(self, world):
def f1(P):
def f2(Q):
return len(list(filter(lambda *args: not Q(*args),
filter(P, world)))) == 0
return f2
return f1
def syn(self):
return {
"dir": "R",
"int": {"dir": "L", "int": "NP", "out": "S"},
"out": {
"dir": "R",
"int": {"dir": "L", "int": "NP", "out": "S"},
"out": "S"
}
}
class NoneMeaning(Meaning):
def sem(self, world):
def f1(P):
def f2(Q):
return len(list(filter(Q, filter(P, world)))) == 0
return f2
return f1
def syn(self):
return {
"dir": "R",
"int": {"dir": "L", "int": "NP", "out": "S"},
"out": {
"dir": "R",
"int": {"dir": "L", "int": "NP", "out": "S"},
"out": "S"
}
}
class CompoundMeaning(Meaning):
def __init__(self, sem, syn):
self._sem = sem
self._syn = syn
def sem(self, world):
return self._sem(world)
def syn(self):
return self._syn
###################################################################
# Compositional semantics
###################################################################
def heuristic(is_good):
if is_good:
return torch.tensor(0.)
return torch.tensor(-100.0)
def world_prior(num_objs, meaning_fn):
prev_factor = torch.tensor(0.)
world = []
for i in range(num_objs):
world.append(Obj("obj_{}".format(i)))
new_factor = heuristic(meaning_fn(world))
pyro.factor("factor_{}".format(i), new_factor - prev_factor)
prev_factor = new_factor
pyro.factor("factor_{}".format(num_objs), prev_factor * -1)
return tuple(world)
def lexical_meaning(word):
meanings = {
"blond": BlondMeaning,
"nice": NiceMeaning,
"Bob": BobMeaning,
"some": SomeMeaning,
"none": NoneMeaning,
"all": AllMeaning
}
if word in meanings:
return meanings[word]()
else:
return UndefinedMeaning()
def apply_world_passing(f, a):
return lambda w: f(w)(a(w))
def syntax_match(s, t):
if "dir" in s and "dir" in t:
return (s["dir"] and t["dir"]) and \
syntax_match(s["int"], t["int"]) and \
syntax_match(s["out"], t["out"])
else:
return s == t
def can_apply(meanings):
inds = []
for i, meaning in enumerate(meanings):
applies = False
s = meaning.syn()
if "dir" in s:
if s["dir"] == "L":
applies = syntax_match(s["int"], meanings[i-1].syn())
elif s["dir"] == "R":
applies = syntax_match(s["int"], meanings[i+1].syn())
else:
applies = False
if applies:
inds.append(i)
return inds
def combine_meaning(meanings, c):
possible_combos = can_apply(meanings)
N = len(possible_combos)
ix = pyro.sample("ix_{}".format(c),
dist.Categorical(torch.ones(N) / N))
i = possible_combos[ix]
s = meanings[i].syn()
if s["dir"] == "L":
f = meanings[i].sem
a = meanings[i-1].sem
new_meaning = CompoundMeaning(sem=apply_world_passing(f, a),
syn=s["out"])
return meanings[0:i-1] + [new_meaning] + meanings[i+1:]
if s["dir"] == "R":
f = meanings[i].sem
a = meanings[i+1].sem
new_meaning = CompoundMeaning(sem=apply_world_passing(f, a),
syn=s["out"])
return meanings[0:i] + [new_meaning] + meanings[i+2:]
def combine_meanings(meanings, c=0):
if len(meanings) == 1:
return meanings[0].sem
else:
return combine_meanings(combine_meaning(meanings, c), c=c+1)
def meaning(utterance):
defined = filter(lambda w: "" != w.syn(),
list(map(lexical_meaning, utterance.split(" "))))
return combine_meanings(list(defined))
@Marginal(num_samples=100)
def literal_listener(utterance):
m = meaning(utterance)
world = world_prior(2, m)
pyro.factor("world_constraint", heuristic(m(world)) * 1000)
return world
def utterance_prior():
utterances = ["some of the blond people are nice",
"all of the blond people are nice",
"none of the blond people are nice"]
ix = pyro.sample("utterance", dist.Categorical(torch.ones(3) / 3.0))
return utterances[ix]
@Marginal(num_samples=100)
def speaker(world):
utterance = utterance_prior()
L = literal_listener(utterance)
pyro.sample("speaker_constraint", L, obs=world)
return utterance
def rsa_listener(utterance, qud):
world = world_prior(2, meaning(utterance))
S = speaker(world)
pyro.sample("listener_constraint", S, obs=utterance)
return qud(world)
def literal_listener_raw(utterance, qud):
m = meaning(utterance)
world = world_prior(3, m)
pyro.factor("world_constraint", heuristic(m(world)) * 1000)
return qud(world)
def main(args):
mll = Marginal(literal_listener_raw, num_samples=args.num_samples)
def is_any_qud(world):
return any(map(lambda obj: obj.nice, world))
print(mll("all blond people are nice", is_any_qud)())
def is_all_qud(world):
m = True
for obj in world:
if obj.blond:
if obj.nice:
m = m and True
else:
m = m and False
else:
m = m and True
return m
rsa = Marginal(rsa_listener, num_samples=args.num_samples)
print(rsa("some of the blond people are nice", is_all_qud)())
if __name__ == "__main__":
assert pyro.__version__.startswith('1.4.0')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-samples', default=10, type=int)
args = parser.parse_args()
main(args)