Unverified Commit c8b06ad5 authored by myhloli's avatar myhloli Committed by GitHub

Merge branch 'master' into master

parents 88f5b932 2783bb39
...@@ -9,7 +9,7 @@ on: ...@@ -9,7 +9,7 @@ on:
paths-ignore: paths-ignore:
- "cmds/**" - "cmds/**"
- "**.md" - "**.md"
workflow_dispatch:
jobs: jobs:
pdf-test: pdf-test:
runs-on: pdf runs-on: pdf
...@@ -18,14 +18,16 @@ jobs: ...@@ -18,14 +18,16 @@ jobs:
fail-fast: true fail-fast: true
steps: steps:
- name: config-net
run: |
export http_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
export https_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
- name: PDF benchmark - name: PDF benchmark
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
fetch-depth: 2 fetch-depth: 2
- name: check-requirements - name: check-requirements
run: | run: |
export http_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
export https_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
changed_files=$(git diff --name-only -r HEAD~1 HEAD) changed_files=$(git diff --name-only -r HEAD~1 HEAD)
echo $changed_files echo $changed_files
if [[ $changed_files =~ "requirements.txt" ]]; then if [[ $changed_files =~ "requirements.txt" ]]; then
...@@ -36,4 +38,12 @@ jobs: ...@@ -36,4 +38,12 @@ jobs:
- name: benchmark - name: benchmark
run: | run: |
echo "start test" echo "start test"
cd tools && python ocr_badcase.py pdf_json_label_0306.json ocr_dataset.json json_files.zip output.json cd tools && python ocr_badcase.py pdf_json_label_0306.json ocr_dataset.json json_files.zip badcase.json overall.json base_data.json
notify_to_feishu:
if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }}
needs: [pdf-test]
runs-on: [pdf]
steps:
- name: notify
run: |
curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'${{ secrets.USER_ID }}'"}]]}}}}' ${{ secrets.WEBHOOK_URL }}
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
name: PDF name: update-base
on: on:
release: push:
types: [published] tags:
- '*released'
workflow_dispatch:
jobs: jobs:
pdf-test: pdf-test:
runs-on: pdf runs-on: pdf
...@@ -15,6 +16,7 @@ jobs: ...@@ -15,6 +16,7 @@ jobs:
steps: steps:
- name: update-base - name: update-base
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: start-update
run: | run: |
python update_base.py echo "start test"
\ No newline at end of file
...@@ -116,6 +116,7 @@ if __name__ == '__main__': ...@@ -116,6 +116,7 @@ if __name__ == '__main__':
pdf_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.pdf" pdf_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.pdf"
json_file_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.json" json_file_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.json"
# ocr_local_parse(pdf_path, json_file_path) # ocr_local_parse(pdf_path, json_file_path)
book_name = "数学新星网/edu_00001236" book_name = "数学新星网/edu_00001236"
ocr_online_parse(book_name) ocr_online_parse(book_name)
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class AbsReaderWriter(ABC): class AbsReaderWriter(ABC):
""" """
同时支持二进制和文本读写的抽象类 同时支持二进制和文本读写的抽象类
TODO
""" """
MODE_TXT = "text"
MODE_BIN = "binary"
def __init__(self, parent_path):
# 初始化代码可以在这里添加,如果需要的话
self.parent_path = parent_path # 对于本地目录是父目录,对于s3是会写到这个apth下。
@abstractmethod
def read(self, path: str, mode="text"):
"""
无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def read(self, path: str): def write(self, content: str, path: str, mode=MODE_TXT):
pass """
无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def write(self, path: str, content: str): def read_jsonl(self, path: str, byte_start=0, byte_end=None, encoding='utf-8'):
pass """
无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
"""
raise NotImplementedError
\ No newline at end of file
import os
from magic_pdf.io.AbsReaderWriter import AbsReaderWriter
from loguru import logger
class DiskReaderWriter(AbsReaderWriter):
def __init__(self, parent_path, encoding='utf-8'):
self.path = parent_path
self.encoding = encoding
def read(self, mode="text"):
if not os.path.exists(self.path):
logger.error(f"文件 {self.path} 不存在")
raise Exception(f"文件 {self.path} 不存在")
if mode == "text":
with open(self.path, 'r', encoding = self.encoding) as f:
return f.read()
elif mode == "binary":
with open(self.path, 'rb') as f:
return f.read()
else:
raise ValueError("Invalid mode. Use 'text' or 'binary'.")
def write(self, data, mode="text"):
if mode == "text":
with open(self.path, 'w', encoding=self.encoding) as f:
f.write(data)
logger.info(f"内容已成功写入 {self.path}")
elif mode == "binary":
with open(self.path, 'wb') as f:
f.write(data)
logger.info(f"内容已成功写入 {self.path}")
else:
raise ValueError("Invalid mode. Use 'text' or 'binary'.")
# 使用示例
if __name__ == "__main__":
file_path = "example.txt"
drw = DiskReaderWriter(file_path)
# 写入内容到文件
drw.write(b"Hello, World!", mode="binary")
# 从文件读取内容
content = drw.read()
if content:
logger.info(f"从 {file_path} 读取的内容: {content}")
from magic_pdf.io import AbsReaderWriter from magic_pdf.io.AbsReaderWriter import AbsReaderWriter
from magic_pdf.libs.commons import parse_aws_param, parse_bucket_key
import boto3
from loguru import logger
from boto3.s3.transfer import TransferConfig
from botocore.config import Config
class DiskReaderWriter(AbsReaderWriter): class S3ReaderWriter(AbsReaderWriter):
def __init__(self, parent_path, encoding='utf-8'): def __init__(self, ak: str, sk: str, endpoint_url: str, addressing_style: str):
self.path = parent_path self.client = self._get_client(ak, sk, endpoint_url, addressing_style)
self.encoding = encoding
def read(self): def _get_client(self, ak: str, sk: str, endpoint_url: str, addressing_style: str):
with open(self.path, 'rb') as f: s3_client = boto3.client(
return f.read() service_name="s3",
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(s3={"addressing_style": addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'}),
)
return s3_client
def read(self, s3_path, mode="text", encoding="utf-8"):
bucket_name, bucket_key = parse_bucket_key(s3_path)
res = self.client.get_object(Bucket=bucket_name, Key=bucket_key)
body = res["Body"].read()
if mode == 'text':
data = body.decode(encoding) # Decode bytes to text
elif mode == 'binary':
data = body
else:
raise ValueError("Invalid mode. Use 'text' or 'binary'.")
return data
def write(self, data): def write(self, data, s3_path, mode="text", encoding="utf-8"):
with open(self.path, 'wb') as f: if mode == 'text':
f.write(data) body = data.encode(encoding) # Encode text data as bytes
elif mode == 'binary':
\ No newline at end of file body = data
else:
raise ValueError("Invalid mode. Use 'text' or 'binary'.")
bucket_name, bucket_key = parse_bucket_key(s3_path)
self.client.put_object(Body=body, Bucket=bucket_name, Key=bucket_key)
logger.info(f"内容已写入 {s3_path} ")
if __name__ == "__main__":
# Config the connection info
ak = ""
sk = ""
endpoint_url = ""
addressing_style = ""
# Create an S3ReaderWriter object
s3_reader_writer = S3ReaderWriter(ak, sk, endpoint_url, addressing_style)
# Write text data to S3
text_data = "This is some text data"
s3_reader_writer.write(data=text_data, s3_path = "s3://bucket_name/ebook/test/test.json", mode='text')
# Read text data from S3
text_data_read = s3_reader_writer.read(s3_path = "s3://bucket_name/ebook/test/test.json", mode='text')
logger.info(f"Read text data from S3: {text_data_read}")
# Write binary data to S3
binary_data = b"This is some binary data"
s3_reader_writer.write(data=text_data, s3_path = "s3://bucket_name/ebook/test/test2.json", mode='binary')
# Read binary data from S3
binary_data_read = s3_reader_writer.read(s3_path = "s3://bucket_name/ebook/test/test2.json", mode='binary')
logger.info(f"Read binary data from S3: {binary_data_read}")
\ No newline at end of file
...@@ -183,11 +183,31 @@ def __valign_lines(blocks, layout_bboxes): ...@@ -183,11 +183,31 @@ def __valign_lines(blocks, layout_bboxes):
return new_layout_bboxes return new_layout_bboxes
def __align_text_in_layout(blocks, layout_bboxes):
"""
由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。
"""
for layout in layout_bboxes:
lb = layout['layout_bbox']
blocks_in_layoutbox = [b for b in blocks if is_in_layout(b['bbox'], lb)]
if len(blocks_in_layoutbox)==0:
continue
for block in blocks_in_layoutbox:
for line in block['lines']:
x0, x1 = line['bbox'][0], line['bbox'][2]
if x0 < lb[0]:
line['bbox'][0] = lb[0]
if x1 > lb[2]:
line['bbox'][2] = lb[2]
def __common_pre_proc(blocks, layout_bboxes): def __common_pre_proc(blocks, layout_bboxes):
""" """
不分语言的,对文本进行预处理 不分语言的,对文本进行预处理
""" """
#__add_line_period(blocks, layout_bboxes) #__add_line_period(blocks, layout_bboxes)
__align_text_in_layout(blocks, layout_bboxes)
aligned_layout_bboxes = __valign_lines(blocks, layout_bboxes) aligned_layout_bboxes = __valign_lines(blocks, layout_bboxes)
return aligned_layout_bboxes return aligned_layout_bboxes
...@@ -233,7 +253,6 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_ ...@@ -233,7 +253,6 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_
layout_paras = [] layout_paras = []
right_tail_distance = 1.5 * char_avg_len right_tail_distance = 1.5 * char_avg_len
for lines in lines_group: for lines in lines_group:
paras = [] paras = []
total_lines = len(lines) total_lines = len(lines)
...@@ -575,8 +594,8 @@ def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang): ...@@ -575,8 +594,8 @@ def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
return connected_layout_paras, page_list_info return connected_layout_paras, page_list_info
def para_split(pdf_info_dict, debug_mode, lang="en"): def para_split(pdf_info_dict, debug_mode, lang="en"):
""" """
根据line和layout情况进行分段 根据line和layout情况进行分段
......
{
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1_score": 1.0,
"pdf间的平均编辑距离": 133.10256410256412,
"pdf间的平均bleu": 0.28838311595434046,
"分段准确率": 0.07220216606498195,
"行内公式准确率": {
"accuracy": 0.004835727492533068,
"precision": 0.008790072388831437,
"recall": 0.010634970284641852,
"f1_score": 0.009624911535739562
},
"行内公式编辑距离": 1.6176470588235294,
"行内公式bleu": 0.17154724654721457,
"行间公式准确率": {
"accuracy": 0.08490566037735849,
"precision": 0.1836734693877551,
"recall": 0.13636363636363635,
"f1_score": 0.1565217391304348
},
"行间公式编辑距离": 113.22222222222223,
"行间公式bleu": 0.2531053359913409,
"丢弃文本准确率": {
"accuracy": 0.00035398230088495576,
"precision": 0.0006389776357827476,
"recall": 0.0007930214115781126,
"f1_score": 0.0007077140835102619
},
"丢弃文本标签准确率": {
"color_background_header_txt_block": {
"precision": 0.0,
"recall": 0.0,
"f1-score": 0.0,
"support": 41.0
},
"header": {
"precision": 0.0,
"recall": 0.0,
"f1-score": 0.0,
"support": 4.0
},
"footnote": {
"precision": 1.0,
"recall": 0.009708737864077669,
"f1-score": 0.019230769230769232,
"support": 103.0
},
"on-table": {
"precision": 0.0,
"recall": 0.0,
"f1-score": 0.0,
"support": 665.0
},
"rotate": {
"precision": 0.0,
"recall": 0.0,
"f1-score": 0.0,
"support": 63.0
},
"on-image": {
"precision": 0.0,
"recall": 0.0,
"f1-score": 0.0,
"support": 380.0
},
"micro avg": {
"precision": 1.0,
"recall": 0.0007961783439490446,
"f1-score": 0.0015910898965791568,
"support": 1256.0
}
},
"丢弃图片准确率": {
"accuracy": 0.0,
"precision": 0.0,
"recall": 0.0,
"f1_score": 0.0
},
"丢弃表格准确率": {
"accuracy": 0.0,
"precision": 0.0,
"recall": 0.0,
"f1_score": 0.0
}
}
\ No newline at end of file
...@@ -413,7 +413,9 @@ def bbox_match_indicator_dropped_text_block(test_dropped_text_bboxs, standard_dr ...@@ -413,7 +413,9 @@ def bbox_match_indicator_dropped_text_block(test_dropped_text_bboxs, standard_dr
# 计算和返回标签匹配指标 # 计算和返回标签匹配指标
text_block_tag_report = classification_report(y_true=standard_tag, y_pred=test_tag, labels=list(set(standard_tag) - {'None'}), output_dict=True, zero_division=0) text_block_tag_report = classification_report(y_true=standard_tag, y_pred=test_tag, labels=list(set(standard_tag) - {'None'}), output_dict=True, zero_division=0)
del text_block_tag_report["macro avg"]
del text_block_tag_report["weighted avg"]
return text_block_report, text_block_tag_report return text_block_report, text_block_tag_report
def handle_multi_deletion(test_page, test_page_tag, test_page_bbox, standard_page_tag, standard_page_bbox): def handle_multi_deletion(test_page, test_page_tag, test_page_bbox, standard_page_tag, standard_page_bbox):
...@@ -500,6 +502,142 @@ def merge_json_data(json_test_df, json_standard_df): ...@@ -500,6 +502,142 @@ def merge_json_data(json_test_df, json_standard_df):
return inner_merge, standard_exist, test_exist return inner_merge, standard_exist, test_exist
def consolidate_data(test_data, standard_data, key_path):
"""
Consolidates data from test and standard datasets based on the provided key path.
:param test_data: Dictionary containing the test dataset.
:param standard_data: Dictionary containing the standard dataset.
:param key_path: List of keys leading to the desired data within the dictionaries.
:return: List containing all items from both test and standard data at the specified key path.
"""
# Initialize an empty list to hold the consolidated data
overall_data_standard = []
overall_data_test = []
# Helper function to recursively navigate through the dictionaries based on the key path
def extract_data(source_data, keys):
for key in keys[:-1]:
source_data = source_data.get(key, {})
return source_data.get(keys[-1], [])
for data in extract_data(standard_data, key_path):
# 假设每个 single_table_tags 已经是一个列表,直接将它的元素添加到总列表中
overall_data_standard.extend(data)
for data in extract_data(test_data, key_path):
overall_data_test.extend(data)
# Extract and extend the overall data list with items from both test and standard datasets
return overall_data_standard, overall_data_test
def overall_calculate_metrics(inner_merge, json_test, json_standard,standard_exist, test_exist):
process_data_standard = process_equations_and_blocks(json_standard, is_standard=True)
process_data_test = process_equations_and_blocks(json_test, is_standard=False)
overall_report = {}
overall_report['accuracy']=metrics.accuracy_score(standard_exist,test_exist)
overall_report['precision']=metrics.precision_score(standard_exist,test_exist)
overall_report['recall']=metrics.recall_score(standard_exist,test_exist)
overall_report['f1_score']=metrics.f1_score(standard_exist,test_exist)
overall_report
test_para_text = np.asarray(process_data_test['para_texts'], dtype=object)[inner_merge['pass_label'] == 'yes']
standard_para_text = np.asarray(process_data_standard['para_texts'], dtype=object)[inner_merge['pass_label'] == 'yes']
ids_yes = inner_merge['id'][inner_merge['pass_label'] == 'yes'].tolist()
pdf_dis = {}
pdf_bleu = {}
# 对pass_label为'yes'的数据计算编辑距离和BLEU得分
for idx,(a, b, id) in enumerate(zip(test_para_text, standard_para_text, ids_yes)):
a1 = ''.join(a)
b1 = ''.join(b)
pdf_dis[id] = Levenshtein_Distance(a, b)
pdf_bleu[id] = sentence_bleu([a1], b1)
overall_report['pdf间的平均编辑距离'] = np.mean(list(pdf_dis.values()))
overall_report['pdf间的平均bleu'] = np.mean(list(pdf_bleu.values()))
# Consolidate equations bboxs inline
overall_equations_bboxs_inline_standard,overall_equations_bboxs_inline_test = consolidate_data(process_data_test, process_data_standard, ["equations_bboxs", "inline"])
# # Consolidate equations texts inline
overall_equations_texts_inline_standard,overall_equations_texts_inline_test = consolidate_data(process_data_test, process_data_standard, ["equations_texts", "inline"])
# Consolidate equations bboxs interline
overall_equations_bboxs_interline_standard,overall_equations_bboxs_interline_test = consolidate_data(process_data_test, process_data_standard, ["equations_bboxs", "interline"])
# Consolidate equations texts interline
overall_equations_texts_interline_standard,overall_equations_texts_interline_test = consolidate_data(process_data_test, process_data_standard, ["equations_texts", "interline"])
overall_dropped_bboxs_text_standard,overall_dropped_bboxs_text_test = consolidate_data(process_data_test, process_data_standard, ["dropped_bboxs","text"])
overall_dropped_tags_text_standard,overall_dropped_tags_text_test = consolidate_data(process_data_test, process_data_standard, ["dropped_tags","text"])
overall_dropped_bboxs_image_standard,overall_dropped_bboxs_image_test = consolidate_data(process_data_test, process_data_standard, ["dropped_bboxs","image"])
overall_dropped_bboxs_table_standard,overall_dropped_bboxs_table_test=consolidate_data(process_data_test, process_data_standard,["dropped_bboxs","table"])
para_nums_test = process_data_test['para_nums']
para_nums_standard=process_data_standard['para_nums']
overall_para_nums_standard = [item for sublist in para_nums_standard for item in (sublist if isinstance(sublist, list) else [sublist])]
overall_para_nums_test = [item for sublist in para_nums_test for item in (sublist if isinstance(sublist, list) else [sublist])]
test_para_num=np.array(overall_para_nums_test)
standard_para_num=np.array(overall_para_nums_standard)
acc_para=np.mean(test_para_num==standard_para_num)
overall_report['分段准确率'] = acc_para
# 行内公式准确率和编辑距离、bleu
overall_report['行内公式准确率'] = bbox_match_indicator_general(
overall_equations_bboxs_inline_test,
overall_equations_bboxs_inline_standard)
overall_report['行内公式编辑距离'], overall_report['行内公式bleu'] = equations_indicator(
overall_equations_bboxs_inline_test,
overall_equations_bboxs_inline_standard,
overall_equations_texts_inline_test,
overall_equations_texts_inline_standard)
# 行间公式准确率和编辑距离、bleu
overall_report['行间公式准确率'] = bbox_match_indicator_general(
overall_equations_bboxs_interline_test,
overall_equations_bboxs_interline_standard)
overall_report['行间公式编辑距离'], overall_report['行间公式bleu'] = equations_indicator(
overall_equations_bboxs_interline_test,
overall_equations_bboxs_interline_standard,
overall_equations_texts_interline_test,
overall_equations_texts_interline_standard)
# 丢弃文本准确率,丢弃文本标签准确率
overall_report['丢弃文本准确率'], overall_report['丢弃文本标签准确率'] = bbox_match_indicator_dropped_text_block(
overall_dropped_bboxs_text_test,
overall_dropped_bboxs_text_standard,
overall_dropped_tags_text_standard,
overall_dropped_tags_text_test)
# 丢弃图片准确率
overall_report['丢弃图片准确率'] = bbox_match_indicator_general(
overall_dropped_bboxs_image_test,
overall_dropped_bboxs_image_standard)
# 丢弃表格准确率
overall_report['丢弃表格准确率'] = bbox_match_indicator_general(
overall_dropped_bboxs_table_test,
overall_dropped_bboxs_table_standard)
return overall_report
def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origin): def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origin):
...@@ -602,21 +740,27 @@ def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origi ...@@ -602,21 +740,27 @@ def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origi
return result_dict return result_dict
def save_results(result_dict, output_path):
def save_results(result_dict,overall_report_dict,badcase_path,overall_path,):
""" """
将结果字典保存为JSON文件至指定路径。 将结果字典保存为JSON文件至指定路径。
参数: 参数:
- result_dict: 包含计算结果的字典。 - result_dict: 包含计算结果的字典。
- output_path: 结果文件的保存路径,包括文件名。 - overall_path: 结果文件的保存路径,包括文件名。
""" """
# 打开指定的文件以写入 # 打开指定的文件以写入
with open(output_path, 'w', encoding='utf-8') as f: with open(badcase_path, 'w', encoding='utf-8') as f:
# 将结果字典转换为JSON格式并写入文件 # 将结果字典转换为JSON格式并写入文件
json.dump(result_dict, f, ensure_ascii=False, indent=4) json.dump(result_dict, f, ensure_ascii=False, indent=4)
print(f"计算结果已经保存到文件:{output_path}") print(f"计算结果已经保存到文件:{badcase_path}")
with open(overall_path, 'w', encoding='utf-8') as f:
# 将结果字典转换为JSON格式并写入文件
json.dump(overall_report_dict, f, ensure_ascii=False, indent=4)
print(f"计算结果已经保存到文件:{overall_path}")
def upload_to_s3(file_path, bucket_name, s3_file_name,AWS_ACCESS_KEY,AWS_SECRET_KEY,END_POINT_URL): def upload_to_s3(file_path, bucket_name, s3_file_name,AWS_ACCESS_KEY,AWS_SECRET_KEY,END_POINT_URL):
""" """
...@@ -634,7 +778,7 @@ def upload_to_s3(file_path, bucket_name, s3_file_name,AWS_ACCESS_KEY,AWS_SECRET_ ...@@ -634,7 +778,7 @@ def upload_to_s3(file_path, bucket_name, s3_file_name,AWS_ACCESS_KEY,AWS_SECRET_
except ClientError as e: except ClientError as e:
print(f"上传文件时发生错误:{e}") print(f"上传文件时发生错误:{e}")
def generate_output_filename(base_path): def generate_filename(badcase_path,overall_path):
""" """
生成带有当前时间戳的输出文件名。 生成带有当前时间戳的输出文件名。
...@@ -647,13 +791,24 @@ def generate_output_filename(base_path): ...@@ -647,13 +791,24 @@ def generate_output_filename(base_path):
# 获取当前时间并格式化为字符串 # 获取当前时间并格式化为字符串
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
# 构建并返回完整的输出文件名 # 构建并返回完整的输出文件名
return f"{base_path}_{current_time}.json" return f"{badcase_path}_{current_time}.json",f"{overall_path}_{current_time}.json"
def compare_edit_distance(json_file, overall_report):
with open(json_file, 'r',encoding='utf-8') as f:
json_data = json.load(f)
json_edit_distance = json_data['pdf间的平均编辑距离']
if overall_report['pdf间的平均编辑距离'] >= json_edit_distance:
return 0
else:
return 1
def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=None, s3_file_name=None, AWS_ACCESS_KEY=None, AWS_SECRET_KEY=None, END_POINT_URL=None): def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_data_path,s3_bucket_name=None, s3_file_name=None, AWS_ACCESS_KEY=None, AWS_SECRET_KEY=None, END_POINT_URL=None):
""" """
主函数,执行整个评估流程。 主函数,执行整个评估流程。
...@@ -661,7 +816,8 @@ def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=No ...@@ -661,7 +816,8 @@ def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=No
- standard_file: 标准文件的路径。 - standard_file: 标准文件的路径。
- test_file: 测试文件的路径。 - test_file: 测试文件的路径。
- zip_file: 压缩包的路径的路径。 - zip_file: 压缩包的路径的路径。
- base_output_path: 结果文件的基础路径和文件名前缀。 - badcase_path: badcase文件的基础路径和文件名前缀。
- overall_path: overall文件的基础路径和文件名前缀。
- s3_bucket_name: S3桶名称(可选)。 - s3_bucket_name: S3桶名称(可选)。
- s3_file_name: S3上的文件名(可选)。 - s3_file_name: S3上的文件名(可选)。
- AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL: AWS访问凭证和端点URL(可选)。 - AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL: AWS访问凭证和端点URL(可选)。
...@@ -675,21 +831,29 @@ def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=No ...@@ -675,21 +831,29 @@ def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=No
# 合并JSON数据 # 合并JSON数据
inner_merge, standard_exist, test_exist = merge_json_data(json_test_origin, json_standard_origin) inner_merge, standard_exist, test_exist = merge_json_data(json_test_origin, json_standard_origin)
#计算总体指标
overall_report_dict=overall_calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'],standard_exist, test_exist)
# 计算指标 # 计算指标
result_dict = calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'], json_standard_origin) result_dict = calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'], json_standard_origin)
# 生成带时间戳的输出文件名 # 生成带时间戳的输出文件名
output_file = generate_output_filename(base_output_path) badcase_file,overall_file = generate_filename(badcase_path,overall_path)
# 保存结果到JSON文件 # 保存结果到JSON文件
save_results(result_dict, output_file) save_results(result_dict, overall_report_dict,badcase_file,overall_file)
result=compare_edit_distance(base_data_path, overall_report_dict)
print(result)
assert result == 1
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="主函数,执行整个评估流程。") parser = argparse.ArgumentParser(description="主函数,执行整个评估流程。")
parser.add_argument('standard_file', type=str, help='标准文件的路径。') parser.add_argument('standard_file', type=str, help='标准文件的路径。')
parser.add_argument('test_file', type=str, help='测试文件的路径。') parser.add_argument('test_file', type=str, help='测试文件的路径。')
parser.add_argument('zip_file', type=str, help='压缩包的路径。') parser.add_argument('zip_file', type=str, help='压缩包的路径。')
parser.add_argument('base_output_path', type=str, help='结果文件的基础路径和文件名前缀。') parser.add_argument('badcase_path', type=str, help='badcase文件的基础路径和文件名前缀。')
parser.add_argument('overall_path', type=str, help='overall文件的基础路径和文件名前缀。')
parser.add_argument('base_data_path', type=str, help='基准文件的基础路径和文件名前缀。')
parser.add_argument('--s3_bucket_name', type=str, help='S3桶名称。', default=None) parser.add_argument('--s3_bucket_name', type=str, help='S3桶名称。', default=None)
parser.add_argument('--s3_file_name', type=str, help='S3上的文件名。', default=None) parser.add_argument('--s3_file_name', type=str, help='S3上的文件名。', default=None)
parser.add_argument('--AWS_ACCESS_KEY', type=str, help='AWS访问密钥。', default=None) parser.add_argument('--AWS_ACCESS_KEY', type=str, help='AWS访问密钥。', default=None)
...@@ -698,5 +862,5 @@ if __name__ == "__main__": ...@@ -698,5 +862,5 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
main(args.standard_file, args.test_file, args.zip_file, args.base_output_path, args.s3_bucket_name, args.s3_file_name, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL) main(args.standard_file, args.test_file, args.zip_file, args.badcase_path,args.overall_path,args.base_data_path,args.s3_bucket_name, args.s3_file_name, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment