in agent/affordance_extractors/lm_affordance_extractor.py [0:0]
def get_log_prob_calibration_thresholds(self):
# Can previously computed thresholds be found?
threshold_file_header = "# Delete this line to recalculate the thresholds on the next run.\n"
if os.path.isfile(CALIBRATION_THRESHOLDS_PATH):
thresholds_file = open(CALIBRATION_THRESHOLDS_PATH, 'r')
lines = thresholds_file.readlines()
thresholds_file.close()
if len(lines) > 0:
if lines[0] == threshold_file_header:
# Yes. Read them from the file.
line = lines[-1]
fields = line[:-1].split('\t')
assert fields[3] == "unknown actions"
self.unknown_action_calibration_thresholds = (float(fields[0]), float(fields[1]), float(fields[2]))
for line in lines[1:-1]:
fields = line[:-1].split('\t')
affordable_attribute = self.affordable_attributes_by_name[fields[3]]
best_thresh_lo = float(fields[0])
best_thresh_md = float(fields[1])
best_thresh_hi = float(fields[2])
affordable_attribute.thresholds = (best_thresh_lo, best_thresh_md, best_thresh_hi)
return
# No, the thresholds cannot be found. So recompute them.
print("Recomputing the log-prob calibration thesholds.")
thresholds_file = open(CALIBRATION_THRESHOLDS_PATH, 'w')
thresholds_file.write(threshold_file_header)
# Load the target probs.
target_nouns = []
target_attribute_scores = open(TARG_ATTR_SCORES_PATH, 'r')
lines = target_attribute_scores.readlines()
target_attribute_scores.close()
for line in lines[1:]:
fields = line[:-1].split(',')
target_nouns.append(fields[0])
scores = fields[1:]
for i, score in enumerate(scores):
self.affordable_attributes[i].target_probs.append(int(score) / 8.)
# Tune each attribute separately.
mean_lowest_error = 0.
num_attributes_examined = 0
for affordable_attribute in self.affordable_attributes:
# if attrib.name != "enemy":
# continue
num_attributes_examined += 1
scores = []
# Collect the raw scores for the target nouns.
for i_noun in range(len(target_nouns)):
scores.append(self.conditional_log_prob_of_attribute_given_noun(affordable_attribute.attribute_name, target_nouns[i_noun]))
# Find the best middle threshold, for pinning to 0.5.
lowest_error = 100.
best_thresh_md = -1
for i_thresh in range(1000):
thresh = 0. - i_thresh * 0.01
mean_squared_error = 0.
for i_noun in range(len(target_nouns)):
est_prob, target_prob, squared_error = self.calc_error_with_1_threshold(scores[i_noun], thresh, i_noun, affordable_attribute)
mean_squared_error += squared_error
mean_squared_error /= len(target_nouns)
if mean_squared_error < lowest_error:
lowest_error = mean_squared_error
best_thresh_md = thresh
lowest_error = 100.
best_thresh_hi = -1
best_thresh_lo = -1
thresh_lo = best_thresh_md
thresh_md = best_thresh_md
# Find the best hi threshold, for clipping to 1.
for i in range(1000):
thresh_hi = thresh_md + i * 0.01
mean_squared_error = 0.
for i_noun in range(len(target_nouns)):
est_prob, target_prob, squared_error = self.calc_error_with_3_thresholds(scores[i_noun], thresh_lo, best_thresh_md, thresh_hi, i_noun, affordable_attribute)
mean_squared_error += squared_error
mean_squared_error /= len(target_nouns)
if mean_squared_error < lowest_error:
lowest_error = mean_squared_error
best_thresh_hi = thresh_hi
# Find the best lo threshold, for clipping to 0.
for i in range(1000):
thresh_lo = thresh_md - i * 0.01
mean_squared_error = 0.
for i_noun in range(len(target_nouns)):
est_prob, target_prob, squared_error = self.calc_error_with_3_thresholds(scores[i_noun], thresh_lo, best_thresh_md, thresh_hi, i_noun, affordable_attribute)
mean_squared_error += squared_error
mean_squared_error /= len(target_nouns)
if mean_squared_error < lowest_error:
lowest_error = mean_squared_error
best_thresh_lo = thresh_lo
thresholds_file.write("{:7.3f}\t{:7.3f}\t{:7.3f}\t{}\n".format(best_thresh_lo, best_thresh_md, best_thresh_hi, affordable_attribute.attribute_name))
mean_lowest_error += lowest_error
affordable_attribute.thresholds = (best_thresh_lo, best_thresh_md, best_thresh_hi)
# Now compute the calibration thresholds for unknown actions.
# Load the target command scores. (The manual labels.)
target_command_scores_file = open(TARG_CMD_SCORES_PATH, 'r')
lines = target_command_scores_file.readlines()
target_command_scores_file.close()
target_command_scores = {}
for line in lines:
fields = line[:-1].split(',')
target_command_scores[fields[0]] = int(fields[1])
# Gather the numbers to be used for tuning.
x_y_pairs = []
for noun in target_nouns:
scored_actions = self.extract_unknown_actions_with_log_probs(noun)
for scored_action in scored_actions:
command = scored_action[0] + ' ' + noun
if command in target_command_scores.keys():
lp = max(-25.0, scored_action[1])
prob_times_8 = target_command_scores[command]
x_y_pairs.append((lp, prob_times_8 / 8.))
for x_y_pair in x_y_pairs:
print("{}\t{}".format(x_y_pair[0], x_y_pair[1]))
# Find the best middle log-prob threshold, for pinning output probs to 0.5.
lowest_error = 100.
best_thresh_md = -1
num_pairs = len(x_y_pairs)
for i_thresh in range(100):
thresh = 0. - i_thresh * 0.1
mean_squared_error = 0.
for x_y_pair in x_y_pairs:
x = x_y_pair[0]
y = x_y_pair[1]
if x > thresh:
y_est = 1.
else:
y_est = 0.
error = y_est - y
squared_error = error * error
mean_squared_error += squared_error
mean_squared_error /= num_pairs
if mean_squared_error < lowest_error:
lowest_error = mean_squared_error
best_thresh_md = thresh
# Find the best hi log-prob threshold, for clipping output probs to 1.
lowest_error = 100.
best_thresh_hi = -1
thresh_md = best_thresh_md
for i in range(100):
thresh_hi = thresh_md + i * 0.1
mean_squared_error = 0.
count = 0
for x_y_pair in x_y_pairs:
x = x_y_pair[0]
y = x_y_pair[1]
if x > thresh_md:
count += 1
if x >= thresh_hi:
y_est = 1.
else:
y_est = 0.5 + 0.5 * (x - thresh_md) / (thresh_hi - thresh_md)
error = y_est - y
squared_error = error * error
mean_squared_error += squared_error
mean_squared_error /= count
if mean_squared_error < lowest_error:
lowest_error = mean_squared_error
best_thresh_hi = thresh_hi
# Find the best lo log-prob threshold, for clipping output probs to 0.
lowest_error = 100.
best_thresh_lo = -1
for i in range(100):
thresh_lo = thresh_md - i * 0.1
mean_squared_error = 0.
count = 0
for x_y_pair in x_y_pairs:
x = x_y_pair[0]
y = x_y_pair[1]
if x < thresh_md:
count += 1
if x >= thresh_lo:
y_est = 0.5 * (x - thresh_lo) / (thresh_md - thresh_lo)
else:
y_est = 0.
error = y_est - y
squared_error = error * error
mean_squared_error += squared_error
mean_squared_error /= count
if mean_squared_error < lowest_error:
lowest_error = mean_squared_error
best_thresh_lo = thresh_lo
self.unknown_action_calibration_thresholds = (best_thresh_lo, best_thresh_md, best_thresh_hi)
thresholds_file.write("{:7.3f}\t{:7.3f}\t{:7.3f}\tunknown actions\n".format(best_thresh_lo, best_thresh_md, best_thresh_hi))
thresholds_file.close()