in utils/compute_score.py [0:0]
def compute_text_classification(file_path):
with open(file_path, 'r', encoding='utf-8') as input_file:
print("Reading\t" + file_path)
sample_list = json.load(input_file)
samples = pd.json_normalize(sample_list, 'outputs')
samples = samples[samples['__raw.task'] == '金融文本分类']
pattern = r"'([^']*)'" # 提取单引号之间字符串
puncs = [',', ',', '、']
acc_list = []
for _, row in samples.iterrows():
label = row['__raw.output']
if row['__raw.sub_task'] == 'ESG情感分类':
if label[0] in row['response']: # '正', '中', '负'
acc = 1
elif label == '负向' and row['response'] == 'Negative' or label == '正向' and row['response'] == 'Positive':
acc = 1
else:
acc = 0
elif row['__raw.sub_task'] == '合规政策审核': # for baichuan2
if any(s in row['response'] for s in ['不合规', '不符合']):
if row['__raw.output'] == '否':
acc = 1
else:
acc = 0
else:
if row['__raw.output'] == '是':
acc = 1
else:
acc = 0
else:
pred = row['response'].strip('\n')
if re.findall(pattern, pred): # "'物料'", " ['经济绩效', '非直接经济影响']"
pred = re.findall(pattern, pred)
if len(pred) == 1 and label in pred:
acc = 1
else:
acc = 0
elif any(p in pred for p in puncs): # "反腐败行为, 非虚假营销, 依法合规纳税, 反不正当竞争, 安全管理实践, 能源, 市场占有率, 排放"
acc = 0
else: # " 依法合规纳税"
if label in pred:
acc = 1
else:
acc = 0
acc_list.append(acc)
return np.mean(acc_list)