examples/rsa/semantic_parsing.py (241 lines of code) (raw):

# Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 """ Combining models of RSA pragmatics and CCG-based compositional semantics. Taken from: http://dippl.org/examples/zSemanticPragmaticMashup.html """ import torch import argparse import collections import pyro import pyro.distributions as dist from search_inference import HashingMarginal, BestFirstSearch, memoize torch.set_default_dtype(torch.float64) def Marginal(fn=None, **kwargs): if fn is None: return lambda _fn: Marginal(_fn, **kwargs) return memoize(lambda *args: HashingMarginal(BestFirstSearch(fn, **kwargs).run(*args))) ################################################################### # Lexical semantics ################################################################### def flip(name, p): return pyro.sample(name, dist.Bernoulli(p)).item() == 1 # hashable state obj = collections.namedtuple("Obj", ["name", "blond", "nice", "tall"]) def Obj(name): return obj(name=name, blond=flip(name + "_blond", 0.5), nice=flip(name + "_nice", 0.5), tall=flip(name + "_tall", 0.5)) class Meaning: def sem(self, world): raise NotImplementedError __call__ = sem def syn(self): raise NotImplementedError class UndefinedMeaning(Meaning): def sem(self, world): return None def syn(self): return "" class BlondMeaning(Meaning): def sem(self, world): return lambda obj: obj.blond def syn(self): return {"dir": "L", "int": "NP", "out": "S"} class NiceMeaning(Meaning): def sem(self, world): return lambda obj: obj.nice def syn(self): return {"dir": "L", "int": "NP", "out": "S"} class TallMeaning(Meaning): def sem(self, world): return lambda obj: obj.tall def syn(self): return {"dir": "L", "int": "NP", "out": "S"} class BobMeaning(Meaning): def sem(self, world): return list(filter(lambda obj: obj.name == "Bob", world))[0] def syn(self): return "NP" class SomeMeaning(Meaning): def sem(self, world): def f1(P): def f2(Q): return len(list(filter(Q, filter(P, world)))) > 0 return f2 return f1 def syn(self): return { "dir": "R", "int": {"dir": "L", "int": "NP", "out": "S"}, "out": { "dir": "R", "int": {"dir": "L", "int": "NP", "out": "S"}, "out": "S" } } class AllMeaning(Meaning): def sem(self, world): def f1(P): def f2(Q): return len(list(filter(lambda *args: not Q(*args), filter(P, world)))) == 0 return f2 return f1 def syn(self): return { "dir": "R", "int": {"dir": "L", "int": "NP", "out": "S"}, "out": { "dir": "R", "int": {"dir": "L", "int": "NP", "out": "S"}, "out": "S" } } class NoneMeaning(Meaning): def sem(self, world): def f1(P): def f2(Q): return len(list(filter(Q, filter(P, world)))) == 0 return f2 return f1 def syn(self): return { "dir": "R", "int": {"dir": "L", "int": "NP", "out": "S"}, "out": { "dir": "R", "int": {"dir": "L", "int": "NP", "out": "S"}, "out": "S" } } class CompoundMeaning(Meaning): def __init__(self, sem, syn): self._sem = sem self._syn = syn def sem(self, world): return self._sem(world) def syn(self): return self._syn ################################################################### # Compositional semantics ################################################################### def heuristic(is_good): if is_good: return torch.tensor(0.) return torch.tensor(-100.0) def world_prior(num_objs, meaning_fn): prev_factor = torch.tensor(0.) world = [] for i in range(num_objs): world.append(Obj("obj_{}".format(i))) new_factor = heuristic(meaning_fn(world)) pyro.factor("factor_{}".format(i), new_factor - prev_factor) prev_factor = new_factor pyro.factor("factor_{}".format(num_objs), prev_factor * -1) return tuple(world) def lexical_meaning(word): meanings = { "blond": BlondMeaning, "nice": NiceMeaning, "Bob": BobMeaning, "some": SomeMeaning, "none": NoneMeaning, "all": AllMeaning } if word in meanings: return meanings[word]() else: return UndefinedMeaning() def apply_world_passing(f, a): return lambda w: f(w)(a(w)) def syntax_match(s, t): if "dir" in s and "dir" in t: return (s["dir"] and t["dir"]) and \ syntax_match(s["int"], t["int"]) and \ syntax_match(s["out"], t["out"]) else: return s == t def can_apply(meanings): inds = [] for i, meaning in enumerate(meanings): applies = False s = meaning.syn() if "dir" in s: if s["dir"] == "L": applies = syntax_match(s["int"], meanings[i-1].syn()) elif s["dir"] == "R": applies = syntax_match(s["int"], meanings[i+1].syn()) else: applies = False if applies: inds.append(i) return inds def combine_meaning(meanings, c): possible_combos = can_apply(meanings) N = len(possible_combos) ix = pyro.sample("ix_{}".format(c), dist.Categorical(torch.ones(N) / N)) i = possible_combos[ix] s = meanings[i].syn() if s["dir"] == "L": f = meanings[i].sem a = meanings[i-1].sem new_meaning = CompoundMeaning(sem=apply_world_passing(f, a), syn=s["out"]) return meanings[0:i-1] + [new_meaning] + meanings[i+1:] if s["dir"] == "R": f = meanings[i].sem a = meanings[i+1].sem new_meaning = CompoundMeaning(sem=apply_world_passing(f, a), syn=s["out"]) return meanings[0:i] + [new_meaning] + meanings[i+2:] def combine_meanings(meanings, c=0): if len(meanings) == 1: return meanings[0].sem else: return combine_meanings(combine_meaning(meanings, c), c=c+1) def meaning(utterance): defined = filter(lambda w: "" != w.syn(), list(map(lexical_meaning, utterance.split(" ")))) return combine_meanings(list(defined)) @Marginal(num_samples=100) def literal_listener(utterance): m = meaning(utterance) world = world_prior(2, m) pyro.factor("world_constraint", heuristic(m(world)) * 1000) return world def utterance_prior(): utterances = ["some of the blond people are nice", "all of the blond people are nice", "none of the blond people are nice"] ix = pyro.sample("utterance", dist.Categorical(torch.ones(3) / 3.0)) return utterances[ix] @Marginal(num_samples=100) def speaker(world): utterance = utterance_prior() L = literal_listener(utterance) pyro.sample("speaker_constraint", L, obs=world) return utterance def rsa_listener(utterance, qud): world = world_prior(2, meaning(utterance)) S = speaker(world) pyro.sample("listener_constraint", S, obs=utterance) return qud(world) def literal_listener_raw(utterance, qud): m = meaning(utterance) world = world_prior(3, m) pyro.factor("world_constraint", heuristic(m(world)) * 1000) return qud(world) def main(args): mll = Marginal(literal_listener_raw, num_samples=args.num_samples) def is_any_qud(world): return any(map(lambda obj: obj.nice, world)) print(mll("all blond people are nice", is_any_qud)()) def is_all_qud(world): m = True for obj in world: if obj.blond: if obj.nice: m = m and True else: m = m and False else: m = m and True return m rsa = Marginal(rsa_listener, num_samples=args.num_samples) print(rsa("some of the blond people are nice", is_all_qud)()) if __name__ == "__main__": assert pyro.__version__.startswith('1.4.0') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-samples', default=10, type=int) args = parser.parse_args() main(args)