Commit fe58649b authored by liukaiwen's avatar liukaiwen

Merge branch 'master' of github.com:papayalove/Magic-PDF

parents d876cbe8 e9843e15
......@@ -5,7 +5,13 @@ name: PDF
on:
push:
branches:
- master
- "master"
paths-ignore:
- "cmds/**"
- "**.md"
pull_request:
branches:
- "master"
paths-ignore:
- "cmds/**"
- "**.md"
......@@ -34,16 +40,26 @@ jobs:
pip install -r requirements.txt
fi
- name: benchmark
- name: config-net-reset
run: |
export http_proxy=""
export https_proxy=""
- name: get-benchmark-result
run: |
echo "start test"
cd tools && python ocr_badcase.py pdf_json_label_0306.json ocr_dataset.json json_files.zip badcase.json overall.json base_data.json
cd tools && python text_badcase.py pdf_json_label_0306.json pdf_json_label_0229.json json_files.zip text_badcase text_overall base_data_text.json --s3_bucket_name llm-process-pperf --s3_file_directory qa-validate/pdf-datasets/badcase --AWS_ACCESS_KEY 7X9CWNHIVOHH3LXRD5WK --AWS_SECRET_KEY IHLyTsv7h4ArzReLWUGZNKvwqB7CMrRi6e7ZyUt0 --END_POINT_URL http://p-ceph-norm-inside.pjlab.org.cn:80
python ocr_badcase.py pdf_json_label_0306.json ocr_dataset.json json_files.zip ocr_badcase ocr_overall base_data_ocr.json --s3_bucket_name llm-process-pperf --s3_file_directory qa-validate/pdf-datasets/badcase --AWS_ACCESS_KEY 7X9CWNHIVOHH3LXRD5WK --AWS_SECRET_KEY IHLyTsv7h4ArzReLWUGZNKvwqB7CMrRi6e7ZyUt0 --END_POINT_URL http://p-ceph-norm-inside.pjlab.org.cn:80
notify_to_feishu:
if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }}
needs: [pdf-test]
runs-on: [pdf]
runs-on: pdf
steps:
- name: notify
run: |
curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'${{ secrets.USER_ID }}'"}]]}}}}' ${{ secrets.WEBHOOK_URL }}
curl ${{ secrets.WEBHOOK_URL }} -H 'Content-Type: application/json' -d '{
"msgtype": "text",
"text": {
"content": "'${{ github.repository }}' GitHubAction Failed!\n 细节请查看:https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"
}
}'
......@@ -115,8 +115,9 @@ def ocr_parse_pdf_core(pdf_bytes, model_output_json_list, book_name, start_page_
if __name__ == '__main__':
pdf_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.pdf"
json_file_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.json"
ocr_local_parse(pdf_path, json_file_path)
# book_name = "科数网/edu_00011318"
# ocr_online_parse(book_name)
# ocr_local_parse(pdf_path, json_file_path)
book_name = "数学新星网/edu_00001236"
ocr_online_parse(book_name)
pass
......@@ -67,9 +67,7 @@ def demo_classify_by_type(book_name=None, debug_mode=True):
img_num_list = pdf_meta["imgs_per_page"]
text_len_list = pdf_meta["text_len_per_page"]
text_layout_list = pdf_meta["text_layout_per_page"]
pdf_path = json_object.get("file_location")
is_text_pdf, results = classify(
pdf_path,
total_page,
page_width,
page_height,
......@@ -89,7 +87,7 @@ def demo_meta_scan(book_name=None, debug_mode=True):
s3_pdf_path = json_object.get("file_location")
s3_config = get_s3_config_dict(s3_pdf_path)
pdf_bytes = read_file(s3_pdf_path, s3_config)
res = pdf_meta_scan(s3_pdf_path, pdf_bytes)
res = pdf_meta_scan(pdf_bytes)
logger.info(json.dumps(res, ensure_ascii=False))
write_json_to_local(res, book_name)
......
......@@ -21,28 +21,122 @@ python magicpdf.py --json s3://llm-pdf-text/scihub/xxxx.json?bytes=0,81350
python magicpdf.py --pdf /home/llm/Downloads/xxxx.pdf --model /home/llm/Downloads/xxxx.json 或者 python magicpdf.py --pdf /home/llm/Downloads/xxxx.pdf
"""
import click
from magic_pdf.libs.config_reader import get_s3_config
from magic_pdf.libs.path_utils import (
parse_s3path,
parse_s3_range_params,
remove_non_official_s3_args,
)
from magic_pdf.libs.config_reader import get_local_dir
from magic_pdf.io.S3ReaderWriter import S3ReaderWriter, MODE_BIN
from magic_pdf.io.DiskReaderWriter import DiskReaderWriter
from magic_pdf.spark.spark_api import parse_union_pdf, parse_txt_pdf, parse_ocr_pdf
import os
import json as json_parse
from datetime import datetime
parse_pdf_methods = click.Choice(["ocr", "txt", "auto"])
def get_pdf_parse_method(method):
if method == "ocr":
return parse_ocr_pdf
elif method == "txt":
return parse_txt_pdf
return parse_union_pdf
def prepare_env():
local_parent_dir = os.path.join(
get_local_dir(), "magic-pdf", datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
)
local_image_dir = os.path.join(local_parent_dir, "images")
local_md_dir = os.path.join(local_parent_dir, "md")
os.makedirs(local_image_dir, exist_ok=True)
os.makedirs(local_md_dir, exist_ok=True)
return local_image_dir, local_md_dir
import click
@click.group()
def cli():
pass
@cli.command()
@click.option('--json', type=str, help='输入一个S3路径')
def json_command(json):
# 这里处理json相关的逻辑
print(f'处理JSON: {json}')
@click.option("--json", type=str, help="输入一个S3路径")
@click.option(
"--method",
type=parse_pdf_methods,
help="指定解析方法。txt: 文本型 pdf 解析方法, ocr: 光学识别解析 pdf, auto: 程序智能选择解析方法",
default="auto",
)
def json_command(json, method):
if not json.startswith("s3://"):
print("usage: python magipdf.py --json s3://some_bucket/some_path")
os.exit(1)
def read_s3_path(s3path):
bucket, key = parse_s3path(s3path)
s3_ak, s3_sk, s3_endpoint = get_s3_config(bucket)
s3_rw = S3ReaderWriter(
s3_ak, s3_sk, s3_endpoint, "auto", remove_non_official_s3_args(s3path)
)
may_range_params = parse_s3_range_params(json)
if may_range_params is None or 2 != len(may_range_params):
byte_start, byte_end = 0, None
else:
byte_start, byte_end = int(may_range_params[0]), int(may_range_params[1])
return s3_rw.read_jsonl(
remove_non_official_s3_args(s3path), byte_start, byte_end, MODE_BIN
)
jso = json_parse.loads(read_s3_path(json).decode("utf-8"))
pdf_data = read_s3_path(jso["file_location"])
local_image_dir, _ = prepare_env()
local_image_rw = DiskReaderWriter(local_image_dir)
parse = get_pdf_parse_method(method)
parse(pdf_data, jso["doc_layout_result"], local_image_rw, is_debug=True)
@cli.command()
@click.option('--pdf', type=click.Path(exists=True), required=True, help='PDF文件的路径')
@click.option('--model', type=click.Path(exists=True), help='模型的路径')
def pdf_command(pdf, model):
@click.option(
"--pdf", type=click.Path(exists=True), required=True, help="PDF文件的路径"
)
@click.option("--model", type=click.Path(exists=True), help="模型的路径")
@click.option(
"--method",
type=parse_pdf_methods,
help="指定解析方法。txt: 文本型 pdf 解析方法, ocr: 光学识别解析 pdf, auto: 程序智能选择解析方法",
default="auto",
)
def pdf_command(pdf, model, method):
# 这里处理pdf和模型相关的逻辑
print(f'处理PDF: {pdf}')
print(f'加载模型: {model}')
if model is None:
model = pdf.replace(".pdf", ".json")
if not os.path.exists(model):
print(f"make sure json file existed and place under {os.dirname(pdf)}")
os.eixt(1)
def read_fn(path):
disk_rw = DiskReaderWriter(os.path.dirname(path))
return disk_rw.read(os.path.basename(path), MODE_BIN)
pdf_data = read_fn(pdf)
jso = json_parse.loads(read_fn(model).decode("utf-8"))
local_image_dir, _ = prepare_env()
local_image_rw = DiskReaderWriter(local_image_dir)
parse = get_pdf_parse_method(method)
parse(pdf_data, jso["doc_layout_result"], local_image_rw, is_debug=True)
if __name__ == '__main__':
if __name__ == "__main__":
"""
python magic_pdf/cli/magicpdf.py json-command --json s3://llm-pdf-text/pdf_ebook_and_paper/format/v070/part-66028dd46437-000076.jsonl?bytes=0,308393
"""
cli()
......@@ -2,6 +2,7 @@ import math
from loguru import logger
from magic_pdf.libs.boxbase import find_bottom_nearest_text_bbox, find_top_nearest_text_bbox
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.ocr_content_type import ContentType
TYPE_INLINE_EQUATION = ContentType.InlineEquation
......@@ -227,7 +228,7 @@ def __insert_before_para(text, type, element, content_list):
logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}")
def mk_universal_format(para_dict: dict):
def mk_universal_format(para_dict: dict, img_buket_path):
"""
构造统一格式 https://aicarrier.feishu.cn/wiki/FqmMwcH69iIdCWkkyjvcDwNUnTY
"""
......@@ -249,7 +250,7 @@ def mk_universal_format(para_dict: dict):
for img in all_page_images:
content_node = {
"type": "image",
"img_path": img['image_path'],
"img_path": join_path(img_buket_path, img['image_path']),
"img_alt":"",
"img_title":"",
"img_caption":""
......@@ -258,7 +259,7 @@ def mk_universal_format(para_dict: dict):
for table in all_page_tables:
content_node = {
"type": "table",
"img_path": table['image_path'],
"img_path": join_path(img_buket_path, table['image_path']),
"table_latex": table.get("text"),
"table_title": "",
"table_caption": "",
......
from magic_pdf.libs.commons import s3_image_save_path, join_path
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import ContentType
......@@ -56,7 +56,7 @@ def ocr_mk_mm_markdown(pdf_info_dict: dict):
if not span.get('image_path'):
continue
else:
content = f"![]({join_path(s3_image_save_path, span['image_path'])})"
content = f"![]({span['image_path']})"
else:
content = ocr_escape_special_markdown_char(span['content']) # 转义特殊符号
if span['type'] == ContentType.InlineEquation:
......@@ -123,7 +123,7 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode):
content = f"\n$$\n{span['content']}\n$$\n"
elif span_type in [ContentType.Image, ContentType.Table]:
if mode == 'mm':
content = f"\n![]({join_path(s3_image_save_path, span['image_path'])})\n"
content = f"\n![]({span['image_path']})\n"
elif mode == 'nlp':
pass
if content != '':
......@@ -138,10 +138,10 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode):
return page_markdown
def para_to_standard_format(para):
def para_to_standard_format(para, img_buket_path):
para_content = {}
if len(para) == 1:
para_content = line_to_standard_format(para[0])
para_content = line_to_standard_format(para[0], img_buket_path)
elif len(para) > 1:
para_text = ''
inline_equation_num = 0
......@@ -171,7 +171,7 @@ def para_to_standard_format(para):
}
return para_content
def make_standard_format_with_para(pdf_info_dict: dict):
def make_standard_format_with_para(pdf_info_dict: dict, img_buket_path: str):
content_list = []
for _, page_info in pdf_info_dict.items():
paras_of_layout = page_info.get("para_blocks")
......@@ -179,12 +179,12 @@ def make_standard_format_with_para(pdf_info_dict: dict):
continue
for paras in paras_of_layout:
for para in paras:
para_content = para_to_standard_format(para)
para_content = para_to_standard_format(para, img_buket_path)
content_list.append(para_content)
return content_list
def line_to_standard_format(line):
def line_to_standard_format(line, img_buket_path):
line_text = ""
inline_equation_num = 0
for span in line['spans']:
......@@ -195,13 +195,13 @@ def line_to_standard_format(line):
if span['type'] == ContentType.Image:
content = {
'type': 'image',
'img_path': join_path(s3_image_save_path, span['image_path'])
'img_path': join_path(img_buket_path, span['image_path'])
}
return content
elif span['type'] == ContentType.Table:
content = {
'type': 'table',
'img_path': join_path(s3_image_save_path, span['image_path'])
'img_path': join_path(img_buket_path, span['image_path'])
}
return content
else:
......
......@@ -15,6 +15,7 @@ from collections import Counter
import click
import numpy as np
from loguru import logger
from magic_pdf.libs.commons import mymax, get_top_percent_list
from magic_pdf.filter.pdf_meta_scan import scan_max_page, junk_limit_min
......@@ -298,7 +299,7 @@ def classify_by_img_narrow_strips(page_width, page_height, img_sz_list):
return narrow_strip_pages_ratio < 0.5
def classify(pdf_path, total_page: int, page_width, page_height, img_sz_list: list, text_len_list: list, img_num_list: list, text_layout_list: list):
def classify(total_page: int, page_width, page_height, img_sz_list: list, text_len_list: list, img_num_list: list, text_layout_list: list):
"""
这里的图片和页面长度单位是pts
:param total_page:
......@@ -323,7 +324,7 @@ def classify(pdf_path, total_page: int, page_width, page_height, img_sz_list: li
elif not any(results.values()):
return False, results
else:
print(f"WARNING: {pdf_path} is not classified by area and text_len, by_image_area: {results['by_image_area']}, by_text: {results['by_text_len']}, by_avg_words: {results['by_avg_words']}, by_img_num: {results['by_img_num']}, by_text_layout: {results['by_text_layout']}, by_img_narrow_strips: {results['by_img_narrow_strips']}", file=sys.stderr) # 利用这种情况可以快速找出来哪些pdf比较特殊,针对性修正分类算法
logger.warning(f"pdf is not classified by area and text_len, by_image_area: {results['by_image_area']}, by_text: {results['by_text_len']}, by_avg_words: {results['by_avg_words']}, by_img_num: {results['by_img_num']}, by_text_layout: {results['by_text_layout']}, by_img_narrow_strips: {results['by_img_narrow_strips']}", file=sys.stderr) # 利用这种情况可以快速找出来哪些pdf比较特殊,针对性修正分类算法
return False, results
......@@ -350,7 +351,7 @@ def main(json_file):
is_needs_password = o['is_needs_password']
if is_encrypted or total_page == 0 or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理
continue
tag = classify(pdf_path, total_page, page_width, page_height, img_sz_list, text_len_list, text_layout_list)
tag = classify(total_page, page_width, page_height, img_sz_list, text_len_list, text_layout_list)
o['is_text_pdf'] = tag
print(json.dumps(o, ensure_ascii=False))
except Exception as e:
......
......@@ -287,7 +287,7 @@ def get_language(doc: fitz.Document):
return language
def pdf_meta_scan(s3_pdf_path: str, pdf_bytes: bytes):
def pdf_meta_scan(pdf_bytes: bytes):
"""
:param s3_pdf_path:
:param pdf_bytes: pdf文件的二进制数据
......@@ -298,8 +298,8 @@ def pdf_meta_scan(s3_pdf_path: str, pdf_bytes: bytes):
is_encrypted = doc.is_encrypted
total_page = len(doc)
if total_page == 0:
logger.warning(f"drop this pdf: {s3_pdf_path}, drop_reason: {DropReason.EMPTY_PDF}")
result = {"need_drop": True, "drop_reason": DropReason.EMPTY_PDF}
logger.warning(f"drop this pdf, drop_reason: {DropReason.EMPTY_PDF}")
result = {"_need_drop": True, "_drop_reason": DropReason.EMPTY_PDF}
return result
else:
page_width_pts, page_height_pts = get_pdf_page_size_pts(doc)
......@@ -322,7 +322,6 @@ def pdf_meta_scan(s3_pdf_path: str, pdf_bytes: bytes):
# 最后输出一条json
res = {
"pdf_path": s3_pdf_path,
"is_needs_password": is_needs_password,
"is_encrypted": is_encrypted,
"total_page": total_page,
......@@ -350,7 +349,7 @@ def main(s3_pdf_path: str, s3_profile: str):
"""
try:
file_content = read_file(s3_pdf_path, s3_profile)
pdf_meta_scan(s3_pdf_path, file_content)
pdf_meta_scan(file_content)
except Exception as e:
print(f"ERROR: {s3_pdf_path}, {e}", file=sys.stderr)
logger.exception(e)
......
......@@ -5,9 +5,11 @@ from loguru import logger
MODE_TXT = "text"
MODE_BIN = "binary"
class DiskReaderWriter(AbsReaderWriter):
def __init__(self, parent_path, encoding='utf-8'):
def __init__(self, parent_path, encoding="utf-8"):
self.path = parent_path
self.encoding = encoding
......@@ -20,10 +22,10 @@ class DiskReaderWriter(AbsReaderWriter):
logger.error(f"文件 {abspath} 不存在")
raise Exception(f"文件 {abspath} 不存在")
if mode == MODE_TXT:
with open(abspath, 'r', encoding = self.encoding) as f:
with open(abspath, "r", encoding=self.encoding) as f:
return f.read()
elif mode == MODE_BIN:
with open(abspath, 'rb') as f:
with open(abspath, "rb") as f:
return f.read()
else:
raise ValueError("Invalid mode. Use 'text' or 'binary'.")
......@@ -37,20 +39,21 @@ class DiskReaderWriter(AbsReaderWriter):
directory_path = os.path.dirname(abspath)
os.makedirs(directory_path)
if mode == MODE_TXT:
with open(abspath, 'w', encoding=self.encoding) as f:
with open(abspath, "w", encoding=self.encoding) as f:
f.write(content)
logger.info(f"内容已成功写入 {abspath}")
elif mode == MODE_BIN:
with open(abspath, 'wb') as f:
with open(abspath, "wb") as f:
f.write(content)
logger.info(f"内容已成功写入 {abspath}")
else:
raise ValueError("Invalid mode. Use 'text' or 'binary'.")
def read_jsonl(self, path: str, byte_start=0, byte_end=None, encoding='utf-8'):
def read_jsonl(self, path: str, byte_start=0, byte_end=None, encoding="utf-8"):
return self.read(path)
# 使用示例
if __name__ == "__main__":
file_path = "io/test/example.txt"
......@@ -63,5 +66,3 @@ if __name__ == "__main__":
content = drw.read(path=file_path)
if content:
logger.info(f"从 {file_path} 读取的内容: {content}")
......@@ -24,7 +24,7 @@ error_log_path = "s3://llm-pdf-text/err_logs/"
# json_dump_path = "s3://pdf_books_temp/json_dump/" # 这条路径仅用于临时本地测试,不能提交到main
json_dump_path = "s3://llm-pdf-text/json_dump/"
s3_image_save_path = "s3://mllm-raw-media/pdf2md_img/" # TODO 基础库不应该有这些存在的路径,应该在业务代码中定义
# s3_image_save_path = "s3://mllm-raw-media/pdf2md_img/" # 基础库不应该有这些存在的路径,应该在业务代码中定义
def get_top_percent_list(num_list, percent):
......@@ -120,28 +120,8 @@ def read_file(pdf_path: str, s3_profile):
return f.read()
def get_docx_model_output(pdf_model_output, pdf_model_s3_profile, page_id):
if isinstance(pdf_model_output, str):
model_output_json_path = join_path(pdf_model_output, f"page_{page_id + 1}.json") # 模型输出的页面编号从1开始的
if os.path.exists(model_output_json_path):
json_from_docx = read_file(model_output_json_path, pdf_model_s3_profile)
model_output_json = json.loads(json_from_docx)
else:
try:
model_output_json_path = join_path(pdf_model_output, "model.json")
with open(model_output_json_path, "r", encoding="utf-8") as f:
model_output_json = json.load(f)
model_output_json = model_output_json["doc_layout_result"][page_id]
except:
s3_model_output_json_path = join_path(pdf_model_output, f"page_{page_id + 1}.json")
s3_model_output_json_path = join_path(pdf_model_output, f"{page_id}.json")
#s3_model_output_json_path = join_path(pdf_model_output, f"page_{page_id }.json")
# logger.warning(f"model_output_json_path: {model_output_json_path} not found. try to load from s3: {s3_model_output_json_path}")
s = read_file(s3_model_output_json_path, pdf_model_s3_profile)
return json.loads(s)
elif isinstance(pdf_model_output, list):
def get_docx_model_output(pdf_model_output, page_id):
model_output_json = pdf_model_output[page_id]
return model_output_json
......
......@@ -2,6 +2,7 @@
根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
"""
import json
import os
......@@ -10,11 +11,7 @@ from loguru import logger
from magic_pdf.libs.commons import parse_bucket_key
def get_s3_config(bucket_name: str):
"""
~/magic-pdf.json 读出来
"""
def read_config():
home_dir = os.path.expanduser("~")
config_file = os.path.join(home_dir, "magic-pdf.json")
......@@ -24,6 +21,14 @@ def get_s3_config(bucket_name: str):
with open(config_file, "r") as f:
config = json.load(f)
return config
def get_s3_config(bucket_name: str):
"""
~/magic-pdf.json 读出来
"""
config = read_config()
bucket_info = config.get("bucket_info")
if bucket_name not in bucket_info:
......@@ -49,5 +54,10 @@ def get_bucket_name(path):
return bucket
if __name__ == '__main__':
def get_local_dir():
config = read_config()
return config.get("temp-output-dir", "/tmp")
if __name__ == "__main__":
ak, sk, endpoint = get_s3_config("llm-raw")
from collections import Counter
from magic_pdf.libs.language import detect_lang
def get_language_from_model(model_list: list):
language_lst = []
for ocr_page_info in model_list:
page_text = ""
layout_dets = ocr_page_info["layout_dets"]
for layout_det in layout_dets:
category_id = layout_det["category_id"]
allow_category_id_list = [15]
if category_id in allow_category_id_list:
page_text += layout_det["text"]
page_language = detect_lang(page_text)
language_lst.append(page_language)
# 统计text_language_list中每种语言的个数
count_dict = Counter(language_lst)
# 输出text_language_list中出现的次数最多的语言
language = max(count_dict, key=count_dict.get)
return language
......@@ -8,7 +8,7 @@ class DropReason:
HIGH_COMPUTATIONAL_lOAD_BY_SVGS = "high_computational_load_by_svgs" # 特殊的SVG图,计算量太大,从而丢弃
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES = "high_computational_load_by_total_pages" # 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT = "missing doc_layout_result" # 版面分析失败
Exception = "exception" # 解析中发生异常
Exception = "_exception" # 解析中发生异常
ENCRYPTED = "encrypted" # PDF是加密的
EMPTY_PDF = "total_page=0" # PDF页面总数为0
NOT_IS_TEXT_PDF = "not_is_text_pdf" # 不是文字版PDF,无法直接解析
......
import hashlib
def compute_md5(file_bytes):
hasher = hashlib.md5()
hasher.update(file_bytes)
return hasher.hexdigest().upper()
def compute_sha256(input_string):
hasher = hashlib.sha256()
# 在Python3中,需要将字符串转化为字节对象才能被哈希函数处理
input_bytes = input_string.encode('utf-8')
hasher.update(input_bytes)
return hasher.hexdigest()
from s3pathlib import S3Path
def remove_non_official_s3_args(s3path):
"""
example: s3://abc/xxxx.json?bytes=0,81350 ==> s3://abc/xxxx.json
"""
arr = s3path.split("?")
return arr[0]
def parse_s3path(s3path: str):
p = S3Path(remove_non_official_s3_args(s3path))
return p.bucket, p.key
def parse_s3_range_params(s3path: str):
"""
example: s3://abc/xxxx.json?bytes=0,81350 ==> [0, 81350]
"""
arr = s3path.split("?bytes=")
if len(arr) == 1:
return None
return arr[1].split(",")
import os
from pathlib import Path
from typing import Tuple
import io
# from app.common.s3 import get_s3_client
from magic_pdf.libs.commons import fitz
from loguru import logger
from magic_pdf.libs.commons import parse_bucket_key, join_path
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.hash_utils import compute_sha256
def cut_image(bbox: Tuple, page_num: int, page: fitz.Page, save_parent_path: str, s3_return_path=None, img_s3_client=None, upload_switch=True):
def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter):
"""
从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径
save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。
"""
# 拼接文件名
filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}.jpg"
# 拼接路径
image_save_path = join_path(save_parent_path, filename)
s3_img_path = join_path(s3_return_path, filename) if s3_return_path is not None else None
# 打印图片文件名
# print(f"Saved {image_save_path}")
#检查坐标
# x_check = int(bbox[2]) - int(bbox[0])
# y_check = int(bbox[3]) - int(bbox[1])
# if x_check <= 0 or y_check <= 0:
#
# if image_save_path.startswith("s3://"):
# logger.exception(f"传入图片坐标有误,x1<x0或y1<y0,{s3_img_path}")
# return s3_img_path
# else:
# logger.exception(f"传入图片坐标有误,x1<x0或y1<y0,{image_save_path}")
# return image_save_path
filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}"
# 老版本返回不带bucket的路径
img_path = join_path(return_path, filename) if return_path is not None else None
# 新版本生成平铺路径
img_hash256_path = f"{compute_sha256(img_path)}.jpg"
# 将坐标转换为fitz.Rect对象
rect = fitz.Rect(*bbox)
......@@ -42,39 +26,17 @@ def cut_image(bbox: Tuple, page_num: int, page: fitz.Page, save_parent_path: str
# 截取图片
pix = page.get_pixmap(clip=rect, matrix=zoom)
if image_save_path.startswith("s3://"):
if not upload_switch:
pass
else:
# 图片保存到s3
bucket_name, bucket_key = parse_bucket_key(image_save_path)
# 将字节流上传到s3
byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
file_obj = io.BytesIO(byte_data)
if img_s3_client is not None:
img_s3_client.upload_fileobj(file_obj, bucket_name, bucket_key)
# 每个图片上传任务都创建一个新的client
# img_s3_client_once = get_s3_client(image_save_path)
# img_s3_client_once.upload_fileobj(file_obj, bucket_name, bucket_key)
else:
logger.exception("must input img_s3_client")
return s3_img_path
else:
# 保存图片到本地
# 先检查一下image_save_path的父目录是否存在,如果不存在,就创建
parent_dir = os.path.dirname(image_save_path)
if not os.path.exists(parent_dir):
os.makedirs(parent_dir)
pix.save(image_save_path, jpg_quality=95)
# 为了直接能在markdown里看,这里把地址改为相对于mardown的地址
pth = Path(image_save_path)
image_save_path = f"{pth.parent.name}/{pth.name}"
return image_save_path
def save_images_by_bboxes(book_name: str, page_num: int, page: fitz.Page, save_path: str,
image_bboxes: list, images_overlap_backup:list, table_bboxes: list, equation_inline_bboxes: list,
equation_interline_bboxes: list, img_s3_client) -> dict:
imageWriter.write(data=byte_data, path=img_hash256_path, mode="binary")
return img_hash256_path
def save_images_by_bboxes(page_num: int, page: fitz.Page, pdf_bytes_md5: str,
image_bboxes: list, images_overlap_backup: list, table_bboxes: list,
equation_inline_bboxes: list,
equation_interline_bboxes: list, imageWriter) -> dict:
"""
返回一个dict, key为bbox, 值是图片地址
"""
......@@ -85,53 +47,30 @@ def save_images_by_bboxes(book_name: str, page_num: int, page: fitz.Page, save_p
interline_eq_info = []
# 图片的保存路径组成是这样的: {s3_or_local_path}/{book_name}/{images|tables|equations}/{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg
s3_return_image_path = join_path(book_name, "images")
image_save_path = join_path(save_path, s3_return_image_path)
s3_return_table_path = join_path(book_name, "tables")
table_save_path = join_path(save_path, s3_return_table_path)
s3_return_equations_inline_path = join_path(book_name, "equations_inline")
equation_inline_save_path = join_path(save_path, s3_return_equations_inline_path)
s3_return_equation_interline_path = join_path(book_name, "equation_interline")
equation_interline_save_path = join_path(save_path, s3_return_equation_interline_path)
def return_path(type):
return join_path(pdf_bytes_md5, type)
for bbox in image_bboxes:
if any([bbox[0]>=bbox[2], bbox[1]>=bbox[3]]):
if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
logger.warning(f"image_bboxes: 错误的box, {bbox}")
continue
image_path = cut_image(bbox, page_num, page, image_save_path, s3_return_image_path, img_s3_client)
image_path = cut_image(bbox, page_num, page, return_path("images"), imageWriter)
image_info.append({"bbox": bbox, "image_path": image_path})
for bbox in images_overlap_backup:
if any([bbox[0]>=bbox[2], bbox[1]>=bbox[3]]):
if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
logger.warning(f"images_overlap_backup: 错误的box, {bbox}")
continue
image_path = cut_image(bbox, page_num, page, image_save_path, s3_return_image_path, img_s3_client)
image_path = cut_image(bbox, page_num, page, return_path("images"), imageWriter)
image_backup_info.append({"bbox": bbox, "image_path": image_path})
for bbox in table_bboxes:
if any([bbox[0]>=bbox[2], bbox[1]>=bbox[3]]):
if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
logger.warning(f"table_bboxes: 错误的box, {bbox}")
continue
image_path = cut_image(bbox, page_num, page, table_save_path, s3_return_table_path, img_s3_client)
image_path = cut_image(bbox, page_num, page, return_path("tables"), imageWriter)
table_info.append({"bbox": bbox, "image_path": image_path})
for bbox in equation_inline_bboxes:
if any([bbox[0]>=bbox[2], bbox[1]>=bbox[3]]):
logger.warning(f"equation_inline_bboxes: 错误的box, {bbox}")
continue
image_path = cut_image(bbox[:4], page_num, page, equation_inline_save_path, s3_return_equations_inline_path, img_s3_client, upload_switch=False)
inline_eq_info.append({'bbox':bbox[:4], "image_path":image_path, "latex_text":bbox[4]})
for bbox in equation_interline_bboxes:
if any([bbox[0]>=bbox[2], bbox[1]>=bbox[3]]):
logger.warning(f"equation_interline_bboxes: 错误的box, {bbox}")
continue
image_path = cut_image(bbox[:4], page_num, page, equation_interline_save_path, s3_return_equation_interline_path, img_s3_client, upload_switch=False)
interline_eq_info.append({"bbox":bbox[:4], "image_path":image_path, "latex_text":bbox[4]})
return image_info, image_backup_info, table_info, inline_eq_info, interline_eq_info
\ No newline at end of file
import json
import os
import time
from loguru import logger
from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_text_bbox
from magic_pdf.libs.commons import (
read_file,
join_path,
fitz,
get_img_s3_client,
get_delta_time,
get_docx_model_output,
)
from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.libs.safe_filename import sanitize_filename
from magic_pdf.para.para_split import para_split
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component
from magic_pdf.pre_proc.detect_footer_by_model import parse_footers
......@@ -38,38 +30,16 @@ from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox
def parse_pdf_by_ocr(
pdf_bytes,
pdf_model_output,
save_path,
book_name,
pdf_model_profile=None,
image_s3_config=None,
imageWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
):
save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
book_name = sanitize_filename(book_name)
md_bookname_save_path = ""
if debug_mode:
save_path = join_path(save_tmp_path, "md")
pdf_local_path = join_path(save_tmp_path, "download-pdfs", book_name)
if not os.path.exists(os.path.dirname(pdf_local_path)):
# 如果目录不存在,创建它
os.makedirs(os.path.dirname(pdf_local_path))
md_bookname_save_path = join_path(save_tmp_path, "md", book_name)
if not os.path.exists(md_bookname_save_path):
# 如果目录不存在,创建它
os.makedirs(md_bookname_save_path)
with open(pdf_local_path + ".pdf", "wb") as pdf_file:
pdf_file.write(pdf_bytes)
pdf_bytes_md5 = compute_md5(pdf_bytes)
pdf_docs = fitz.open("pdf", pdf_bytes)
# 初始化空的pdf_info_dict
pdf_info_dict = {}
img_s3_client = get_img_s3_client(save_path, image_s3_config)
start_time = time.time()
......@@ -91,16 +61,14 @@ def parse_pdf_by_ocr(
# 获取当前页的模型数据
ocr_page_info = get_docx_model_output(
pdf_model_output, pdf_model_profile, page_id
pdf_model_output, page_id
)
"""从json中获取每页的页码、页眉、页脚的bbox"""
page_no_bboxes = parse_pageNos(page_id, page, ocr_page_info)
header_bboxes = parse_headers(page_id, page, ocr_page_info)
footer_bboxes = parse_footers(page_id, page, ocr_page_info)
footnote_bboxes = parse_footnotes_by_model(
page_id, page, ocr_page_info, md_bookname_save_path, debug_mode=debug_mode
)
footnote_bboxes = parse_footnotes_by_model(page_id, page, ocr_page_info, debug_mode=debug_mode)
# 构建需要remove的bbox字典
need_remove_spans_bboxes_dict = {
......@@ -179,7 +147,7 @@ def parse_pdf_by_ocr(
spans, dropped_spans_by_removed_bboxes = remove_spans_by_bboxes_dict(spans, need_remove_spans_bboxes_dict)
'''对image和table截图'''
spans = cut_image_and_table(spans, page, page_id, book_name, save_path, img_s3_client)
spans = cut_image_and_table(spans, page, page_id, pdf_bytes_md5, imageWriter)
'''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)'''
displayed_list = []
......@@ -242,16 +210,4 @@ def parse_pdf_by_ocr(
"""分段"""
para_split(pdf_info_dict, debug_mode=debug_mode)
'''在测试时,保存调试信息'''
if debug_mode:
params_file_save_path = join_path(
save_tmp_path, "md", book_name, "preproc_out.json"
)
with open(params_file_save_path, "w", encoding="utf-8") as f:
json.dump(pdf_info_dict, f, ensure_ascii=False, indent=4)
# drow_bbox
draw_layout_bbox(pdf_info_dict, pdf_bytes, md_bookname_save_path)
draw_text_bbox(pdf_info_dict, pdf_bytes, md_bookname_save_path)
return pdf_info_dict
......@@ -12,6 +12,7 @@ from magic_pdf.layout.bbox_sort import (
)
from magic_pdf.layout.layout_sort import LAYOUT_UNPROC, get_bboxes_layout, get_columns_cnt_of_layout, sort_text_block
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.markdown_utils import escape_special_markdown_char
from magic_pdf.libs.safe_filename import sanitize_filename
from magic_pdf.libs.vis_utils import draw_bbox_on_page, draw_layout_bbox_on_page
......@@ -73,46 +74,20 @@ paraMergeException_msg = ParaMergeException().message
def parse_pdf_by_txt(
pdf_bytes,
pdf_model_output,
save_path,
book_name,
pdf_model_profile=None,
image_s3_config=None,
imageWriter,
start_page_id=0,
end_page_id=None,
junk_img_bojids=[],
debug_mode=False,
):
save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
md_bookname_save_path = ""
book_name = sanitize_filename(book_name)
if debug_mode:
save_path = join_path(save_tmp_path, "md")
pdf_local_path = join_path(save_tmp_path, "download-pdfs", book_name)
if not os.path.exists(os.path.dirname(pdf_local_path)):
# 如果目录不存在,创建它
os.makedirs(os.path.dirname(pdf_local_path))
md_bookname_save_path = join_path(save_tmp_path, "md", book_name)
if not os.path.exists(md_bookname_save_path):
# 如果目录不存在,创建它
os.makedirs(md_bookname_save_path)
with open(pdf_local_path + ".pdf", "wb") as pdf_file:
pdf_file.write(pdf_bytes)
pdf_bytes_md5 = compute_md5(pdf_bytes)
pdf_docs = fitz.open("pdf", pdf_bytes)
pdf_info_dict = {}
img_s3_client = get_img_s3_client(save_path, image_s3_config) # 更改函数名和参数,避免歧义
# img_s3_client = "img_s3_client" #不创建这个对象,直接用字符串占位
start_time = time.time()
"""通过统计pdf全篇文字,识别正文字体"""
main_text_font = get_main_text_font(pdf_docs)
end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
for page_id in range(start_page_id, end_page_id + 1):
page = pdf_docs[page_id]
......@@ -128,20 +103,11 @@ def parse_pdf_by_txt(
# 对单页面非重复id的img数量做统计,如果当前页超过1500则直接return need_drop
"""
page_imgs = page.get_images()
img_counts = 0
for img in page_imgs:
img_bojid = img[0]
if img_bojid in junk_img_bojids: # 判断这个图片在不在junklist中
continue # 如果在junklist就不用管了,跳过
else:
recs = page.get_image_rects(img, transform=True)
if recs: # 如果这张图在当前页面有展示
img_counts += 1
if img_counts >= 1500: # 如果去除了junkimg的影响,单页img仍然超过1500的话,就排除当前pdf
logger.warning(
f"page_id: {page_id}, img_counts: {img_counts}, drop this pdf: {book_name}, drop_reason: {DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}"
)
result = {"need_drop": True, "drop_reason": DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}
# 去除对junkimg的依赖,简化逻辑
if len(page_imgs) > 1500: # 如果当前页超过1500张图片,直接跳过
logger.warning(f"page_id: {page_id}, img_counts: {len(page_imgs)}, drop this pdf")
result = {"_need_drop": True, "_drop_reason": DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}
if not debug_mode:
return result
......@@ -154,10 +120,10 @@ def parse_pdf_by_txt(
"dict",
flags=fitz.TEXTFLAGS_TEXT,
)["blocks"]
model_output_json = get_docx_model_output(pdf_model_output, pdf_model_profile, page_id)
model_output_json = get_docx_model_output(pdf_model_output, page_id)
# 解析图片
image_bboxes = parse_images(page_id, page, model_output_json, junk_img_bojids)
image_bboxes = parse_images(page_id, page, model_output_json)
image_bboxes = fix_image_vertical(image_bboxes, text_raw_blocks) # 修正图片的位置
image_bboxes = fix_seperated_image(image_bboxes) # 合并有边重合的图片
image_bboxes = include_img_title(text_raw_blocks, image_bboxes) # 向图片上方和下方寻找title,使用规则进行匹配,暂时只支持英文规则
......@@ -225,22 +191,18 @@ def parse_pdf_by_txt(
"""
==================================================================================================================================
"""
if debug_mode: # debugmode截图到本地
save_path = join_path(save_tmp_path, "md")
# 把图、表、公式都进行截图,保存到存储上,返回图片路径作为内容
image_info, image_backup_info, table_info, inline_eq_info, interline_eq_info = save_images_by_bboxes(
book_name,
page_id,
page,
save_path,
pdf_bytes_md5,
image_bboxes,
images_overlap_backup,
table_bboxes,
equations_inline_bboxes,
equations_interline_bboxes,
# 传入img_s3_client
img_s3_client,
imageWriter
) # 只要表格和图片的截图
""""以下进入到公式替换环节 """
......@@ -253,13 +215,13 @@ def parse_pdf_by_txt(
"""去掉footnote, 从文字和图片中(先去角标再去footnote试试)"""
# 通过模型识别到的footnote
footnote_bboxes_by_model = parse_footnotes_by_model(page_id, page, model_output_json, md_bookname_save_path, debug_mode=debug_mode)
footnote_bboxes_by_model = parse_footnotes_by_model(page_id, page, model_output_json, debug_mode=debug_mode)
# 通过规则识别到的footnote
footnote_bboxes_by_rule = parse_footnotes_by_rule(remain_text_blocks, page_height, page_id, main_text_font)
"""进入pdf过滤器,去掉一些不合理的pdf"""
is_good_pdf, err = pdf_filter(page, remain_text_blocks, table_bboxes, image_bboxes)
if not is_good_pdf:
logger.warning(f"page_id: {page_id}, drop this pdf: {book_name}, reason: {err}")
logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {err}")
if not debug_mode:
return err
......@@ -273,8 +235,8 @@ def parse_pdf_by_txt(
if is_text_block_horz_overlap:
# debug_show_bbox(pdf_docs, page_id, [b['bbox'] for b in remain_text_blocks], [], [], join_path(save_path, book_name, f"{book_name}_debug.pdf"), 0)
logger.warning(f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.TEXT_BLCOK_HOR_OVERLAP}")
result = {"need_drop": True, "drop_reason": DropReason.TEXT_BLCOK_HOR_OVERLAP}
logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TEXT_BLCOK_HOR_OVERLAP}")
result = {"_need_drop": True, "_drop_reason": DropReason.TEXT_BLCOK_HOR_OVERLAP}
if not debug_mode:
return result
......@@ -292,24 +254,24 @@ def parse_pdf_by_txt(
layout_bboxes, layout_tree = get_bboxes_layout(all_bboxes, page_boundry, page_id)
if len(remain_text_blocks)>0 and len(all_bboxes)>0 and len(layout_bboxes)==0:
logger.warning(f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}")
result = {"need_drop": True, "drop_reason": DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}
logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}")
result = {"_need_drop": True, "_drop_reason": DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}
if not debug_mode:
return result
"""以下去掉复杂的布局和超过2列的布局"""
if any([lay["layout_label"] == LAYOUT_UNPROC for lay in layout_bboxes]): # 复杂的布局
logger.warning(f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.COMPLICATED_LAYOUT}")
result = {"need_drop": True, "drop_reason": DropReason.COMPLICATED_LAYOUT}
logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.COMPLICATED_LAYOUT}")
result = {"_need_drop": True, "_drop_reason": DropReason.COMPLICATED_LAYOUT}
if not debug_mode:
return result
layout_column_width = get_columns_cnt_of_layout(layout_tree)
if layout_column_width > 2: # 去掉超过2列的布局pdf
logger.warning(f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}")
logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}")
result = {
"need_drop": True,
"drop_reason": DropReason.TOO_MANY_LAYOUT_COLUMNS,
"_need_drop": True,
"_drop_reason": DropReason.TOO_MANY_LAYOUT_COLUMNS,
"extra_info": {"column_cnt": layout_column_width},
}
if not debug_mode:
......@@ -390,28 +352,11 @@ def parse_pdf_by_txt(
for page_info in pdf_info_dict.values():
is_good_pdf, err = pdf_post_filter(page_info)
if not is_good_pdf:
logger.warning(f"page_id: {i}, drop this pdf: {book_name}, reason: {err}")
logger.warning(f"page_id: {i}, drop this pdf: {pdf_bytes_md5}, reason: {err}")
if not debug_mode:
return err
i += 1
if debug_mode:
params_file_save_path = join_path(save_tmp_path, "md", book_name, "preproc_out.json")
page_draw_rect_save_path = join_path(save_tmp_path, "md", book_name, "layout.pdf")
# dir_path = os.path.dirname(page_draw_rect_save_path)
# if not os.path.exists(dir_path):
# # 如果目录不存在,创建它
# os.makedirs(dir_path)
with open(params_file_save_path, "w", encoding="utf-8") as f:
json.dump(pdf_info_dict, f, ensure_ascii=False, indent=4)
# 先检测本地 page_draw_rect_save_path 是否存在,如果存在则删除
if os.path.exists(page_draw_rect_save_path):
os.remove(page_draw_rect_save_path)
# 绘制bbox和layout到pdf
draw_bbox_on_page(pdf_docs, pdf_info_dict, page_draw_rect_save_path)
draw_layout_bbox_on_page(pdf_docs, pdf_info_dict, header, footer, page_draw_rect_save_path)
if debug_mode:
# 打印后处理阶段耗时
logger.info(f"post_processing_time: {get_delta_time(start_time)}")
......@@ -429,56 +374,28 @@ def parse_pdf_by_txt(
para_process_pipeline = ParaProcessPipeline()
def _deal_with_text_exception(error_info):
logger.warning(f"page_id: {page_id}, drop this pdf: {book_name}, reason: {error_info}")
logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {error_info}")
if error_info == denseSingleLineBlockException_msg:
logger.warning(f"Drop this pdf: {book_name}, reason: {DropReason.DENSE_SINGLE_LINE_BLOCK}")
result = {"need_drop": True, "drop_reason": DropReason.DENSE_SINGLE_LINE_BLOCK}
logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.DENSE_SINGLE_LINE_BLOCK}")
result = {"_need_drop": True, "_drop_reason": DropReason.DENSE_SINGLE_LINE_BLOCK}
return result
if error_info == titleDetectionException_msg:
logger.warning(f"Drop this pdf: {book_name}, reason: {DropReason.TITLE_DETECTION_FAILED}")
result = {"need_drop": True, "drop_reason": DropReason.TITLE_DETECTION_FAILED}
logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TITLE_DETECTION_FAILED}")
result = {"_need_drop": True, "_drop_reason": DropReason.TITLE_DETECTION_FAILED}
return result
elif error_info == titleLevelException_msg:
logger.warning(f"Drop this pdf: {book_name}, reason: {DropReason.TITLE_LEVEL_FAILED}")
result = {"need_drop": True, "drop_reason": DropReason.TITLE_LEVEL_FAILED}
logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TITLE_LEVEL_FAILED}")
result = {"_need_drop": True, "_drop_reason": DropReason.TITLE_LEVEL_FAILED}
return result
elif error_info == paraSplitException_msg:
logger.warning(f"Drop this pdf: {book_name}, reason: {DropReason.PARA_SPLIT_FAILED}")
result = {"need_drop": True, "drop_reason": DropReason.PARA_SPLIT_FAILED}
logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.PARA_SPLIT_FAILED}")
result = {"_need_drop": True, "_drop_reason": DropReason.PARA_SPLIT_FAILED}
return result
elif error_info == paraMergeException_msg:
logger.warning(f"Drop this pdf: {book_name}, reason: {DropReason.PARA_MERGE_FAILED}")
result = {"need_drop": True, "drop_reason": DropReason.PARA_MERGE_FAILED}
logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.PARA_MERGE_FAILED}")
result = {"_need_drop": True, "_drop_reason": DropReason.PARA_MERGE_FAILED}
return result
if debug_mode:
input_pdf_file = f"{pdf_local_path}.pdf"
output_dir = f"{save_path}/{book_name}"
output_pdf_file = f"{output_dir}/pdf_annos.pdf"
"""
Call the para_process_pipeline function to process the pdf_info_dict.
Parameters:
para_debug_mode: str or None
If para_debug_mode is None, the para_process_pipeline will not keep any intermediate results.
If para_debug_mode is "simple", the para_process_pipeline will only keep the annos on the pdf and the final results as a json file.
If para_debug_mode is "full", the para_process_pipeline will keep all the intermediate results generated during each step.
"""
pdf_info_dict, error_info = para_process_pipeline.para_process_pipeline(
pdf_info_dict,
para_debug_mode="simple",
input_pdf_path=input_pdf_file,
output_pdf_path=output_pdf_file,
)
# 打印段落处理阶段耗时
logger.info(f"para_process_time: {get_delta_time(start_time)}")
# debug的时候不return drop信息
if error_info is not None:
_deal_with_text_exception(error_info)
return pdf_info_dict
else:
pdf_info_dict, error_info = para_process_pipeline.para_process_pipeline(pdf_info_dict)
if error_info is not None:
return _deal_with_text_exception(error_info)
......
......@@ -112,7 +112,6 @@ def parse_pdf_for_train(
pdf_model_output,
save_path,
book_name,
pdf_model_profile=None,
image_s3_config=None,
start_page_id=0,
end_page_id=None,
......@@ -184,8 +183,8 @@ def parse_pdf_for_train(
f"page_id: {page_id}, img_counts: {img_counts}, drop this pdf: {book_name}, drop_reason: {DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}"
)
result = {
"need_drop": True,
"drop_reason": DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS,
"_need_drop": True,
"_drop_reason": DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS,
}
if not debug_mode:
return result
......@@ -200,7 +199,7 @@ def parse_pdf_for_train(
flags=fitz.TEXTFLAGS_TEXT,
)["blocks"]
model_output_json = get_docx_model_output(
pdf_model_output, pdf_model_profile, page_id
pdf_model_output, page_id
)
# 解析图片
......@@ -397,8 +396,8 @@ def parse_pdf_for_train(
f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.TEXT_BLCOK_HOR_OVERLAP}"
)
result = {
"need_drop": True,
"drop_reason": DropReason.TEXT_BLCOK_HOR_OVERLAP,
"_need_drop": True,
"_drop_reason": DropReason.TEXT_BLCOK_HOR_OVERLAP,
}
if not debug_mode:
return result
......@@ -444,8 +443,8 @@ def parse_pdf_for_train(
f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}"
)
result = {
"need_drop": True,
"drop_reason": DropReason.CAN_NOT_DETECT_PAGE_LAYOUT,
"_need_drop": True,
"_drop_reason": DropReason.CAN_NOT_DETECT_PAGE_LAYOUT,
}
if not debug_mode:
return result
......@@ -457,7 +456,7 @@ def parse_pdf_for_train(
logger.warning(
f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.COMPLICATED_LAYOUT}"
)
result = {"need_drop": True, "drop_reason": DropReason.COMPLICATED_LAYOUT}
result = {"_need_drop": True, "_drop_reason": DropReason.COMPLICATED_LAYOUT}
if not debug_mode:
return result
......@@ -467,8 +466,8 @@ def parse_pdf_for_train(
f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}"
)
result = {
"need_drop": True,
"drop_reason": DropReason.TOO_MANY_LAYOUT_COLUMNS,
"_need_drop": True,
"_drop_reason": DropReason.TOO_MANY_LAYOUT_COLUMNS,
"extra_info": {"column_cnt": layout_column_width},
}
if not debug_mode:
......@@ -617,8 +616,8 @@ def parse_pdf_for_train(
f"Drop this pdf: {book_name}, reason: {DropReason.DENSE_SINGLE_LINE_BLOCK}"
)
result = {
"need_drop": True,
"drop_reason": DropReason.DENSE_SINGLE_LINE_BLOCK,
"_need_drop": True,
"_drop_reason": DropReason.DENSE_SINGLE_LINE_BLOCK,
}
return result
if error_info == titleDetectionException_msg:
......@@ -626,27 +625,27 @@ def parse_pdf_for_train(
f"Drop this pdf: {book_name}, reason: {DropReason.TITLE_DETECTION_FAILED}"
)
result = {
"need_drop": True,
"drop_reason": DropReason.TITLE_DETECTION_FAILED,
"_need_drop": True,
"_drop_reason": DropReason.TITLE_DETECTION_FAILED,
}
return result
elif error_info == titleLevelException_msg:
logger.warning(
f"Drop this pdf: {book_name}, reason: {DropReason.TITLE_LEVEL_FAILED}"
)
result = {"need_drop": True, "drop_reason": DropReason.TITLE_LEVEL_FAILED}
result = {"_need_drop": True, "_drop_reason": DropReason.TITLE_LEVEL_FAILED}
return result
elif error_info == paraSplitException_msg:
logger.warning(
f"Drop this pdf: {book_name}, reason: {DropReason.PARA_SPLIT_FAILED}"
)
result = {"need_drop": True, "drop_reason": DropReason.PARA_SPLIT_FAILED}
result = {"_need_drop": True, "_drop_reason": DropReason.PARA_SPLIT_FAILED}
return result
elif error_info == paraMergeException_msg:
logger.warning(
f"Drop this pdf: {book_name}, reason: {DropReason.PARA_MERGE_FAILED}"
)
result = {"need_drop": True, "drop_reason": DropReason.PARA_MERGE_FAILED}
result = {"_need_drop": True, "_drop_reason": DropReason.PARA_MERGE_FAILED}
return result
if debug_mode:
......
......@@ -32,8 +32,8 @@ def meta_scan(jso: dict, doc_layout_check=True) -> dict:
if (
"doc_layout_result" not in jso
): # 检测json中是存在模型数据,如果没有则需要跳过该pdf
jso["need_drop"] = True
jso["drop_reason"] = DropReason.MISS_DOC_LAYOUT_RESULT
jso["_need_drop"] = True
jso["_drop_reason"] = DropReason.MISS_DOC_LAYOUT_RESULT
return jso
try:
data_source = get_data_source(jso)
......@@ -58,10 +58,10 @@ def meta_scan(jso: dict, doc_layout_check=True) -> dict:
start_time = time.time() # 记录开始时间
res = pdf_meta_scan(s3_pdf_path, file_content)
if res.get(
"need_drop", False
"_need_drop", False
): # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析
jso["need_drop"] = True
jso["drop_reason"] = res["drop_reason"]
jso["_need_drop"] = True
jso["_drop_reason"] = res["_drop_reason"]
else: # 正常返回
jso["pdf_meta"] = res
jso["content"] = ""
......@@ -85,7 +85,7 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
return jso
# 开始正式逻辑
try:
......@@ -113,8 +113,8 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
if (
is_encrypted or is_needs_password
): # 加密的,需要密码的,没有页面的,都不处理
jso["need_drop"] = True
jso["drop_reason"] = DropReason.ENCRYPTED
jso["_need_drop"] = True
jso["_drop_reason"] = DropReason.ENCRYPTED
else:
start_time = time.time() # 记录开始时间
is_text_pdf, results = classify(
......@@ -139,8 +139,8 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
if (
text_language not in allow_language
): # 如果语言不在允许的语言中,则drop
jso["need_drop"] = True
jso["drop_reason"] = DropReason.NOT_ALLOW_LANGUAGE
jso["_need_drop"] = True
jso["_drop_reason"] = DropReason.NOT_ALLOW_LANGUAGE
return jso
else:
# 先不drop
......@@ -148,8 +148,8 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
jso["_pdf_type"] = "OCR"
jso["pdf_meta"] = pdf_meta
jso["classify_time"] = classify_time
# jso["need_drop"] = True
# jso["drop_reason"] = DropReason.NOT_IS_TEXT_PDF
# jso["_need_drop"] = True
# jso["_drop_reason"] = DropReason.NOT_IS_TEXT_PDF
extra_info = {"classify_rules": []}
for condition, result in results.items():
if not result:
......@@ -162,7 +162,7 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
def drop_needdrop_pdf(jso: dict) -> dict:
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']} need drop",
file=sys.stderr,
......@@ -176,7 +176,7 @@ def pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
......@@ -203,7 +203,7 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
return jso
# 开始正式逻辑
s3_pdf_path = jso.get("file_location")
......@@ -220,8 +220,8 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
svgs_per_page_list = jso["pdf_meta"]["svgs_per_page"]
max_svgs = max(svgs_per_page_list)
if max_svgs > 3000:
jso["need_drop"] = True
jso["drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS
jso["_need_drop"] = True
jso["_drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS
else:
try:
save_path = s3_image_save_path
......@@ -244,10 +244,10 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
debug_mode=debug_mode,
)
if pdf_info_dict.get(
"need_drop", False
"_need_drop", False
): # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析
jso["need_drop"] = True
jso["drop_reason"] = pdf_info_dict["drop_reason"]
jso["_need_drop"] = True
jso["_drop_reason"] = pdf_info_dict["_drop_reason"]
else: # 正常返回,将 pdf_info_dict 压缩并存储
pdf_info_dict = JsonCompressor.compress_json(pdf_info_dict)
jso["pdf_intermediate_dict"] = pdf_info_dict
......@@ -269,7 +269,7 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
return jso
# 开始正式逻辑
s3_pdf_path = jso.get("file_location")
......@@ -295,8 +295,8 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d
svgs_per_page_list = jso["pdf_meta"]["svgs_per_page"]
max_svgs = max(svgs_per_page_list)
if max_svgs > 3000:
jso["need_drop"] = True
jso["drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS
jso["_need_drop"] = True
jso["_drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS
# elif total_page > 1000:
# jso['need_drop'] = True
# jso['drop_reason'] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES
......@@ -323,10 +323,10 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d
debug_mode=debug_mode,
)
if pdf_info_dict.get(
"need_drop", False
"_need_drop", False
): # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析
jso["need_drop"] = True
jso["drop_reason"] = pdf_info_dict["drop_reason"]
jso["_need_drop"] = True
jso["_drop_reason"] = pdf_info_dict["_drop_reason"]
else: # 正常返回,将 pdf_info_dict 压缩并存储
jso["parsed_results"] = convert_to_train_format(pdf_info_dict)
pdf_info_dict = JsonCompressor.compress_json(pdf_info_dict)
......
......@@ -17,7 +17,7 @@ def ocr_pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
......@@ -45,7 +45,7 @@ def ocr_pdf_intermediate_dict_to_markdown_with_para(jso: dict, mode, debug_mode=
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
......@@ -78,7 +78,7 @@ def ocr_pdf_intermediate_dict_to_markdown_with_para_and_pagination(jso: dict, de
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
......@@ -108,7 +108,7 @@ def ocr_pdf_intermediate_dict_to_markdown_with_para_for_qa(
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
......@@ -137,7 +137,7 @@ def ocr_pdf_intermediate_dict_to_standard_format(jso: dict, debug_mode=False) ->
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
......@@ -165,7 +165,7 @@ def ocr_pdf_intermediate_dict_to_standard_format_with_para(jso: dict, debug_mode
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True
......@@ -221,7 +221,7 @@ def ocr_parse_pdf_core(pdf_bytes, model_output_json_list, book_name, start_page_
# 专门用来跑被drop的pdf,跑完之后需要把need_drop字段置为false
def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
if not jso.get("need_drop", False):
if not jso.get("_need_drop", False):
return jso
else:
try:
......@@ -233,7 +233,7 @@ def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
)
jso["pdf_intermediate_dict"] = JsonCompressor.compress_json(pdf_info_dict)
jso["parse_time"] = parse_time
jso["need_drop"] = False
jso["_need_drop"] = False
except Exception as e:
jso = exception_handler(jso, e)
return jso
......@@ -244,7 +244,7 @@ def ocr_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
return jso
try:
pdf_bytes = get_pdf_bytes(jso)
......
......@@ -18,7 +18,7 @@ def txt_pdf_to_standard_format(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop")
jso["dropped"] = True
......@@ -46,7 +46,7 @@ def txt_pdf_to_mm_markdown_format(jso: dict, debug_mode=False) -> dict:
if debug_mode:
pass
else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False):
if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop")
jso["dropped"] = True
......
......@@ -62,6 +62,6 @@ def pdf_post_filter(page_info) -> tuple:
"""
bool_is_pseudo_single_column, extra_info = __is_pseudo_single_column(page_info)
if bool_is_pseudo_single_column:
return False, {"need_drop": True, "drop_reason": DropReason.PSEUDO_SINGLE_COLUMN, "extra_info": extra_info}
return False, {"_need_drop": True, "_drop_reason": DropReason.PSEUDO_SINGLE_COLUMN, "extra_info": extra_info}
return True, None
\ No newline at end of file
......@@ -3,7 +3,7 @@ from magic_pdf.libs.commons import fitz # pyMuPDF库
from magic_pdf.libs.coordinate_transform import get_scale_ratio
def parse_footnotes_by_model(page_ID: int, page: fitz.Page, json_from_DocXchain_obj: dict, md_bookname_save_path, debug_mode=False):
def parse_footnotes_by_model(page_ID: int, page: fitz.Page, json_from_DocXchain_obj: dict, md_bookname_save_path=None, debug_mode=False):
"""
:param page_ID: int类型,当前page在当前pdf文档中是第page_D页。
:param page :fitz读取的当前页的内容
......
......@@ -3,18 +3,16 @@ from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.libs.pdf_image_tools import cut_image
def cut_image_and_table(spans, page, page_id, book_name, save_path, img_s3_client):
def s3_return_path(type):
return join_path(book_name, type)
def cut_image_and_table(spans, page, page_id, pdf_bytes_md5, imageWriter):
def img_save_path(type):
return join_path(save_path, s3_return_path(type))
def return_path(type):
return join_path(pdf_bytes_md5, type)
for span in spans:
span_type = span['type']
if span_type == ContentType.Image:
span['image_path'] = cut_image(span['bbox'], page_id, page, img_save_path('images'), s3_return_path=s3_return_path('images'), img_s3_client=img_s3_client)
span['image_path'] = cut_image(span['bbox'], page_id, page, return_path=return_path('images'), imageWriter=imageWriter)
elif span_type == ContentType.Table:
span['image_path'] = cut_image(span['bbox'], page_id, page, img_save_path('tables'), s3_return_path=s3_return_path('tables'), img_s3_client=img_s3_client)
span['image_path'] = cut_image(span['bbox'], page_id, page, return_path=return_path('tables'), imageWriter=imageWriter)
return spans
......@@ -68,7 +68,7 @@ def pdf_filter(page:fitz.Page, text_blocks, table_bboxes, image_bboxes) -> tuple
"""
if __is_contain_color_background_rect(page, text_blocks, image_bboxes):
return False, {"need_drop": True, "drop_reason": DropReason.COLOR_BACKGROUND_TEXT_BOX}
return False, {"_need_drop": True, "_drop_reason": DropReason.COLOR_BACKGROUND_TEXT_BOX}
return True, None
\ No newline at end of file
from loguru import logger
from magic_pdf.dict2md.mkcontent import mk_universal_format
from magic_pdf.dict2md.ocr_mkcontent import make_standard_format_with_para
from magic_pdf.filter.pdf_classify_by_type import classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
from magic_pdf.libs.detect_language_from_model import get_language_from_model
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.json_compressor import JsonCompressor
from magic_pdf.spark.spark_api import parse_union_pdf, parse_ocr_pdf
class UNIPipe:
def __init__(self):
pass
def classify(self, pdf_bytes: bytes) -> str:
"""
根据pdf的元数据,判断是否是文本pdf,还是ocr pdf
"""
pdf_meta = pdf_meta_scan(pdf_bytes)
if pdf_meta.get("_need_drop", False): # 如果返回了需要丢弃的标志,则抛出异常
raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
else:
is_encrypted = pdf_meta["is_encrypted"]
is_needs_password = pdf_meta["is_needs_password"]
if is_encrypted or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理
raise Exception(f"pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}")
else:
is_text_pdf, results = classify(
pdf_meta["total_page"],
pdf_meta["page_width_pts"],
pdf_meta["page_height_pts"],
pdf_meta["image_info_per_page"],
pdf_meta["text_len_per_page"],
pdf_meta["imgs_per_page"],
pdf_meta["text_layout_per_page"],
)
if is_text_pdf:
return "txt"
else:
return "ocr"
def parse(self, pdf_bytes: bytes, image_writer, jso_useful_key) -> dict:
"""
根据pdf类型,解析pdf
"""
text_language = get_language_from_model(jso_useful_key['model_list'])
allow_language = ["zh", "en"] # 允许的语言,目前只允许简中和英文的
logger.info(f"pdf text_language is {text_language}")
if text_language not in allow_language: # 如果语言不在允许的语言中,则drop
raise Exception(f"pdf meta_scan need_drop,reason is {DropReason.NOT_ALLOW_LANGUAGE}")
else:
if jso_useful_key['_pdf_type'] == "txt":
pdf_mid_data = parse_union_pdf(pdf_bytes, jso_useful_key['model_list'], image_writer)
elif jso_useful_key['_pdf_type'] == "ocr":
pdf_mid_data = parse_ocr_pdf(pdf_bytes, jso_useful_key['model_list'], image_writer)
else:
raise Exception(f"pdf type is not txt or ocr")
return JsonCompressor.compress(pdf_mid_data)
def mk_uni_format(self, pdf_mid_data: str, img_buket_path: str) -> list:
"""
根据pdf类型,生成统一格式content_list
"""
pdf_mid_data = JsonCompressor.decompress_json(pdf_mid_data)
parse_type = pdf_mid_data["_parse_type"]
if parse_type == "txt":
content_list = mk_universal_format(pdf_mid_data, img_buket_path)
elif parse_type == "ocr":
content_list = make_standard_format_with_para(pdf_mid_data, img_buket_path)
return content_list
if __name__ == '__main__':
# 测试
pipe = UNIPipe()
pdf_bytes = open(r"D:\project\20231108code-clean\magic_pdf\tmp\unittest\download-pdfs\数学新星网\edu_00001544.pdf",
"rb").read()
pdf_type = pipe.classify(pdf_bytes)
logger.info(f"pdf_type is {pdf_type}")
......@@ -26,9 +26,9 @@ def get_bookid(jso: dict):
def exception_handler(jso: dict, e):
logger.exception(e)
jso["need_drop"] = True
jso["drop_reason"] = DropReason.Exception
jso["exception"] = f"ERROR: {e}"
jso["_need_drop"] = True
jso["_drop_reason"] = DropReason.Exception
jso["_exception"] = f"ERROR: {e}"
return jso
......
......@@ -12,27 +12,86 @@
其余部分至于构造s3cli, 获取ak,sk都在code-clean里写代码完成。不要反向依赖!!!
"""
from loguru import logger
from magic_pdf.io import AbsReaderWriter
from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
def parse_txt_pdf(pdf_bytes:bytes, pdf_models:list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args, **kwargs):
"""
解析文本类pdf
"""
pass
pdf_info_dict = parse_pdf_by_txt(
pdf_bytes,
pdf_models,
imageWriter,
start_page_id=start_page,
debug_mode=is_debug,
)
pdf_info_dict["parse_type"] = "txt"
return pdf_info_dict
def parse_ocr_pdf(pdf_bytes:bytes, pdf_models:list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args, **kwargs):
"""
解析ocr类pdf
"""
pass
pdf_info_dict = parse_pdf_by_ocr(
pdf_bytes,
pdf_models,
imageWriter,
start_page_id=start_page,
debug_mode=is_debug,
)
pdf_info_dict["_parse_type"] = "ocr"
return pdf_info_dict
def parse_union_pdf(pdf_bytes:bytes, pdf_models:list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args, **kwargs):
"""
ocr和文本混合的pdf,全部解析出来
"""
pass
\ No newline at end of file
def parse_pdf(method):
try:
return method(
pdf_bytes,
pdf_models,
imageWriter,
start_page_id=start_page,
debug_mode=is_debug,
)
except Exception as e:
logger.error(f"{method.__name__} error: {e}")
return None
pdf_info_dict = parse_pdf(parse_pdf_by_txt)
if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
if pdf_info_dict is None:
raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
else:
pdf_info_dict["_parse_type"] = "ocr"
else:
pdf_info_dict["_parse_type"] = "txt"
return pdf_info_dict
def spark_json_extractor(jso: dict) -> dict:
"""
从json中提取数据,返回一个dict
"""
return {
"_pdf_type": jso["_pdf_type"],
"model_list": jso["doc_layout_result"],
}
......@@ -16,3 +16,5 @@ en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_
zh_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.7.0/zh_core_web_sm-3.7.0-py3-none-any.whl
scikit-learn==1.4.1.post1
nltk==3.8.1
s3pathlib>=2.1.1
{
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1_score": 1.0,
"pdf间的平均编辑距离": 19.82051282051282,
"pdf间的平均bleu": 0.9002485609584511,
"阅读顺序编辑距离": 0.3176895306859206,
"分段准确率": 0.8989169675090253,
"行内公式准确率": {
"accuracy": 0.9782741738066095,
"precision": 0.9782741738066095,
"recall": 1.0,
"f1_score": 0.9890177880897139
},
"行内公式编辑距离": 0.0,
"行内公式bleu": 0.20340450120213166,
"行间公式准确率": {
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1_score": 1.0
},
"行间公式编辑距离": 0.0,
"行间公式bleu": 0.3662262622386575,
"丢弃文本准确率": {
"accuracy": 0.867870036101083,
"precision": 0.9064856711915535,
"recall": 0.9532117367168914,
"f1_score": 0.9292616930807885
},
"丢弃文本标签准确率": {
"color_background_header_txt_block": {
"precision": 0.0,
"recall": 0.0,
"f1-score": 0.0,
"support": 41.0
},
"rotate": {
"precision": 1.0,
"recall": 0.9682539682539683,
"f1-score": 0.9838709677419355,
"support": 63.0
},
"footnote": {
"precision": 1.0,
"recall": 0.883495145631068,
"f1-score": 0.9381443298969072,
"support": 103.0
},
"header": {
"precision": 1.0,
"recall": 1.0,
"f1-score": 1.0,
"support": 4.0
},
"on-image": {
"precision": 0.9947643979057592,
"recall": 1.0,
"f1-score": 0.9973753280839895,
"support": 380.0
},
"on-table": {
"precision": 1.0,
"recall": 0.9443609022556391,
"f1-score": 0.97138437741686,
"support": 665.0
},
"micro avg": {
"precision": 0.9982847341337907,
"recall": 0.9267515923566879,
"f1-score": 0.9611890999174236,
"support": 1256.0
}
},
"丢弃图片准确率": {
"accuracy": 0.8666666666666667,
"precision": 0.9285714285714286,
"recall": 0.9285714285714286,
"f1_score": 0.9285714285714286
},
"丢弃表格准确率": {
"accuracy": 0,
"precision": 0,
"recall": 0,
"f1_score": 0
}
}
\ No newline at end of file
No preview for this file type
......@@ -432,74 +432,7 @@ def handle_multi_deletion(test_page, test_page_tag, test_page_bbox, standard_pag
def check_json_files_in_zip_exist(zip_file_path, standard_json_path_in_zip, test_json_path_in_zip):
"""
检查ZIP文件中是否存在指定的JSON文件
"""
with zipfile.ZipFile(zip_file_path, 'r') as z:
# 获取ZIP文件中所有文件的列表
all_files_in_zip = z.namelist()
# 检查标准文件和测试文件是否都在ZIP文件中
if standard_json_path_in_zip not in all_files_in_zip or test_json_path_in_zip not in all_files_in_zip:
raise FileNotFoundError("One or both of the required JSON files are missing from the ZIP archive.")
def read_json_files_from_streams(standard_file_stream, test_file_stream):
"""
从文件流中读取JSON文件内容
"""
pdf_json_standard = [json.loads(line) for line in standard_file_stream]
pdf_json_test = [json.loads(line) for line in test_file_stream]
json_standard_origin = pd.DataFrame(pdf_json_standard)
json_test_origin = pd.DataFrame(pdf_json_test)
return json_standard_origin, json_test_origin
def read_json_files_from_zip(zip_file_path, standard_json_path_in_zip, test_json_path_in_zip):
"""
从ZIP文件中读取两个JSON文件并返回它们的DataFrame
"""
with zipfile.ZipFile(zip_file_path, 'r') as z:
with z.open(standard_json_path_in_zip) as standard_file_stream, \
z.open(test_json_path_in_zip) as test_file_stream:
standard_file_text_stream = TextIOWrapper(standard_file_stream, encoding='utf-8')
test_file_text_stream = TextIOWrapper(test_file_stream, encoding='utf-8')
json_standard_origin, json_test_origin = read_json_files_from_streams(
standard_file_text_stream, test_file_text_stream
)
return json_standard_origin, json_test_origin
def merge_json_data(json_test_df, json_standard_df):
"""
基于ID合并测试和标准数据集,并返回合并后的数据及存在性检查结果。
参数:
- json_test_df: 测试数据的DataFrame。
- json_standard_df: 标准数据的DataFrame。
返回:
- inner_merge: 内部合并的DataFrame,包含匹配的数据行。
- standard_exist: 标准数据存在性的Series。
- test_exist: 测试数据存在性的Series。
"""
test_data = json_test_df[['id', 'mid_json']].drop_duplicates(subset='id', keep='first').reset_index(drop=True)
standard_data = json_standard_df[['id', 'mid_json', 'pass_label']].drop_duplicates(subset='id', keep='first').reset_index(drop=True)
outer_merge = pd.merge(test_data, standard_data, on='id', how='outer')
outer_merge.columns = ['id', 'test_mid_json', 'standard_mid_json', 'pass_label']
standard_exist = outer_merge.standard_mid_json.notnull()
test_exist = outer_merge.test_mid_json.notnull()
inner_merge = pd.merge(test_data, standard_data, on='id', how='inner')
inner_merge.columns = ['id', 'test_mid_json', 'standard_mid_json', 'pass_label']
return inner_merge, standard_exist, test_exist
def consolidate_data(test_data, standard_data, key_path):
......@@ -533,6 +466,20 @@ def consolidate_data(test_data, standard_data, key_path):
return overall_data_standard, overall_data_test
def overall_calculate_metrics(inner_merge, json_test, json_standard,standard_exist, test_exist):
"""
计算整体的指标,包括准确率、精确率、召回率、F1值、平均编辑距离、平均BLEU得分、分段准确率、公式准确率、公式编辑距离、公式BLEU、丢弃文本准确率、丢弃文本标签准确率、丢弃图片准确率、丢弃表格准确率等。
Args:
inner_merge (dict): 包含merge信息的字典,包括pass_label和id等信息。
json_test (dict): 测试集的json数据。
json_standard (dict): 标准集的json数据。
standard_exist (list): 标准集中存在的id列表。
test_exist (list): 测试集中存在的id列表。
Returns:
dict: 包含整体指标值的字典。
"""
process_data_standard = process_equations_and_blocks(json_standard, is_standard=True)
process_data_test = process_equations_and_blocks(json_test, is_standard=False)
......@@ -739,9 +686,77 @@ def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origi
return result_dict
def check_json_files_in_zip_exist(zip_file_path, standard_json_path_in_zip, test_json_path_in_zip):
"""
检查ZIP文件中是否存在指定的JSON文件
"""
with zipfile.ZipFile(zip_file_path, 'r') as z:
# 获取ZIP文件中所有文件的列表
all_files_in_zip = z.namelist()
# 检查标准文件和测试文件是否都在ZIP文件中
if standard_json_path_in_zip not in all_files_in_zip or test_json_path_in_zip not in all_files_in_zip:
raise FileNotFoundError("One or both of the required JSON files are missing from the ZIP archive.")
def read_json_files_from_streams(standard_file_stream, test_file_stream):
"""
从文件流中读取JSON文件内容
"""
pdf_json_standard = [json.loads(line) for line in standard_file_stream]
pdf_json_test = [json.loads(line) for line in test_file_stream]
json_standard_origin = pd.DataFrame(pdf_json_standard)
json_test_origin = pd.DataFrame(pdf_json_test)
return json_standard_origin, json_test_origin
def save_results(result_dict,overall_report_dict,badcase_path,overall_path,):
def read_json_files_from_zip(zip_file_path, standard_json_path_in_zip, test_json_path_in_zip):
"""
从ZIP文件中读取两个JSON文件并返回它们的DataFrame
"""
with zipfile.ZipFile(zip_file_path, 'r') as z:
with z.open(standard_json_path_in_zip) as standard_file_stream, \
z.open(test_json_path_in_zip) as test_file_stream:
standard_file_text_stream = TextIOWrapper(standard_file_stream, encoding='utf-8')
test_file_text_stream = TextIOWrapper(test_file_stream, encoding='utf-8')
json_standard_origin, json_test_origin = read_json_files_from_streams(
standard_file_text_stream, test_file_text_stream
)
return json_standard_origin, json_test_origin
def merge_json_data(json_test_df, json_standard_df):
"""
基于ID合并测试和标准数据集,并返回合并后的数据及存在性检查结果。
参数:
- json_test_df: 测试数据的DataFrame。
- json_standard_df: 标准数据的DataFrame。
返回:
- inner_merge: 内部合并的DataFrame,包含匹配的数据行。
- standard_exist: 标准数据存在性的Series。
- test_exist: 测试数据存在性的Series。
"""
test_data = json_test_df[['id', 'mid_json']].drop_duplicates(subset='id', keep='first').reset_index(drop=True)
standard_data = json_standard_df[['id', 'mid_json', 'pass_label']].drop_duplicates(subset='id', keep='first').reset_index(drop=True)
outer_merge = pd.merge(test_data, standard_data, on='id', how='outer')
outer_merge.columns = ['id', 'test_mid_json', 'standard_mid_json', 'pass_label']
standard_exist = outer_merge.standard_mid_json.notnull()
test_exist = outer_merge.test_mid_json.notnull()
inner_merge = pd.merge(test_data, standard_data, on='id', how='inner')
inner_merge.columns = ['id', 'test_mid_json', 'standard_mid_json', 'pass_label']
return inner_merge, standard_exist, test_exist
def save_results(result_dict,overall_report_dict,badcase_path,overall_path, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url):
"""
将结果字典保存为JSON文件至指定路径。
......@@ -749,35 +764,46 @@ def save_results(result_dict,overall_report_dict,badcase_path,overall_path,):
- result_dict: 包含计算结果的字典。
- overall_path: 结果文件的保存路径,包括文件名。
"""
with open(overall_path, 'w', encoding='utf-8') as f:
# 将结果字典转换为JSON格式并写入文件
json.dump(overall_report_dict, f, ensure_ascii=False, indent=4)
final_overall_path = upload_to_s3(overall_path, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
overall_path_res = "OCR抽取方案整体评测指标结果请查看:" + final_overall_path
print(f'\033[31m{overall_path_res}\033[0m')
# 打开指定的文件以写入
with open(badcase_path, 'w', encoding='utf-8') as f:
# 将结果字典转换为JSON格式并写入文件
json.dump(result_dict, f, ensure_ascii=False, indent=4)
final_badcase_path = upload_to_s3(badcase_path, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
badcase_path_res = "OCR抽取方案评测badcase输出报告查看:" + final_badcase_path
print(f'\033[31m{badcase_path_res}\033[0m')
print(f"计算结果已经保存到文件:{badcase_path}")
with open(overall_path, 'w', encoding='utf-8') as f:
# 将结果字典转换为JSON格式并写入文件
json.dump(overall_report_dict, f, ensure_ascii=False, indent=4)
print(f"计算结果已经保存到文件:{overall_path}")
def upload_to_s3(file_path, bucket_name, s3_file_name,AWS_ACCESS_KEY,AWS_SECRET_KEY,END_POINT_URL):
def upload_to_s3(file_path, bucket_name, s3_directory, AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL):
"""
上传文件到Amazon S3
"""
s3 = boto3.client('s3',aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY,endpoint_url=END_POINT_URL)
# 创建S3客户端
s3 = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY, endpoint_url=END_POINT_URL)
try:
# 从文件路径中提取文件名
file_name = os.path.basename(file_path)
# 创建S3对象键,将s3_directory和file_name连接起来
s3_object_key = f"{s3_directory}/{file_name}" # 使用斜杠直接连接
# 上传文件到S3
s3.upload_file(file_path, bucket_name, s3_file_name)
print(f"文件 {s3_file_name} 成功上传到S3存储桶 {bucket_name} 中的路径 {file_path}")
s3.upload_file(file_path, bucket_name, s3_object_key)
s3_path = f"http://st.bigdata.shlab.tech/S3_Browser?output_path=s3://{bucket_name}/{s3_directory}/{file_name}"
return s3_path
#print(f"文件 {file_path} 成功上传到S3存储桶 {bucket_name} 中的目录 {s3_directory},文件名为 {file_name}")
except FileNotFoundError:
print(f"文件 {s3_file_name} 未找到,请检查文件路径是否正确。")
print(f"文件 {file_path} 未找到,请检查文件路径是否正确。")
except NoCredentialsError:
print("无法找到AWS凭证,请确认您的AWS访问密钥和密钥ID是否正确。")
except ClientError as e:
print(f"上传文件时发生错误:{e}")
def generate_filename(badcase_path,overall_path):
"""
生成带有当前时间戳的输出文件名。
......@@ -808,7 +834,8 @@ def compare_edit_distance(json_file, overall_report):
def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_data_path,s3_bucket_name=None, s3_file_name=None, AWS_ACCESS_KEY=None, AWS_SECRET_KEY=None, END_POINT_URL=None):
def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_data_path, s3_bucket_name=None, s3_file_directory=None,
aws_access_key=None, aws_secret_key=None, end_point_url=None):
"""
主函数,执行整个评估流程。
......@@ -819,7 +846,7 @@ def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_dat
- badcase_path: badcase文件的基础路径和文件名前缀。
- overall_path: overall文件的基础路径和文件名前缀。
- s3_bucket_name: S3桶名称(可选)。
- s3_file_name: S3上的文件名(可选)。
- s3_file_directory: S3上的文件保存目录(可选)。
- AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL: AWS访问凭证和端点URL(可选)。
"""
# 检查文件是否存在
......@@ -840,10 +867,19 @@ def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_dat
badcase_file,overall_file = generate_filename(badcase_path,overall_path)
# 保存结果到JSON文件
save_results(result_dict, overall_report_dict,badcase_file,overall_file)
#save_results(result_dict, overall_report_dict,badcase_file,overall_file)
save_results(result_dict, overall_report_dict,badcase_file,overall_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
result=compare_edit_distance(base_data_path, overall_report_dict)
print(result)
"""
if all([s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url]):
try:
upload_to_s3(badcase_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
upload_to_s3(overall_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
except Exception as e:
print(f"上传到S3时发生错误: {e}")
"""
#print(result)
assert result == 1
if __name__ == "__main__":
......@@ -855,12 +891,12 @@ if __name__ == "__main__":
parser.add_argument('overall_path', type=str, help='overall文件的基础路径和文件名前缀。')
parser.add_argument('base_data_path', type=str, help='基准文件的基础路径和文件名前缀。')
parser.add_argument('--s3_bucket_name', type=str, help='S3桶名称。', default=None)
parser.add_argument('--s3_file_name', type=str, help='S3上的文件名。', default=None)
parser.add_argument('--s3_file_directory', type=str, help='S3上的文件名。', default=None)
parser.add_argument('--AWS_ACCESS_KEY', type=str, help='AWS访问密钥。', default=None)
parser.add_argument('--AWS_SECRET_KEY', type=str, help='AWS秘密密钥。', default=None)
parser.add_argument('--END_POINT_URL', type=str, help='AWS端点URL。', default=None)
args = parser.parse_args()
main(args.standard_file, args.test_file, args.zip_file, args.badcase_path,args.overall_path,args.base_data_path,args.s3_bucket_name, args.s3_file_name, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
main(args.standard_file, args.test_file, args.zip_file, args.badcase_path,args.overall_path,args.base_data_path,args.s3_bucket_name, args.s3_file_directory, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
import json
import pandas as pd
import numpy as np
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import argparse
import os
from sklearn.metrics import classification_report
from sklearn import metrics
from datetime import datetime
import boto3
from botocore.exceptions import NoCredentialsError, ClientError
from io import TextIOWrapper
import zipfile
def Levenshtein_Distance(str1, str2):
"""
计算并返回两个字符串之间的Levenshtein编辑距离。
参数:
- str1: 字符串,第一个比较字符串。
- str2: 字符串,第二个比较字符串。
返回:
- int: str1和str2之间的Levenshtein距离。
方法:
- 使用动态规划构建一个矩阵(matrix),其中matrix[i][j]表示str1的前i个字符和str2的前j个字符之间的Levenshtein距离。
- 矩阵的初始值设定为边界情况,即一个字符串与空字符串之间的距离。
- 遍历矩阵填充每个格子的值,根据字符是否相等选择插入、删除或替换操作的最小代价。
"""
# 初始化矩阵,大小为(len(str1)+1) x (len(str2)+1),边界情况下的距离为i和j
matrix = [[i + j for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
# 遍历str1和str2的每个字符,更新矩阵中的值
for i in range(1, len(str1) + 1):
for j in range(1, len(str2) + 1):
# 如果当前字符相等,替换代价为0;否则为1
d = 0 if (str1[i - 1] == str2[j - 1]) else 1
# 更新当前位置的值为从str1[i]转换到str2[j]的最小操作数
matrix[i][j] = min(matrix[i - 1][j] + 1, # 删除操作
matrix[i][j - 1] + 1, # 插入操作
matrix[i - 1][j - 1] + d) # 替换操作
# 返回右下角的值,即str1和str2之间的Levenshtein距离
return matrix[len(str1)][len(str2)]
def bbox_offset(b_t, b_s):
"""
判断两个边界框(bounding box)之间的重叠程度是否符合给定的标准。
参数:
- b_t: 测试文档中的边界框(bbox),格式为(x1, y1, x2, y2),
其中(x1, y1)是左上角的坐标,(x2, y2)是右下角的坐标。
- b_s: 标准文档中的边界框(bbox),格式同上。
返回:
- True: 如果两个边界框的重叠面积与两个边界框合计面积的差的比例超过0.95,
表明它们足够接近。
- False: 否则,表示两个边界框不足够接近。
注意:
- 函数首先计算两个bbox的交集区域,如果这个区域的面积相对于两个bbox的面积差非常大,
则认为这两个bbox足够接近。
- 如果交集区域的计算结果导致无效区域(比如宽度或高度为负值),或者分母为0(即两个bbox完全不重叠),
则函数会返回False。
"""
# 分别提取两个bbox的坐标
x1_t, y1_t, x2_t, y2_t = b_t
x1_s, y1_s, x2_s, y2_s = b_s
# 计算两个bbox交集区域的坐标
x1 = max(x1_t, x1_s)
x2 = min(x2_t, x2_s)
y1 = max(y1_t, y1_s)
y2 = min(y2_t, y2_s)
# 如果计算出的交集区域有效,则计算其面积
if x2 > x1 and y2 > y1:
area_overlap = (x2 - x1) * (y2 - y1)
else:
# 交集区域无效,视为无重叠
area_overlap = 0
# 计算两个bbox的总面积,减去重叠部分避免重复计算
area_t = (x2_t - x1_t) * (y2_t - y1_t) + (x2_s - x1_s) * (y2_s - y1_s) - area_overlap
# 判断重叠面积是否符合标准
# return area_overlap / total_area
if area_t-area_overlap==0 or area_overlap/area_t>0.95:
return True
else:
return False
def equations_indicator(test_equations_bboxs, standard_equations_bboxs, test_equations, standard_equations):
"""
根据边界框匹配的方程计算编辑距离和BLEU分数。
参数:
- test_equations_bboxs: 测试方程的边界框列表。
- standard_equations_bboxs: 标准方程的边界框列表。
- test_equations: 测试方程的列表。
- standard_equations: 标准方程的列表。
返回:
- 一个元组,包含匹配方程的平均Levenshtein编辑距离和BLEU分数。
"""
# 初始化匹配方程列表
test_match_equations = []
standard_match_equations = []
# 匹配方程基于边界框重叠
for index, (test_bbox, standard_bbox) in enumerate(zip(test_equations_bboxs, standard_equations_bboxs)):
if not (test_bbox and standard_bbox): # 跳过任一空列表
continue
for i, sb in enumerate(standard_bbox):
for j, tb in enumerate(test_bbox):
if bbox_offset(sb, tb):
standard_match_equations.append(standard_equations[index][i])
test_match_equations.append(test_equations[index][j])
break # 找到第一个匹配后即跳出循环
# 使用Levenshtein距离和BLEU分数计算编辑距离
dis = [Levenshtein_Distance(a, b) for a, b in zip(test_match_equations, standard_match_equations) if a and b]
# 应用平滑函数计算BLEU分数
sm_func = SmoothingFunction().method1
bleu = [sentence_bleu([a.split()], b.split(), smoothing_function=sm_func) for a, b in zip(test_match_equations, standard_match_equations) if a and b]
# 计算平均编辑距离和BLEU分数,处理空列表情况
equations_edit = np.mean(dis) if dis else float('0.0')
equations_bleu = np.mean(bleu) if bleu else float('0.0')
return equations_edit, equations_bleu
def bbox_match_indicator_general(test_bboxs_list, standard_bboxs_list):
"""
计算边界框匹配指标,支持掉落的表格、图像和文本块。
此版本的函数专注于计算基于边界框的匹配指标,而不涉及标签匹配逻辑。
参数:
- test_bboxs: 测试集的边界框列表,按页面组织。
- standard_bboxs: 标准集的边界框列表,按页面组织。
返回:
- 一个字典,包含准确度、精确度、召回率和F1分数。
"""
# 如果两个列表都完全为空,返回0值指标
if all(len(page) == 0 for page in test_bboxs_list) and all(len(page) == 0 for page in standard_bboxs_list):
return {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1_score': 0}
matched_bbox = []
matched_standard_bbox = []
for test_page, standard_page in zip(test_bboxs_list, standard_bboxs_list):
test_page_bbox, standard_page_bbox = [], []
for standard_bbox in standard_page:
if len(standard_bbox) != 4:
continue
matched = False
for test_bbox in test_page:
if len(test_bbox) == 4 and bbox_offset(standard_bbox, test_bbox):
matched = True
break
test_page_bbox.append(int(matched))
standard_page_bbox.append(1)
# 后处理以处理多删情况,保持原逻辑不变
diff_num = len(test_page) + test_page_bbox.count(0) - len(standard_page)
if diff_num > 0:
test_page_bbox.extend([1] * diff_num)
standard_page_bbox.extend([0] * diff_num)
matched_bbox.extend(test_page_bbox)
matched_standard_bbox.extend(standard_page_bbox)
block_report = {
'accuracy': metrics.accuracy_score(matched_standard_bbox, matched_bbox),
'precision': metrics.precision_score(matched_standard_bbox, matched_bbox, zero_division=0),
'recall': metrics.recall_score(matched_standard_bbox, matched_bbox, zero_division=0),
'f1_score': metrics.f1_score(matched_standard_bbox, matched_bbox, zero_division=0)
}
return block_report
def bbox_match_indicator_dropped_text_block(test_dropped_text_bboxs, standard_dropped_text_bboxs, standard_dropped_text_tag, test_dropped_text_tag):
"""
计算丢弃文本块的边界框匹配相关指标,包括准确率、精确率、召回率和F1分数,
同时也计算文本块标签的匹配指标。
参数:
- test_dropped_text_bboxs: 测试集的丢弃文本块边界框列表
- standard_dropped_text_bboxs: 标准集的丢弃文本块边界框列表
- standard_dropped_text_tag: 标准集的丢弃文本块标签列表
- test_dropped_text_tag: 测试集的丢弃文本块标签列表
返回:
- 一个包含边界框匹配指标和文本块标签匹配指标的元组
"""
test_text_bbox, standard_text_bbox = [], []
test_tag, standard_tag = [], []
for index, (test_page, standard_page) in enumerate(zip(test_dropped_text_bboxs, standard_dropped_text_bboxs)):
# 初始化每个页面的结果列表
test_page_tag, standard_page_tag = [], []
test_page_bbox, standard_page_bbox = [], []
for i, standard_bbox in enumerate(standard_page):
matched = False
for j, test_bbox in enumerate(test_page):
if bbox_offset(standard_bbox, test_bbox):
# 匹配成功,记录标签和边界框匹配结果
matched = True
test_page_tag.append(test_dropped_text_tag[index][j])
test_page_bbox.append(1)
break
if not matched:
# 未匹配,记录'None'和边界框未匹配结果
test_page_tag.append('None')
test_page_bbox.append(0)
# 标准边界框和标签总是被视为匹配的
standard_page_tag.append(standard_dropped_text_tag[index][i])
standard_page_bbox.append(1)
# 处理可能的多删情况
handle_multi_deletion(test_page, test_page_tag, test_page_bbox, standard_page_tag, standard_page_bbox)
# 合并当前页面的结果到整体结果中
test_tag.extend(test_page_tag)
standard_tag.extend(standard_page_tag)
test_text_bbox.extend(test_page_bbox)
standard_text_bbox.extend(standard_page_bbox)
# 计算和返回匹配指标
if not standard_text_bbox or not test_text_bbox:
# print("警告:边界框列表为空,跳过性能指标的计算。")
text_block_report = {
'accuracy': np.nan,
'precision': np.nan,
'recall': np.nan,
'f1_score': np.nan
}
else:
text_block_report = {
'accuracy': metrics.accuracy_score(standard_text_bbox, test_text_bbox),
'precision': metrics.precision_score(standard_text_bbox, test_text_bbox, zero_division=0),
'recall': metrics.recall_score(standard_text_bbox, test_text_bbox, zero_division=0),
'f1_score': metrics.f1_score(standard_text_bbox, test_text_bbox, zero_division=0)
}
# 对于classification_report,确保至少有一个非'None'标签存在
labels = list(set(standard_tag) - {'None'})
if labels:
text_block_tag_report = classification_report(y_true=standard_tag, y_pred=test_tag, labels=labels, output_dict=True, zero_division=0)
# 删除不需要的平均值报告,以简化输出
text_block_tag_report.pop("macro avg", None)
text_block_tag_report.pop("weighted avg", None)
else:
# print("警告:无有效标签进行匹配,跳过标签匹配指标的计算。")
text_block_tag_report = {}
return text_block_report, text_block_tag_report
def handle_multi_deletion(test_page, test_page_tag, test_page_bbox, standard_page_tag, standard_page_bbox):
"""
处理多删情况,即测试页面的边界框或标签数量多于标准页面。
"""
excess_count = len(test_page) + test_page_bbox.count(0) - len(standard_page_tag)
if excess_count > 0:
# 对于多出的项,将它们视为正确匹配的边界框,但标签视为'None'
test_page_bbox.extend([1] * excess_count)
standard_page_bbox.extend([0] * excess_count)
test_page_tag.extend(['None'] * excess_count)
standard_page_tag.extend(['None'] * excess_count)
def read_json_files(standard_file, test_file):
"""
读取JSON文件内容
"""
with open(standard_file, 'r', encoding='utf-8') as sf:
pdf_json_standard = [json.loads(line) for line in sf]
with open(test_file, 'r', encoding='utf-8') as tf:
pdf_json_test = [json.loads(line) for line in tf]
json_standard_origin = pd.DataFrame(pdf_json_standard)
json_test = pd.DataFrame(pdf_json_test)
return json_standard_origin, json_test
def merge_json_data(json_test_df, json_standard_df):
"""
基于ID合并测试和标准数据集,并返回合并后的数据及存在性检查结果。
参数:
- json_test_df: 测试数据的DataFrame。
- json_standard_df: 标准数据的DataFrame。
返回:
- inner_merge: 内部合并的DataFrame,包含匹配的数据行。
- standard_exist: 标准数据存在性的Series。
- test_exist: 测试数据存在性的Series。
"""
test_data = json_test_df[['id', 'mid_json']].drop_duplicates(subset='id', keep='first').reset_index(drop=True)
standard_data = json_standard_df[['id', 'mid_json', 'pass_label']].drop_duplicates(subset='id', keep='first').reset_index(drop=True)
outer_merge = pd.merge(test_data, standard_data, on='id', how='outer')
outer_merge.columns = ['id', 'test_mid_json', 'standard_mid_json', 'pass_label']
standard_exist = outer_merge.standard_mid_json.notnull()
test_exist = outer_merge.test_mid_json.notnull()
inner_merge = pd.merge(test_data, standard_data, on='id', how='inner')
inner_merge.columns = ['id', 'test_mid_json', 'standard_mid_json', 'pass_label']
return inner_merge, standard_exist, test_exist
def process_equations_and_blocks(json_data):
"""
处理JSON数据,提取公式、文本块、图片块和表格块的边界框和文本信息。
参数:
- json_data: 列表,包含标准文档或测试文档的JSON数据。
返回:
- 字典,包含处理后的数据。
"""
equations_bboxs = {"inline": [], "interline": []}
equations_texts = {"inline": [], "interline": []}
dropped_bboxs = {"text": [], "image": [], "table": []}
dropped_tags = {"text": []}
para_texts = []
para_nums = []
preproc_nums = []
for i in json_data:
mid_json = pd.DataFrame(i).iloc[:,:-1]
page_data = {
"equations_bboxs_list": {"inline": [], "interline": []},
"equations_texts_list": {"inline": [], "interline": []},
"dropped_bboxs_list": {"text": [], "image": [], "table": []},
"dropped_tags_list": {"text": []},
"para_texts_list": [],
"para_nums_list": [],
"preproc_nums_list":[]
}
for eq_type in ["inline", "interline"]:
for equations in mid_json.loc[f"{eq_type}_equations", :]:
bboxs = [eq['bbox'] for eq in equations]
texts = [eq['latex_text'] for eq in equations]
page_data["equations_bboxs_list"][eq_type].append(bboxs)
page_data["equations_texts_list"][eq_type].append(texts)
equations_bboxs["inline"].append(page_data["equations_bboxs_list"]["inline"])
equations_bboxs["interline"].append(page_data["equations_bboxs_list"]["interline"])
equations_texts["inline"].append(page_data["equations_texts_list"]["inline"])
equations_texts["interline"].append(page_data["equations_texts_list"]["interline"])
# 提取丢弃的文本块信息
for dropped_text_blocks in mid_json.loc['droped_text_block',:]:
bboxs, tags = [], []
for block in dropped_text_blocks:
bboxs.append(block['bbox'])
tags.append(block.get('tag', 'None'))
page_data["dropped_bboxs_list"]["text"].append(bboxs)
page_data["dropped_tags_list"]["text"].append(tags)
dropped_bboxs["text"].append(page_data["dropped_bboxs_list"]["text"])
dropped_tags["text"].append(page_data["dropped_tags_list"]["text"])
# 同时处理删除的图片块和表格块
for block_type in ['image', 'table']:
# page_blocks_list = []
for blocks in mid_json.loc[f'droped_{block_type}_block', :]:
# 如果是标准数据,直接添加整个块的列表
page_data["dropped_bboxs_list"][block_type].append(blocks)
# 将当前页面的块边界框列表添加到结果字典中
dropped_bboxs['image'].append(page_data["dropped_bboxs_list"]['image'])
dropped_bboxs['table'].append(page_data["dropped_bboxs_list"]['table'])
# 处理段落
for para_blocks in mid_json.loc['para_blocks', :]:
page_data["para_nums_list"].append(len(para_blocks)) # 计算段落数
for para_block in para_blocks:
page_data["para_texts_list"].append(para_block['text'])
for preproc_blocks in mid_json.loc['preproc_blocks', :]:
numbers=[]
for preproc_block in preproc_blocks:
numbers.append(preproc_block['number'])
page_data["preproc_nums_list"].append(numbers)
para_texts.append(page_data["para_texts_list"])
para_nums.append(page_data["para_nums_list"])
preproc_nums.append(page_data["preproc_nums_list"])
return {
"equations_bboxs": equations_bboxs,
"equations_texts": equations_texts,
"dropped_bboxs": dropped_bboxs,
"dropped_tags": dropped_tags,
"para_texts": para_texts,
"para_nums": para_nums,
"preproc_nums": preproc_nums
}
def consolidate_data(test_data, standard_data, key_path):
"""
Consolidates data from test and standard datasets based on the provided key path.
:param test_data: Dictionary containing the test dataset.
:param standard_data: Dictionary containing the standard dataset.
:param key_path: List of keys leading to the desired data within the dictionaries.
:return: List containing all items from both test and standard data at the specified key path.
"""
# Initialize an empty list to hold the consolidated data
overall_data_standard = []
overall_data_test = []
# Helper function to recursively navigate through the dictionaries based on the key path
def extract_data(source_data, keys):
for key in keys[:-1]:
source_data = source_data.get(key, {})
return source_data.get(keys[-1], [])
for data in extract_data(standard_data, key_path):
# 假设每个 single_table_tags 已经是一个列表,直接将它的元素添加到总列表中
overall_data_standard.extend(data)
for data in extract_data(test_data, key_path):
overall_data_test.extend(data)
# Extract and extend the overall data list with items from both test and standard datasets
return overall_data_standard, overall_data_test
def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origin):
"""
计算指标
"""
# 创建ID到file_id的映射
id_to_file_id_map = pd.Series(json_standard_origin.file_id.values, index=json_standard_origin.id).to_dict()
# 处理标准数据和测试数据
process_data_standard = process_equations_and_blocks(json_standard)
process_data_test = process_equations_and_blocks(json_test)
# 从inner_merge中筛选出pass_label为'yes'的数据
test_para_text = np.asarray(process_data_test['para_texts'], dtype=object)[inner_merge['pass_label'] == 'yes']
standard_para_text = np.asarray(process_data_standard['para_texts'], dtype=object)[inner_merge['pass_label'] == 'yes']
ids_yes = inner_merge['id'][inner_merge['pass_label'] == 'yes'].tolist()
pdf_dis = {}
pdf_bleu = {}
# 对pass_label为'yes'的数据计算编辑距离和BLEU得分
for idx, (a, b, id) in enumerate(zip(test_para_text, standard_para_text, ids_yes)):
a1 = ''.join(a)
b1 = ''.join(b)
pdf_dis[id] = Levenshtein_Distance(a, b)
pdf_bleu[id] = sentence_bleu([a1], b1)
result_dict = {}
acc_para=[]
# 对所有数据计算其他指标
for index, id_value in enumerate(inner_merge['id'].tolist()):
result = {}
# 增加file_id到结果中
file_id = id_to_file_id_map.get(id_value, "Unknown")
result['file_id'] = file_id
# 根据id判断是否需要计算pdf_dis和pdf_bleu
if id_value in ids_yes:
result['pdf_dis'] = pdf_dis[id_value]
result['pdf_bleu'] = pdf_bleu[id_value]
# 阅读顺序编辑距离的均值
preproc_num_dis=[]
for a,b in zip(process_data_test['preproc_nums'][index],process_data_standard['preproc_nums'][index]):
preproc_num_dis.append(Levenshtein_Distance(a,b))
result['阅读顺序编辑距离']=np.mean(preproc_num_dis)
# 计算分段准确率
single_test_para_num = np.array(process_data_test['para_nums'][index])
single_standard_para_num = np.array(process_data_standard['para_nums'][index])
acc_para.append(np.mean(single_test_para_num == single_standard_para_num))
result['分段准确率'] = acc_para[index]
# 行内公式准确率和编辑距离、bleu
result['行内公式准确率'] = bbox_match_indicator_general(
process_data_test["equations_bboxs"]["inline"][index],
process_data_standard["equations_bboxs"]["inline"][index])
result['行内公式编辑距离'], result['行内公式bleu'] = equations_indicator(
process_data_test["equations_bboxs"]["inline"][index],
process_data_standard["equations_bboxs"]["inline"][index],
process_data_test["equations_texts"]["inline"][index],
process_data_standard["equations_texts"]["inline"][index])
# 行间公式准确率和编辑距离、bleu
result['行间公式准确率'] = bbox_match_indicator_general(
process_data_test["equations_bboxs"]["interline"][index],
process_data_standard["equations_bboxs"]["interline"][index])
result['行间公式编辑距离'], result['行间公式bleu'] = equations_indicator(
process_data_test["equations_bboxs"]["interline"][index],
process_data_standard["equations_bboxs"]["interline"][index],
process_data_test["equations_texts"]["interline"][index],
process_data_standard["equations_texts"]["interline"][index])
# 丢弃文本准确率,丢弃文本标签准确率
result['丢弃文本准确率'], result['丢弃文本标签准确率'] = bbox_match_indicator_dropped_text_block(
process_data_test["dropped_bboxs"]["text"][index],
process_data_standard["dropped_bboxs"]["text"][index],
process_data_standard["dropped_tags"]["text"][index],
process_data_test["dropped_tags"]["text"][index])
# 丢弃图片准确率
result['丢弃图片准确率'] = bbox_match_indicator_general(
process_data_test["dropped_bboxs"]["image"][index],
process_data_standard["dropped_bboxs"]["image"][index])
# 丢弃表格准确率
result['丢弃表格准确率'] = bbox_match_indicator_general(
process_data_test["dropped_bboxs"]["table"][index],
process_data_standard["dropped_bboxs"]["table"][index])
# 将结果存入result_dict
result_dict[id_value] = result
return result_dict
def overall_calculate_metrics(inner_merge, json_test, json_standard,standard_exist, test_exist):
"""
计算整体指标:包括准确性、精确度、召回率、F1分数以及不同方面的详细指标。
参数:
- inner_merge: 合并后的内部数据,包含测试和标准数据的合并结果。
- json_test: 测试数据的JSON格式。
- json_standard: 标准数据的JSON格式。
- standard_exist: 标准存在的标签数据。
- test_exist: 测试存在的标签数据。
返回值:
- overall_report: 包含各种指标的字典。
"""
# 处理标准数据和测试数据,提取方程式和块
process_data_standard = process_equations_and_blocks(json_standard)
process_data_test = process_equations_and_blocks(json_test)
# 初始化整体报告,并计算基础指标
overall_report = {}
overall_report['accuracy']=metrics.accuracy_score(standard_exist,test_exist)
overall_report['precision']=metrics.precision_score(standard_exist,test_exist)
overall_report['recall']=metrics.recall_score(standard_exist,test_exist)
overall_report['f1_score']=metrics.f1_score(standard_exist,test_exist)
overall_report
# 提取通过标签的数据,并计算编辑距离和BLEU得分
test_para_text = np.asarray(process_data_test['para_texts'], dtype=object)[inner_merge['pass_label'] == 'yes']
standard_para_text = np.asarray(process_data_standard['para_texts'], dtype=object)[inner_merge['pass_label'] == 'yes']
ids_yes = inner_merge['id'][inner_merge['pass_label'] == 'yes'].tolist()
pdf_dis = {}
pdf_bleu = {}
for idx,(a, b, id) in enumerate(zip(test_para_text, standard_para_text, ids_yes)):
a1 = ''.join(a)
b1 = ''.join(b)
pdf_dis[id] = Levenshtein_Distance(a, b)
pdf_bleu[id] = sentence_bleu([a1], b1)
overall_report['pdf间的平均编辑距离'] = np.mean(list(pdf_dis.values()))
overall_report['pdf间的平均bleu'] = np.mean(list(pdf_bleu.values()))
# 合并数据中的方程式bbox和inline数据
overall_equations_bboxs_inline_standard,overall_equations_bboxs_inline_test = consolidate_data(process_data_test, process_data_standard, ["equations_bboxs", "inline"])
# 合并数据中的方程式文本和inline数据
overall_equations_texts_inline_standard,overall_equations_texts_inline_test = consolidate_data(process_data_test, process_data_standard, ["equations_texts", "inline"])
# 合并数据中的方程式bbox和interline数据
overall_equations_bboxs_interline_standard,overall_equations_bboxs_interline_test = consolidate_data(process_data_test, process_data_standard, ["equations_bboxs", "interline"])
# 合并数据中的方程式文本和interline数据
overall_equations_texts_interline_standard,overall_equations_texts_interline_test = consolidate_data(process_data_test, process_data_standard, ["equations_texts", "interline"])
# 合并丢弃的bbox和text数据
overall_dropped_bboxs_text_standard,overall_dropped_bboxs_text_test = consolidate_data(process_data_test, process_data_standard, ["dropped_bboxs","text"])
# 合并丢弃的tags和text数据
overall_dropped_tags_text_standard,overall_dropped_tags_text_test = consolidate_data(process_data_test, process_data_standard, ["dropped_tags","text"])
# 合并丢弃的bbox和image数据
overall_dropped_bboxs_image_standard,overall_dropped_bboxs_image_test = consolidate_data(process_data_test, process_data_standard, ["dropped_bboxs","image"])
# 合并丢弃的bbox和table数据
overall_dropped_bboxs_table_standard,overall_dropped_bboxs_table_test=consolidate_data(process_data_test, process_data_standard,["dropped_bboxs","table"])
# 合并阅读顺序的编辑距离
overall_preproc_standard,overall_preproc_test = consolidate_data(process_data_test, process_data_standard, ["preproc_nums"])
# 计算测试和标准数据的段落数量
para_nums_test = process_data_test['para_nums']
para_nums_standard=process_data_standard['para_nums']
overall_para_nums_standard = [item for sublist in para_nums_standard for item in (sublist if isinstance(sublist, list) else [sublist])]
overall_para_nums_test = [item for sublist in para_nums_test for item in (sublist if isinstance(sublist, list) else [sublist])]
preproc_num_dis=[]
for a,b in zip(overall_preproc_standard,overall_preproc_test):
preproc_num_dis.append(Levenshtein_Distance(a,b))
overall_report['阅读顺序编辑距离']=np.mean(preproc_num_dis)
# 计算段落匹配准确性
test_para_num=np.array(overall_para_nums_test)
standard_para_num=np.array(overall_para_nums_standard)
acc_para=np.mean(test_para_num==standard_para_num)
overall_report['分段准确率'] = acc_para
# 计算并更新报告中的各种指标
overall_report['行内公式准确率'] = bbox_match_indicator_general(
overall_equations_bboxs_inline_test,
overall_equations_bboxs_inline_standard)
overall_report['行内公式编辑距离'], overall_report['行内公式bleu'] = equations_indicator(
overall_equations_bboxs_inline_test,
overall_equations_bboxs_inline_standard,
overall_equations_texts_inline_test,
overall_equations_texts_inline_standard)
overall_report['行间公式准确率'] = bbox_match_indicator_general(
overall_equations_bboxs_interline_test,
overall_equations_bboxs_interline_standard)
overall_report['行间公式编辑距离'], overall_report['行间公式bleu'] = equations_indicator(
overall_equations_bboxs_interline_test,
overall_equations_bboxs_interline_standard,
overall_equations_texts_interline_test,
overall_equations_texts_interline_standard)
overall_report['丢弃文本准确率'], overall_report['丢弃文本标签准确率'] = bbox_match_indicator_dropped_text_block(
overall_dropped_bboxs_text_test,
overall_dropped_bboxs_text_standard,
overall_dropped_tags_text_standard,
overall_dropped_tags_text_test)
overall_report['丢弃图片准确率'] = bbox_match_indicator_general(
overall_dropped_bboxs_image_test,
overall_dropped_bboxs_image_standard)
overall_report['丢弃表格准确率'] = bbox_match_indicator_general(
overall_dropped_bboxs_table_test,
overall_dropped_bboxs_table_standard)
return overall_report
def check_json_files_in_zip_exist(zip_file_path, standard_json_path_in_zip, test_json_path_in_zip):
"""
检查ZIP文件中是否存在指定的JSON文件
"""
with zipfile.ZipFile(zip_file_path, 'r') as z:
# 获取ZIP文件中所有文件的列表
all_files_in_zip = z.namelist()
# 检查标准文件和测试文件是否都在ZIP文件中
if standard_json_path_in_zip not in all_files_in_zip or test_json_path_in_zip not in all_files_in_zip:
raise FileNotFoundError("One or both of the required JSON files are missing from the ZIP archive.")
def read_json_files_from_streams(standard_file_stream, test_file_stream):
"""
从文件流中读取JSON文件内容
"""
pdf_json_standard = [json.loads(line) for line in standard_file_stream]
pdf_json_test = [json.loads(line) for line in test_file_stream]
json_standard_origin = pd.DataFrame(pdf_json_standard)
json_test_origin = pd.DataFrame(pdf_json_test)
return json_standard_origin, json_test_origin
def read_json_files_from_zip(zip_file_path, standard_json_path_in_zip, test_json_path_in_zip):
"""
从ZIP文件中读取两个JSON文件并返回它们的DataFrame
"""
with zipfile.ZipFile(zip_file_path, 'r') as z:
with z.open(standard_json_path_in_zip) as standard_file_stream, \
z.open(test_json_path_in_zip) as test_file_stream:
standard_file_text_stream = TextIOWrapper(standard_file_stream, encoding='utf-8')
test_file_text_stream = TextIOWrapper(test_file_stream, encoding='utf-8')
json_standard_origin, json_test_origin = read_json_files_from_streams(
standard_file_text_stream, test_file_text_stream
)
return json_standard_origin, json_test_origin
def merge_json_data(json_test_df, json_standard_df):
"""
基于ID合并测试和标准数据集,并返回合并后的数据及存在性检查结果。
参数:
- json_test_df: 测试数据的DataFrame。
- json_standard_df: 标准数据的DataFrame。
返回:
- inner_merge: 内部合并的DataFrame,包含匹配的数据行。
- standard_exist: 标准数据存在性的Series。
- test_exist: 测试数据存在性的Series。
"""
test_data = json_test_df[['id', 'mid_json']].drop_duplicates(subset='id', keep='first').reset_index(drop=True)
standard_data = json_standard_df[['id', 'mid_json', 'pass_label']].drop_duplicates(subset='id', keep='first').reset_index(drop=True)
outer_merge = pd.merge(test_data, standard_data, on='id', how='outer')
outer_merge.columns = ['id', 'test_mid_json', 'standard_mid_json', 'pass_label']
standard_exist = outer_merge.standard_mid_json.notnull()
test_exist = outer_merge.test_mid_json.notnull()
inner_merge = pd.merge(test_data, standard_data, on='id', how='inner')
inner_merge.columns = ['id', 'test_mid_json', 'standard_mid_json', 'pass_label']
return inner_merge, standard_exist, test_exist
def save_results(result_dict,overall_report_dict,badcase_path,overall_path, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url):
"""
将结果字典保存为JSON文件至指定路径。
参数:
- result_dict: 包含计算结果的字典。
- overall_path: 结果文件的保存路径,包括文件名。
"""
with open(overall_path, 'w', encoding='utf-8') as f:
# 将结果字典转换为JSON格式并写入文件
json.dump(overall_report_dict, f, ensure_ascii=False, indent=4)
final_overall_path = upload_to_s3(overall_path, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
overall_path_res = "文本型PDF抽取方案整体评测指标结果请查看:" + final_overall_path
print(f'\033[31m{overall_path_res}\033[0m')
# 打开指定的文件以写入
with open(badcase_path, 'w', encoding='utf-8') as f:
# 将结果字典转换为JSON格式并写入文件
json.dump(result_dict, f, ensure_ascii=False, indent=4)
final_badcase_path = upload_to_s3(badcase_path, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
badcase_path_res = "文本型PDF抽取方案评测badcase输出报告查看:" + final_badcase_path
print(f'\033[31m{badcase_path_res}\033[0m')
def upload_to_s3(file_path, bucket_name, s3_directory, AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL):
"""
上传文件到Amazon S3
"""
# 创建S3客户端
s3 = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY, endpoint_url=END_POINT_URL)
try:
# 从文件路径中提取文件名
file_name = os.path.basename(file_path)
# 创建S3对象键,将s3_directory和file_name连接起来
s3_object_key = f"{s3_directory}/{file_name}" # 使用斜杠直接连接
# 上传文件到S3
s3.upload_file(file_path, bucket_name, s3_object_key)
s3_path = f"http://st.bigdata.shlab.tech/S3_Browser?output_path=s3://{bucket_name}/{s3_directory}/{file_name}"
return s3_path
#print(f"文件 {file_path} 成功上传到S3存储桶 {bucket_name} 中的目录 {s3_directory},文件名为 {file_name}")
except FileNotFoundError:
print(f"文件 {file_path} 未找到,请检查文件路径是否正确。")
except NoCredentialsError:
print("无法找到AWS凭证,请确认您的AWS访问密钥和密钥ID是否正确。")
except ClientError as e:
print(f"上传文件时发生错误:{e}")
def generate_filename(badcase_path,overall_path):
"""
生成带有当前时间戳的输出文件名。
参数:
- base_path: 基础路径和文件名前缀。
返回:
- 带有当前时间戳的完整输出文件名。
"""
# 获取当前时间并格式化为字符串
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
# 构建并返回完整的输出文件名
return f"{badcase_path}_{current_time}.json",f"{overall_path}_{current_time}.json"
def compare_edit_distance(json_file, overall_report):
with open(json_file, 'r',encoding='utf-8') as f:
json_data = json.load(f)
json_edit_distance = json_data['pdf间的平均编辑距离']
if overall_report['pdf间的平均编辑距离'] > json_edit_distance:
return 0
else:
return 1
def main(standard_file, test_file, zip_file, badcase_path, overall_path,base_data_path, s3_bucket_name=None, s3_file_directory=None,
aws_access_key=None, aws_secret_key=None, end_point_url=None):
"""
主函数,执行整个评估流程。
参数:
- standard_file: 标准文件的路径。
- test_file: 测试文件的路径。
- zip_file: 压缩包的路径的路径。
- badcase_path: badcase文件的基础路径和文件名前缀。
- overall_path: overall文件的基础路径和文件名前缀。
- s3_bucket_name: S3桶名称(可选)。
- s3_file_directory: S3上的文件保存目录(可选)。
- AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL: AWS访问凭证和端点URL(可选)。
"""
# 检查文件是否存在
check_json_files_in_zip_exist(zip_file, standard_file, test_file)
# 读取JSON文件内容
json_standard_origin, json_test_origin = read_json_files_from_zip(zip_file, standard_file, test_file)
# 合并JSON数据
inner_merge, standard_exist, test_exist = merge_json_data(json_test_origin, json_standard_origin)
#计算总体指标
overall_report_dict=overall_calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'],standard_exist, test_exist)
# 计算指标
result_dict = calculate_metrics(inner_merge, inner_merge['test_mid_json'], inner_merge['standard_mid_json'], json_standard_origin)
# 生成带时间戳的输出文件名
badcase_file,overall_file = generate_filename(badcase_path,overall_path)
# 保存结果到JSON文件
save_results(result_dict, overall_report_dict,badcase_file,overall_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
result=compare_edit_distance(base_data_path, overall_report_dict)
"""
if all([s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url]):
try:
upload_to_s3(badcase_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
upload_to_s3(overall_file, s3_bucket_name, s3_file_directory, aws_access_key, aws_secret_key, end_point_url)
except Exception as e:
print(f"上传到S3时发生错误: {e}")
"""
assert result == 1
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="主函数,执行整个评估流程。")
parser.add_argument('standard_file', type=str, help='标准文件的路径。')
parser.add_argument('test_file', type=str, help='测试文件的路径。')
parser.add_argument('zip_file', type=str, help='压缩包的路径。')
parser.add_argument('badcase_path', type=str, help='badcase文件的基础路径和文件名前缀。')
parser.add_argument('overall_path', type=str, help='overall文件的基础路径和文件名前缀。')
parser.add_argument('base_data_path', type=str, help='基准文件的基础路径和文件名前缀。')
parser.add_argument('--s3_bucket_name', type=str, help='S3桶名称。', default=None)
parser.add_argument('--s3_file_directory', type=str, help='S3上的文件名。', default=None)
parser.add_argument('--AWS_ACCESS_KEY', type=str, help='AWS访问密钥。', default=None)
parser.add_argument('--AWS_SECRET_KEY', type=str, help='AWS秘密密钥。', default=None)
parser.add_argument('--END_POINT_URL', type=str, help='AWS端点URL。', default=None)
args = parser.parse_args()
main(args.standard_file, args.test_file, args.zip_file, args.badcase_path,args.overall_path,args.base_data_path,args.s3_bucket_name, args.s3_file_directory, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment