def get_log_prob_calibration_thresholds()

in agent/affordance_extractors/lm_affordance_extractor.py [0:0]


    def get_log_prob_calibration_thresholds(self):
        # Can previously computed thresholds be found?
        threshold_file_header = "# Delete this line to recalculate the thresholds on the next run.\n"
        if os.path.isfile(CALIBRATION_THRESHOLDS_PATH):
            thresholds_file = open(CALIBRATION_THRESHOLDS_PATH, 'r')
            lines = thresholds_file.readlines()
            thresholds_file.close()
            if len(lines) > 0:
                if lines[0] == threshold_file_header:
                    # Yes. Read them from the file.
                    line = lines[-1]
                    fields = line[:-1].split('\t')
                    assert fields[3] == "unknown actions"
                    self.unknown_action_calibration_thresholds = (float(fields[0]), float(fields[1]), float(fields[2]))
                    for line in lines[1:-1]:
                        fields = line[:-1].split('\t')
                        affordable_attribute = self.affordable_attributes_by_name[fields[3]]
                        best_thresh_lo = float(fields[0])
                        best_thresh_md = float(fields[1])
                        best_thresh_hi = float(fields[2])
                        affordable_attribute.thresholds = (best_thresh_lo, best_thresh_md, best_thresh_hi)
                    return

        # No, the thresholds cannot be found. So recompute them.
        print("Recomputing the log-prob calibration thesholds.")
        thresholds_file = open(CALIBRATION_THRESHOLDS_PATH, 'w')
        thresholds_file.write(threshold_file_header)

        # Load the target probs.
        target_nouns = []
        target_attribute_scores = open(TARG_ATTR_SCORES_PATH, 'r')
        lines = target_attribute_scores.readlines()
        target_attribute_scores.close()
        for line in lines[1:]:
            fields = line[:-1].split(',')
            target_nouns.append(fields[0])
            scores = fields[1:]
            for i, score in enumerate(scores):
                self.affordable_attributes[i].target_probs.append(int(score) / 8.)

        # Tune each attribute separately.
        mean_lowest_error = 0.
        num_attributes_examined = 0
        for affordable_attribute in self.affordable_attributes:
            # if attrib.name != "enemy":
            #     continue
            num_attributes_examined += 1
            scores = []

            # Collect the raw scores for the target nouns.
            for i_noun in range(len(target_nouns)):
                scores.append(self.conditional_log_prob_of_attribute_given_noun(affordable_attribute.attribute_name, target_nouns[i_noun]))

            # Find the best middle threshold, for pinning to 0.5.
            lowest_error = 100.
            best_thresh_md = -1
            for i_thresh in range(1000):
                thresh = 0. - i_thresh * 0.01
                mean_squared_error = 0.
                for i_noun in range(len(target_nouns)):
                    est_prob, target_prob, squared_error = self.calc_error_with_1_threshold(scores[i_noun], thresh, i_noun, affordable_attribute)
                    mean_squared_error += squared_error
                mean_squared_error /= len(target_nouns)
                if mean_squared_error < lowest_error:
                    lowest_error = mean_squared_error
                    best_thresh_md = thresh

            lowest_error = 100.
            best_thresh_hi = -1
            best_thresh_lo = -1
            thresh_lo = best_thresh_md
            thresh_md = best_thresh_md

            # Find the best hi threshold, for clipping to 1.
            for i in range(1000):
                thresh_hi = thresh_md + i * 0.01
                mean_squared_error = 0.
                for i_noun in range(len(target_nouns)):
                    est_prob, target_prob, squared_error = self.calc_error_with_3_thresholds(scores[i_noun], thresh_lo, best_thresh_md, thresh_hi, i_noun, affordable_attribute)
                    mean_squared_error += squared_error
                mean_squared_error /= len(target_nouns)
                if mean_squared_error < lowest_error:
                    lowest_error = mean_squared_error
                    best_thresh_hi = thresh_hi

            # Find the best lo threshold, for clipping to 0.
            for i in range(1000):
                thresh_lo = thresh_md - i * 0.01
                mean_squared_error = 0.
                for i_noun in range(len(target_nouns)):
                    est_prob, target_prob, squared_error = self.calc_error_with_3_thresholds(scores[i_noun], thresh_lo, best_thresh_md, thresh_hi, i_noun, affordable_attribute)
                    mean_squared_error += squared_error
                mean_squared_error /= len(target_nouns)
                if mean_squared_error < lowest_error:
                    lowest_error = mean_squared_error
                    best_thresh_lo = thresh_lo

            thresholds_file.write("{:7.3f}\t{:7.3f}\t{:7.3f}\t{}\n".format(best_thresh_lo, best_thresh_md, best_thresh_hi, affordable_attribute.attribute_name))
            mean_lowest_error += lowest_error
            affordable_attribute.thresholds = (best_thresh_lo, best_thresh_md, best_thresh_hi)

        # Now compute the calibration thresholds for unknown actions.

        # Load the target command scores. (The manual labels.)
        target_command_scores_file = open(TARG_CMD_SCORES_PATH, 'r')
        lines = target_command_scores_file.readlines()
        target_command_scores_file.close()
        target_command_scores = {}
        for line in lines:
            fields = line[:-1].split(',')
            target_command_scores[fields[0]] = int(fields[1])

        # Gather the numbers to be used for tuning.
        x_y_pairs = []
        for noun in target_nouns:
            scored_actions = self.extract_unknown_actions_with_log_probs(noun)
            for scored_action in scored_actions:
                command = scored_action[0] + ' ' + noun
                if command in target_command_scores.keys():
                    lp = max(-25.0, scored_action[1])
                    prob_times_8 = target_command_scores[command]
                    x_y_pairs.append((lp, prob_times_8 / 8.))
        for x_y_pair in x_y_pairs:
            print("{}\t{}".format(x_y_pair[0], x_y_pair[1]))

        # Find the best middle log-prob threshold, for pinning output probs to 0.5.
        lowest_error = 100.
        best_thresh_md = -1
        num_pairs = len(x_y_pairs)
        for i_thresh in range(100):
            thresh = 0. - i_thresh * 0.1
            mean_squared_error = 0.
            for x_y_pair in x_y_pairs:
                x = x_y_pair[0]
                y = x_y_pair[1]
                if x > thresh:
                    y_est = 1.
                else:
                    y_est = 0.
                error = y_est - y
                squared_error = error * error
                mean_squared_error += squared_error
            mean_squared_error /= num_pairs
            if mean_squared_error < lowest_error:
                lowest_error = mean_squared_error
                best_thresh_md = thresh

        # Find the best hi log-prob threshold, for clipping output probs to 1.
        lowest_error = 100.
        best_thresh_hi = -1
        thresh_md = best_thresh_md
        for i in range(100):
            thresh_hi = thresh_md + i * 0.1
            mean_squared_error = 0.
            count = 0
            for x_y_pair in x_y_pairs:
                x = x_y_pair[0]
                y = x_y_pair[1]
                if x > thresh_md:
                    count += 1
                    if x >= thresh_hi:
                        y_est = 1.
                    else:
                        y_est = 0.5 + 0.5 * (x - thresh_md) / (thresh_hi - thresh_md)
                    error = y_est - y
                    squared_error = error * error
                    mean_squared_error += squared_error
            mean_squared_error /= count
            if mean_squared_error < lowest_error:
                lowest_error = mean_squared_error
                best_thresh_hi = thresh_hi

        # Find the best lo log-prob threshold, for clipping output probs to 0.
        lowest_error = 100.
        best_thresh_lo = -1
        for i in range(100):
            thresh_lo = thresh_md - i * 0.1
            mean_squared_error = 0.
            count = 0
            for x_y_pair in x_y_pairs:
                x = x_y_pair[0]
                y = x_y_pair[1]
                if x < thresh_md:
                    count += 1
                    if x >= thresh_lo:
                        y_est = 0.5 * (x - thresh_lo) / (thresh_md - thresh_lo)
                    else:
                        y_est = 0.
                    error = y_est - y
                    squared_error = error * error
                    mean_squared_error += squared_error
            mean_squared_error /= count
            if mean_squared_error < lowest_error:
                lowest_error = mean_squared_error
                best_thresh_lo = thresh_lo

        self.unknown_action_calibration_thresholds = (best_thresh_lo, best_thresh_md, best_thresh_hi)
        thresholds_file.write("{:7.3f}\t{:7.3f}\t{:7.3f}\tunknown actions\n".format(best_thresh_lo, best_thresh_md, best_thresh_hi))
        thresholds_file.close()