def compute_text_classification()

in utils/compute_score.py [0:0]


def compute_text_classification(file_path):
    with open(file_path, 'r', encoding='utf-8') as input_file:
        print("Reading\t" + file_path)
        sample_list = json.load(input_file)
        samples = pd.json_normalize(sample_list, 'outputs')
    samples = samples[samples['__raw.task'] == '金融文本分类']
    pattern = r"'([^']*)'"  # 提取单引号之间字符串
    puncs = [',', ',', '、']
    acc_list = []
    for _, row in samples.iterrows():
        label = row['__raw.output']
        if row['__raw.sub_task'] == 'ESG情感分类':
            if label[0] in row['response']:  # '正', '中', '负'
                acc = 1
            elif label == '负向' and row['response'] == 'Negative' or label == '正向' and row['response'] == 'Positive':
                acc = 1
            else:
                acc = 0
        elif row['__raw.sub_task'] == '合规政策审核':  # for baichuan2
            if any(s in row['response'] for s in ['不合规', '不符合']):
                if row['__raw.output'] == '否':
                    acc = 1
                else:
                    acc = 0
            else:
                if row['__raw.output'] == '是':
                    acc = 1
                else:
                    acc = 0
        else:
            pred = row['response'].strip('\n')
            if re.findall(pattern, pred):  # "'物料'", " ['经济绩效', '非直接经济影响']"
                pred = re.findall(pattern, pred)
                if len(pred) == 1 and label in pred:
                    acc = 1
                else:
                    acc = 0
            elif any(p in pred for p in puncs):  # "反腐败行为, 非虚假营销, 依法合规纳税, 反不正当竞争, 安全管理实践, 能源, 市场占有率, 排放"
                acc = 0
            else:  # " 依法合规纳税"
                if label in pred:
                    acc = 1
                else:
                    acc = 0
        acc_list.append(acc)
    return np.mean(acc_list)