utils/compute_score.py (761 lines of code) (raw):

import json import os import numpy as np import re import sys import pandas as pd import jieba import nltk from sklearn.metrics import f1_score as f1 from sklearn.metrics import accuracy_score from rouge import Rouge from sacrebleu.metrics import BLEU from comet import load_from_checkpoint from bert_score import score sys.setrecursionlimit(100000) # 例如这里设置为十万 def tokenizer_en(sent): return nltk.word_tokenize(sent) def acc_score(y_true, y_pred): res = accuracy_score(y_true, y_pred) print("ACC:", res) return res def f1_score(y_true, y_pred): res = f1(y_true, y_pred, average='weighted') print("Weighted F1:", res) return res def bleu_score(reference, candidate): bleu = BLEU() bleu_score = bleu.corpus_score(candidate, [reference]) bleu_1 = bleu_score.precisions[0] bleu_4 = bleu_score.score print("BLEU-1:", bleu_1) print("BLEU-4:", bleu_4) return bleu_1, bleu_4 def bert_score(reference, candidate): P, R, F1 = score(candidate, reference, model_type="bert-base-chinese", lang="zh", verbose=True) print('Bert Score: %s' % F1.mean()) return F1.mean() def rouge_score(reference, candidate): rouge_calculator = Rouge() # Rouge-1 [r, p, f] rouge_1 = rouge_calculator.get_scores(candidate, reference, avg=True)['rouge-1']['f'] # Rouge-2 rouge_2 = rouge_calculator.get_scores(candidate, reference, avg=True)['rouge-2']['f'] # Rouge-L rouge_l = rouge_calculator.get_scores(candidate, reference, avg=True)['rouge-l']['f'] print("Rouge-1:", rouge_1 * 100) print("Rouge-2:", rouge_2 * 100) print("Rouge-L:", rouge_l * 100) return rouge_1, rouge_2, rouge_l def compute_nmt_zh2en(file_path, model_path): references = [] candidates = [] comet_data = [] print("金融中英翻译") # model_path = download_model("Unbabel/XCOMET-XL") model = load_from_checkpoint(model_path) # XCOMET-XL/checkpoints/model.ckpt """ data = [ { "src": "10 到 15 分钟可以送到吗", "mt": "Can I receive my food in 10 to 15 minutes?", "ref": "Can it be delivered between 10 to 15 minutes?" },... ] """ with open(file_path, 'r') as input_file: print("Reading\t" + file_path) sample_list = json.load(input_file) for line in sample_list: if line['sub_task'] == '金融中英翻译': references.append(" ".join(tokenizer_en(line['output']))) candidates.append(" ".join(tokenizer_en(line['predict']))) src = line['instruction'].replace("你是一个金融行业专家,请将下面金融领域的中文内容翻译成准确、专业的英文。\n中文:", "").replace("英文:", "").strip() comet_data.append({ "src": src, "mt": line['predict'], "ref": line['output'] }) _, bleu4 = bleu_score(references, candidates) model_output = model.predict(comet_data, batch_size=32, gpus=1) print(model_output.system_score) # system-level score print('\n') return bleu4, model_output.system_score def compute_nmt_en2zh(file_path, model_path): references = [] candidates = [] comet_data = [] # model_path = download_model("Unbabel/XCOMET-XL") model = load_from_checkpoint(model_path) # XCOMET-XL/checkpoints/model.ckpt print("金融英中翻译") with open(file_path, 'r') as input_file: print("Reading\t" + file_path) sample_list = json.load(input_file) for line in sample_list: if line['sub_task'] == '金融英中翻译': references.append(" ".join(jieba.lcut(line['output']))) candidates.append(" ".join(jieba.lcut(line['predict']))) src = line['instruction'].replace( "你是一个金融行业专家,请将下面金融领域的英文内容翻译成准确、专业的中文。\n英文:", "").replace("中文:", "").strip() comet_data.append({ "src": src, "mt": line['predict'], "ref": line['output'] }) _, bleu4 = bleu_score(references, candidates) model_output = model.predict(comet_data, batch_size=32, gpus=1) print(model_output.system_score) # system-level score print('\n') return bleu4, model_output.system_score def compute_text_generation(file_path): references = [] candidates = [] sub_task_references = {} sub_task_candidates = {} with open(file_path, 'r', encoding='utf-8') as input_file: print("Reading\t" + file_path) sample_list = json.load(input_file) sample_list = pd.json_normalize(sample_list, 'outputs') for _, line in sample_list.iterrows(): if line['__raw.task'] == '金融文本生成': # compute all score if not line['response']: line['response'] = 'None' references.append(" ".join(jieba.lcut(line['__raw.output']))) candidates.append(" ".join(jieba.lcut(line['response']))) # compute sub_task score sub_task = line['__raw.sub_task'] if sub_task not in sub_task_references: sub_task_references[sub_task] = [] sub_task_candidates[sub_task] = [] sub_task_references[sub_task].append(" ".join(jieba.lcut(line['__raw.output']))) sub_task_candidates[sub_task].append(" ".join(jieba.lcut(line['response']))) # result print(f"task: 金融文本生成") _, _, rougel = rouge_score(references, candidates) bert = bert_score(references, candidates) print('\n') sub_task_tg, sub_task_tg_bert = {}, {} for sub_task, refs in sub_task_references.items(): print(f"Sub-task: {sub_task}") candidates = sub_task_candidates[sub_task] _, _, rouge_l = rouge_score(refs, candidates) bert_sub_task = bert_score(refs, candidates) sub_task_tg[sub_task] = rouge_l sub_task_tg_bert[sub_task] = bert_sub_task print('\n') return rougel, sub_task_tg, bert, sub_task_tg_bert def compute_finqa(file_path): references = [] candidates = [] sub_task_references = {} sub_task_candidates = {} with open(file_path, 'r', encoding='utf-8') as input_file: print("Reading\t" + file_path) sample_list = json.load(input_file) samples = pd.json_normalize(sample_list, 'outputs') for _, line in samples.iterrows(): if line['__raw.task'] == '金融咨询': # compute all score if not line['response']: line['response'] = 'None' references.append(" ".join(jieba.lcut(line['__raw.output']))) candidates.append(" ".join(jieba.lcut(line['response']))) # compute sub_task score sub_task = line['__raw.sub_task'] if sub_task not in sub_task_references: sub_task_references[sub_task] = [] sub_task_candidates[sub_task] = [] sub_task_references[sub_task].append(" ".join(jieba.lcut(line['__raw.output']))) sub_task_candidates[sub_task].append(" ".join(jieba.lcut(line['response']))) # result print(f"task: 金融咨询") _, _, rougel = rouge_score(references, candidates) bert = bert_score(references, candidates) print('\n') sub_task_qa, sub_task_bert = {}, {} for sub_task, refs in sub_task_references.items(): print(f"Sub-task: {sub_task}") candidates = sub_task_candidates[sub_task] _, _, rouge_l = rouge_score(refs, candidates) print('\n') return rougel, bert def compute_text_classification(file_path): with open(file_path, 'r', encoding='utf-8') as input_file: print("Reading\t" + file_path) sample_list = json.load(input_file) samples = pd.json_normalize(sample_list, 'outputs') samples = samples[samples['__raw.task'] == '金融文本分类'] pattern = r"'([^']*)'" # 提取单引号之间字符串 puncs = [',', ',', '、'] acc_list = [] for _, row in samples.iterrows(): label = row['__raw.output'] if row['__raw.sub_task'] == 'ESG情感分类': if label[0] in row['response']: # '正', '中', '负' acc = 1 elif label == '负向' and row['response'] == 'Negative' or label == '正向' and row['response'] == 'Positive': acc = 1 else: acc = 0 elif row['__raw.sub_task'] == '合规政策审核': # for baichuan2 if any(s in row['response'] for s in ['不合规', '不符合']): if row['__raw.output'] == '否': acc = 1 else: acc = 0 else: if row['__raw.output'] == '是': acc = 1 else: acc = 0 else: pred = row['response'].strip('\n') if re.findall(pattern, pred): # "'物料'", " ['经济绩效', '非直接经济影响']" pred = re.findall(pattern, pred) if len(pred) == 1 and label in pred: acc = 1 else: acc = 0 elif any(p in pred for p in puncs): # "反腐败行为, 非虚假营销, 依法合规纳税, 反不正当竞争, 安全管理实践, 能源, 市场占有率, 排放" acc = 0 else: # " 依法合规纳税" if label in pred: acc = 1 else: acc = 0 acc_list.append(acc) return np.mean(acc_list) def get_num_infos(target_str, target_char=None): '''统计target_char在target_str里面出现的个数''' if target_char is None: target_char = ['、', ':', ';', ',', '\n'] num_infos = [] for sep in target_char: num = target_str.count(sep) if num > 0: num_info = {'sep': sep, 'num': num} num_infos.append(num_info) num_infos = sorted(num_infos, key=lambda x: x['num']) return num_infos def pad_list(org_list, target_length, pad_val, left=False): '''对list进行padding操作''' if left: new_list = [pad_val] * (target_length - len(org_list)) + org_list else: new_list = org_list + [pad_val] * (target_length - len(org_list)) return new_list def decode_re_content(content): '''把映射的内容转换成map''' num_infos = get_num_infos(content) class_info_dict = None if len(num_infos) > 0: try: sep = num_infos[0]['sep'] ind_info_list = [strip(x) for x in content.split(sep)] if len(num_infos) > 1: sep = num_infos[1]['sep'] class_info_dict = dict([tuple(x.split(sep)) for x in ind_info_list]) else: if len(ind_info_list) % 2 != 0: ind_info_list = ind_info_list[1:] class_info_dict = dict([tuple(ind_info_list)]) except: ## 出现异常,使用第二个分隔符作为第一分隔符 try: sep = num_infos[1]['sep'] ind_info_list = content.split(sep) sep = num_infos[0]['sep'] class_info_dict = dict([tuple(x.split(sep)) for x in ind_info_list if sep in x]) except: content = [strip(x) for x in content.split('\n') if len(x) > 0] content = '\n'.join(content) reg_exp = '([\u4e00-\u9fa5]+)[:\s]+([正|负|中|无])[\u4e00-\u9fa5]+' find_res = re.findall(reg_exp, content) class_info_dict = {x[0]: x[1] for x in find_res} return class_info_dict def extract_re_expr(class_info_list, pattern, default_val=None): '''正则匹配''' find_results = [] for class_info in class_info_list: find_result = re.findall(pattern, class_info, re.DOTALL) if len(find_result) > 0: find_results.append(find_result[0]) else: find_results.append(default_val) return find_results def extract_re_exprs(class_info_list, patterns, default_val=None): '''多个正则匹配''' for pattern in patterns: find_result = extract_re_expr(class_info_list, pattern, default_val) if None not in find_result: break return find_result def strip(target_str, strip_lst='\n.;'): '''进行字符串的strip''' new_str = target_str.strip(strip_lst) new_str = re.sub('^\s+', '', new_str) new_str = re.sub('\s+$', '', new_str) return new_str def financial_extract_result_process(sub_df): '''处理模型输出的结果''' contents_adj = [] for i in range(sub_df.shape[0]): content = sub_df.iloc[i].predict.lstrip().rstrip() content = content.replace(':', ':').replace(',', ',').replace(' ', '').replace(';', ';').replace('。', '.').replace( ',\n', '\n') result = {} if 'value' in content: find_list = re.findall('(key\d):(.*)\s*(value\d?):?\s*(.*)', content) bad_res = False for find_res in find_list: if len(find_res) != 4: bad_res = True break result[strip(find_res[1])] = strip(find_res[3]) if bad_res: result = {} else: find_list = re.findall('(.*):(.*)', content) for find_res in find_list: if len(find_res) != 2: break if 'key' in find_res[0]: break result[strip(find_res[0])] = strip(find_res[1]) contents_adj.append(result) return contents_adj def cal_f1(label, pred): label_set = set(['|'.join(x) for x in label.items()]) pred_set = set(['|'.join(x) for x in pred.items()]) comment_set = label_set & pred_set precision = len(comment_set) / (len(pred_set) + 1e-12) recall = len(comment_set) / (len(label_set) + 1e-12) if precision + recall == 0: f1 = 0 else: f1 = 2 * (precision * recall) / (precision + recall) * 100 return f1 def output_adj(x): '''由于样本中gpt的数据比较异常,是字符串,这里对其进行转换''' if isinstance(x, str): x = eval(x) return x def industry_classification_result_process_single(content): content = content.replace(':', ':').replace(',', ',').replace(' ', '').replace(';', ';').replace('。', '.').replace( ',\n', '\n') match_result = re.match('行业:(.*)情感分类:(.*)', content, re.DOTALL) class_info_dict = {} if match_result is not None: ind_info, class_info = match_result.groups() ind_info = ind_info.strip('\n.;') class_info = class_info.strip('\n.;') num_infos = get_num_infos(class_info) ## 情感分类为多个项目 if len(num_infos) > 0: sep = num_infos[0]['sep'] class_info_list = class_info.split(sep) ## 情感分类为二阶 try: if len(num_infos) > 1: sec_sep = num_infos[1]['sep'] class_info_dict = dict([tuple(x.split(':')) for x in class_info_list]) else: ## 情感分类为一阶 num_infos = get_num_infos(ind_info) if len(num_infos) > 0: sep = num_infos[0]['sep'] ind_info_list = ind_info.split(sep) class_info_list_adj = extract_re_exprs(class_info_list, ['.*[为是](.*)', '.*((.*))']) if None not in class_info_list_adj: class_info_dict = dict(zip(ind_info_list, class_info_list_adj)) else: max_len = max(len(ind_info_list), len(class_info_list)) class_info_list = pad_list(class_info_list, max_len, '无') class_info_dict = dict(zip(ind_info_list, class_info_list)) except: find_res = re.findall('(\w+)[:\s]+([正|负|中|无])\w+', class_info) class_info_dict = {x[0]: x[1] for x in find_res} else: num_infos = get_num_infos(ind_info) if len(num_infos) > 0: sep = num_infos[0]['sep'] ind_info_list = ind_info.split(sep) class_info_dict = dict(zip(ind_info_list, [class_info] * len(ind_info_list))) else: class_info_dict = {ind_info: class_info} pass else: if '抽取结果' in content: match_result = re.match('抽取结果.*?:(.*)', content, re.DOTALL) if match_result is not None: match_result = strip(match_result.groups()[0]) match_result1 = re.match('.*情感分类结果:(.*)', match_result, re.DOTALL) if match_result1 is not None: class_info = match_result1.groups()[0] # num_infos = [x for x in sorted(get_num_infos(ind_info), key=lambda x: x['num']) if x['num'] > 0] num_infos = get_num_infos(class_info) if len(num_infos) > 0: sep = num_infos[0]['sep'] class_info_list = class_info.split(sep) match_result1 = re.match('(.*)情感分类结果', match_result, re.DOTALL) if match_result1 is not None: ind_info = match_result1.groups()[0].strip() # num_infos = [x for x in sorted(get_num_infos(ind_info), key=lambda x: x['num']) if x['num'] > 0] num_infos = get_num_infos(class_info) if len(num_infos) > 0: sep = num_infos[0]['sep'] ind_info_list = ind_info.split(sep) if len(class_info_list) == len(ind_info_list): class_info_dict = dict(zip(ind_info_list, class_info_list)) else: max_len = max(len(ind_info_list), len(class_info_list)) ind_info_list = pad_list(ind_info_list, max_len, '无') class_info_list = pad_list(class_info_list, max_len, '无') class_info_dict = dict(zip(ind_info_list, class_info_list)) else: try: num_infos = get_num_infos(match_result) if len(num_infos) > 0: sep = num_infos[0]['sep'] ind_info_list = match_result.split(sep) if len(num_infos) > 1: sep = num_infos[1]['sep'] class_info_dict = dict([tuple(x.split(sep)) for x in ind_info_list]) else: class_info_dict = dict(zip(ind_info_list, ['无'] * len(ind_info_list))) else: class_info_dict = {match_result: '无'} except: find_res = re.findall('(\w+)[:\s]+([正|负|中|无])\w+', match_result) class_info_dict = {x[0]: x[1] for x in find_res} else: find_res = re.findall('(\w+)[:\s]+([正|负|中|无])\w+', content) class_info_dict = {x[0]: x[1] for x in find_res} else: find_result = [(strip(x[0]), strip(x[1])) for x in re.findall('^行业:(.*)情感分类:(.*)', content, re.DOTALL)] if len(find_result) > 0: ind_info = find_result[0][0] class_info = find_result[0][1] num_infos = get_num_infos(ind_info) if len(num_infos) > 0: sep = num_infos[0]['sep'] ind_info_list = ind_info.split(sep) num_infos = get_num_infos(class_info) if num_infos > 0: sep = num_infos[0]['sep'] class_info_list = class_info.split(sep) else: class_info_list = [class_info] * len(ind_info_list) else: ## 预测结果如下:xxx:xxx,xxx:xxx find_result = [strip(x) for x in re.findall('^.*结果.*?:(.*)', content, re.DOTALL)] if len(find_result) > 0: class_info_dict = decode_re_content(find_result[0]) else: class_info_dict = decode_re_content(content) if class_info_dict is None or len(class_info_dict) <= 0: find_res = re.findall('(\w+)[:\s]*([正|负|中|无])\w+', content) class_info_dict = {x[0]: x[1] for x in find_res} return class_info_dict def financial_extract_result_process_single(content): content = content.replace(':', ':').replace(',', ',').replace(' ', '').replace(';', ';').replace('。', '.').replace( ',\n', '\n') result = {} if 'value' in content: find_list = re.findall('(key\d):(.*)\s*(value\d?):?\s*(.*)', content) bad_res = False for find_res in find_list: if len(find_res) != 4: bad_res = True break result[strip(find_res[1])] = strip(find_res[3]) pass if bad_res: result = {} else: find_list = re.findall('(.*):(.*)', content) bad_res = False for find_res in find_list: if len(find_res) != 2: bad_res = True break if 'key' in find_res[0]: bad_res = True break result[strip(find_res[0])] = strip(find_res[1]) return result def industry_classification_result_process(sub_df): '''行业情感信息抽取的数据处理''' contents_adj = [] for i in range(sub_df.shape[0]): print(f'=============={i}==============') content = sub_df.iloc[i].predict.lstrip().rstrip() content = content.replace(':', ':').replace(',', ',').replace(' ', '').replace(';', ';').replace('。', '.').replace( ',\n', '\n') match_result = re.match('行业:(.*)情感分类:(.*)', content, re.DOTALL) class_info_dict = {} if match_result is not None: ind_info, class_info = match_result.groups() ind_info = ind_info.strip('\n.;') class_info = class_info.strip('\n.;') num_infos = get_num_infos(class_info) ## 情感分类为多个项目 if len(num_infos) > 0: sep = num_infos[0]['sep'] class_info_list = class_info.split(sep) ## 情感分类为二阶 try: if len(num_infos) > 1: sec_sep = num_infos[1]['sep'] class_info_dict = dict([tuple(x.split(':')) for x in class_info_list]) else: ## 情感分类为一阶 num_infos = get_num_infos(ind_info) if len(num_infos) > 0: sep = num_infos[0]['sep'] ind_info_list = ind_info.split(sep) class_info_list_adj = extract_re_exprs(class_info_list, ['.*[为是](.*)', '.*((.*))']) if None not in class_info_list_adj: class_info_dict = dict(zip(ind_info_list, class_info_list_adj)) else: max_len = max(len(ind_info_list), len(class_info_list)) class_info_list = pad_list(class_info_list, max_len, '无') class_info_dict = dict(zip(ind_info_list, class_info_list)) except: find_res = re.findall('(\w+)[:\s]+([正|负|中|无])\w+', class_info) class_info_dict = {x[0]: x[1] for x in find_res} pass else: num_infos = get_num_infos(ind_info) if len(num_infos) > 0: sep = num_infos[0]['sep'] ind_info_list = ind_info.split(sep) class_info_dict = dict(zip(ind_info_list, [class_info] * len(ind_info_list))) else: class_info_dict = {ind_info: class_info} pass else: if '抽取结果' in content: match_result = re.match('抽取结果.*?:(.*)', content, re.DOTALL) if match_result is not None: match_result = strip(match_result.groups()[0]) match_result1 = re.match('.*情感分类结果:(.*)', match_result, re.DOTALL) if match_result1 is not None: class_info = match_result1.groups()[0] # num_infos = [x for x in sorted(get_num_infos(ind_info), key=lambda x: x['num']) if x['num'] > 0] num_infos = get_num_infos(class_info) if len(num_infos) > 0: sep = num_infos[0]['sep'] class_info_list = class_info.split(sep) match_result1 = re.match('(.*)情感分类结果', match_result, re.DOTALL) if match_result1 is not None: ind_info = match_result1.groups()[0].strip() # num_infos = [x for x in sorted(get_num_infos(ind_info), key=lambda x: x['num']) if x['num'] > 0] num_infos = get_num_infos(class_info) if len(num_infos) > 0: sep = num_infos[0]['sep'] ind_info_list = ind_info.split(sep) if len(class_info_list) == len(ind_info_list): class_info_dict = dict(zip(ind_info_list, class_info_list)) else: max_len = max(len(ind_info_list), len(class_info_list)) ind_info_list = pad_list(ind_info_list, max_len, '无') class_info_list = pad_list(class_info_list, max_len, '无') class_info_dict = dict(zip(ind_info_list, class_info_list)) else: try: num_infos = get_num_infos(match_result) if len(num_infos) > 0: sep = num_infos[0]['sep'] ind_info_list = match_result.split(sep) if len(num_infos) > 1: sep = num_infos[1]['sep'] class_info_dict = dict([tuple(x.split(sep)) for x in ind_info_list]) else: class_info_dict = dict(zip(ind_info_list, ['无'] * len(ind_info_list))) pass else: class_info_dict = {match_result: '无'} pass except: find_res = re.findall('(\w+)[:\s]+([正|负|中|无])\w+', match_result) class_info_dict = {x[0]: x[1] for x in find_res} pass pass pass else: find_res = re.findall('(\w+)[:\s]+([正|负|中|无])\w+', content) class_info_dict = {x[0]: x[1] for x in find_res} pass else: find_result = [(strip(x[0]), strip(x[1])) for x in re.findall('^行业:(.*)情感分类:(.*)', content, re.DOTALL)] if len(find_result) > 0: ind_info = find_result[0][0] class_info = find_result[0][1] num_infos = get_num_infos(ind_info) if len(num_infos) > 0: sep = num_infos[0]['sep'] ind_info_list = ind_info.split(sep) num_infos = get_num_infos(class_info) if num_infos > 0: sep = num_infos[0]['sep'] class_info_list = class_info.split(sep) else: class_info_list = [class_info] * len(ind_info_list) else: ## 预测结果如下:xxx:xxx,xxx:xxx find_result = [strip(x) for x in re.findall('^.*结果.*?:(.*)', content, re.DOTALL)] if len(find_result) > 0: class_info_dict = decode_re_content(find_result[0]) pass else: class_info_dict = decode_re_content(content) pass if class_info_dict is None or len(class_info_dict) <= 0: find_res = re.findall('(\w+)[:\s]*([正|负|中|无])\w+', content) class_info_dict = {x[0]: x[1] for x in find_res} if len(class_info_dict) <= 0: print(content) contents_adj.append(class_info_dict) return contents_adj def process_industry_classification_output(output_adj): '''处理行业情感信息抽取的标签,以便进行匹配计算f1''' new_output = {} for k, v in output_adj.items(): find_res = re.findall('(正|负|中|无).*', v) if len(find_res) == 0: continue new_output[k] = find_res[0] return new_output def industry_classification_result_process1(sub_df): contents_adj = [] for i in range(sub_df.shape[0]): content = sub_df.iloc[i].predict.lstrip().rstrip() content = [strip(x) for x in content.split('\n') if len(x) > 0] content = '\n'.join(content) content = content.replace(':', ':').replace(',', ',').replace(';', ';'). \ replace('。', '.').replace(',\n', '\n') find_res1 = re.findall('(\S+)[:|\s]+(\S+)', content) find_res2 = re.findall('(\S+)\s+(\S+)', content) if len(find_res1) > len(find_res2): find_res = find_res1 else: find_res = find_res2 result = {x[0]: x[1] for x in find_res if len(re.findall('正|负|中|无', x[1])) > 0} contents_adj.append(result) return contents_adj def cal_financial_extract_score(data): '''金融事件抽取得分''' content = data['response'] output = data['__raw.output'] content_adj = financial_extract_result_process_single(content) output = output_adj(output) output = (lambda a: {x['role']: x['argument'] for x in a})(output) f1 = cal_f1(output, content_adj) return f1 def cal_industry_classification_score(data): '''金融行业情感分类打分''' content = data['response'] output = data['__raw.output'] contents_adj = industry_classification_result_process_single(content) contents_adj = process_industry_classification_output(contents_adj) output = (lambda x: process_industry_classification_output(eval(x)))(output) f1 = cal_f1(output, contents_adj) return f1 def compute_extraction(file_path): with open(file_path, 'r', encoding='utf-8') as input_file: print("Reading\t" + file_path) sample_list = json.load(input_file) samples = pd.json_normalize(sample_list, 'outputs') samples = samples[samples['__raw.task'] == '金融文本抽取'] f1_list = [] for _, row in samples.iterrows(): label = row['__raw.output'] pred = row['response'] if row['__raw.sub_task'] == '金融事件主体抽取': if label in pred: acc = 1 else: acc = 0 f1 = 2 * acc * 1 / (acc + 1) elif row['__raw.sub_task'] == '金融事件因果关系抽取': entities = {'原因类型': 'reason_type', '原因产品': 'reason_product', '原因地区': 'reason_region', '原因行业': 'reason_industry', '结果类型': 'result_type', '结果产品': 'result_product', '结果地区': 'result_region', '结果行业': 'result_industry'} not_mentioned = r'未提及|无|无结果|无明确|None' acc, recall = 0, 0 for ent in entities: pattern = f'{ent}(.*?)(\n|$)' # pattern = f'(?<={ent}\nvalue[1-8]: ).*' # for gpt4 try: _label = label[0][entities[ent]] # 实体对应答案 except TypeError: _label = eval(label)[0][entities[ent]] if re.findall(pattern, pred): # 有对应抽取结果 _pred = ''.join(s for s in re.findall(pattern, pred)[0] if s not in [',', ',', '。', '.', '、', ';', ';', ':', ':', ' ']) recall += 1 # print(_pred) if re.findall(not_mentioned, _pred): # 模型回答为空 if _label == '': acc += 1 elif _label in _pred: acc += 1 else: pass else: if _pred == _label: acc += 1 else: pass else: # 无抽取结果,acc、recall均不得分 pass acc /= 8 recall /= 8 if acc + recall == 0: f1 = 0 else: f1 = 2 * acc * recall / (acc + recall) elif row['__raw.sub_task'] == '金融事件抽取': f1 = cal_financial_extract_score(row) / 100 elif row['__raw.sub_task'] == '行业情感信息抽取': f1 = cal_industry_classification_score(row) / 100 else: f1 = 0 f1_list.append(f1) samples['f1'] = f1_list return samples['f1'].mean(), samples[['__raw.sub_task', 'f1']].groupby('__raw.sub_task').mean().to_dict()['f1'] def main(model, path): files = os.listdir(f'{path}') qa_bert_list = [] tg_bert_list, sub_task_tg_bert_list = [], [] qa_list, tg_list, e2z_bleu_list, e2z_comet_list, z2e_bleu_list, z2e_comet_list, comet_list, acc_list, f1_list = [], [], [], [], [], [], [], [], [] sub_task_qa_list, sub_task_tg_list, sub_task_tc_list, sub_task_re_list = [], [], [], [] for f in files: if model in f: print('Model: %s' % model) file_path = f'{path}/{f}' # QA rouge_l, qa_bert = compute_finqa(file_path) qa_list.append(rouge_l) qa_bert_list.append(qa_bert) # TG rouge_l_tg, sub_task_tg, tg_bert, tg_sub_task_bert = compute_text_generation(file_path) tg_list.append(rouge_l_tg) sub_task_tg_list.append(sub_task_tg) tg_bert_list.append(tg_bert) sub_task_tg_bert_list.append(tg_sub_task_bert) # MT-e2zh bleu, comet = compute_nmt_en2zh(file_path) e2z_bleu_list.append(bleu) e2z_comet_list.append(comet) # MT-zh2e bleu, comet = compute_nmt_zh2en(file_path) z2e_bleu_list.append(bleu) z2e_comet_list.append(comet) # TC acc, sub_task_acc = compute_text_classification(file_path) acc_list.append(acc) sub_task_tc_list.append(sub_task_acc) # RE f1, sub_task_re = compute_extraction(file_path) f1_list.append(f1) sub_task_re_list.append(sub_task_re) # 总分类 print('QA mean: %s' % np.mean(qa_list), '\n') print('QA Std: %s' % np.std(qa_list), '\n') print('QA bert mean: %s' % np.mean(qa_bert_list), '\n') print('QA bert Std: %s' % np.std(qa_bert_list), '\n') print('TG mean: %s' % np.mean(tg_list), '\n') print('TG Std: %s' % np.std(tg_list), '\n') print('TG bert mean: %s' % np.mean(tg_bert_list), '\n') print('TG bert Std: %s' % np.std(tg_bert_list), '\n') print('EN2CH mean: %s' % np.mean(e2z_bleu_list), '\n') print('EN2CH Std: %s' % np.std(e2z_bleu_list), '\n') print('EN2CH comet mean: %s' % np.mean(e2z_comet_list), '\n') print('EN2CH comet Std: %s' % np.std(e2z_comet_list), '\n') print('CH2EN mean: %s' % np.mean(z2e_bleu_list), '\n') print('CH2EN Std: %s' % np.std(z2e_bleu_list), '\n') print('CH2EN comet mean: %s' % np.mean(z2e_comet_list), '\n') print('CH2EN comet Std: %s' % np.std(z2e_comet_list), '\n') print('TG mean: %s' % np.mean(tg_list), '\n') print('TG Std: %s' % np.std(tg_list), '\n') ## 子任务分类 # TG tg_df = pd.DataFrame(sub_task_tg_list) print('TG sub task mean: %s' % tg_df.mean()) print('TG sub task std: %s' % tg_df.std()) tg_bert_df = pd.DataFrame(sub_task_tg_bert_list) print('TG sub task bert mean: %s' % tg_bert_df.mean()) print('TG sub task bert std: %s' % tg_bert_df.std()) # TC tc_df = pd.DataFrame(sub_task_tc_list) print('TG sub task mean: %s' % tc_df.mean()) print('TG sub task std: %s' % tc_df.std()) # RE re_df = pd.DataFrame(sub_task_re_list) print('TG sub task mean: %s' % re_df.mean()) print('TG sub task std: %s' % re_df.std())