in E2E_TOD/db_ops.py [0:0]
def queryJsons(self, domain, constraints, exactly_match=True, return_name=False):
"""Returns the list of entities for a given domain
based on the annotation of the belief state
constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'}
"""
# query the db
if domain == 'taxi':
return [{'taxi_colors': random.choice(self.dbs[domain]['taxi_colors']),
'taxi_types': random.choice(self.dbs[domain]['taxi_types']),
'taxi_phone': [random.randint(1, 9) for _ in range(10)]}]
if domain == 'police':
return self.dbs['police']
if domain == 'hospital':
if constraints.get('department'):
for entry in self.dbs['hospital']:
if entry.get('department') == constraints.get('department'):
return [entry]
else:
return []
valid_cons = False
for v in constraints.values():
if v not in ["not mentioned", ""]:
valid_cons = True
if not valid_cons:
return []
match_result = []
if 'name' in constraints:
for db_ent in self.dbs[domain]:
if 'name' in db_ent:
cons = constraints['name']
dbn = db_ent['name']
if cons == dbn:
db_ent = db_ent if not return_name else db_ent['name']
match_result.append(db_ent)
return match_result
for db_ent in self.dbs[domain]:
match = True
for s, v in constraints.items():
if s == 'name':
continue
if s in ['people', 'stay'] or(domain == 'hotel' and s == 'day') or \
(domain == 'restaurant' and s in ['day', 'time']):
continue
skip_case = {"don't care":1, "do n't care":1, "dont care":1, "not mentioned":1, "dontcare":1, "":1}
if skip_case.get(v):
continue
if s not in db_ent:
# logging.warning('Searching warning: slot %s not in %s db'%(s, domain))
match = False
break
# v = 'guesthouse' if v == 'guest house' else v
# v = 'swimmingpool' if v == 'swimming pool' else v
v = 'yes' if v == 'free' else v
if s in ['arrive', 'leave']:
try:
h,m = v.split(':') # raise error if time value is not xx:xx format
v = int(h)*60+int(m)
except:
match = False
break
time = int(db_ent[s].split(':')[0])*60+int(db_ent[s].split(':')[1])
if s == 'arrive' and v>time:
match = False
if s == 'leave' and v<time:
match = False
else:
if exactly_match and v != db_ent[s]:
match = False
break
elif v not in db_ent[s]:
match = False
break
if match:
match_result.append(db_ent)
if not return_name:
return match_result
else:
if domain == 'train':
match_result = [e['id'] for e in match_result]
else:
match_result = [e['name'] for e in match_result]
return match_result