in ttw/data_loader.py [0:0]
def __init__(self, data_dir, set, goldstandard_features=True, resnet_features=False, fasttext_features=False, T=2):
self.data_dir = data_dir
self.map = Map(data_dir, neighborhoods, include_empty_corners=True)
self.T = T
self.act_dict = ActionAgnosticDictionary()
self.configs = json.load(open(os.path.join(data_dir, 'configurations.{}.json'.format(set))))
self.feature_loaders = dict()
self.data = {}
if fasttext_features:
textfeatures = dict()
for n in neighborhoods:
textfeatures[n] = json.load(open(os.path.join(data_dir, n, "text.json")))
self.feature_loaders['fasttext'] = FasttextFeatures(textfeatures, os.path.join(data_dir, 'wiki.en.bin'))
self.data['fasttext'] = list()
if resnet_features:
self.feature_loaders['resnet'] = ResnetFeatures(os.path.join(data_dir, 'resnetfeat.json'))
self.data['fasttext'] = list()
if goldstandard_features:
self.feature_loaders['goldstandard'] = GoldstandardFeatures(self.map)
self.data['goldstandard'] = list()
assert (len(self.feature_loaders) > 0)
self.data['actions'] = list()
self.data['landmarks'] = list()
self.data['target'] = list()
action_list = ['UP', 'DOWN', 'LEFT', 'RIGHT']
action_set = [action_list] * self.T
all_possible_actions = list(itertools.product(*action_set))
for config in self.configs:
for a in all_possible_actions:
neighborhood = config['neighborhood']
target_loc = config['target_location']
boundaries = config['boundaries']
obs = {k: list() for k in self.feature_loaders.keys()}
actions = list()
loc = copy.deepcopy(config['target_location'])
for p in range(self.T + 1):
for k, feature_loader in self.feature_loaders.items():
obs[k].append(feature_loader.get(neighborhood, loc))
if p != self.T:
sampled_act = random.choice(action_list)
sampled_enc = self.act_dict.encode(sampled_act)
actions.append(sampled_enc)
loc = step_agnostic(sampled_act, loc, boundaries)
if self.T == 0:
actions.append(0)
for k in self.feature_loaders.keys():
self.data[k].append(obs[k])
self.data['actions'].append(actions)
landmarks, label_index = self.map.get_landmarks(neighborhood, boundaries, target_loc)
self.data['landmarks'].append(landmarks)
self.data['target'].append(label_index)