Commit 23bacc60 authored by Shuimo's avatar Shuimo

add an option to freely output 'badcase.json

parents d1457937 4191fa96
...@@ -40,15 +40,20 @@ jobs: ...@@ -40,15 +40,20 @@ jobs:
pip install -r requirements.txt pip install -r requirements.txt
fi fi
- name: config-net-reset
- name: benchmark run: |
export http_proxy=""
export https_proxy=""
- name: get-benchmark-result
run: | run: |
echo "start test" echo "start test"
cd tools && python ocr_badcase.py pdf_json_label_0306.json ocr_dataset.json json_files.zip badcase.json overall.json base_data.json cd tools && python text_badcase.py pdf_json_label_0306.json pdf_json_label_0229.json json_files.zip text_badcase text_overall base_data_text.json --s3_bucket_name llm-process-pperf --s3_file_directory qa-validate/pdf-datasets/badcase --AWS_ACCESS_KEY 7X9CWNHIVOHH3LXRD5WK --AWS_SECRET_KEY IHLyTsv7h4ArzReLWUGZNKvwqB7CMrRi6e7ZyUt0 --END_POINT_URL http://p-ceph-norm-inside.pjlab.org.cn:80
python ocr_badcase.py pdf_json_label_0306.json ocr_dataset.json json_files.zip ocr_badcase ocr_overall base_data_ocr.json --s3_bucket_name llm-process-pperf --s3_file_directory qa-validate/pdf-datasets/badcase --AWS_ACCESS_KEY 7X9CWNHIVOHH3LXRD5WK --AWS_SECRET_KEY IHLyTsv7h4ArzReLWUGZNKvwqB7CMrRi6e7ZyUt0 --END_POINT_URL http://p-ceph-norm-inside.pjlab.org.cn:80
notify_to_feishu: notify_to_feishu:
if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }} if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }}
needs: [pdf-test] needs: [pdf-test]
runs-on: [pdf] runs-on: pdf
steps: steps:
- name: notify - name: notify
run: | run: |
......
...@@ -22,15 +22,15 @@ git clone https://github.com/magicpdf/Magic-PDF.git ...@@ -22,15 +22,15 @@ git clone https://github.com/magicpdf/Magic-PDF.git
2.Install the requirements 2.Install the requirements
```sh ```sh
cd Magic-PDF
pip install -r requirements.txt pip install -r requirements.txt
``` ```
3.Run the main script 3.Run the command line
```sh ```sh
use demo/text_demo.py export PYTHONPATH=.
or python magic_pdf/cli/magicpdf.py --help
use demo/ocr_demo.py
``` ```
### 版权说明 ### 版权说明
......
...@@ -15,7 +15,7 @@ from loguru import logger ...@@ -15,7 +15,7 @@ from loguru import logger
from magic_pdf.libs.config_reader import get_s3_config_dict from magic_pdf.libs.config_reader import get_s3_config_dict
from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
from magic_pdf.spark.base import get_data_source from magic_pdf.spark.spark_api import get_data_source
def demo_parse_pdf(book_name=None, start_page_id=0, debug_mode=True): def demo_parse_pdf(book_name=None, start_page_id=0, debug_mode=True):
...@@ -67,9 +67,7 @@ def demo_classify_by_type(book_name=None, debug_mode=True): ...@@ -67,9 +67,7 @@ def demo_classify_by_type(book_name=None, debug_mode=True):
img_num_list = pdf_meta["imgs_per_page"] img_num_list = pdf_meta["imgs_per_page"]
text_len_list = pdf_meta["text_len_per_page"] text_len_list = pdf_meta["text_len_per_page"]
text_layout_list = pdf_meta["text_layout_per_page"] text_layout_list = pdf_meta["text_layout_per_page"]
pdf_path = json_object.get("file_location")
is_text_pdf, results = classify( is_text_pdf, results = classify(
pdf_path,
total_page, total_page,
page_width, page_width,
page_height, page_height,
...@@ -89,7 +87,7 @@ def demo_meta_scan(book_name=None, debug_mode=True): ...@@ -89,7 +87,7 @@ def demo_meta_scan(book_name=None, debug_mode=True):
s3_pdf_path = json_object.get("file_location") s3_pdf_path = json_object.get("file_location")
s3_config = get_s3_config_dict(s3_pdf_path) s3_config = get_s3_config_dict(s3_pdf_path)
pdf_bytes = read_file(s3_pdf_path, s3_config) pdf_bytes = read_file(s3_pdf_path, s3_config)
res = pdf_meta_scan(s3_pdf_path, pdf_bytes) res = pdf_meta_scan(pdf_bytes)
logger.info(json.dumps(res, ensure_ascii=False)) logger.info(json.dumps(res, ensure_ascii=False))
write_json_to_local(res, book_name) write_json_to_local(res, book_name)
......
...@@ -21,28 +21,175 @@ python magicpdf.py --json s3://llm-pdf-text/scihub/xxxx.json?bytes=0,81350 ...@@ -21,28 +21,175 @@ python magicpdf.py --json s3://llm-pdf-text/scihub/xxxx.json?bytes=0,81350
python magicpdf.py --pdf /home/llm/Downloads/xxxx.pdf --model /home/llm/Downloads/xxxx.json 或者 python magicpdf.py --pdf /home/llm/Downloads/xxxx.pdf python magicpdf.py --pdf /home/llm/Downloads/xxxx.pdf --model /home/llm/Downloads/xxxx.json 或者 python magicpdf.py --pdf /home/llm/Downloads/xxxx.pdf
""" """
import os
import json as json_parse
import click
from loguru import logger
from pathlib import Path
from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe
from magic_pdf.libs.config_reader import get_s3_config
from magic_pdf.libs.path_utils import (
parse_s3path,
parse_s3_range_params,
remove_non_official_s3_args,
)
from magic_pdf.libs.config_reader import get_local_dir
from magic_pdf.rw.S3ReaderWriter import S3ReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
parse_pdf_methods = click.Choice(["ocr", "txt", "auto"])
def prepare_env(pdf_file_name):
local_parent_dir = os.path.join(
get_local_dir(), "magic-pdf", pdf_file_name
)
local_image_dir = os.path.join(local_parent_dir, "images")
local_md_dir = local_parent_dir
os.makedirs(local_image_dir, exist_ok=True)
os.makedirs(local_md_dir, exist_ok=True)
return local_image_dir, local_md_dir
def _do_parse(pdf_file_name, pdf_bytes, model_list, parse_method, image_writer, md_writer, image_dir):
if parse_method == "auto":
pipe = UNIPipe(pdf_bytes, model_list, image_writer, image_dir, is_debug=True)
elif parse_method == "txt":
pipe = TXTPipe(pdf_bytes, model_list, image_writer, image_dir, is_debug=True)
elif parse_method == "ocr":
pipe = OCRPipe(pdf_bytes, model_list, image_writer, image_dir, is_debug=True)
else:
print("unknow parse method")
os.exit(1)
pipe.pipe_classify()
pipe.pipe_parse()
md_content = pipe.pipe_mk_markdown()
#part_file_name = datetime.now().strftime("%H-%M-%S")
md_writer.write(
content=md_content, path=f"{pdf_file_name}.md", mode=AbsReaderWriter.MODE_TXT
)
md_writer.write(
content=json_parse.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4),
path=f"{pdf_file_name}.json",
mode=AbsReaderWriter.MODE_TXT,
)
# try:
# content_list = pipe.pipe_mk_uni_format()
# except Exception as e:
# logger.exception(e)
# md_writer.write(
# str(content_list), f"{part_file_name}.txt", AbsReaderWriter.MODE_TXT
# )
import click
@click.group() @click.group()
def cli(): def cli():
pass pass
@cli.command() @cli.command()
@click.option('--json', type=str, help='输入一个S3路径') @click.option("--json", type=str, help="输入一个S3路径")
def json_command(json): @click.option(
# 这里处理json相关的逻辑 "--method",
print(f'处理JSON: {json}') type=parse_pdf_methods,
help="指定解析方法。txt: 文本型 pdf 解析方法, ocr: 光学识别解析 pdf, auto: 程序智能选择解析方法",
default="auto",
)
def json_command(json, method):
if not json.startswith("s3://"):
print("usage: python magipdf.py --json s3://some_bucket/some_path")
os.exit(1)
def read_s3_path(s3path):
bucket, key = parse_s3path(s3path)
s3_ak, s3_sk, s3_endpoint = get_s3_config(bucket)
s3_rw = S3ReaderWriter(
s3_ak, s3_sk, s3_endpoint, "auto", remove_non_official_s3_args(s3path)
)
may_range_params = parse_s3_range_params(s3path)
if may_range_params is None or 2 != len(may_range_params):
byte_start, byte_end = 0, None
else:
byte_start, byte_end = int(may_range_params[0]), int(may_range_params[1])
byte_end += byte_start - 1
return s3_rw.read_jsonl(
remove_non_official_s3_args(s3path),
byte_start,
byte_end,
AbsReaderWriter.MODE_BIN,
)
jso = json_parse.loads(read_s3_path(json).decode("utf-8"))
s3_file_path = jso["file_location"]
pdf_file_name = Path(s3_file_path).stem
pdf_data = read_s3_path(s3_file_path)
local_image_dir, local_md_dir = prepare_env(pdf_file_name)
local_image_rw, local_md_rw = DiskReaderWriter(local_image_dir), DiskReaderWriter(
local_md_dir
)
_do_parse(
pdf_file_name,
pdf_data,
jso["doc_layout_result"],
method,
local_image_rw,
local_md_rw,
os.path.basename(local_image_dir),
)
@cli.command() @cli.command()
@click.option('--pdf', type=click.Path(exists=True), required=True, help='PDF文件的路径') @click.option(
@click.option('--model', type=click.Path(exists=True), help='模型的路径') "--pdf", type=click.Path(exists=True), required=True, help="PDF文件的路径"
def pdf_command(pdf, model): )
@click.option("--model", type=click.Path(exists=True), help="模型的路径")
@click.option(
"--method",
type=parse_pdf_methods,
help="指定解析方法。txt: 文本型 pdf 解析方法, ocr: 光学识别解析 pdf, auto: 程序智能选择解析方法",
default="auto",
)
def pdf_command(pdf, model, method):
# 这里处理pdf和模型相关的逻辑 # 这里处理pdf和模型相关的逻辑
print(f'处理PDF: {pdf}') if model is None:
print(f'加载模型: {model}') model = pdf.replace(".pdf", ".json")
if not os.path.exists(model):
print(f"make sure json file existed and place under {os.dirname(pdf)}")
os.exit(1)
def read_fn(path):
disk_rw = DiskReaderWriter(os.path.dirname(path))
return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
pdf_data = read_fn(pdf)
jso = json_parse.loads(read_fn(model).decode("utf-8"))
pdf_file_name = Path(pdf).stem
local_image_dir, local_md_dir = prepare_env(pdf_file_name)
local_image_rw, local_md_rw = DiskReaderWriter(local_image_dir), DiskReaderWriter(
local_md_dir
)
_do_parse(
pdf_file_name,
pdf_data,
jso,
method,
local_image_rw,
local_md_rw,
os.path.basename(local_image_dir),
)
if __name__ == '__main__': if __name__ == "__main__":
"""
python magic_pdf/cli/magicpdf.py json-command --json s3://llm-pdf-text/pdf_ebook_and_paper/manual/v001/part-660407a28beb-000002.jsonl?bytes=0,63551
"""
cli() cli()
...@@ -2,6 +2,7 @@ import math ...@@ -2,6 +2,7 @@ import math
from loguru import logger from loguru import logger
from magic_pdf.libs.boxbase import find_bottom_nearest_text_bbox, find_top_nearest_text_bbox from magic_pdf.libs.boxbase import find_bottom_nearest_text_bbox, find_top_nearest_text_bbox
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.ocr_content_type import ContentType from magic_pdf.libs.ocr_content_type import ContentType
TYPE_INLINE_EQUATION = ContentType.InlineEquation TYPE_INLINE_EQUATION = ContentType.InlineEquation
...@@ -227,12 +228,12 @@ def __insert_before_para(text, type, element, content_list): ...@@ -227,12 +228,12 @@ def __insert_before_para(text, type, element, content_list):
logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}") logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}")
def mk_universal_format(para_dict: dict): def mk_universal_format(pdf_info_list: list, img_buket_path):
""" """
构造统一格式 https://aicarrier.feishu.cn/wiki/FqmMwcH69iIdCWkkyjvcDwNUnTY 构造统一格式 https://aicarrier.feishu.cn/wiki/FqmMwcH69iIdCWkkyjvcDwNUnTY
""" """
content_lst = [] content_lst = []
for _, page_info in para_dict.items(): for page_info in pdf_info_list:
page_lst = [] # 一个page内的段落列表 page_lst = [] # 一个page内的段落列表
para_blocks = page_info.get("para_blocks") para_blocks = page_info.get("para_blocks")
pymu_raw_blocks = page_info.get("preproc_blocks") pymu_raw_blocks = page_info.get("preproc_blocks")
...@@ -249,7 +250,7 @@ def mk_universal_format(para_dict: dict): ...@@ -249,7 +250,7 @@ def mk_universal_format(para_dict: dict):
for img in all_page_images: for img in all_page_images:
content_node = { content_node = {
"type": "image", "type": "image",
"img_path": img['image_path'], "img_path": join_path(img_buket_path, img['image_path']),
"img_alt":"", "img_alt":"",
"img_title":"", "img_title":"",
"img_caption":"" "img_caption":""
...@@ -258,7 +259,7 @@ def mk_universal_format(para_dict: dict): ...@@ -258,7 +259,7 @@ def mk_universal_format(para_dict: dict):
for table in all_page_tables: for table in all_page_tables:
content_node = { content_node = {
"type": "table", "type": "table",
"img_path": table['image_path'], "img_path": join_path(img_buket_path, table['image_path']),
"table_latex": table.get("text"), "table_latex": table.get("text"),
"table_title": "", "table_title": "",
"table_caption": "", "table_caption": "",
......
from loguru import logger
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.language import detect_lang from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import ContentType from magic_pdf.libs.ocr_content_type import ContentType, BlockType
import wordninja import wordninja
import re import re
...@@ -16,90 +19,41 @@ def split_long_words(text): ...@@ -16,90 +19,41 @@ def split_long_words(text):
return ' '.join(segments) return ' '.join(segments)
def ocr_mk_nlp_markdown(pdf_info_dict: dict): def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
markdown = []
for _, page_info in pdf_info_dict.items():
blocks = page_info.get("preproc_blocks")
if not blocks:
continue
for block in blocks:
for line in block['lines']:
line_text = ''
for span in line['spans']:
if not span.get('content'):
continue
content = ocr_escape_special_markdown_char(span['content']) # 转义特殊符号
if span['type'] == ContentType.InlineEquation:
content = f"${content}$"
elif span['type'] == ContentType.InterlineEquation:
content = f"$$\n{content}\n$$"
line_text += content + ' '
# 在行末添加两个空格以强制换行
markdown.append(line_text.strip() + ' ')
return '\n'.join(markdown)
def ocr_mk_mm_markdown(pdf_info_dict: dict):
markdown = []
for _, page_info in pdf_info_dict.items():
blocks = page_info.get("preproc_blocks")
if not blocks:
continue
for block in blocks:
for line in block['lines']:
line_text = ''
for span in line['spans']:
if not span.get('content'):
if not span.get('image_path'):
continue
else:
content = f"![]({span['image_path']})"
else:
content = ocr_escape_special_markdown_char(span['content']) # 转义特殊符号
if span['type'] == ContentType.InlineEquation:
content = f"${content}$"
elif span['type'] == ContentType.InterlineEquation:
content = f"$$\n{content}\n$$"
line_text += content + ' '
# 在行末添加两个空格以强制换行
markdown.append(line_text.strip() + ' ')
return '\n'.join(markdown)
def ocr_mk_mm_markdown_with_para(pdf_info_dict: dict):
markdown = [] markdown = []
for _, page_info in pdf_info_dict.items(): for page_info in pdf_info_list:
paras_of_layout = page_info.get("para_blocks") paras_of_layout = page_info.get("para_blocks")
page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "mm") page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "mm", img_buket_path)
markdown.extend(page_markdown) markdown.extend(page_markdown)
return '\n\n'.join(markdown) return '\n\n'.join(markdown)
def ocr_mk_nlp_markdown_with_para(pdf_info_dict: dict): def ocr_mk_nlp_markdown_with_para(pdf_info_dict: list):
markdown = [] markdown = []
for _, page_info in pdf_info_dict.items(): for page_info in pdf_info_dict:
paras_of_layout = page_info.get("para_blocks") paras_of_layout = page_info.get("para_blocks")
page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "nlp") page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "nlp")
markdown.extend(page_markdown) markdown.extend(page_markdown)
return '\n\n'.join(markdown) return '\n\n'.join(markdown)
def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: dict):
def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list, img_buket_path):
markdown_with_para_and_pagination = [] markdown_with_para_and_pagination = []
for page_no, page_info in pdf_info_dict.items(): page_no = 0
for page_info in pdf_info_dict:
paras_of_layout = page_info.get("para_blocks") paras_of_layout = page_info.get("para_blocks")
if not paras_of_layout: if not paras_of_layout:
continue continue
page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "mm") page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "mm", img_buket_path)
markdown_with_para_and_pagination.append({ markdown_with_para_and_pagination.append({
'page_no': page_no, 'page_no': page_no,
'md_content': '\n\n'.join(page_markdown) 'md_content': '\n\n'.join(page_markdown)
}) })
page_no += 1
return markdown_with_para_and_pagination return markdown_with_para_and_pagination
def ocr_mk_markdown_with_para_core(paras_of_layout, mode): def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=""):
page_markdown = [] page_markdown = []
for paras in paras_of_layout: for paras in paras_of_layout:
for para in paras: for para in paras:
...@@ -122,7 +76,7 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode): ...@@ -122,7 +76,7 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode):
content = f"\n$$\n{span['content']}\n$$\n" content = f"\n$$\n{span['content']}\n$$\n"
elif span_type in [ContentType.Image, ContentType.Table]: elif span_type in [ContentType.Image, ContentType.Table]:
if mode == 'mm': if mode == 'mm':
content = f"\n![]({span['image_path']})\n" content = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
elif mode == 'nlp': elif mode == 'nlp':
pass pass
if content != '': if content != '':
...@@ -137,10 +91,86 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode): ...@@ -137,10 +91,86 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode):
return page_markdown return page_markdown
def para_to_standard_format(para): def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
page_markdown = []
for para_block in paras_of_layout:
para_text = ''
para_type = para_block.get('type')
if para_type == BlockType.Text:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Title:
para_text = f"# {merge_para_with_text(para_block)}"
elif para_type == BlockType.InterlineEquation:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Image:
if mode == 'nlp':
continue
elif mode == 'mm':
img_blocks = para_block.get('blocks')
for img_block in img_blocks:
if img_block.get('type') == BlockType.ImageBody:
for line in img_block.get('lines'):
for span in line['spans']:
if span.get('type') == ContentType.Image:
para_text = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
for img_block in img_blocks:
if img_block.get('type') == BlockType.ImageCaption:
para_text += merge_para_with_text(img_block)
elif para_type == BlockType.Table:
if mode == 'nlp':
continue
elif mode == 'mm':
table_blocks = para_block.get('blocks')
for table_block in table_blocks:
if table_block.get('type') == BlockType.TableBody:
for line in table_block.get('lines'):
for span in line['spans']:
if span.get('type') == ContentType.Table:
para_text = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
for table_block in table_blocks:
if table_block.get('type') == BlockType.TableCaption:
para_text += merge_para_with_text(table_block)
elif table_block.get('type') == BlockType.TableFootnote:
para_text += merge_para_with_text(table_block)
if para_text.strip() == '':
continue
else:
page_markdown.append(para_text.strip() + ' ')
return page_markdown
def merge_para_with_text(para):
para_text = ''
for line in para['lines']:
for span in line['spans']:
span_type = span.get('type')
content = ''
language = ''
if span_type == ContentType.Text:
content = span['content']
language = detect_lang(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
content = f"${span['content']}$"
elif span_type == ContentType.InterlineEquation:
content = f"\n$$\n{span['content']}\n$$\n"
if content != '':
if language == 'en': # 英文语境下 content间需要空格分隔
para_text += content + ' '
else: # 中文语境下,content间不需要空格分隔
para_text += content
return para_text
def para_to_standard_format(para, img_buket_path):
para_content = {} para_content = {}
if len(para) == 1: if len(para) == 1:
para_content = line_to_standard_format(para[0]) para_content = line_to_standard_format(para[0], img_buket_path)
elif len(para) > 1: elif len(para) > 1:
para_text = '' para_text = ''
inline_equation_num = 0 inline_equation_num = 0
...@@ -148,6 +178,7 @@ def para_to_standard_format(para): ...@@ -148,6 +178,7 @@ def para_to_standard_format(para):
for span in line['spans']: for span in line['spans']:
language = '' language = ''
span_type = span.get('type') span_type = span.get('type')
content = ""
if span_type == ContentType.Text: if span_type == ContentType.Text:
content = span['content'] content = span['content']
language = detect_lang(content) language = detect_lang(content)
...@@ -170,20 +201,21 @@ def para_to_standard_format(para): ...@@ -170,20 +201,21 @@ def para_to_standard_format(para):
} }
return para_content return para_content
def make_standard_format_with_para(pdf_info_dict: dict):
def make_standard_format_with_para(pdf_info_dict: list, img_buket_path: str):
content_list = [] content_list = []
for _, page_info in pdf_info_dict.items(): for page_info in pdf_info_dict:
paras_of_layout = page_info.get("para_blocks") paras_of_layout = page_info.get("para_blocks")
if not paras_of_layout: if not paras_of_layout:
continue continue
for paras in paras_of_layout: for paras in paras_of_layout:
for para in paras: for para in paras:
para_content = para_to_standard_format(para) para_content = para_to_standard_format(para, img_buket_path)
content_list.append(para_content) content_list.append(para_content)
return content_list return content_list
def line_to_standard_format(line): def line_to_standard_format(line, img_buket_path):
line_text = "" line_text = ""
inline_equation_num = 0 inline_equation_num = 0
for span in line['spans']: for span in line['spans']:
...@@ -194,13 +226,13 @@ def line_to_standard_format(line): ...@@ -194,13 +226,13 @@ def line_to_standard_format(line):
if span['type'] == ContentType.Image: if span['type'] == ContentType.Image:
content = { content = {
'type': 'image', 'type': 'image',
'img_path': span['image_path'] 'img_path': join_path(img_buket_path, span['image_path'])
} }
return content return content
elif span['type'] == ContentType.Table: elif span['type'] == ContentType.Table:
content = { content = {
'type': 'table', 'type': 'table',
'img_path': span['image_path'] 'img_path': join_path(img_buket_path, span['image_path'])
} }
return content return content
else: else:
...@@ -226,7 +258,7 @@ def line_to_standard_format(line): ...@@ -226,7 +258,7 @@ def line_to_standard_format(line):
return content return content
def ocr_mk_mm_standard_format(pdf_info_dict: dict): def ocr_mk_mm_standard_format(pdf_info_dict: list):
""" """
content_list content_list
type string image/text/table/equation(行间的单独拿出来,行内的和text合并) type string image/text/table/equation(行间的单独拿出来,行内的和text合并)
...@@ -236,7 +268,7 @@ def ocr_mk_mm_standard_format(pdf_info_dict: dict): ...@@ -236,7 +268,7 @@ def ocr_mk_mm_standard_format(pdf_info_dict: dict):
img_path string s3://full/path/to/img.jpg img_path string s3://full/path/to/img.jpg
""" """
content_list = [] content_list = []
for _, page_info in pdf_info_dict.items(): for page_info in pdf_info_dict:
blocks = page_info.get("preproc_blocks") blocks = page_info.get("preproc_blocks")
if not blocks: if not blocks:
continue continue
......
...@@ -15,6 +15,7 @@ from collections import Counter ...@@ -15,6 +15,7 @@ from collections import Counter
import click import click
import numpy as np import numpy as np
from loguru import logger
from magic_pdf.libs.commons import mymax, get_top_percent_list from magic_pdf.libs.commons import mymax, get_top_percent_list
from magic_pdf.filter.pdf_meta_scan import scan_max_page, junk_limit_min from magic_pdf.filter.pdf_meta_scan import scan_max_page, junk_limit_min
...@@ -298,7 +299,7 @@ def classify_by_img_narrow_strips(page_width, page_height, img_sz_list): ...@@ -298,7 +299,7 @@ def classify_by_img_narrow_strips(page_width, page_height, img_sz_list):
return narrow_strip_pages_ratio < 0.5 return narrow_strip_pages_ratio < 0.5
def classify(pdf_path, total_page: int, page_width, page_height, img_sz_list: list, text_len_list: list, img_num_list: list, text_layout_list: list): def classify(total_page: int, page_width, page_height, img_sz_list: list, text_len_list: list, img_num_list: list, text_layout_list: list):
""" """
这里的图片和页面长度单位是pts 这里的图片和页面长度单位是pts
:param total_page: :param total_page:
...@@ -323,7 +324,7 @@ def classify(pdf_path, total_page: int, page_width, page_height, img_sz_list: li ...@@ -323,7 +324,7 @@ def classify(pdf_path, total_page: int, page_width, page_height, img_sz_list: li
elif not any(results.values()): elif not any(results.values()):
return False, results return False, results
else: else:
print(f"WARNING: {pdf_path} is not classified by area and text_len, by_image_area: {results['by_image_area']}, by_text: {results['by_text_len']}, by_avg_words: {results['by_avg_words']}, by_img_num: {results['by_img_num']}, by_text_layout: {results['by_text_layout']}, by_img_narrow_strips: {results['by_img_narrow_strips']}", file=sys.stderr) # 利用这种情况可以快速找出来哪些pdf比较特殊,针对性修正分类算法 logger.warning(f"pdf is not classified by area and text_len, by_image_area: {results['by_image_area']}, by_text: {results['by_text_len']}, by_avg_words: {results['by_avg_words']}, by_img_num: {results['by_img_num']}, by_text_layout: {results['by_text_layout']}, by_img_narrow_strips: {results['by_img_narrow_strips']}", file=sys.stderr) # 利用这种情况可以快速找出来哪些pdf比较特殊,针对性修正分类算法
return False, results return False, results
...@@ -350,7 +351,7 @@ def main(json_file): ...@@ -350,7 +351,7 @@ def main(json_file):
is_needs_password = o['is_needs_password'] is_needs_password = o['is_needs_password']
if is_encrypted or total_page == 0 or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理 if is_encrypted or total_page == 0 or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理
continue continue
tag = classify(pdf_path, total_page, page_width, page_height, img_sz_list, text_len_list, text_layout_list) tag = classify(total_page, page_width, page_height, img_sz_list, text_len_list, text_layout_list)
o['is_text_pdf'] = tag o['is_text_pdf'] = tag
print(json.dumps(o, ensure_ascii=False)) print(json.dumps(o, ensure_ascii=False))
except Exception as e: except Exception as e:
......
...@@ -287,7 +287,7 @@ def get_language(doc: fitz.Document): ...@@ -287,7 +287,7 @@ def get_language(doc: fitz.Document):
return language return language
def pdf_meta_scan(s3_pdf_path: str, pdf_bytes: bytes): def pdf_meta_scan(pdf_bytes: bytes):
""" """
:param s3_pdf_path: :param s3_pdf_path:
:param pdf_bytes: pdf文件的二进制数据 :param pdf_bytes: pdf文件的二进制数据
...@@ -298,8 +298,8 @@ def pdf_meta_scan(s3_pdf_path: str, pdf_bytes: bytes): ...@@ -298,8 +298,8 @@ def pdf_meta_scan(s3_pdf_path: str, pdf_bytes: bytes):
is_encrypted = doc.is_encrypted is_encrypted = doc.is_encrypted
total_page = len(doc) total_page = len(doc)
if total_page == 0: if total_page == 0:
logger.warning(f"drop this pdf: {s3_pdf_path}, drop_reason: {DropReason.EMPTY_PDF}") logger.warning(f"drop this pdf, drop_reason: {DropReason.EMPTY_PDF}")
result = {"need_drop": True, "drop_reason": DropReason.EMPTY_PDF} result = {"_need_drop": True, "_drop_reason": DropReason.EMPTY_PDF}
return result return result
else: else:
page_width_pts, page_height_pts = get_pdf_page_size_pts(doc) page_width_pts, page_height_pts = get_pdf_page_size_pts(doc)
...@@ -322,7 +322,6 @@ def pdf_meta_scan(s3_pdf_path: str, pdf_bytes: bytes): ...@@ -322,7 +322,6 @@ def pdf_meta_scan(s3_pdf_path: str, pdf_bytes: bytes):
# 最后输出一条json # 最后输出一条json
res = { res = {
"pdf_path": s3_pdf_path,
"is_needs_password": is_needs_password, "is_needs_password": is_needs_password,
"is_encrypted": is_encrypted, "is_encrypted": is_encrypted,
"total_page": total_page, "total_page": total_page,
...@@ -350,7 +349,7 @@ def main(s3_pdf_path: str, s3_profile: str): ...@@ -350,7 +349,7 @@ def main(s3_pdf_path: str, s3_profile: str):
""" """
try: try:
file_content = read_file(s3_pdf_path, s3_profile) file_content = read_file(s3_pdf_path, s3_profile)
pdf_meta_scan(s3_pdf_path, file_content) pdf_meta_scan(file_content)
except Exception as e: except Exception as e:
print(f"ERROR: {s3_pdf_path}, {e}", file=sys.stderr) print(f"ERROR: {s3_pdf_path}, {e}", file=sys.stderr)
logger.exception(e) logger.exception(e)
......
from enum import Enum
class ModelBlockTypeEnum(Enum):
TITLE = 0
PLAIN_TEXT = 1
ABANDON = 2
ISOLATE_FORMULA = 8
EMBEDDING = 13
ISOLATED = 14
\ No newline at end of file
from loguru import logger from loguru import logger
import math
def _is_in_or_part_overlap(box1, box2) -> bool: def _is_in_or_part_overlap(box1, box2) -> bool:
""" """
...@@ -332,3 +332,42 @@ def find_right_nearest_text_bbox(pymu_blocks, obj_bbox): ...@@ -332,3 +332,42 @@ def find_right_nearest_text_bbox(pymu_blocks, obj_bbox):
return right_boxes[0] return right_boxes[0]
else: else:
return None return None
def bbox_relative_pos(bbox1, bbox2):
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
left = x2b < x1
right = x1b < x2
bottom = y2b < y1
top = y1b < y2
return left, right, bottom, top
def bbox_distance(bbox1, bbox2):
def dist(point1, point2):
return math.sqrt((point1[0]-point2[0])**2 + (point1[1]-point2[1])**2)
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
if top and left:
return dist((x1, y1b), (x2b, y2))
elif left and bottom:
return dist((x1, y1), (x2b, y2b))
elif bottom and right:
return dist((x1b, y1), (x2, y2b))
elif right and top:
return dist((x1b, y1b), (x2, y2))
elif left:
return x1 - x2b
elif right:
return x2 - x1b
elif bottom:
return y1 - y2b
elif top:
return y2 - y1b
else: # rectangles intersect
return 0
\ No newline at end of file
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
根据bucket的名字返回对应的s3 AK, SK,endpoint三元组 根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
""" """
import json import json
import os import os
...@@ -10,20 +11,24 @@ from loguru import logger ...@@ -10,20 +11,24 @@ from loguru import logger
from magic_pdf.libs.commons import parse_bucket_key from magic_pdf.libs.commons import parse_bucket_key
def get_s3_config(bucket_name: str): def read_config():
"""
~/magic-pdf.json 读出来
"""
home_dir = os.path.expanduser("~") home_dir = os.path.expanduser("~")
config_file = os.path.join(home_dir, "magic-pdf.json") config_file = os.path.join(home_dir, "magic-pdf.json")
if not os.path.exists(config_file): if not os.path.exists(config_file):
raise Exception("magic-pdf.json not found") raise Exception(f"{config_file} not found")
with open(config_file, "r") as f: with open(config_file, "r") as f:
config = json.load(f) config = json.load(f)
return config
def get_s3_config(bucket_name: str):
"""
~/magic-pdf.json 读出来
"""
config = read_config()
bucket_info = config.get("bucket_info") bucket_info = config.get("bucket_info")
if bucket_name not in bucket_info: if bucket_name not in bucket_info:
...@@ -49,5 +54,10 @@ def get_bucket_name(path): ...@@ -49,5 +54,10 @@ def get_bucket_name(path):
return bucket return bucket
if __name__ == '__main__': def get_local_dir():
config = read_config()
return config.get("temp-output-dir", "/tmp")
if __name__ == "__main__":
ak, sk, endpoint = get_s3_config("llm-raw") ak, sk, endpoint = get_s3_config("llm-raw")
def dict_to_list(input_dict):
items_list = []
for _, item in input_dict.items():
items_list.append(item)
return items_list
def get_scale_ratio(ocr_page_info, page): def get_scale_ratio(model_page_info, page):
pix = page.get_pixmap(dpi=72) pix = page.get_pixmap(dpi=72)
pymu_width = int(pix.w) pymu_width = int(pix.w)
pymu_height = int(pix.h) pymu_height = int(pix.h)
width_from_json = ocr_page_info['page_info']['width'] width_from_json = model_page_info['page_info']['width']
height_from_json = ocr_page_info['page_info']['height'] height_from_json = model_page_info['page_info']['height']
horizontal_scale_ratio = width_from_json / pymu_width horizontal_scale_ratio = width_from_json / pymu_width
vertical_scale_ratio = height_from_json / pymu_height vertical_scale_ratio = height_from_json / pymu_height
return horizontal_scale_ratio, vertical_scale_ratio return horizontal_scale_ratio, vertical_scale_ratio
from collections import Counter
from magic_pdf.libs.language import detect_lang
def get_language_from_model(model_list: list):
language_lst = []
for ocr_page_info in model_list:
page_text = ""
layout_dets = ocr_page_info["layout_dets"]
for layout_det in layout_dets:
category_id = layout_det["category_id"]
allow_category_id_list = [15]
if category_id in allow_category_id_list:
page_text += layout_det["text"]
page_language = detect_lang(page_text)
language_lst.append(page_language)
# 统计text_language_list中每种语言的个数
count_dict = Counter(language_lst)
# 输出text_language_list中出现的次数最多的语言
language = max(count_dict, key=count_dict.get)
return language
...@@ -8,7 +8,7 @@ class DropReason: ...@@ -8,7 +8,7 @@ class DropReason:
HIGH_COMPUTATIONAL_lOAD_BY_SVGS = "high_computational_load_by_svgs" # 特殊的SVG图,计算量太大,从而丢弃 HIGH_COMPUTATIONAL_lOAD_BY_SVGS = "high_computational_load_by_svgs" # 特殊的SVG图,计算量太大,从而丢弃
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES = "high_computational_load_by_total_pages" # 计算量超过负荷,当前方法下计算量消耗过大 HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES = "high_computational_load_by_total_pages" # 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT = "missing doc_layout_result" # 版面分析失败 MISS_DOC_LAYOUT_RESULT = "missing doc_layout_result" # 版面分析失败
Exception = "exception" # 解析中发生异常 Exception = "_exception" # 解析中发生异常
ENCRYPTED = "encrypted" # PDF是加密的 ENCRYPTED = "encrypted" # PDF是加密的
EMPTY_PDF = "total_page=0" # PDF页面总数为0 EMPTY_PDF = "total_page=0" # PDF页面总数为0
NOT_IS_TEXT_PDF = "not_is_text_pdf" # 不是文字版PDF,无法直接解析 NOT_IS_TEXT_PDF = "not_is_text_pdf" # 不是文字版PDF,无法直接解析
......
...@@ -16,3 +16,4 @@ class DropTag: ...@@ -16,3 +16,4 @@ class DropTag:
FOOTNOTE = "footnote" FOOTNOTE = "footnote"
NOT_IN_LAYOUT = "not_in_layout" NOT_IN_LAYOUT = "not_in_layout"
SPAN_OVERLAP = "span_overlap" SPAN_OVERLAP = "span_overlap"
BLOCK_OVERLAP = "block_overlap"
def float_gt(a, b):
if 0.0001 >= abs(a -b):
return False
return a > b
def float_equal(a, b):
if 0.0001 >= abs(a-b):
return True
return False
\ No newline at end of file
...@@ -5,3 +5,16 @@ class ContentType: ...@@ -5,3 +5,16 @@ class ContentType:
InlineEquation = "inline_equation" InlineEquation = "inline_equation"
InterlineEquation = "interline_equation" InterlineEquation = "interline_equation"
class BlockType:
Image = "image"
ImageBody = "image_body"
ImageCaption = "image_caption"
Table = "table"
TableBody = "table_body"
TableCaption = "table_caption"
TableFootnote = "table_footnote"
Text = "text"
Title = "title"
InterlineEquation = "interline_equation"
Footnote = "footnote"
from s3pathlib import S3Path
def remove_non_official_s3_args(s3path):
"""
example: s3://abc/xxxx.json?bytes=0,81350 ==> s3://abc/xxxx.json
"""
arr = s3path.split("?")
return arr[0]
def parse_s3path(s3path: str):
p = S3Path(remove_non_official_s3_args(s3path))
return p.bucket, p.key
def parse_s3_range_params(s3path: str):
"""
example: s3://abc/xxxx.json?bytes=0,81350 ==> [0, 81350]
"""
arr = s3path.split("?bytes=")
if len(arr) == 1:
return None
return arr[1].split(",")
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.libs.commons import fitz from magic_pdf.libs.commons import fitz
from loguru import logger
from magic_pdf.libs.commons import join_path from magic_pdf.libs.commons import join_path
from magic_pdf.libs.hash_utils import compute_sha256 from magic_pdf.libs.hash_utils import compute_sha256
def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter): def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: AbsReaderWriter):
""" """
从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径
save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。 save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。
...@@ -28,49 +28,6 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri ...@@ -28,49 +28,6 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri
byte_data = pix.tobytes(output='jpeg', jpg_quality=95) byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
imageWriter.write(data=byte_data, path=img_hash256_path, mode="binary") imageWriter.write(byte_data, img_hash256_path, AbsReaderWriter.MODE_BIN)
return img_hash256_path return img_hash256_path
def save_images_by_bboxes(page_num: int, page: fitz.Page, pdf_bytes_md5: str,
image_bboxes: list, images_overlap_backup: list, table_bboxes: list,
equation_inline_bboxes: list,
equation_interline_bboxes: list, imageWriter) -> dict:
"""
返回一个dict, key为bbox, 值是图片地址
"""
image_info = []
image_backup_info = []
table_info = []
inline_eq_info = []
interline_eq_info = []
# 图片的保存路径组成是这样的: {s3_or_local_path}/{book_name}/{images|tables|equations}/{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg
def return_path(type):
return join_path(pdf_bytes_md5, type)
for bbox in image_bboxes:
if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
logger.warning(f"image_bboxes: 错误的box, {bbox}")
continue
image_path = cut_image(bbox, page_num, page, return_path("images"), imageWriter)
image_info.append({"bbox": bbox, "image_path": image_path})
for bbox in images_overlap_backup:
if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
logger.warning(f"images_overlap_backup: 错误的box, {bbox}")
continue
image_path = cut_image(bbox, page_num, page, return_path("images"), imageWriter)
image_backup_info.append({"bbox": bbox, "image_path": image_path})
for bbox in table_bboxes:
if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
logger.warning(f"table_bboxes: 错误的box, {bbox}")
continue
image_path = cut_image(bbox, page_num, page, return_path("tables"), imageWriter)
table_info.append({"bbox": bbox, "image_path": image_path})
return image_info, image_backup_info, table_info, inline_eq_info, interline_eq_info
\ No newline at end of file
import json
import math
from magic_pdf.libs.commons import fitz
from loguru import logger
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
from magic_pdf.libs.math import float_gt
from magic_pdf.libs.boxbase import _is_in, bbox_relative_pos, bbox_distance
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
class MagicModel:
"""
每个函数没有得到元素的时候返回空list
"""
def __fix_axis(self):
for model_page_info in self.__model_list:
need_remove_list = []
page_no = model_page_info["page_info"]["page_no"]
horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
model_page_info, self.__docs[page_no]
)
layout_dets = model_page_info["layout_dets"]
for layout_det in layout_dets:
x0, y0, _, _, x1, y1, _, _ = layout_det["poly"]
bbox = [
int(x0 / horizontal_scale_ratio),
int(y0 / vertical_scale_ratio),
int(x1 / horizontal_scale_ratio),
int(y1 / vertical_scale_ratio),
]
layout_det["bbox"] = bbox
# 删除高度或者宽度为0的spans
if bbox[2] - bbox[0] == 0 or bbox[3] - bbox[1] == 0:
need_remove_list.append(layout_det)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_confidence(self):
for model_page_info in self.__model_list:
need_remove_list = []
layout_dets = model_page_info["layout_dets"]
for layout_det in layout_dets:
if layout_det["score"] < 0.6:
need_remove_list.append(layout_det)
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __init__(self, model_list: list, docs: fitz.Document):
self.__model_list = model_list
self.__docs = docs
self.__fix_axis()
#@TODO 删除掉一些低置信度的会导致分段错误,后面再修复
# self.__fix_by_confidence()
def __reduct_overlap(self, bboxes):
N = len(bboxes)
keep = [True] * N
for i in range(N):
for j in range(N):
if i == j:
continue
if _is_in(bboxes[i], bboxes[j]):
keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance(
self, page_no, subject_category_id, object_category_id
):
"""
假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object 只能属于一个 subject
"""
ret = []
MAX_DIS_OF_POINT = 10 ** 9 + 7
subjects = self.__reduct_overlap(
list(
map(
lambda x: x["bbox"],
filter(
lambda x: x["category_id"] == subject_category_id,
self.__model_list[page_no]["layout_dets"],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: x["bbox"],
filter(
lambda x: x["category_id"] == object_category_id,
self.__model_list[page_no]["layout_dets"],
),
)
)
)
subject_object_relation_map = {}
subjects.sort(key=lambda x: x[0] ** 2 + x[1] ** 2) # get the distance !
all_bboxes = []
for v in subjects:
all_bboxes.append({"category_id": subject_category_id, "bbox": v})
for v in objects:
all_bboxes.append({"category_id": object_category_id, "bbox": v})
N = len(all_bboxes)
dis = [[MAX_DIS_OF_POINT] * N for _ in range(N)]
for i in range(N):
for j in range(i):
if (
all_bboxes[i]["category_id"] == subject_category_id
and all_bboxes[j]["category_id"] == subject_category_id
):
continue
dis[i][j] = bbox_distance(all_bboxes[i]["bbox"], all_bboxes[j]["bbox"])
dis[j][i] = dis[i][j]
used = set()
for i in range(N):
# 求第 i 个 subject 所关联的 object
if all_bboxes[i]["category_id"] != subject_category_id:
continue
seen = set()
candidates = []
arr = []
for j in range(N):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]
),
)
)
)
if pos_flag_count > 1:
continue
if (
all_bboxes[j]["category_id"] != object_category_id
or j in used
or dis[i][j] == MAX_DIS_OF_POINT
):
continue
arr.append((dis[i][j], j))
arr.sort(key=lambda x: x[0])
if len(arr) > 0:
candidates.append(arr[0][1])
seen.add(arr[0][1])
# 已经获取初始种子
for j in set(candidates):
tmp = []
for k in range(i + 1, N):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
all_bboxes[j]["bbox"], all_bboxes[k]["bbox"]
),
)
)
)
if pos_flag_count > 1:
continue
if (
all_bboxes[k]["category_id"] != object_category_id
or k in used
or k in seen
or dis[j][k] == MAX_DIS_OF_POINT
):
continue
is_nearest = True
for l in range(i + 1, N):
if l in (j, k) or l in used or l in seen:
continue
if not float_gt(dis[l][k], dis[j][k]):
is_nearest = False
break
if is_nearest:
tmp.append(k)
seen.add(k)
candidates = tmp
if len(candidates) == 0:
break
# 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
# 先扩一下 bbox,
x0s = [all_bboxes[idx]["bbox"][0] for idx in seen] + [
all_bboxes[i]["bbox"][0]
]
y0s = [all_bboxes[idx]["bbox"][1] for idx in seen] + [
all_bboxes[i]["bbox"][1]
]
x1s = [all_bboxes[idx]["bbox"][2] for idx in seen] + [
all_bboxes[i]["bbox"][2]
]
y1s = [all_bboxes[idx]["bbox"][3] for idx in seen] + [
all_bboxes[i]["bbox"][3]
]
ox0, oy0, ox1, oy1 = min(x0s), min(y0s), max(x1s), max(y1s)
ix0, iy0, ix1, iy1 = all_bboxes[i]["bbox"]
# 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
caption_poses = [
[ox0, oy0, ix0, oy1],
[ox0, oy0, ox1, iy0],
[ox0, iy1, ox1, oy1],
[ix1, oy0, ox1, oy1],
]
caption_areas = []
for bbox in caption_poses:
embed_arr = []
for idx in seen:
if _is_in(all_bboxes[idx]["bbox"], bbox):
embed_arr.append(idx)
if len(embed_arr) > 0:
embed_x0 = min([all_bboxes[idx]["bbox"][0] for idx in embed_arr])
embed_y0 = min([all_bboxes[idx]["bbox"][1] for idx in embed_arr])
embed_x1 = max([all_bboxes[idx]["bbox"][2] for idx in embed_arr])
embed_y1 = max([all_bboxes[idx]["bbox"][3] for idx in embed_arr])
caption_areas.append(
int(abs(embed_x1 - embed_x0) * abs(embed_y1 - embed_y0))
)
else:
caption_areas.append(0)
subject_object_relation_map[i] = []
if max(caption_areas) > 0:
max_area_idx = caption_areas.index(max(caption_areas))
caption_bbox = caption_poses[max_area_idx]
for j in seen:
if _is_in(all_bboxes[j]["bbox"], caption_bbox):
used.add(j)
subject_object_relation_map[i].append(j)
for i in sorted(subject_object_relation_map.keys()):
result = {
"subject_body": all_bboxes[i]["bbox"],
"all": all_bboxes[i]["bbox"],
}
if len(subject_object_relation_map[i]) > 0:
x0 = min(
[all_bboxes[j]["bbox"][0] for j in subject_object_relation_map[i]]
)
y0 = min(
[all_bboxes[j]["bbox"][1] for j in subject_object_relation_map[i]]
)
x1 = max(
[all_bboxes[j]["bbox"][2] for j in subject_object_relation_map[i]]
)
y1 = max(
[all_bboxes[j]["bbox"][3] for j in subject_object_relation_map[i]]
)
result["object_body"] = [x0, y0, x1, y1]
result["all"] = [
min(x0, all_bboxes[i]["bbox"][0]),
min(y0, all_bboxes[i]["bbox"][1]),
max(x1, all_bboxes[i]["bbox"][2]),
max(y1, all_bboxes[i]["bbox"][3]),
]
ret.append(result)
total_subject_object_dis = 0
# 计算已经配对的 distance 距离
for i in subject_object_relation_map.keys():
for j in subject_object_relation_map[i]:
total_subject_object_dis += bbox_distance(
all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]
)
# 计算未匹配的 subject 和 object 的距离(非精确版)
with_caption_subject = set(
[
key
for key in subject_object_relation_map.keys()
if len(subject_object_relation_map[i]) > 0
]
)
for i in range(N):
if all_bboxes[i]["category_id"] != object_category_id or i in used:
continue
candidates = []
for j in range(N):
if (
all_bboxes[j]["category_id"] != subject_category_id
or j in with_caption_subject
):
continue
candidates.append((dis[i][j], j))
if len(candidates) > 0:
candidates.sort(key=lambda x: x[0])
total_subject_object_dis += candidates[0][1]
with_caption_subject.add(j)
return ret, total_subject_object_dis
def get_imgs(self, page_no: int): # @许瑞
records, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
return [
{
"bbox": record["all"],
"img_body_bbox": record["subject_body"],
"img_caption_bbox": record.get("object_body", None),
}
for record in records
]
def get_tables(
self, page_no: int
) -> list: # 3个坐标, caption, table主体,table-note
with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6)
with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7)
ret = []
N, M = len(with_captions), len(with_footnotes)
assert N == M
for i in range(N):
record = {
"table_caption_bbox": with_captions[i].get("object_body", None),
"table_body_bbox": with_captions[i]["subject_body"],
"table_footnote_bbox": with_footnotes[i].get("object_body", None),
}
x0 = min(with_captions[i]["all"][0], with_footnotes[i]["all"][0])
y0 = min(with_captions[i]["all"][1], with_footnotes[i]["all"][1])
x1 = max(with_captions[i]["all"][2], with_footnotes[i]["all"][2])
y1 = max(with_captions[i]["all"][3], with_footnotes[i]["all"][3])
record["bbox"] = [x0, y0, x1, y1]
ret.append(record)
return ret
def get_equations(self, page_no: int) -> list: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type(ModelBlockTypeEnum.EMBEDDING.value, page_no, ["latex"])
interline_equations = self.__get_blocks_by_type(ModelBlockTypeEnum.ISOLATED.value, page_no, ["latex"])
interline_equations_blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no)
return inline_equations, interline_equations, interline_equations_blocks
def get_discarded(self, page_no: int) -> list: # 自研模型,只有坐标
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.ABANDON.value, page_no)
return blocks
def get_text_blocks(self, page_no: int) -> list: # 自研模型搞的,只有坐标,没有字
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.PLAIN_TEXT.value, page_no)
return blocks
def get_title_blocks(self, page_no: int) -> list: # 自研模型,只有坐标,没字
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.TITLE.value, page_no)
return blocks
def get_ocr_text(self, page_no: int) -> list: # paddle 搞的,有字也有坐标
text_spans = []
model_page_info = self.__model_list[page_no]
layout_dets = model_page_info["layout_dets"]
for layout_det in layout_dets:
if layout_det["category_id"] == "15":
span = {
"bbox": layout_det['bbox'],
"content": layout_det["text"],
}
text_spans.append(span)
return text_spans
def get_all_spans(self, page_no: int) -> list:
all_spans = []
model_page_info = self.__model_list[page_no]
layout_dets = model_page_info["layout_dets"]
allow_category_id_list = [3, 5, 13, 14, 15]
"""当成span拼接的"""
# 3: 'image', # 图片
# 5: 'table', # 表格
# 13: 'inline_equation', # 行内公式
# 14: 'interline_equation', # 行间公式
# 15: 'text', # ocr识别文本
for layout_det in layout_dets:
category_id = layout_det["category_id"]
if category_id in allow_category_id_list:
span = {
"bbox": layout_det['bbox']
}
if category_id == 3:
span["type"] = ContentType.Image
elif category_id == 5:
span["type"] = ContentType.Table
elif category_id == 13:
span["content"] = layout_det["latex"]
span["type"] = ContentType.InlineEquation
elif category_id == 14:
span["content"] = layout_det["latex"]
span["type"] = ContentType.InterlineEquation
elif category_id == 15:
span["content"] = layout_det["text"]
span["type"] = ContentType.Text
all_spans.append(span)
return all_spans
def get_page_size(self, page_no: int): # 获取页面宽高
# 获取当前页的page对象
page = self.__docs[page_no]
# 获取当前页的宽高
page_w = page.rect.width
page_h = page.rect.height
return page_w, page_h
def __get_blocks_by_type(self, type: int, page_no: int, extra_col: list[str] = []) -> list:
blocks = []
for page_dict in self.__model_list:
layout_dets = page_dict.get("layout_dets", [])
page_info = page_dict.get("page_info", {})
page_number = page_info.get("page_no", -1)
if page_no != page_number:
continue
for item in layout_dets:
category_id = item.get("category_id", -1)
bbox = item.get("bbox", None)
if category_id == type:
block = {
"bbox": bbox
}
for col in extra_col:
block[col] = item.get(col, None)
blocks.append(block)
return blocks
if __name__ == "__main__":
drw = DiskReaderWriter(r"D:/project/20231108code-clean")
if 0:
pdf_file_path = r"linshixuqiu\19983-00.pdf"
model_file_path = r"linshixuqiu\19983-00_new.json"
pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
model_list = json.loads(model_json_txt)
write_path = r"D:\project\20231108code-clean\linshixuqiu\19983-00"
img_bucket_path = "imgs"
img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
pdf_docs = fitz.open("pdf", pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
if 1:
model_list = json.loads(
drw.read("/opt/data/pdf/20240418/j.chroma.2009.03.042.json")
)
pdf_bytes = drw.read(
"/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf", AbsReaderWriter.MODE_BIN
)
pdf_docs = fitz.open("pdf", pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
for i in range(7):
print(magic_model.get_imgs(i))
...@@ -299,9 +299,9 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_ ...@@ -299,9 +299,9 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_
layout_list_info[0] = True layout_list_info[0] = True
if end==total_lines-1: if end==total_lines-1:
layout_list_info[1] = True layout_list_info[1] = True
else: else: # 是普通文本
for i, line in enumerate(lines[start:end+1]): for i, line in enumerate(lines[start:end+1]):
# 如果i有下一行,那么就要根据下一行位置综合判断是否要分段。如果i之后没有行,那么只需要判断一下行结尾特征。 # 如果i有下一行,那么就要根据下一行位置综合判断是否要分段。如果i之后没有行,那么只需要判断i行自己的结尾特征。
cur_line_type = line['spans'][-1]['type'] cur_line_type = line['spans'][-1]['type']
next_line = lines[i+1] if i<total_lines-1 else None next_line = lines[i+1] if i<total_lines-1 else None
...@@ -547,7 +547,7 @@ def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, deb ...@@ -547,7 +547,7 @@ def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, deb
if "Table" in first_line_text or "Figure" in first_line_text: if "Table" in first_line_text or "Figure" in first_line_text:
pass pass
if debug_mode: if debug_mode:
logger.info(line_hi.std()) logger.debug(line_hi.std())
if line_hi.std()<2: if line_hi.std()<2:
"""行高度相同,那么判断是否居中""" """行高度相同,那么判断是否居中"""
...@@ -560,7 +560,7 @@ def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, deb ...@@ -560,7 +560,7 @@ def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, deb
merge_para = [l[0] for l in layout_para[start:end+1]] merge_para = [l[0] for l in layout_para[start:end+1]]
para_text = ''.join([__get_span_text(span) for line in merge_para for span in line['spans']]) para_text = ''.join([__get_span_text(span) for line in merge_para for span in line['spans']])
if debug_mode: if debug_mode:
logger.info(para_text) logger.debug(para_text)
layout_para[start:end+1] = [merge_para] layout_para[start:end+1] = [merge_para]
index_offset -= end-start index_offset -= end-start
...@@ -587,6 +587,8 @@ def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang): ...@@ -587,6 +587,8 @@ def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
3. 参照上述行尾特征进行分段。 3. 参照上述行尾特征进行分段。
4. 图、表,目前独占一行,不考虑分段。 4. 图、表,目前独占一行,不考虑分段。
""" """
if page_num==343:
pass
lines_group = __group_line_by_layout(blocks, layout_bboxes, lang) # block内分段 lines_group = __group_line_by_layout(blocks, layout_bboxes, lang) # block内分段
layout_paras, layout_list_info = __split_para_in_layoutbox(lines_group, new_layout_bbox, lang) # layout内分段 layout_paras, layout_list_info = __split_para_in_layoutbox(lines_group, new_layout_bbox, lang) # layout内分段
layout_paras2, page_list_info = __connect_list_inter_layout(layout_paras, new_layout_bbox, layout_list_info, page_num, lang) # layout之间连接列表段落 layout_paras2, page_list_info = __connect_list_inter_layout(layout_paras, new_layout_bbox, layout_list_info, page_num, lang) # layout之间连接列表段落
......
from sklearn.cluster import DBSCAN
import numpy as np
from loguru import logger
from magic_pdf.libs.boxbase import _is_in_or_part_overlap_with_area_ratio as is_in_layout
from magic_pdf.libs.ocr_content_type import ContentType, BlockType
from magic_pdf.model.magic_model import MagicModel
LINE_STOP_FLAG = ['.', '!', '?', '。', '!', '?', ":", ":", ")", ")", ";"]
INLINE_EQUATION = ContentType.InlineEquation
INTERLINE_EQUATION = ContentType.InterlineEquation
TEXT = ContentType.Text
def __get_span_text(span):
c = span.get('content', '')
if len(c) == 0:
c = span.get('image_path', '')
return c
def __detect_list_lines(lines, new_layout_bboxes, lang):
"""
探测是否包含了列表,并且把列表的行分开.
这样的段落特点是,顶格字母大写/数字,紧跟着几行缩进的。缩进的行首字母含小写的。
"""
def find_repeating_patterns(lst):
indices = []
ones_indices = []
i = 0
while i < len(lst) - 1: # 确保余下元素至少有2个
if lst[i] == 1 and lst[i + 1] in [2, 3]: # 额外检查以防止连续出现的1
start = i
ones_in_this_interval = [i]
i += 1
while i < len(lst) and lst[i] in [2, 3]:
i += 1
# 验证下一个序列是否符合条件
if i < len(lst) - 1 and lst[i] == 1 and lst[i + 1] in [2, 3] and lst[i - 1] in [2, 3]:
while i < len(lst) and lst[i] in [1, 2, 3]:
if lst[i] == 1:
ones_in_this_interval.append(i)
i += 1
indices.append((start, i - 1))
ones_indices.append(ones_in_this_interval)
else:
i += 1
else:
i += 1
return indices, ones_indices
"""===================="""
def split_indices(slen, index_array):
result = []
last_end = 0
for start, end in sorted(index_array):
if start > last_end:
# 前一个区间结束到下一个区间开始之间的部分标记为"text"
result.append(('text', last_end, start - 1))
# 区间内标记为"list"
result.append(('list', start, end))
last_end = end + 1
if last_end < slen:
# 如果最后一个区间结束后还有剩余的字符串,将其标记为"text"
result.append(('text', last_end, slen - 1))
return result
"""===================="""
if lang != 'en':
return lines, None
else:
total_lines = len(lines)
line_fea_encode = []
"""
对每一行进行特征编码,编码规则如下:
1. 如果行顶格,且大写字母开头或者数字开头,编码为1
2. 如果顶格,其他非大写开头编码为4
3. 如果非顶格,首字符大写,编码为2
4. 如果非顶格,首字符非大写编码为3
"""
for l in lines:
first_char = __get_span_text(l['spans'][0])[0]
layout = __find_layout_bbox_by_line(l['bbox'], new_layout_bboxes)
if not layout:
line_fea_encode.append(0)
else:
layout_left = layout[0]
if l['bbox'][0] == layout_left:
if first_char.isupper() or first_char.isdigit():
line_fea_encode.append(1)
else:
line_fea_encode.append(4)
else:
if first_char.isupper():
line_fea_encode.append(2)
else:
line_fea_encode.append(3)
# 然后根据编码进行分段, 选出来 1,2,3连续出现至少2次的行,认为是列表。
list_indice, list_start_idx = find_repeating_patterns(line_fea_encode)
if len(list_indice) > 0:
logger.info(f"发现了列表,列表行数:{list_indice}, {list_start_idx}")
# TODO check一下这个特列表里缩进的行左侧是不是对齐的。
segments = []
for start, end in list_indice:
for i in range(start, end + 1):
if i > 0:
if line_fea_encode[i] == 4:
logger.info(f"列表行的第{i}行不是顶格的")
break
else:
logger.info(f"列表行的第{start}到第{end}行是列表")
return split_indices(total_lines, list_indice), list_start_idx
def __valign_lines(blocks, layout_bboxes):
"""
在一个layoutbox内对齐行的左侧和右侧。
扫描行的左侧和右侧,如果x0, x1差距不超过一个阈值,就强行对齐到所处layout的左右两侧(和layout有一段距离)。
3是个经验值,TODO,计算得来,可以设置为1.5个正文字符。
"""
min_distance = 3
min_sample = 2
new_layout_bboxes = []
for layout_box in layout_bboxes:
blocks_in_layoutbox = [b for b in blocks if b["type"] == BlockType.Text and is_in_layout(b['bbox'], layout_box['layout_bbox'])]
if len(blocks_in_layoutbox) == 0 or len(blocks_in_layoutbox[0]["lines"]) == 0:
new_layout_bboxes.append(layout_box['layout_bbox'])
continue
x0_lst = np.array([[line['bbox'][0], 0] for block in blocks_in_layoutbox for line in block['lines']])
x1_lst = np.array([[line['bbox'][2], 0] for block in blocks_in_layoutbox for line in block['lines']])
x0_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x0_lst)
x1_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x1_lst)
x0_uniq_label = np.unique(x0_clusters.labels_)
x1_uniq_label = np.unique(x1_clusters.labels_)
x0_2_new_val = {} # 存储旧值对应的新值映射
x1_2_new_val = {}
for label in x0_uniq_label:
if label == -1:
continue
x0_index_of_label = np.where(x0_clusters.labels_ == label)
x0_raw_val = x0_lst[x0_index_of_label][:, 0]
x0_new_val = np.min(x0_lst[x0_index_of_label][:, 0])
x0_2_new_val.update({idx: x0_new_val for idx in x0_raw_val})
for label in x1_uniq_label:
if label == -1:
continue
x1_index_of_label = np.where(x1_clusters.labels_ == label)
x1_raw_val = x1_lst[x1_index_of_label][:, 0]
x1_new_val = np.max(x1_lst[x1_index_of_label][:, 0])
x1_2_new_val.update({idx: x1_new_val for idx in x1_raw_val})
for block in blocks_in_layoutbox:
for line in block['lines']:
x0, x1 = line['bbox'][0], line['bbox'][2]
if x0 in x0_2_new_val:
line['bbox'][0] = int(x0_2_new_val[x0])
if x1 in x1_2_new_val:
line['bbox'][2] = int(x1_2_new_val[x1])
# 其余对不齐的保持不动
# 由于修改了block里的line长度,现在需要重新计算block的bbox
for block in blocks_in_layoutbox:
if len(block["lines"]) > 0:
block['bbox'] = [min([line['bbox'][0] for line in block['lines']]),
min([line['bbox'][1] for line in block['lines']]),
max([line['bbox'][2] for line in block['lines']]),
max([line['bbox'][3] for line in block['lines']])]
"""新计算layout的bbox,因为block的bbox变了。"""
layout_x0 = min([block['bbox'][0] for block in blocks_in_layoutbox])
layout_y0 = min([block['bbox'][1] for block in blocks_in_layoutbox])
layout_x1 = max([block['bbox'][2] for block in blocks_in_layoutbox])
layout_y1 = max([block['bbox'][3] for block in blocks_in_layoutbox])
new_layout_bboxes.append([layout_x0, layout_y0, layout_x1, layout_y1])
return new_layout_bboxes
def __align_text_in_layout(blocks, layout_bboxes):
"""
由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。
"""
for layout in layout_bboxes:
lb = layout['layout_bbox']
blocks_in_layoutbox = [block for block in blocks if block["type"] == BlockType.Text and is_in_layout(block['bbox'], lb)]
if len(blocks_in_layoutbox) == 0:
continue
for block in blocks_in_layoutbox:
for line in block.get("lines", []):
x0, x1 = line['bbox'][0], line['bbox'][2]
if x0 < lb[0]:
line['bbox'][0] = lb[0]
if x1 > lb[2]:
line['bbox'][2] = lb[2]
def __common_pre_proc(blocks, layout_bboxes):
"""
不分语言的,对文本进行预处理
"""
# __add_line_period(blocks, layout_bboxes)
__align_text_in_layout(blocks, layout_bboxes)
aligned_layout_bboxes = __valign_lines(blocks, layout_bboxes)
return aligned_layout_bboxes
def __pre_proc_zh_blocks(blocks, layout_bboxes):
"""
对中文文本进行分段预处理
"""
pass
def __pre_proc_en_blocks(blocks, layout_bboxes):
"""
对英文文本进行分段预处理
"""
pass
def __group_line_by_layout(blocks, layout_bboxes, lang="en"):
"""
每个layout内的行进行聚合
"""
# 因为只是一个block一行目前, 一个block就是一个段落
lines_group = []
blocks_group = []
for lyout in layout_bboxes:
lines = [line for block in blocks if block["type"] == BlockType.Text and is_in_layout(block['bbox'], lyout['layout_bbox']) for line in
block['lines']]
blocks = [block for block in blocks if is_in_layout(block['bbox'], lyout['layout_bbox'])]
lines_group.append(lines)
blocks_group.append(blocks)
return lines_group, blocks_group
def __split_para_in_layoutbox2(lines_group, new_layout_bbox, lang="en", char_avg_len=10):
"""
"""
def __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang="en", char_avg_len=10):
"""
lines_group 进行行分段——layout内部进行分段。lines_group内每个元素是一个Layoutbox内的所有行。
1. 先计算每个group的左右边界。
2. 然后根据行末尾特征进行分段。
末尾特征:以句号等结束符结尾。并且距离右侧边界有一定距离。
且下一行开头不留空白。
"""
list_info = [] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
for blocks in blocks_group:
is_start_list = None
is_end_list = None
if len(blocks) == 0:
list_info.append([False, False])
continue
if blocks[0]["type"] != BlockType.Text and blocks[-1]["type"] != BlockType.Text:
list_info.append([False, False])
continue
if blocks[0]["type"] != BlockType.Text:
is_start_list = False
if blocks[-1]["type"] != BlockType.Text:
is_end_list = False
lines = [line for block in blocks if
block["type"] == BlockType.Text for line in
block['lines']]
total_lines = len(lines)
if total_lines == 1:
list_info.append([False, False])
continue
"""在进入到真正的分段之前,要对文字块从统计维度进行对齐方式的探测,
对齐方式分为以下:
1. 左对齐的文本块(特点是左侧顶格,或者左侧不顶格但是右侧顶格的行数大于非顶格的行数,顶格的首字母有大写也有小写)
1) 右侧对齐的行,单独成一段
2) 中间对齐的行,按照字体/行高聚合成一段
2. 左对齐的列表块(其特点是左侧顶格的行数小于等于非顶格的行数,非定格首字母会有小写,顶格90%是大写。并且左侧顶格行数大于1,大于1是为了这种模式连续出现才能称之为列表)
这样的文本块,顶格的为一个段落开头,紧随其后非顶格的行属于这个段落。
"""
text_segments, list_start_line = __detect_list_lines(lines, new_layout_bbox, lang)
"""根据list_range,把lines分成几个部分
"""
# layout_right = __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[2]
# layout_left = __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[0]
para = [] # 元素是line
layout_list_info = [False, False] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
for content_type, start, end in text_segments:
if content_type == 'list':
if start == 0 and is_start_list is None:
layout_list_info[0] = True
if end == total_lines - 1 and is_end_list is None:
layout_list_info[1] = True
# paras = __split_para_lines(lines, text_blocks)
list_info.append(layout_list_info)
return list_info
def __split_para_lines(lines: list, text_blocks: list) -> list:
text_paras = []
other_paras = []
text_lines = []
for line in lines:
spans_types = [span["type"] for span in line]
if ContentType.Table in spans_types:
other_paras.append([line])
continue
if ContentType.Image in spans_types:
other_paras.append([line])
continue
if ContentType.InterlineEquation in spans_types:
other_paras.append([line])
continue
text_lines.append(line)
for block in text_blocks:
block_bbox = block["bbox"]
para = []
for line in text_lines:
bbox = line["bbox"]
if is_in_layout(bbox, block_bbox):
para.append(line)
if len(para) > 0:
text_paras.append(para)
paras = other_paras.extend(text_paras)
paras_sorted = sorted(paras, key = lambda x: x[0]["bbox"][1])
return paras_sorted
def __connect_list_inter_layout(blocks_group, new_layout_bbox, layout_list_info, page_num, lang):
"""
如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO 因为没有区分列表和段落,所以这个方法暂时不实现。
根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。
"""
if len(blocks_group) == 0 or len(blocks_group) == 0: # 0的时候最后的return 会出错
return blocks_group, [False, False]
for i in range(1, len(blocks_group)):
if len(blocks_group[i]) == 0 or len(blocks_group[i-1]) == 0:
continue
pre_layout_list_info = layout_list_info[i - 1]
next_layout_list_info = layout_list_info[i]
pre_last_para = blocks_group[i - 1][-1].get("lines", [])
next_paras = blocks_group[i]
next_first_para = next_paras[0]
if pre_layout_list_info[1] and not next_layout_list_info[0] and next_first_para["type"] == BlockType.Text: # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
logger.info(f"连接page {page_num} 内的list")
# 向layout_paras[i] 寻找开头具有相同缩进的连续的行
may_list_lines = []
for j in range(len(next_paras)):
lines = next_paras[j].get("lines", [])
if len(lines) == 1: # 只可能是一行,多行情况再需要分析了
if lines[0]['bbox'][0] > __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[0]:
may_list_lines.append(lines[0])
else:
break
else:
break
# 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
if len(may_list_lines) > 0 and len(set([x['bbox'][0] for x in may_list_lines])) == 1:
pre_last_para.extend(may_list_lines)
blocks_group[i] = blocks_group[i][len(may_list_lines):]
# layout_paras[i] = layout_paras[i][len(may_list_lines):]
return blocks_group, [layout_list_info[0][0], layout_list_info[-1][1]] # 同时还返回了这个页面级别的开头、结尾是不是列表的信息
def __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox,
pre_page_list_info, next_page_list_info, page_num, lang):
"""
如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO 因为没有区分列表和段落,所以这个方法暂时不实现。
根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。
"""
if len(pre_page_paras) == 0 or len(next_page_paras) == 0: # 0的时候最后的return 会出错
return False
if len(pre_page_paras[-1]) == 0 or len(next_page_paras[0]) == 0:
return False
if pre_page_paras[-1][-1]["type"] != BlockType.Text or next_page_paras[0][0]["type"] != BlockType.Text:
return False
if pre_page_list_info[1] and not next_page_list_info[0]: # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
logger.info(f"连接page {page_num} 内的list")
# 向layout_paras[i] 寻找开头具有相同缩进的连续的行
may_list_lines = []
for j in range(len(next_page_paras[0])):
lines = next_page_paras[0][j]["lines"]
if len(lines) == 1: # 只可能是一行,多行情况再需要分析了
if lines[0]['bbox'][0] > __find_layout_bbox_by_line(lines[0]['bbox'], next_page_layout_bbox)[0]:
may_list_lines.append(lines[0])
else:
break
else:
break
# 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
if len(may_list_lines) > 0 and len(set([x['bbox'][0] for x in may_list_lines])) == 1:
#pre_page_paras[-1].append(may_list_lines)
pre_page_paras[-1][-1]["lines"].extend(may_list_lines)
next_page_paras[0] = next_page_paras[0][len(may_list_lines):]
return True
return False
def __find_layout_bbox_by_line(line_bbox, layout_bboxes):
"""
根据line找到所在的layout
"""
for layout in layout_bboxes:
if is_in_layout(line_bbox, layout):
return layout
return None
def __connect_para_inter_layoutbox(blocks_group, new_layout_bbox, lang):
"""
layout之间进行分段。
主要是计算前一个layOut的最后一行和后一个layout的第一行是否可以连接。
连接的条件需要同时满足:
1. 上一个layout的最后一行沾满整个行。并且没有结尾符号。
2. 下一行开头不留空白。
"""
connected_layout_paras = []
connected_layout_blocks = []
if len(blocks_group) == 0:
return connected_layout_blocks
#connected_layout_paras.append(layout_paras[0])
connected_layout_blocks.append(blocks_group[0])
for i in range(1, len(blocks_group)):
try:
if len(blocks_group[i]) == 0 or len(blocks_group[i - 1]) == 0: # TODO 考虑连接问题,
continue
# text类型的段才需要考虑layout间的合并
if blocks_group[i - 1][-1]["type"] != BlockType.Text or blocks_group[i][0]["type"] != BlockType.Text:
connected_layout_blocks.append(blocks_group[i])
continue
pre_last_line = blocks_group[i - 1][-1]["lines"][-1]
next_first_line = blocks_group[i][0]["lines"][0]
except Exception as e:
logger.error(f"page layout {i} has no line")
continue
pre_last_line_text = ''.join([__get_span_text(span) for span in pre_last_line['spans']])
pre_last_line_type = pre_last_line['spans'][-1]['type']
next_first_line_text = ''.join([__get_span_text(span) for span in next_first_line['spans']])
next_first_line_type = next_first_line['spans'][0]['type']
if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT, INLINE_EQUATION]:
#connected_layout_paras.append(layout_paras[i])
connected_layout_blocks.append(blocks_group[i])
continue
pre_x2_max = __find_layout_bbox_by_line(pre_last_line['bbox'], new_layout_bbox)[2]
next_x0_min = __find_layout_bbox_by_line(next_first_line['bbox'], new_layout_bbox)[0]
pre_last_line_text = pre_last_line_text.strip()
next_first_line_text = next_first_line_text.strip()
if pre_last_line['bbox'][2] == pre_x2_max and pre_last_line_text[-1] not in LINE_STOP_FLAG and \
next_first_line['bbox'][0] == next_x0_min: # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
"""连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
#connected_layout_paras[-1][-1].extend(layout_paras[i][0])
connected_layout_blocks[-1][-1]["lines"].extend(blocks_group[i][0]["lines"])
#layout_paras[i].pop(0) # 删除后一个layout的第一个段落, 因为他已经被合并到前一个layout的最后一个段落了。
blocks_group[i][0]["lines"] = [] #删除后一个layout第一个段落中的lines,因为他已经被合并到前一个layout的最后一个段落了
blocks_group[i][0]["lines_deleted"] = True
# if len(layout_paras[i]) == 0:
# layout_paras.pop(i)
# else:
# connected_layout_paras.append(layout_paras[i])
connected_layout_blocks.append(blocks_group[i])
else:
"""连接段落条件不成立,将前一个layout的段落加入到结果中。"""
#connected_layout_paras.append(layout_paras[i])
connected_layout_blocks.append(blocks_group[i])
return connected_layout_blocks
def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, page_num,
lang):
"""
连接起来相邻两个页面的段落——前一个页面最后一个段落和后一个页面的第一个段落。
是否可以连接的条件:
1. 前一个页面的最后一个段落最后一行沾满整个行。并且没有结尾符号。
2. 后一个页面的第一个段落第一行没有空白开头。
"""
# 有的页面可能压根没有文字
if len(pre_page_paras) == 0 or len(next_page_paras) == 0 or len(pre_page_paras[0]) == 0 or len(
next_page_paras[0]) == 0: # TODO [[]]为什么出现在pre_page_paras里?
return False
pre_last_block = pre_page_paras[-1][-1]
next_first_block = next_page_paras[0][0]
if pre_last_block["type"] != BlockType.Text or next_first_block["type"] != BlockType.Text:
return False
if len(pre_last_block["lines"]) == 0 or len(next_first_block["lines"]) == 0:
return False
pre_last_para = pre_last_block["lines"]
next_first_para = next_first_block["lines"]
pre_last_line = pre_last_para[-1]
next_first_line = next_first_para[0]
pre_last_line_text = ''.join([__get_span_text(span) for span in pre_last_line['spans']])
pre_last_line_type = pre_last_line['spans'][-1]['type']
next_first_line_text = ''.join([__get_span_text(span) for span in next_first_line['spans']])
next_first_line_type = next_first_line['spans'][0]['type']
if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT,
INLINE_EQUATION]: # TODO,真的要做好,要考虑跨table, image, 行间的情况
# 不是文本,不连接
return False
pre_x2_max = __find_layout_bbox_by_line(pre_last_line['bbox'], pre_page_layout_bbox)[2]
next_x0_min = __find_layout_bbox_by_line(next_first_line['bbox'], next_page_layout_bbox)[0]
pre_last_line_text = pre_last_line_text.strip()
next_first_line_text = next_first_line_text.strip()
if pre_last_line['bbox'][2] == pre_x2_max and pre_last_line_text[-1] not in LINE_STOP_FLAG and \
next_first_line['bbox'][0] == next_x0_min: # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
"""连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
pre_last_para.extend(next_first_para)
#next_page_paras[0].pop(0) # 删除后一个页面的第一个段落, 因为他已经被合并到前一个页面的最后一个段落了。
next_page_paras[0][0]["lines"] = []
next_page_paras[0][0]["lines_deleted"] = True
return True
else:
return False
def find_consecutive_true_regions(input_array):
start_index = None # 连续True区域的起始索引
regions = [] # 用于保存所有连续True区域的起始和结束索引
for i in range(len(input_array)):
# 如果我们找到了一个True值,并且当前并没有在连续True区域中
if input_array[i] and start_index is None:
start_index = i # 记录连续True区域的起始索引
# 如果我们找到了一个False值,并且当前在连续True区域中
elif not input_array[i] and start_index is not None:
# 如果连续True区域长度大于1,那么将其添加到结果列表中
if i - start_index > 1:
regions.append((start_index, i - 1))
start_index = None # 重置起始索引
# 如果最后一个元素是True,那么需要将最后一个连续True区域加入到结果列表中
if start_index is not None and len(input_array) - start_index > 1:
regions.append((start_index, len(input_array) - 1))
return regions
def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, debug_mode):
"""
找出来中间对齐的连续单行文本,如果连续行高度相同,那么合并为一个段落。
一个line居中的条件是:
1. 水平中心点跨越layout的中心点。
2. 左右两侧都有空白
"""
for layout_i, layout_para in enumerate(page_paras):
layout_box = new_layout_bbox[layout_i]
single_line_paras_tag = []
for i in range(len(layout_para)):
#single_line_paras_tag.append(len(layout_para[i]) == 1 and layout_para[i][0]['spans'][0]['type'] == TEXT)
single_line_paras_tag.append(layout_para[i]['type'] == BlockType.Text and len(layout_para[i]["lines"]) == 1)
"""找出来连续的单行文本,如果连续行高度相同,那么合并为一个段落。"""
consecutive_single_line_indices = find_consecutive_true_regions(single_line_paras_tag)
if len(consecutive_single_line_indices) > 0:
#index_offset = 0
"""检查这些行是否是高度相同的,居中的"""
for start, end in consecutive_single_line_indices:
#start += index_offset
#end += index_offset
line_hi = np.array([block["lines"][0]['bbox'][3] - block["lines"][0]['bbox'][1] for block in layout_para[start:end + 1]])
first_line_text = ''.join([__get_span_text(span) for span in layout_para[start]["lines"][0]['spans']])
if "Table" in first_line_text or "Figure" in first_line_text:
pass
if debug_mode:
logger.info(line_hi.std())
if line_hi.std() < 2:
"""行高度相同,那么判断是否居中"""
all_left_x0 = [block["lines"][0]['bbox'][0] for block in layout_para[start:end + 1]]
all_right_x1 = [block["lines"][0]['bbox'][2] for block in layout_para[start:end + 1]]
layout_center = (layout_box[0] + layout_box[2]) / 2
if all([x0 < layout_center < x1 for x0, x1 in zip(all_left_x0, all_right_x1)]) \
and not all([x0 == layout_box[0] for x0 in all_left_x0]) \
and not all([x1 == layout_box[2] for x1 in all_right_x1]):
merge_para = [block["lines"][0] for block in layout_para[start:end + 1]]
para_text = ''.join([__get_span_text(span) for line in merge_para for span in line['spans']])
if debug_mode:
logger.info(para_text)
layout_para[start]["lines"] = merge_para
for i_para in range(start+1, end+1):
layout_para[i_para]["lines"] = []
layout_para[i_para]["lines_deleted"] = True
#layout_para[start:end + 1] = [merge_para]
#index_offset -= end - start
return
def __merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang):
"""
找出来连续的单行文本,如果首行顶格,接下来的几个单行段落缩进对齐,那么合并为一个段落。
"""
pass
def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
"""
根据line和layout情况进行分段
先实现一个根据行末尾特征分段的简单方法。
"""
"""
算法思路:
1. 扫描layout里每一行,找出来行尾距离layout有边界有一定距离的行。
2. 从上述行中找到末尾是句号等可作为断行标志的行。
3. 参照上述行尾特征进行分段。
4. 图、表,目前独占一行,不考虑分段。
"""
lines_group, blocks_group = __group_line_by_layout(blocks, layout_bboxes, lang) # block内分段
layout_list_info = __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang) # layout内分段
blocks_group, page_list_info = __connect_list_inter_layout(blocks_group, new_layout_bbox, layout_list_info,
page_num, lang) # layout之间连接列表段落
connected_layout_blocks = __connect_para_inter_layoutbox(blocks_group, new_layout_bbox, lang) # layout间链接段落
return connected_layout_blocks, page_list_info
def para_split(pdf_info_dict, debug_mode, lang="en"):
new_layout_of_pages = [] # 数组的数组,每个元素是一个页面的layoutS
all_page_list_info = [] # 保存每个页面开头和结尾是否是列表
for page_num, page in pdf_info_dict.items():
blocks = page['preproc_blocks']
layout_bboxes = page['layout_bboxes']
new_layout_bbox = __common_pre_proc(blocks, layout_bboxes)
new_layout_of_pages.append(new_layout_bbox)
splited_blocks, page_list_info = __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang)
all_page_list_info.append(page_list_info)
page['para_blocks'] = splited_blocks
"""连接页面与页面之间的可能合并的段落"""
pdf_infos = list(pdf_info_dict.values())
for page_num, page in enumerate(pdf_info_dict.values()):
if page_num == 0:
continue
pre_page_paras = pdf_infos[page_num - 1]['para_blocks']
next_page_paras = pdf_infos[page_num]['para_blocks']
pre_page_layout_bbox = new_layout_of_pages[page_num - 1]
next_page_layout_bbox = new_layout_of_pages[page_num]
is_conn = __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox,
next_page_layout_bbox, page_num, lang)
if debug_mode:
if is_conn:
logger.info(f"连接了第{page_num - 1}页和第{page_num}页的段落")
is_list_conn = __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox,
next_page_layout_bbox, all_page_list_info[page_num - 1],
all_page_list_info[page_num], page_num, lang)
if debug_mode:
if is_list_conn:
logger.info(f"连接了第{page_num - 1}页和第{page_num}页的列表段落")
"""接下来可能会漏掉一些特别的一些可以合并的内容,对他们进行段落连接
1. 正文中有时出现一个行顶格,接下来几行缩进的情况。
2. 居中的一些连续单行,如果高度相同,那么可能是一个段落。
"""
for page_num, page in enumerate(pdf_info_dict.values()):
page_paras = page['para_blocks']
new_layout_bbox = new_layout_of_pages[page_num]
__connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, debug_mode=debug_mode)
__merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang)
# layout展平
for page_num, page in enumerate(pdf_info_dict.values()):
page_paras = page['para_blocks']
page_blocks = [block for layout in page_paras for block in layout]
page["para_blocks"] = page_blocks
...@@ -5,6 +5,7 @@ from magic_pdf.libs.commons import ( ...@@ -5,6 +5,7 @@ from magic_pdf.libs.commons import (
get_delta_time, get_delta_time,
get_docx_model_output, get_docx_model_output,
) )
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.coordinate_transform import get_scale_ratio from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.drop_tag import DropTag from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.hash_utils import compute_md5 from magic_pdf.libs.hash_utils import compute_md5
...@@ -15,7 +16,7 @@ from magic_pdf.pre_proc.detect_footer_by_model import parse_footers ...@@ -15,7 +16,7 @@ from magic_pdf.pre_proc.detect_footer_by_model import parse_footers
from magic_pdf.pre_proc.detect_footnote import parse_footnotes_by_model from magic_pdf.pre_proc.detect_footnote import parse_footnotes_by_model
from magic_pdf.pre_proc.detect_header import parse_headers from magic_pdf.pre_proc.detect_header import parse_headers
from magic_pdf.pre_proc.detect_page_number import parse_pageNos from magic_pdf.pre_proc.detect_page_number import parse_pageNos
from magic_pdf.pre_proc.ocr_cut_image import cut_image_and_table from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.ocr_detect_layout import layout_detect from magic_pdf.pre_proc.ocr_detect_layout import layout_detect
from magic_pdf.pre_proc.ocr_dict_merge import ( from magic_pdf.pre_proc.ocr_dict_merge import (
merge_spans_to_line_by_layout, merge_lines_to_block, merge_spans_to_line_by_layout, merge_lines_to_block,
...@@ -26,7 +27,6 @@ from magic_pdf.pre_proc.ocr_span_list_modify import remove_spans_by_bboxes, remo ...@@ -26,7 +27,6 @@ from magic_pdf.pre_proc.ocr_span_list_modify import remove_spans_by_bboxes, remo
from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox
def parse_pdf_by_ocr( def parse_pdf_by_ocr(
pdf_bytes, pdf_bytes,
pdf_model_output, pdf_model_output,
...@@ -147,7 +147,7 @@ def parse_pdf_by_ocr( ...@@ -147,7 +147,7 @@ def parse_pdf_by_ocr(
spans, dropped_spans_by_removed_bboxes = remove_spans_by_bboxes_dict(spans, need_remove_spans_bboxes_dict) spans, dropped_spans_by_removed_bboxes = remove_spans_by_bboxes_dict(spans, need_remove_spans_bboxes_dict)
'''对image和table截图''' '''对image和table截图'''
spans = cut_image_and_table(spans, page, page_id, pdf_bytes_md5, imageWriter) spans = ocr_cut_image_and_table(spans, page, page_id, pdf_bytes_md5, imageWriter)
'''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)''' '''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)'''
displayed_list = [] displayed_list = []
...@@ -208,6 +208,13 @@ def parse_pdf_by_ocr( ...@@ -208,6 +208,13 @@ def parse_pdf_by_ocr(
pdf_info_dict[f"page_{page_id}"] = page_info pdf_info_dict[f"page_{page_id}"] = page_info
"""分段""" """分段"""
para_split(pdf_info_dict, debug_mode=debug_mode) para_split(pdf_info_dict, debug_mode=debug_mode)
return pdf_info_dict """dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = {
"pdf_info": pdf_info_list,
}
return new_pdf_info_dict
import time
from loguru import logger
from magic_pdf.layout.layout_sort import get_bboxes_layout
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.model.magic_model import MagicModel
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split
from magic_pdf.pre_proc.ocr_dict_merge import sort_blocks_by_layout, fill_spans_in_blocks, fix_block_spans
from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2
# from magic_pdf.para.para_split import para_split
from magic_pdf.para.para_split_v2 import para_split
def parse_pdf_by_ocr(pdf_bytes,
model_list,
imageWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
):
pdf_bytes_md5 = compute_md5(pdf_bytes)
pdf_docs = fitz.open("pdf", pdf_bytes)
'''初始化空的pdf_info_dict'''
pdf_info_dict = {}
'''用model_list和docs对象初始化magic_model'''
magic_model = MagicModel(model_list, pdf_docs)
'''根据输入的起始范围解析pdf'''
end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
'''初始化启动时间'''
start_time = time.time()
for page_id in range(start_page_id, end_page_id + 1):
'''debug时输出每页解析的耗时'''
if debug_mode:
time_now = time.time()
logger.info(
f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}"
)
start_time = time_now
'''从magic_model对象中获取后面会用到的区块信息'''
img_blocks = magic_model.get_imgs(page_id)
table_blocks = magic_model.get_tables(page_id)
discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
page_w, page_h = magic_model.get_page_size(page_id)
'''将所有区块的bbox整理到一起'''
all_bboxes = ocr_prepare_bboxes_for_layout_split(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
interline_equation_blocks, page_w, page_h)
'''根据区块信息计算layout'''
page_boundry = [0, 0, page_w, page_h]
layout_bboxes, layout_tree = get_bboxes_layout(all_bboxes, page_boundry, page_id)
'''根据layout顺序,对当前页面所有需要留下的block进行排序'''
sorted_blocks = sort_blocks_by_layout(all_bboxes, layout_bboxes)
'''获取所有需要拼接的span资源'''
spans = magic_model.get_all_spans(page_id)
'''删除重叠spans中较小的那些'''
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
'''对image和table截图'''
spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter)
'''将span填入排好序的blocks中'''
block_with_spans = fill_spans_in_blocks(sorted_blocks, spans)
'''对block进行fix操作'''
fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
'''获取QA需要外置的list'''
images, tables, interline_equations = get_qa_need_list_v2(fix_blocks)
'''构造pdf_info_dict'''
page_info = ocr_construct_page_component_v2(fix_blocks, layout_bboxes, page_id, page_w, page_h, layout_tree,
images, tables, interline_equations, discarded_blocks)
pdf_info_dict[f"page_{page_id}"] = page_info
"""分段"""
try:
para_split(pdf_info_dict, debug_mode=debug_mode)
except Exception as e:
logger.exception(e)
raise e
"""dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = {
"pdf_info": pdf_info_list,
}
return new_pdf_info_dict
...@@ -11,11 +11,13 @@ from magic_pdf.layout.bbox_sort import ( ...@@ -11,11 +11,13 @@ from magic_pdf.layout.bbox_sort import (
prepare_bboxes_for_layout_split, prepare_bboxes_for_layout_split,
) )
from magic_pdf.layout.layout_sort import LAYOUT_UNPROC, get_bboxes_layout, get_columns_cnt_of_layout, sort_text_block from magic_pdf.layout.layout_sort import LAYOUT_UNPROC, get_bboxes_layout, get_columns_cnt_of_layout, sort_text_block
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.drop_reason import DropReason from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5 from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.markdown_utils import escape_special_markdown_char from magic_pdf.libs.markdown_utils import escape_special_markdown_char
from magic_pdf.libs.safe_filename import sanitize_filename from magic_pdf.libs.safe_filename import sanitize_filename
from magic_pdf.libs.vis_utils import draw_bbox_on_page, draw_layout_bbox_on_page from magic_pdf.libs.vis_utils import draw_bbox_on_page, draw_layout_bbox_on_page
from magic_pdf.pre_proc.cut_image import txt_save_images_by_bboxes
from magic_pdf.pre_proc.detect_images import parse_images from magic_pdf.pre_proc.detect_images import parse_images
from magic_pdf.pre_proc.detect_tables import parse_tables # 获取tables的bbox from magic_pdf.pre_proc.detect_tables import parse_tables # 获取tables的bbox
from magic_pdf.pre_proc.detect_equation import parse_equations # 获取equations的bbox from magic_pdf.pre_proc.detect_equation import parse_equations # 获取equations的bbox
...@@ -47,8 +49,6 @@ from para.exceptions import ( ...@@ -47,8 +49,6 @@ from para.exceptions import (
) )
''' '''
from magic_pdf.libs.commons import read_file, join_path
from magic_pdf.libs.pdf_image_tools import save_images_by_bboxes
from magic_pdf.post_proc.remove_footnote import merge_footnote_blocks, remove_footnote_blocks from magic_pdf.post_proc.remove_footnote import merge_footnote_blocks, remove_footnote_blocks
from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
from magic_pdf.pre_proc.equations_replace import combine_chars_to_pymudict, remove_chars_in_text_blocks, replace_equations_in_textblock from magic_pdf.pre_proc.equations_replace import combine_chars_to_pymudict, remove_chars_in_text_blocks, replace_equations_in_textblock
...@@ -107,7 +107,7 @@ def parse_pdf_by_txt( ...@@ -107,7 +107,7 @@ def parse_pdf_by_txt(
# 去除对junkimg的依赖,简化逻辑 # 去除对junkimg的依赖,简化逻辑
if len(page_imgs) > 1500: # 如果当前页超过1500张图片,直接跳过 if len(page_imgs) > 1500: # 如果当前页超过1500张图片,直接跳过
logger.warning(f"page_id: {page_id}, img_counts: {len(page_imgs)}, drop this pdf") logger.warning(f"page_id: {page_id}, img_counts: {len(page_imgs)}, drop this pdf")
result = {"need_drop": True, "drop_reason": DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS} result = {"_need_drop": True, "_drop_reason": DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}
if not debug_mode: if not debug_mode:
return result return result
...@@ -193,7 +193,7 @@ def parse_pdf_by_txt( ...@@ -193,7 +193,7 @@ def parse_pdf_by_txt(
""" """
# 把图、表、公式都进行截图,保存到存储上,返回图片路径作为内容 # 把图、表、公式都进行截图,保存到存储上,返回图片路径作为内容
image_info, image_backup_info, table_info, inline_eq_info, interline_eq_info = save_images_by_bboxes( image_info, image_backup_info, table_info, inline_eq_info, interline_eq_info = txt_save_images_by_bboxes(
page_id, page_id,
page, page,
pdf_bytes_md5, pdf_bytes_md5,
...@@ -236,7 +236,7 @@ def parse_pdf_by_txt( ...@@ -236,7 +236,7 @@ def parse_pdf_by_txt(
if is_text_block_horz_overlap: if is_text_block_horz_overlap:
# debug_show_bbox(pdf_docs, page_id, [b['bbox'] for b in remain_text_blocks], [], [], join_path(save_path, book_name, f"{book_name}_debug.pdf"), 0) # debug_show_bbox(pdf_docs, page_id, [b['bbox'] for b in remain_text_blocks], [], [], join_path(save_path, book_name, f"{book_name}_debug.pdf"), 0)
logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TEXT_BLCOK_HOR_OVERLAP}") logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TEXT_BLCOK_HOR_OVERLAP}")
result = {"need_drop": True, "drop_reason": DropReason.TEXT_BLCOK_HOR_OVERLAP} result = {"_need_drop": True, "_drop_reason": DropReason.TEXT_BLCOK_HOR_OVERLAP}
if not debug_mode: if not debug_mode:
return result return result
...@@ -255,14 +255,14 @@ def parse_pdf_by_txt( ...@@ -255,14 +255,14 @@ def parse_pdf_by_txt(
if len(remain_text_blocks)>0 and len(all_bboxes)>0 and len(layout_bboxes)==0: if len(remain_text_blocks)>0 and len(all_bboxes)>0 and len(layout_bboxes)==0:
logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}") logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}")
result = {"need_drop": True, "drop_reason": DropReason.CAN_NOT_DETECT_PAGE_LAYOUT} result = {"_need_drop": True, "_drop_reason": DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}
if not debug_mode: if not debug_mode:
return result return result
"""以下去掉复杂的布局和超过2列的布局""" """以下去掉复杂的布局和超过2列的布局"""
if any([lay["layout_label"] == LAYOUT_UNPROC for lay in layout_bboxes]): # 复杂的布局 if any([lay["layout_label"] == LAYOUT_UNPROC for lay in layout_bboxes]): # 复杂的布局
logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.COMPLICATED_LAYOUT}") logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.COMPLICATED_LAYOUT}")
result = {"need_drop": True, "drop_reason": DropReason.COMPLICATED_LAYOUT} result = {"_need_drop": True, "_drop_reason": DropReason.COMPLICATED_LAYOUT}
if not debug_mode: if not debug_mode:
return result return result
...@@ -270,8 +270,8 @@ def parse_pdf_by_txt( ...@@ -270,8 +270,8 @@ def parse_pdf_by_txt(
if layout_column_width > 2: # 去掉超过2列的布局pdf if layout_column_width > 2: # 去掉超过2列的布局pdf
logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}") logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}")
result = { result = {
"need_drop": True, "_need_drop": True,
"drop_reason": DropReason.TOO_MANY_LAYOUT_COLUMNS, "_drop_reason": DropReason.TOO_MANY_LAYOUT_COLUMNS,
"extra_info": {"column_cnt": layout_column_width}, "extra_info": {"column_cnt": layout_column_width},
} }
if not debug_mode: if not debug_mode:
...@@ -377,27 +377,34 @@ def parse_pdf_by_txt( ...@@ -377,27 +377,34 @@ def parse_pdf_by_txt(
logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {error_info}") logger.warning(f"page_id: {page_id}, drop this pdf: {pdf_bytes_md5}, reason: {error_info}")
if error_info == denseSingleLineBlockException_msg: if error_info == denseSingleLineBlockException_msg:
logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.DENSE_SINGLE_LINE_BLOCK}") logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.DENSE_SINGLE_LINE_BLOCK}")
result = {"need_drop": True, "drop_reason": DropReason.DENSE_SINGLE_LINE_BLOCK} result = {"_need_drop": True, "_drop_reason": DropReason.DENSE_SINGLE_LINE_BLOCK}
return result return result
if error_info == titleDetectionException_msg: if error_info == titleDetectionException_msg:
logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TITLE_DETECTION_FAILED}") logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TITLE_DETECTION_FAILED}")
result = {"need_drop": True, "drop_reason": DropReason.TITLE_DETECTION_FAILED} result = {"_need_drop": True, "_drop_reason": DropReason.TITLE_DETECTION_FAILED}
return result return result
elif error_info == titleLevelException_msg: elif error_info == titleLevelException_msg:
logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TITLE_LEVEL_FAILED}") logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.TITLE_LEVEL_FAILED}")
result = {"need_drop": True, "drop_reason": DropReason.TITLE_LEVEL_FAILED} result = {"_need_drop": True, "_drop_reason": DropReason.TITLE_LEVEL_FAILED}
return result return result
elif error_info == paraSplitException_msg: elif error_info == paraSplitException_msg:
logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.PARA_SPLIT_FAILED}") logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.PARA_SPLIT_FAILED}")
result = {"need_drop": True, "drop_reason": DropReason.PARA_SPLIT_FAILED} result = {"_need_drop": True, "_drop_reason": DropReason.PARA_SPLIT_FAILED}
return result return result
elif error_info == paraMergeException_msg: elif error_info == paraMergeException_msg:
logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.PARA_MERGE_FAILED}") logger.warning(f"Drop this pdf: {pdf_bytes_md5}, reason: {DropReason.PARA_MERGE_FAILED}")
result = {"need_drop": True, "drop_reason": DropReason.PARA_MERGE_FAILED} result = {"_need_drop": True, "_drop_reason": DropReason.PARA_MERGE_FAILED}
return result return result
pdf_info_dict, error_info = para_process_pipeline.para_process_pipeline(pdf_info_dict) pdf_info_dict, error_info = para_process_pipeline.para_process_pipeline(pdf_info_dict)
if error_info is not None: if error_info is not None:
return _deal_with_text_exception(error_info) return _deal_with_text_exception(error_info)
return pdf_info_dict
"""dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = {
"pdf_info": pdf_info_list,
}
return new_pdf_info_dict
import time
from loguru import logger
from magic_pdf.layout.layout_sort import get_bboxes_layout
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.model.magic_model import MagicModel
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split
from magic_pdf.pre_proc.ocr_dict_merge import (
sort_blocks_by_layout,
fill_spans_in_blocks,
fix_block_spans,
)
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.pre_proc.ocr_span_list_modify import (
remove_overlaps_min_spans,
get_qa_need_list_v2,
)
from magic_pdf.pre_proc.equations_replace import (
combine_chars_to_pymudict,
remove_chars_in_text_blocks,
replace_equations_in_textblock,
)
from magic_pdf.pre_proc.equations_replace import (
combine_chars_to_pymudict,
remove_chars_in_text_blocks,
replace_equations_in_textblock,
)
from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
from magic_pdf.libs.math import float_equal
from magic_pdf.para.para_split_v2 import para_split
def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]
char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[
"blocks"
]
text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
text_blocks = replace_equations_in_textblock(
text_blocks, inline_equations, interline_equations
)
text_blocks = remove_citation_marker(text_blocks)
text_blocks = remove_chars_in_text_blocks(text_blocks)
spans = []
for v in text_blocks:
for line in v["lines"]:
for span in line["spans"]:
bbox = span["bbox"]
if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
continue
spans.append(
{
"bbox": list(span["bbox"]),
"content": span["text"],
"type": ContentType.Text,
}
)
return spans
def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
def parse_pdf_by_txt(
pdf_bytes,
model_list,
imageWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
):
pdf_bytes_md5 = compute_md5(pdf_bytes)
pdf_docs = fitz.open("pdf", pdf_bytes)
"""初始化空的pdf_info_dict"""
pdf_info_dict = {}
"""用model_list和docs对象初始化magic_model"""
magic_model = MagicModel(model_list, pdf_docs)
"""根据输入的起始范围解析pdf"""
end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
"""初始化启动时间"""
start_time = time.time()
for page_id in range(start_page_id, end_page_id + 1):
"""debug时输出每页解析的耗时"""
if debug_mode:
time_now = time.time()
logger.info(
f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}"
)
start_time = time_now
"""从magic_model对象中获取后面会用到的区块信息"""
img_blocks = magic_model.get_imgs(page_id)
table_blocks = magic_model.get_tables(page_id)
discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = (
magic_model.get_equations(page_id)
)
page_w, page_h = magic_model.get_page_size(page_id)
"""将所有区块的bbox整理到一起"""
all_bboxes = ocr_prepare_bboxes_for_layout_split(
img_blocks,
table_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equation_blocks,
page_w,
page_h,
)
"""根据区块信息计算layout"""
page_boundry = [0, 0, page_w, page_h]
layout_bboxes, layout_tree = get_bboxes_layout(
all_bboxes, page_boundry, page_id
)
"""根据layout顺序,对当前页面所有需要留下的block进行排序"""
sorted_blocks = sort_blocks_by_layout(all_bboxes, layout_bboxes)
"""ocr 中文本类的 span 用 pymu spans 替换!"""
ocr_spans = magic_model.get_all_spans(page_id)
pymu_spans = txt_spans_extract(
pdf_docs[page_id], inline_equations, interline_equations
)
spans = replace_text_span(pymu_spans, ocr_spans)
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
"""对image和table截图"""
spans = ocr_cut_image_and_table(
spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter
)
"""将span填入排好序的blocks中"""
block_with_spans = fill_spans_in_blocks(sorted_blocks, spans)
"""对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
"""获取QA需要外置的list"""
images, tables, interline_equations = get_qa_need_list_v2(fix_blocks)
"""构造pdf_info_dict"""
page_info = ocr_construct_page_component_v2(
fix_blocks,
layout_bboxes,
page_id,
page_w,
page_h,
layout_tree,
images,
tables,
interline_equations,
discarded_blocks,
)
pdf_info_dict[f"page_{page_id}"] = page_info
"""分段"""
try:
para_split(pdf_info_dict, debug_mode=debug_mode)
except Exception as e:
logger.exception(e)
raise e
"""dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = {
"pdf_info": pdf_info_list,
}
return new_pdf_info_dict
if __name__ == "__main__":
if 1:
import fitz
import json
with open("/opt/data/pdf/20240418/25536-00.pdf", "rb") as f:
pdf_bytes = f.read()
pdf_docs = fitz.open("pdf", pdf_bytes)
with open("/opt/data/pdf/20240418/25536-00.json") as f:
model_list = json.loads(f.readline())
magic_model = MagicModel(model_list, pdf_docs)
for i in range(7):
print(magic_model.get_imgs(i))
for page_no, page in enumerate(pdf_docs):
inline_equations, interline_equations, interline_equation_blocks = (
magic_model.get_equations(page_no)
)
text_raw_blocks = page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]
char_level_text_blocks = page.get_text(
"rawdict", flags=fitz.TEXTFLAGS_TEXT
)["blocks"]
text_blocks = combine_chars_to_pymudict(
text_raw_blocks, char_level_text_blocks
)
text_blocks = replace_equations_in_textblock(
text_blocks, inline_equations, interline_equations
)
text_blocks = remove_citation_marker(text_blocks)
text_blocks = remove_chars_in_text_blocks(text_blocks)
...@@ -26,6 +26,7 @@ from magic_pdf.libs.drop_reason import DropReason ...@@ -26,6 +26,7 @@ from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.markdown_utils import escape_special_markdown_char from magic_pdf.libs.markdown_utils import escape_special_markdown_char
from magic_pdf.libs.safe_filename import sanitize_filename from magic_pdf.libs.safe_filename import sanitize_filename
from magic_pdf.libs.vis_utils import draw_bbox_on_page, draw_layout_bbox_on_page from magic_pdf.libs.vis_utils import draw_bbox_on_page, draw_layout_bbox_on_page
from magic_pdf.pre_proc.cut_image import txt_save_images_by_bboxes
from magic_pdf.pre_proc.detect_images import parse_images from magic_pdf.pre_proc.detect_images import parse_images
from magic_pdf.pre_proc.detect_tables import parse_tables # 获取tables的bbox from magic_pdf.pre_proc.detect_tables import parse_tables # 获取tables的bbox
from magic_pdf.pre_proc.detect_equation import parse_equations # 获取equations的bbox from magic_pdf.pre_proc.detect_equation import parse_equations # 获取equations的bbox
...@@ -62,7 +63,6 @@ from para.exceptions import ( ...@@ -62,7 +63,6 @@ from para.exceptions import (
""" """
from magic_pdf.libs.commons import read_file, join_path from magic_pdf.libs.commons import read_file, join_path
from magic_pdf.libs.pdf_image_tools import save_images_by_bboxes
from magic_pdf.post_proc.remove_footnote import ( from magic_pdf.post_proc.remove_footnote import (
merge_footnote_blocks, merge_footnote_blocks,
remove_footnote_blocks, remove_footnote_blocks,
...@@ -183,8 +183,8 @@ def parse_pdf_for_train( ...@@ -183,8 +183,8 @@ def parse_pdf_for_train(
f"page_id: {page_id}, img_counts: {img_counts}, drop this pdf: {book_name}, drop_reason: {DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}" f"page_id: {page_id}, img_counts: {img_counts}, drop this pdf: {book_name}, drop_reason: {DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}"
) )
result = { result = {
"need_drop": True, "_need_drop": True,
"drop_reason": DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS, "_drop_reason": DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS,
} }
if not debug_mode: if not debug_mode:
return result return result
...@@ -323,7 +323,7 @@ def parse_pdf_for_train( ...@@ -323,7 +323,7 @@ def parse_pdf_for_train(
# 把图、表、公式都进行截图,保存到存储上,返回图片路径作为内容 # 把图、表、公式都进行截图,保存到存储上,返回图片路径作为内容
image_info, image_backup_info, table_info, inline_eq_info, interline_eq_info = ( image_info, image_backup_info, table_info, inline_eq_info, interline_eq_info = (
save_images_by_bboxes( txt_save_images_by_bboxes(
book_name, book_name,
page_id, page_id,
page, page,
...@@ -396,8 +396,8 @@ def parse_pdf_for_train( ...@@ -396,8 +396,8 @@ def parse_pdf_for_train(
f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.TEXT_BLCOK_HOR_OVERLAP}" f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.TEXT_BLCOK_HOR_OVERLAP}"
) )
result = { result = {
"need_drop": True, "_need_drop": True,
"drop_reason": DropReason.TEXT_BLCOK_HOR_OVERLAP, "_drop_reason": DropReason.TEXT_BLCOK_HOR_OVERLAP,
} }
if not debug_mode: if not debug_mode:
return result return result
...@@ -443,8 +443,8 @@ def parse_pdf_for_train( ...@@ -443,8 +443,8 @@ def parse_pdf_for_train(
f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}" f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}"
) )
result = { result = {
"need_drop": True, "_need_drop": True,
"drop_reason": DropReason.CAN_NOT_DETECT_PAGE_LAYOUT, "_drop_reason": DropReason.CAN_NOT_DETECT_PAGE_LAYOUT,
} }
if not debug_mode: if not debug_mode:
return result return result
...@@ -456,7 +456,7 @@ def parse_pdf_for_train( ...@@ -456,7 +456,7 @@ def parse_pdf_for_train(
logger.warning( logger.warning(
f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.COMPLICATED_LAYOUT}" f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.COMPLICATED_LAYOUT}"
) )
result = {"need_drop": True, "drop_reason": DropReason.COMPLICATED_LAYOUT} result = {"_need_drop": True, "_drop_reason": DropReason.COMPLICATED_LAYOUT}
if not debug_mode: if not debug_mode:
return result return result
...@@ -466,8 +466,8 @@ def parse_pdf_for_train( ...@@ -466,8 +466,8 @@ def parse_pdf_for_train(
f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}" f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}"
) )
result = { result = {
"need_drop": True, "_need_drop": True,
"drop_reason": DropReason.TOO_MANY_LAYOUT_COLUMNS, "_drop_reason": DropReason.TOO_MANY_LAYOUT_COLUMNS,
"extra_info": {"column_cnt": layout_column_width}, "extra_info": {"column_cnt": layout_column_width},
} }
if not debug_mode: if not debug_mode:
...@@ -616,8 +616,8 @@ def parse_pdf_for_train( ...@@ -616,8 +616,8 @@ def parse_pdf_for_train(
f"Drop this pdf: {book_name}, reason: {DropReason.DENSE_SINGLE_LINE_BLOCK}" f"Drop this pdf: {book_name}, reason: {DropReason.DENSE_SINGLE_LINE_BLOCK}"
) )
result = { result = {
"need_drop": True, "_need_drop": True,
"drop_reason": DropReason.DENSE_SINGLE_LINE_BLOCK, "_drop_reason": DropReason.DENSE_SINGLE_LINE_BLOCK,
} }
return result return result
if error_info == titleDetectionException_msg: if error_info == titleDetectionException_msg:
...@@ -625,27 +625,27 @@ def parse_pdf_for_train( ...@@ -625,27 +625,27 @@ def parse_pdf_for_train(
f"Drop this pdf: {book_name}, reason: {DropReason.TITLE_DETECTION_FAILED}" f"Drop this pdf: {book_name}, reason: {DropReason.TITLE_DETECTION_FAILED}"
) )
result = { result = {
"need_drop": True, "_need_drop": True,
"drop_reason": DropReason.TITLE_DETECTION_FAILED, "_drop_reason": DropReason.TITLE_DETECTION_FAILED,
} }
return result return result
elif error_info == titleLevelException_msg: elif error_info == titleLevelException_msg:
logger.warning( logger.warning(
f"Drop this pdf: {book_name}, reason: {DropReason.TITLE_LEVEL_FAILED}" f"Drop this pdf: {book_name}, reason: {DropReason.TITLE_LEVEL_FAILED}"
) )
result = {"need_drop": True, "drop_reason": DropReason.TITLE_LEVEL_FAILED} result = {"_need_drop": True, "_drop_reason": DropReason.TITLE_LEVEL_FAILED}
return result return result
elif error_info == paraSplitException_msg: elif error_info == paraSplitException_msg:
logger.warning( logger.warning(
f"Drop this pdf: {book_name}, reason: {DropReason.PARA_SPLIT_FAILED}" f"Drop this pdf: {book_name}, reason: {DropReason.PARA_SPLIT_FAILED}"
) )
result = {"need_drop": True, "drop_reason": DropReason.PARA_SPLIT_FAILED} result = {"_need_drop": True, "_drop_reason": DropReason.PARA_SPLIT_FAILED}
return result return result
elif error_info == paraMergeException_msg: elif error_info == paraMergeException_msg:
logger.warning( logger.warning(
f"Drop this pdf: {book_name}, reason: {DropReason.PARA_MERGE_FAILED}" f"Drop this pdf: {book_name}, reason: {DropReason.PARA_MERGE_FAILED}"
) )
result = {"need_drop": True, "drop_reason": DropReason.PARA_MERGE_FAILED} result = {"_need_drop": True, "_drop_reason": DropReason.PARA_MERGE_FAILED}
return result return result
if debug_mode: if debug_mode:
......
from abc import ABC, abstractmethod
from magic_pdf.dict2md.mkcontent import mk_universal_format, mk_mm_markdown
from magic_pdf.dict2md.ocr_mkcontent import make_standard_format_with_para, ocr_mk_mm_markdown_with_para
from magic_pdf.filter.pdf_classify_by_type import classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.json_compressor import JsonCompressor
class AbsPipe(ABC):
"""
txt和ocr处理的抽象类
"""
PIP_OCR = "ocr"
PIP_TXT = "txt"
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, img_parent_path:str, is_debug:bool=False):
self.pdf_bytes = pdf_bytes
self.model_list = model_list
self.image_writer = image_writer
self.img_parent_path = img_parent_path
self.pdf_mid_data = None # 未压缩
self.is_debug = is_debug
def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data)
@abstractmethod
def pipe_classify(self):
"""
有状态的分类
"""
raise NotImplementedError
@abstractmethod
def pipe_parse(self):
"""
有状态的解析
"""
raise NotImplementedError
@abstractmethod
def pipe_mk_uni_format(self):
"""
有状态的组装统一格式
"""
raise NotImplementedError
@abstractmethod
def pipe_mk_markdown(self):
"""
有状态的组装markdown
"""
raise NotImplementedError
@staticmethod
def classify(pdf_bytes: bytes) -> str:
"""
根据pdf的元数据,判断是否是文本pdf,还是ocr pdf
"""
pdf_meta = pdf_meta_scan(pdf_bytes)
if pdf_meta.get("_need_drop", False): # 如果返回了需要丢弃的标志,则抛出异常
raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
else:
is_encrypted = pdf_meta["is_encrypted"]
is_needs_password = pdf_meta["is_needs_password"]
if is_encrypted or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理
raise Exception(f"pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}")
else:
is_text_pdf, results = classify(
pdf_meta["total_page"],
pdf_meta["page_width_pts"],
pdf_meta["page_height_pts"],
pdf_meta["image_info_per_page"],
pdf_meta["text_len_per_page"],
pdf_meta["imgs_per_page"],
pdf_meta["text_layout_per_page"],
)
if is_text_pdf:
return AbsPipe.PIP_TXT
else:
return AbsPipe.PIP_OCR
@staticmethod
def mk_uni_format(compressed_pdf_mid_data: str, img_buket_path: str) -> list:
"""
根据pdf类型,生成统一格式content_list
"""
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
parse_type = pdf_mid_data["_parse_type"]
pdf_info_list = pdf_mid_data["pdf_info"]
if parse_type == AbsPipe.PIP_TXT:
content_list = mk_universal_format(pdf_info_list, img_buket_path)
elif parse_type == AbsPipe.PIP_OCR:
content_list = make_standard_format_with_para(pdf_info_list, img_buket_path)
return content_list
@staticmethod
def mk_markdown(compressed_pdf_mid_data: str, img_buket_path: str) -> list:
"""
根据pdf类型,markdown
"""
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
parse_type = pdf_mid_data["_parse_type"]
pdf_info_list = pdf_mid_data["pdf_info"]
if parse_type == AbsPipe.PIP_TXT:
# content_list = mk_universal_format(pdf_info_list, img_buket_path)
# md_content = mk_mm_markdown(content_list)
md_content = ocr_mk_mm_markdown_with_para(pdf_info_list, img_buket_path)
elif parse_type == AbsPipe.PIP_OCR:
md_content = ocr_mk_mm_markdown_with_para(pdf_info_list, img_buket_path)
return md_content
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.libs.json_compressor import JsonCompressor
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, img_parent_path: str, is_debug:bool=False):
super().__init__(pdf_bytes, model_list, image_writer, img_parent_path, is_debug)
def pipe_classify(self):
pass
def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug)
def pipe_mk_uni_format(self):
content_list = AbsPipe.mk_uni_format(self.get_compress_pdf_mid_data(), self.img_parent_path)
return content_list
def pipe_mk_markdown(self):
md_content = AbsPipe.mk_markdown(self.get_compress_pdf_mid_data(), self.img_parent_path)
return md_content
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.libs.json_compressor import JsonCompressor
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, img_parent_path: str, is_debug:bool=False):
super().__init__(pdf_bytes, model_list, image_writer, img_parent_path, is_debug)
def pipe_classify(self):
pass
def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug)
def pipe_mk_uni_format(self):
content_list = AbsPipe.mk_uni_format(self.get_compress_pdf_mid_data(), self.img_parent_path)
return content_list
def pipe_mk_markdown(self):
md_content = AbsPipe.mk_markdown(self.get_compress_pdf_mid_data(), self.img_parent_path)
return md_content
import json
from loguru import logger
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
from magic_pdf.libs.commons import join_path
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class UNIPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, img_parent_path: str,
is_debug: bool = False):
self.pdf_type = self.PIP_OCR
super().__init__(pdf_bytes, model_list, image_writer, img_parent_path, is_debug)
def pipe_classify(self):
self.pdf_type = UNIPipe.classify(self.pdf_bytes)
def pipe_parse(self):
if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug)
elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug)
def pipe_mk_uni_format(self):
content_list = AbsPipe.mk_uni_format(self.get_compress_pdf_mid_data(), self.img_parent_path)
return content_list
def pipe_mk_markdown(self):
markdown_content = AbsPipe.mk_markdown(self.get_compress_pdf_mid_data(), self.img_parent_path)
return markdown_content
if __name__ == '__main__':
# 测试
drw = DiskReaderWriter(r"D:/project/20231108code-clean")
pdf_file_path = r"linshixuqiu\19983-00.pdf"
model_file_path = r"linshixuqiu\19983-00.json"
pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
model_list = json.loads(model_json_txt)
write_path = r"D:\project\20231108code-clean\linshixuqiu\19983-00"
img_bucket_path = "imgs"
img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
pipe = UNIPipe(pdf_bytes, model_list, img_writer, img_bucket_path)
pipe.pipe_classify()
pipe.pipe_parse()
md_content = pipe.pipe_mk_markdown()
try:
content_list = pipe.pipe_mk_uni_format()
except Exception as e:
logger.exception(e)
md_writer = DiskReaderWriter(write_path)
md_writer.write(md_content, "19983-00.md", AbsReaderWriter.MODE_TXT)
md_writer.write(json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4), "19983-00.json",
AbsReaderWriter.MODE_TXT)
md_writer.write(str(content_list), "19983-00.txt", AbsReaderWriter.MODE_TXT)
...@@ -32,8 +32,8 @@ def meta_scan(jso: dict, doc_layout_check=True) -> dict: ...@@ -32,8 +32,8 @@ def meta_scan(jso: dict, doc_layout_check=True) -> dict:
if ( if (
"doc_layout_result" not in jso "doc_layout_result" not in jso
): # 检测json中是存在模型数据,如果没有则需要跳过该pdf ): # 检测json中是存在模型数据,如果没有则需要跳过该pdf
jso["need_drop"] = True jso["_need_drop"] = True
jso["drop_reason"] = DropReason.MISS_DOC_LAYOUT_RESULT jso["_drop_reason"] = DropReason.MISS_DOC_LAYOUT_RESULT
return jso return jso
try: try:
data_source = get_data_source(jso) data_source = get_data_source(jso)
...@@ -58,10 +58,10 @@ def meta_scan(jso: dict, doc_layout_check=True) -> dict: ...@@ -58,10 +58,10 @@ def meta_scan(jso: dict, doc_layout_check=True) -> dict:
start_time = time.time() # 记录开始时间 start_time = time.time() # 记录开始时间
res = pdf_meta_scan(s3_pdf_path, file_content) res = pdf_meta_scan(s3_pdf_path, file_content)
if res.get( if res.get(
"need_drop", False "_need_drop", False
): # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析 ): # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析
jso["need_drop"] = True jso["_need_drop"] = True
jso["drop_reason"] = res["drop_reason"] jso["_drop_reason"] = res["_drop_reason"]
else: # 正常返回 else: # 正常返回
jso["pdf_meta"] = res jso["pdf_meta"] = res
jso["content"] = "" jso["content"] = ""
...@@ -85,7 +85,7 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict: ...@@ -85,7 +85,7 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
return jso return jso
# 开始正式逻辑 # 开始正式逻辑
try: try:
...@@ -113,8 +113,8 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict: ...@@ -113,8 +113,8 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
if ( if (
is_encrypted or is_needs_password is_encrypted or is_needs_password
): # 加密的,需要密码的,没有页面的,都不处理 ): # 加密的,需要密码的,没有页面的,都不处理
jso["need_drop"] = True jso["_need_drop"] = True
jso["drop_reason"] = DropReason.ENCRYPTED jso["_drop_reason"] = DropReason.ENCRYPTED
else: else:
start_time = time.time() # 记录开始时间 start_time = time.time() # 记录开始时间
is_text_pdf, results = classify( is_text_pdf, results = classify(
...@@ -139,8 +139,8 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict: ...@@ -139,8 +139,8 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
if ( if (
text_language not in allow_language text_language not in allow_language
): # 如果语言不在允许的语言中,则drop ): # 如果语言不在允许的语言中,则drop
jso["need_drop"] = True jso["_need_drop"] = True
jso["drop_reason"] = DropReason.NOT_ALLOW_LANGUAGE jso["_drop_reason"] = DropReason.NOT_ALLOW_LANGUAGE
return jso return jso
else: else:
# 先不drop # 先不drop
...@@ -148,8 +148,8 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict: ...@@ -148,8 +148,8 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
jso["_pdf_type"] = "OCR" jso["_pdf_type"] = "OCR"
jso["pdf_meta"] = pdf_meta jso["pdf_meta"] = pdf_meta
jso["classify_time"] = classify_time jso["classify_time"] = classify_time
# jso["need_drop"] = True # jso["_need_drop"] = True
# jso["drop_reason"] = DropReason.NOT_IS_TEXT_PDF # jso["_drop_reason"] = DropReason.NOT_IS_TEXT_PDF
extra_info = {"classify_rules": []} extra_info = {"classify_rules": []}
for condition, result in results.items(): for condition, result in results.items():
if not result: if not result:
...@@ -162,7 +162,7 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict: ...@@ -162,7 +162,7 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
def drop_needdrop_pdf(jso: dict) -> dict: def drop_needdrop_pdf(jso: dict) -> dict:
if jso.get("need_drop", False): if jso.get("_need_drop", False):
logger.info( logger.info(
f"book_name is:{get_data_source(jso)}/{jso['file_id']} need drop", f"book_name is:{get_data_source(jso)}/{jso['file_id']} need drop",
file=sys.stderr, file=sys.stderr,
...@@ -176,7 +176,7 @@ def pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict: ...@@ -176,7 +176,7 @@ def pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict:
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"]) book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr) logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True jso["dropped"] = True
...@@ -203,7 +203,7 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: ...@@ -203,7 +203,7 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
return jso return jso
# 开始正式逻辑 # 开始正式逻辑
s3_pdf_path = jso.get("file_location") s3_pdf_path = jso.get("file_location")
...@@ -220,8 +220,8 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: ...@@ -220,8 +220,8 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
svgs_per_page_list = jso["pdf_meta"]["svgs_per_page"] svgs_per_page_list = jso["pdf_meta"]["svgs_per_page"]
max_svgs = max(svgs_per_page_list) max_svgs = max(svgs_per_page_list)
if max_svgs > 3000: if max_svgs > 3000:
jso["need_drop"] = True jso["_need_drop"] = True
jso["drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS jso["_drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS
else: else:
try: try:
save_path = s3_image_save_path save_path = s3_image_save_path
...@@ -244,10 +244,10 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: ...@@ -244,10 +244,10 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
debug_mode=debug_mode, debug_mode=debug_mode,
) )
if pdf_info_dict.get( if pdf_info_dict.get(
"need_drop", False "_need_drop", False
): # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析 ): # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析
jso["need_drop"] = True jso["_need_drop"] = True
jso["drop_reason"] = pdf_info_dict["drop_reason"] jso["_drop_reason"] = pdf_info_dict["_drop_reason"]
else: # 正常返回,将 pdf_info_dict 压缩并存储 else: # 正常返回,将 pdf_info_dict 压缩并存储
pdf_info_dict = JsonCompressor.compress_json(pdf_info_dict) pdf_info_dict = JsonCompressor.compress_json(pdf_info_dict)
jso["pdf_intermediate_dict"] = pdf_info_dict jso["pdf_intermediate_dict"] = pdf_info_dict
...@@ -269,7 +269,7 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d ...@@ -269,7 +269,7 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
return jso return jso
# 开始正式逻辑 # 开始正式逻辑
s3_pdf_path = jso.get("file_location") s3_pdf_path = jso.get("file_location")
...@@ -295,8 +295,8 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d ...@@ -295,8 +295,8 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d
svgs_per_page_list = jso["pdf_meta"]["svgs_per_page"] svgs_per_page_list = jso["pdf_meta"]["svgs_per_page"]
max_svgs = max(svgs_per_page_list) max_svgs = max(svgs_per_page_list)
if max_svgs > 3000: if max_svgs > 3000:
jso["need_drop"] = True jso["_need_drop"] = True
jso["drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS jso["_drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS
# elif total_page > 1000: # elif total_page > 1000:
# jso['need_drop'] = True # jso['need_drop'] = True
# jso['drop_reason'] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES # jso['drop_reason'] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES
...@@ -323,10 +323,10 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d ...@@ -323,10 +323,10 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d
debug_mode=debug_mode, debug_mode=debug_mode,
) )
if pdf_info_dict.get( if pdf_info_dict.get(
"need_drop", False "_need_drop", False
): # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析 ): # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析
jso["need_drop"] = True jso["_need_drop"] = True
jso["drop_reason"] = pdf_info_dict["drop_reason"] jso["_drop_reason"] = pdf_info_dict["_drop_reason"]
else: # 正常返回,将 pdf_info_dict 压缩并存储 else: # 正常返回,将 pdf_info_dict 压缩并存储
jso["parsed_results"] = convert_to_train_format(pdf_info_dict) jso["parsed_results"] = convert_to_train_format(pdf_info_dict)
pdf_info_dict = JsonCompressor.compress_json(pdf_info_dict) pdf_info_dict = JsonCompressor.compress_json(pdf_info_dict)
......
...@@ -17,7 +17,7 @@ def ocr_pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict: ...@@ -17,7 +17,7 @@ def ocr_pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict:
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"]) book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr) logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True jso["dropped"] = True
...@@ -45,7 +45,7 @@ def ocr_pdf_intermediate_dict_to_markdown_with_para(jso: dict, mode, debug_mode= ...@@ -45,7 +45,7 @@ def ocr_pdf_intermediate_dict_to_markdown_with_para(jso: dict, mode, debug_mode=
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"]) book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr) logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True jso["dropped"] = True
...@@ -78,7 +78,7 @@ def ocr_pdf_intermediate_dict_to_markdown_with_para_and_pagination(jso: dict, de ...@@ -78,7 +78,7 @@ def ocr_pdf_intermediate_dict_to_markdown_with_para_and_pagination(jso: dict, de
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"]) book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr) logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True jso["dropped"] = True
...@@ -108,7 +108,7 @@ def ocr_pdf_intermediate_dict_to_markdown_with_para_for_qa( ...@@ -108,7 +108,7 @@ def ocr_pdf_intermediate_dict_to_markdown_with_para_for_qa(
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"]) book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr) logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True jso["dropped"] = True
...@@ -137,7 +137,7 @@ def ocr_pdf_intermediate_dict_to_standard_format(jso: dict, debug_mode=False) -> ...@@ -137,7 +137,7 @@ def ocr_pdf_intermediate_dict_to_standard_format(jso: dict, debug_mode=False) ->
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"]) book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr) logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True jso["dropped"] = True
...@@ -165,7 +165,7 @@ def ocr_pdf_intermediate_dict_to_standard_format_with_para(jso: dict, debug_mode ...@@ -165,7 +165,7 @@ def ocr_pdf_intermediate_dict_to_standard_format_with_para(jso: dict, debug_mode
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"]) book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop", file=sys.stderr) logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
jso["dropped"] = True jso["dropped"] = True
...@@ -221,7 +221,7 @@ def ocr_parse_pdf_core(pdf_bytes, model_output_json_list, book_name, start_page_ ...@@ -221,7 +221,7 @@ def ocr_parse_pdf_core(pdf_bytes, model_output_json_list, book_name, start_page_
# 专门用来跑被drop的pdf,跑完之后需要把need_drop字段置为false # 专门用来跑被drop的pdf,跑完之后需要把need_drop字段置为false
def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
if not jso.get("need_drop", False): if not jso.get("_need_drop", False):
return jso return jso
else: else:
try: try:
...@@ -233,7 +233,7 @@ def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: ...@@ -233,7 +233,7 @@ def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
) )
jso["pdf_intermediate_dict"] = JsonCompressor.compress_json(pdf_info_dict) jso["pdf_intermediate_dict"] = JsonCompressor.compress_json(pdf_info_dict)
jso["parse_time"] = parse_time jso["parse_time"] = parse_time
jso["need_drop"] = False jso["_need_drop"] = False
except Exception as e: except Exception as e:
jso = exception_handler(jso, e) jso = exception_handler(jso, e)
return jso return jso
...@@ -244,7 +244,7 @@ def ocr_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: ...@@ -244,7 +244,7 @@ def ocr_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
return jso return jso
try: try:
pdf_bytes = get_pdf_bytes(jso) pdf_bytes = get_pdf_bytes(jso)
......
...@@ -18,7 +18,7 @@ def txt_pdf_to_standard_format(jso: dict, debug_mode=False) -> dict: ...@@ -18,7 +18,7 @@ def txt_pdf_to_standard_format(jso: dict, debug_mode=False) -> dict:
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"]) book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop") logger.info(f"book_name is:{book_name} need drop")
jso["dropped"] = True jso["dropped"] = True
...@@ -46,7 +46,7 @@ def txt_pdf_to_mm_markdown_format(jso: dict, debug_mode=False) -> dict: ...@@ -46,7 +46,7 @@ def txt_pdf_to_mm_markdown_format(jso: dict, debug_mode=False) -> dict:
if debug_mode: if debug_mode:
pass pass
else: # 如果debug没开,则检测是否有needdrop字段 else: # 如果debug没开,则检测是否有needdrop字段
if jso.get("need_drop", False): if jso.get("_need_drop", False):
book_name = join_path(get_data_source(jso), jso["file_id"]) book_name = join_path(get_data_source(jso), jso["file_id"])
logger.info(f"book_name is:{book_name} need drop") logger.info(f"book_name is:{book_name} need drop")
jso["dropped"] = True jso["dropped"] = True
......
...@@ -62,6 +62,6 @@ def pdf_post_filter(page_info) -> tuple: ...@@ -62,6 +62,6 @@ def pdf_post_filter(page_info) -> tuple:
""" """
bool_is_pseudo_single_column, extra_info = __is_pseudo_single_column(page_info) bool_is_pseudo_single_column, extra_info = __is_pseudo_single_column(page_info)
if bool_is_pseudo_single_column: if bool_is_pseudo_single_column:
return False, {"need_drop": True, "drop_reason": DropReason.PSEUDO_SINGLE_COLUMN, "extra_info": extra_info} return False, {"_need_drop": True, "_drop_reason": DropReason.PSEUDO_SINGLE_COLUMN, "extra_info": extra_info}
return True, None return True, None
\ No newline at end of file
def construct_page_component(page_id, image_info, table_info, text_blocks_preproc, layout_bboxes, inline_eq_info,
def construct_page_component(page_id, image_info, table_info, text_blocks_preproc, layout_bboxes, inline_eq_info, interline_eq_info, raw_pymu_blocks, interline_eq_info, raw_pymu_blocks,
removed_text_blocks, removed_image_blocks, images_backup, droped_table_block, table_backup,layout_tree, removed_text_blocks, removed_image_blocks, images_backup, droped_table_block, table_backup,
layout_tree,
page_w, page_h, footnote_bboxes_tmp): page_w, page_h, footnote_bboxes_tmp):
""" """
...@@ -51,3 +52,19 @@ def ocr_construct_page_component(blocks, layout_bboxes, page_id, page_w, page_h, ...@@ -51,3 +52,19 @@ def ocr_construct_page_component(blocks, layout_bboxes, page_id, page_w, page_h,
'droped_bboxes': need_remove_spans_bboxes_dict, 'droped_bboxes': need_remove_spans_bboxes_dict,
} }
return return_dict return return_dict
def ocr_construct_page_component_v2(blocks, layout_bboxes, page_id, page_w, page_h, layout_tree,
images, tables, interline_equations, discarded_blocks):
return_dict = {
'preproc_blocks': blocks,
'layout_bboxes': layout_bboxes,
'page_idx': page_id,
'page_size': [page_w, page_h],
'_layout_tree': layout_tree,
'images': images,
'tables': tables,
'interline_equations': interline_equations,
'discarded_blocks': discarded_blocks,
}
return return_dict
from loguru import logger
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.libs.pdf_image_tools import cut_image
def ocr_cut_image_and_table(spans, page, page_id, pdf_bytes_md5, imageWriter):
def return_path(type):
return join_path(pdf_bytes_md5, type)
for span in spans:
span_type = span['type']
if span_type == ContentType.Image:
if not check_img_bbox(span['bbox']):
continue
span['image_path'] = cut_image(span['bbox'], page_id, page, return_path=return_path('images'),
imageWriter=imageWriter)
elif span_type == ContentType.Table:
if not check_img_bbox(span['bbox']):
continue
span['image_path'] = cut_image(span['bbox'], page_id, page, return_path=return_path('tables'),
imageWriter=imageWriter)
return spans
def txt_save_images_by_bboxes(page_num: int, page, pdf_bytes_md5: str,
image_bboxes: list, images_overlap_backup: list, table_bboxes: list,
equation_inline_bboxes: list,
equation_interline_bboxes: list, imageWriter) -> dict:
"""
返回一个dict, key为bbox, 值是图片地址
"""
image_info = []
image_backup_info = []
table_info = []
inline_eq_info = []
interline_eq_info = []
# 图片的保存路径组成是这样的: {s3_or_local_path}/{book_name}/{images|tables|equations}/{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg
def return_path(type):
return join_path(pdf_bytes_md5, type)
for bbox in image_bboxes:
if not check_img_bbox(bbox):
continue
image_path = cut_image(bbox, page_num, page, return_path("images"), imageWriter)
image_info.append({"bbox": bbox, "image_path": image_path})
for bbox in images_overlap_backup:
if not check_img_bbox(bbox):
continue
image_path = cut_image(bbox, page_num, page, return_path("images"), imageWriter)
image_backup_info.append({"bbox": bbox, "image_path": image_path})
for bbox in table_bboxes:
if not check_img_bbox(bbox):
continue
image_path = cut_image(bbox, page_num, page, return_path("tables"), imageWriter)
table_info.append({"bbox": bbox, "image_path": image_path})
return image_info, image_backup_info, table_info, inline_eq_info, interline_eq_info
def check_img_bbox(bbox) -> bool:
if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
logger.warning(f"image_bboxes: 错误的box, {bbox}")
return False
return True
""" """
对pymupdf返回的结构里的公式进行替换,替换为模型识别的公式结果 对pymupdf返回的结构里的公式进行替换,替换为模型识别的公式结果
""" """
from magic_pdf.libs.commons import fitz from magic_pdf.libs.commons import fitz
import json import json
import os import os
...@@ -17,24 +18,24 @@ def combine_chars_to_pymudict(block_dict, char_dict): ...@@ -17,24 +18,24 @@ def combine_chars_to_pymudict(block_dict, char_dict):
把block级别的pymupdf 结构里加入char结构 把block级别的pymupdf 结构里加入char结构
""" """
# 因为block_dict 被裁剪过,因此先把他和char_dict文字块对齐,才能进行补充 # 因为block_dict 被裁剪过,因此先把他和char_dict文字块对齐,才能进行补充
char_map = {tuple(item['bbox']):item for item in char_dict} char_map = {tuple(item["bbox"]): item for item in char_dict}
for i in range(len(block_dict)): # blcok for i in range(len(block_dict)): # blcok
block = block_dict[i] block = block_dict[i]
key = block['bbox'] key = block["bbox"]
char_dict_item = char_map[tuple(key)] char_dict_item = char_map[tuple(key)]
char_dict_map = {tuple(item['bbox']):item for item in char_dict_item['lines']} char_dict_map = {tuple(item["bbox"]): item for item in char_dict_item["lines"]}
for j in range(len(block['lines'])): for j in range(len(block["lines"])):
lines = block['lines'][j] lines = block["lines"][j]
with_char_lines = char_dict_map[lines['bbox']] with_char_lines = char_dict_map[lines["bbox"]]
for k in range(len(lines['spans'])): for k in range(len(lines["spans"])):
spans = lines['spans'][k] spans = lines["spans"][k]
try: try:
chars = with_char_lines['spans'][k]['chars'] chars = with_char_lines["spans"][k]["chars"]
except Exception as e: except Exception as e:
logger.error(char_dict[i]['lines'][j]) logger.error(char_dict[i]["lines"][j])
spans['chars'] = chars spans["chars"] = chars
return block_dict return block_dict
...@@ -54,23 +55,22 @@ def calculate_overlap_area_2_minbox_area_ratio(bbox1, min_bbox): ...@@ -54,23 +55,22 @@ def calculate_overlap_area_2_minbox_area_ratio(bbox1, min_bbox):
# The area of overlap area # The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top) intersection_area = (x_right - x_left) * (y_bottom - y_top)
min_box_area = (min_bbox[3]-min_bbox[1])*(min_bbox[2]-min_bbox[0]) min_box_area = (min_bbox[3] - min_bbox[1]) * (min_bbox[2] - min_bbox[0])
if min_box_area==0: if min_box_area == 0:
return 0 return 0
else: else:
return intersection_area / min_box_area return intersection_area / min_box_area
def _is_xin(bbox1, bbox2): def _is_xin(bbox1, bbox2):
area1 = abs(bbox1[2]-bbox1[0])*abs(bbox1[3]-bbox1[1]) area1 = abs(bbox1[2] - bbox1[0]) * abs(bbox1[3] - bbox1[1])
area2 = abs(bbox2[2]-bbox2[0])*abs(bbox2[3]-bbox2[1]) area2 = abs(bbox2[2] - bbox2[0]) * abs(bbox2[3] - bbox2[1])
if area1<area2: if area1 < area2:
ratio = calculate_overlap_area_2_minbox_area_ratio(bbox2, bbox1) ratio = calculate_overlap_area_2_minbox_area_ratio(bbox2, bbox1)
else: else:
ratio = calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2) ratio = calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
return ratio>0.6 return ratio > 0.6
def remove_text_block_in_interline_equation_bbox(interline_bboxes, text_blocks): def remove_text_block_in_interline_equation_bbox(interline_bboxes, text_blocks):
...@@ -78,8 +78,11 @@ def remove_text_block_in_interline_equation_bbox(interline_bboxes, text_blocks): ...@@ -78,8 +78,11 @@ def remove_text_block_in_interline_equation_bbox(interline_bboxes, text_blocks):
for eq_bbox in interline_bboxes: for eq_bbox in interline_bboxes:
removed_txt_blk = [] removed_txt_blk = []
for text_blk in text_blocks: for text_blk in text_blocks:
text_bbox = text_blk['bbox'] text_bbox = text_blk["bbox"]
if calculate_overlap_area_2_minbox_area_ratio(eq_bbox['bbox'], text_bbox)>=0.7: if (
calculate_overlap_area_2_minbox_area_ratio(eq_bbox["bbox"], text_bbox)
>= 0.7
):
removed_txt_blk.append(text_blk) removed_txt_blk.append(text_blk)
for blk in removed_txt_blk: for blk in removed_txt_blk:
text_blocks.remove(blk) text_blocks.remove(blk)
...@@ -87,7 +90,6 @@ def remove_text_block_in_interline_equation_bbox(interline_bboxes, text_blocks): ...@@ -87,7 +90,6 @@ def remove_text_block_in_interline_equation_bbox(interline_bboxes, text_blocks):
return text_blocks return text_blocks
def _is_in_or_part_overlap(box1, box2) -> bool: def _is_in_or_part_overlap(box1, box2) -> bool:
""" """
两个bbox是否有部分重叠或者包含 两个bbox是否有部分重叠或者包含
...@@ -98,54 +100,78 @@ def _is_in_or_part_overlap(box1, box2) -> bool: ...@@ -98,54 +100,78 @@ def _is_in_or_part_overlap(box1, box2) -> bool:
x0_1, y0_1, x1_1, y1_1 = box1 x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2 x0_2, y0_2, x1_2, y1_2 = box2
return not (x1_1 < x0_2 or # box1在box2的左边 return not (
x0_1 > x1_2 or # box1在box2的右边 x1_1 < x0_2 # box1在box2的左边
y1_1 < y0_2 or # box1在box2的上边 or x0_1 > x1_2 # box1在box2的右边
y0_1 > y1_2) # box1在box2的下边 or y1_1 < y0_2 # box1在box2的上边
or y0_1 > y1_2
) # box1在box2的下边
def remove_text_block_overlap_interline_equation_bbox(interline_eq_bboxes, pymu_block_list): def remove_text_block_overlap_interline_equation_bbox(
interline_eq_bboxes, pymu_block_list
):
"""消除掉行行内公式有部分重叠的文本块的内容。 """消除掉行行内公式有部分重叠的文本块的内容。
同时重新计算消除重叠之后文本块的大小""" 同时重新计算消除重叠之后文本块的大小"""
deleted_block = [] deleted_block = []
for text_block in pymu_block_list: for text_block in pymu_block_list:
deleted_line = [] deleted_line = []
for line in text_block['lines']: for line in text_block["lines"]:
deleted_span = [] deleted_span = []
for span in line['spans']: for span in line["spans"]:
deleted_chars = [] deleted_chars = []
for char in span['chars']: for char in span["chars"]:
if any([_is_in_or_part_overlap(char['bbox'], eq_bbox['bbox']) for eq_bbox in interline_eq_bboxes]): if any(
[
_is_in_or_part_overlap(char["bbox"], eq_bbox["bbox"])
for eq_bbox in interline_eq_bboxes
]
):
deleted_chars.append(char) deleted_chars.append(char)
# 检查span里没有char则删除这个span # 检查span里没有char则删除这个span
for char in deleted_chars: for char in deleted_chars:
span['chars'].remove(char) span["chars"].remove(char)
# 重新计算这个span的大小 # 重新计算这个span的大小
if len(span['chars'])==0: # 删除这个span if len(span["chars"]) == 0: # 删除这个span
deleted_span.append(span) deleted_span.append(span)
else: else:
span['bbox'] = min([b['bbox'][0] for b in span['chars']]),min([b['bbox'][1] for b in span['chars']]),max([b['bbox'][2] for b in span['chars']]), max([b['bbox'][3] for b in span['chars']]) span["bbox"] = (
min([b["bbox"][0] for b in span["chars"]]),
min([b["bbox"][1] for b in span["chars"]]),
max([b["bbox"][2] for b in span["chars"]]),
max([b["bbox"][3] for b in span["chars"]]),
)
# 检查这个span # 检查这个span
for span in deleted_span: for span in deleted_span:
line['spans'].remove(span) line["spans"].remove(span)
if len(line['spans'])==0: #删除这个line if len(line["spans"]) == 0: # 删除这个line
deleted_line.append(line) deleted_line.append(line)
else: else:
line['bbox'] = min([b['bbox'][0] for b in line['spans']]),min([b['bbox'][1] for b in line['spans']]),max([b['bbox'][2] for b in line['spans']]), max([b['bbox'][3] for b in line['spans']]) line["bbox"] = (
min([b["bbox"][0] for b in line["spans"]]),
min([b["bbox"][1] for b in line["spans"]]),
max([b["bbox"][2] for b in line["spans"]]),
max([b["bbox"][3] for b in line["spans"]]),
)
# 检查这个block是否可以删除 # 检查这个block是否可以删除
for line in deleted_line: for line in deleted_line:
text_block['lines'].remove(line) text_block["lines"].remove(line)
if len(text_block['lines'])==0: # 删除block if len(text_block["lines"]) == 0: # 删除block
deleted_block.append(text_block) deleted_block.append(text_block)
else: else:
text_block['bbox'] = min([b['bbox'][0] for b in text_block['lines']]),min([b['bbox'][1] for b in text_block['lines']]),max([b['bbox'][2] for b in text_block['lines']]), max([b['bbox'][3] for b in text_block['lines']]) text_block["bbox"] = (
min([b["bbox"][0] for b in text_block["lines"]]),
min([b["bbox"][1] for b in text_block["lines"]]),
max([b["bbox"][2] for b in text_block["lines"]]),
max([b["bbox"][3] for b in text_block["lines"]]),
)
# 检查text block删除 # 检查text block删除
for block in deleted_block: for block in deleted_block:
pymu_block_list.remove(block) pymu_block_list.remove(block)
if len(pymu_block_list)==0: if len(pymu_block_list) == 0:
return [] return []
return pymu_block_list return pymu_block_list
...@@ -154,8 +180,8 @@ def remove_text_block_overlap_interline_equation_bbox(interline_eq_bboxes, pymu_ ...@@ -154,8 +180,8 @@ def remove_text_block_overlap_interline_equation_bbox(interline_eq_bboxes, pymu_
def insert_interline_equations_textblock(interline_eq_bboxes, pymu_block_list): def insert_interline_equations_textblock(interline_eq_bboxes, pymu_block_list):
"""在行间公式对应的地方插上一个伪造的block""" """在行间公式对应的地方插上一个伪造的block"""
for eq in interline_eq_bboxes: for eq in interline_eq_bboxes:
bbox = eq['bbox'] bbox = eq["bbox"]
latex_content = eq['latex_text'] latex_content = eq["latex"]
text_block = { text_block = {
"number": len(pymu_block_list), "number": len(pymu_block_list),
"type": 0, "type": 0,
...@@ -172,24 +198,19 @@ def insert_interline_equations_textblock(interline_eq_bboxes, pymu_block_list): ...@@ -172,24 +198,19 @@ def insert_interline_equations_textblock(interline_eq_bboxes, pymu_block_list):
"ascender": 0.9409999847412109, "ascender": 0.9409999847412109,
"descender": -0.3050000071525574, "descender": -0.3050000071525574,
"text": f"\n$$\n{latex_content}\n$$\n", "text": f"\n$$\n{latex_content}\n$$\n",
"origin": [ "origin": [bbox[0], bbox[1]],
bbox[0], "bbox": bbox,
bbox[1]
],
"bbox": bbox
} }
], ],
"wmode": 0, "wmode": 0,
"dir": [ "dir": [1.0, 0.0],
1.0, "bbox": bbox,
0.0
],
"bbox": bbox
} }
] ],
} }
pymu_block_list.append(text_block) pymu_block_list.append(text_block)
def x_overlap_ratio(box1, box2): def x_overlap_ratio(box1, box2):
a, _, c, _ = box1 a, _, c, _ = box1
e, _, g, _ = box2 e, _, g, _ = box2
...@@ -205,8 +226,10 @@ def x_overlap_ratio(box1, box2): ...@@ -205,8 +226,10 @@ def x_overlap_ratio(box1, box2):
return overlap_ratio return overlap_ratio
def __is_x_dir_overlap(bbox1, bbox2): def __is_x_dir_overlap(bbox1, bbox2):
return not (bbox1[2]<bbox2[0] or bbox1[0]>bbox2[2]) return not (bbox1[2] < bbox2[0] or bbox1[0] > bbox2[2])
def __y_overlap_ratio(box1, box2): def __y_overlap_ratio(box1, box2):
"""""" """"""
...@@ -224,6 +247,7 @@ def __y_overlap_ratio(box1, box2): ...@@ -224,6 +247,7 @@ def __y_overlap_ratio(box1, box2):
return overlap_ratio return overlap_ratio
def replace_line_v2(eqinfo, line): def replace_line_v2(eqinfo, line):
""" """
扫描这一行所有的和公式框X方向重叠的char,然后计算char的左、右x0, x1,位于这个区间内的span删除掉。 扫描这一行所有的和公式框X方向重叠的char,然后计算char的左、右x0, x1,位于这个区间内的span删除掉。
...@@ -233,54 +257,55 @@ def replace_line_v2(eqinfo, line): ...@@ -233,54 +257,55 @@ def replace_line_v2(eqinfo, line):
first_overlap_span_idx = -1 first_overlap_span_idx = -1
last_overlap_span = -1 last_overlap_span = -1
delete_chars = [] delete_chars = []
for i in range(0, len(line['spans'])): for i in range(0, len(line["spans"])):
if line['spans'][i].get("_type", None) is not None: if line["spans"][i].get("_type", None) is not None:
continue # 忽略,因为已经是插入的伪造span公式了 continue # 忽略,因为已经是插入的伪造span公式了
for char in line['spans'][i]['chars']: for char in line["spans"][i]["chars"]:
if __is_x_dir_overlap(eqinfo['bbox'], char['bbox']): if __is_x_dir_overlap(eqinfo["bbox"], char["bbox"]):
line_txt = "" line_txt = ""
for span in line['spans']: for span in line["spans"]:
span_txt = "<span>" span_txt = "<span>"
for ch in span['chars']: for ch in span["chars"]:
span_txt = span_txt + ch['c'] span_txt = span_txt + ch["c"]
span_txt = span_txt + "</span>" span_txt = span_txt + "</span>"
line_txt = line_txt + span_txt line_txt = line_txt + span_txt
if first_overlap_span_idx == -1: if first_overlap_span_idx == -1:
first_overlap_span = line['spans'][i] first_overlap_span = line["spans"][i]
first_overlap_span_idx = i first_overlap_span_idx = i
last_overlap_span = line['spans'][i] last_overlap_span = line["spans"][i]
delete_chars.append(char) delete_chars.append(char)
# 第一个和最后一个char要进行检查,到底属于公式多还是属于正常span多 # 第一个和最后一个char要进行检查,到底属于公式多还是属于正常span多
if len(delete_chars)>0: if len(delete_chars) > 0:
ch0_bbox = delete_chars[0]['bbox'] ch0_bbox = delete_chars[0]["bbox"]
if x_overlap_ratio(eqinfo['bbox'], ch0_bbox)<0.51: if x_overlap_ratio(eqinfo["bbox"], ch0_bbox) < 0.51:
delete_chars.remove(delete_chars[0]) delete_chars.remove(delete_chars[0])
if len(delete_chars)>0: if len(delete_chars) > 0:
ch0_bbox = delete_chars[-1]['bbox'] ch0_bbox = delete_chars[-1]["bbox"]
if x_overlap_ratio(eqinfo['bbox'], ch0_bbox)<0.51: if x_overlap_ratio(eqinfo["bbox"], ch0_bbox) < 0.51:
delete_chars.remove(delete_chars[-1]) delete_chars.remove(delete_chars[-1])
# 计算x方向上被删除区间内的char的真实x0, x1 # 计算x方向上被删除区间内的char的真实x0, x1
if len(delete_chars): if len(delete_chars):
x0, x1 = min([b['bbox'][0] for b in delete_chars]), max([b['bbox'][2] for b in delete_chars]) x0, x1 = min([b["bbox"][0] for b in delete_chars]), max(
[b["bbox"][2] for b in delete_chars]
)
else: else:
logger.debug(f"行内公式替换没有发生,尝试下一行匹配, eqinfo={eqinfo}") logger.debug(f"行内公式替换没有发生,尝试下一行匹配, eqinfo={eqinfo}")
return False return False
# 删除位于x0, x1这两个中间的span # 删除位于x0, x1这两个中间的span
delete_span = [] delete_span = []
for span in line['spans']: for span in line["spans"]:
span_box = span['bbox'] span_box = span["bbox"]
if x0<=span_box[0] and span_box[2]<=x1: if x0 <= span_box[0] and span_box[2] <= x1:
delete_span.append(span) delete_span.append(span)
for span in delete_span: for span in delete_span:
line['spans'].remove(span) line["spans"].remove(span)
equation_span = { equation_span = {
"size": 9.962599754333496, "size": 9.962599754333496,
...@@ -291,67 +316,91 @@ def replace_line_v2(eqinfo, line): ...@@ -291,67 +316,91 @@ def replace_line_v2(eqinfo, line):
"ascender": 0.9409999847412109, "ascender": 0.9409999847412109,
"descender": -0.3050000071525574, "descender": -0.3050000071525574,
"text": "", "text": "",
"origin": [ "origin": [337.1410153102337, 216.0205245153934],
337.1410153102337,
216.0205245153934
],
"bbox": [ "bbox": [
337.1410153102337, 337.1410153102337,
216.0205245153934, 216.0205245153934,
390.4496373892022, 390.4496373892022,
228.50171037628277 228.50171037628277,
] ],
} }
#equation_span = line['spans'][0].copy() # equation_span = line['spans'][0].copy()
equation_span['text'] = f" ${eqinfo['latex_text']}$ " equation_span["text"] = f" ${eqinfo['latex']}$ "
equation_span['bbox'] = [x0, equation_span['bbox'][1], x1, equation_span['bbox'][3]] equation_span["bbox"] = [x0, equation_span["bbox"][1], x1, equation_span["bbox"][3]]
equation_span['origin'] = [equation_span['bbox'][0], equation_span['bbox'][1]] equation_span["origin"] = [equation_span["bbox"][0], equation_span["bbox"][1]]
equation_span['chars'] = delete_chars equation_span["chars"] = delete_chars
equation_span['_type'] = TYPE_INLINE_EQUATION equation_span["_type"] = TYPE_INLINE_EQUATION
equation_span['_eq_bbox'] = eqinfo['bbox'] equation_span["_eq_bbox"] = eqinfo["bbox"]
line['spans'].insert(first_overlap_span_idx+1, equation_span) # 放入公式 line["spans"].insert(first_overlap_span_idx + 1, equation_span) # 放入公式
# logger.info(f"==>text is 【{line_txt}】, equation is 【{eqinfo['latex_text']}】") # logger.info(f"==>text is 【{line_txt}】, equation is 【{eqinfo['latex_text']}】")
# 第一个、和最后一个有overlap的span进行分割,然后插入对应的位置 # 第一个、和最后一个有overlap的span进行分割,然后插入对应的位置
first_span_chars = [char for char in first_overlap_span['chars'] if (char['bbox'][2]+char['bbox'][0])/2<x0] first_span_chars = [
tail_span_chars = [char for char in last_overlap_span['chars'] if (char['bbox'][0]+char['bbox'][2])/2>x1] char
for char in first_overlap_span["chars"]
if (char["bbox"][2] + char["bbox"][0]) / 2 < x0
]
tail_span_chars = [
char
for char in last_overlap_span["chars"]
if (char["bbox"][0] + char["bbox"][2]) / 2 > x1
]
if len(first_span_chars)>0: if len(first_span_chars) > 0:
first_overlap_span['chars'] = first_span_chars first_overlap_span["chars"] = first_span_chars
first_overlap_span['text'] = ''.join([char['c'] for char in first_span_chars]) first_overlap_span["text"] = "".join([char["c"] for char in first_span_chars])
first_overlap_span['bbox'] = (first_overlap_span['bbox'][0], first_overlap_span['bbox'][1], max([chr['bbox'][2] for chr in first_span_chars]), first_overlap_span['bbox'][3]) first_overlap_span["bbox"] = (
first_overlap_span["bbox"][0],
first_overlap_span["bbox"][1],
max([chr["bbox"][2] for chr in first_span_chars]),
first_overlap_span["bbox"][3],
)
# first_overlap_span['_type'] = "first" # first_overlap_span['_type'] = "first"
else: else:
# 删掉 # 删掉
if first_overlap_span not in delete_span: if first_overlap_span not in delete_span:
line['spans'].remove(first_overlap_span) line["spans"].remove(first_overlap_span)
if len(tail_span_chars) > 0:
if len(tail_span_chars)>0: if last_overlap_span == first_overlap_span: # 这个时候应该插入一个新的
if last_overlap_span==first_overlap_span: # 这个时候应该插入一个新的 tail_span_txt = "".join([char["c"] for char in tail_span_chars])
tail_span_txt = ''.join([char['c'] for char in tail_span_chars])
last_span_to_insert = last_overlap_span.copy() last_span_to_insert = last_overlap_span.copy()
last_span_to_insert['chars'] = tail_span_chars last_span_to_insert["chars"] = tail_span_chars
last_span_to_insert['text'] = ''.join([char['c'] for char in tail_span_chars]) last_span_to_insert["text"] = "".join(
last_span_to_insert['bbox'] = (min([chr['bbox'][0] for chr in tail_span_chars]), last_overlap_span['bbox'][1], last_overlap_span['bbox'][2], last_overlap_span['bbox'][3]) [char["c"] for char in tail_span_chars]
)
last_span_to_insert["bbox"] = (
min([chr["bbox"][0] for chr in tail_span_chars]),
last_overlap_span["bbox"][1],
last_overlap_span["bbox"][2],
last_overlap_span["bbox"][3],
)
# 插入到公式对象之后 # 插入到公式对象之后
equation_idx = line['spans'].index(equation_span) equation_idx = line["spans"].index(equation_span)
line['spans'].insert(equation_idx+1, last_span_to_insert) # 放入公式 line["spans"].insert(equation_idx + 1, last_span_to_insert) # 放入公式
else: # 直接修改原来的span else: # 直接修改原来的span
last_overlap_span['chars'] = tail_span_chars last_overlap_span["chars"] = tail_span_chars
last_overlap_span['text'] = ''.join([char['c'] for char in tail_span_chars]) last_overlap_span["text"] = "".join([char["c"] for char in tail_span_chars])
last_overlap_span['bbox'] = (min([chr['bbox'][0] for chr in tail_span_chars]), last_overlap_span['bbox'][1], last_overlap_span['bbox'][2], last_overlap_span['bbox'][3]) last_overlap_span["bbox"] = (
min([chr["bbox"][0] for chr in tail_span_chars]),
last_overlap_span["bbox"][1],
last_overlap_span["bbox"][2],
last_overlap_span["bbox"][3],
)
else: else:
# 删掉 # 删掉
if last_overlap_span not in delete_span and last_overlap_span!=first_overlap_span: if (
line['spans'].remove(last_overlap_span) last_overlap_span not in delete_span
and last_overlap_span != first_overlap_span
):
line["spans"].remove(last_overlap_span)
remain_txt = "" remain_txt = ""
for span in line['spans']: for span in line["spans"]:
span_txt = "<span>" span_txt = "<span>"
for char in span['chars']: for char in span["chars"]:
span_txt = span_txt + char['c'] span_txt = span_txt + char["c"]
span_txt = span_txt + "</span>" span_txt = span_txt + "</span>"
...@@ -364,11 +413,16 @@ def replace_line_v2(eqinfo, line): ...@@ -364,11 +413,16 @@ def replace_line_v2(eqinfo, line):
def replace_eq_blk(eqinfo, text_block): def replace_eq_blk(eqinfo, text_block):
"""替换行内公式""" """替换行内公式"""
for line in text_block['lines']: for line in text_block["lines"]:
line_bbox = line['bbox'] line_bbox = line["bbox"]
if _is_xin(eqinfo['bbox'], line_bbox) or __y_overlap_ratio(eqinfo['bbox'], line_bbox)>0.6: # 定位到行, 使用y方向重合率是因为有的时候,一个行的宽度会小于公式位置宽度:行很高,公式很窄, if (
_is_xin(eqinfo["bbox"], line_bbox)
or __y_overlap_ratio(eqinfo["bbox"], line_bbox) > 0.6
): # 定位到行, 使用y方向重合率是因为有的时候,一个行的宽度会小于公式位置宽度:行很高,公式很窄,
replace_succ = replace_line_v2(eqinfo, line) replace_succ = replace_line_v2(eqinfo, line)
if not replace_succ: # 有的时候,一个pdf的line高度从API里会计算的有问题,因此在行内span级别会替换不成功,这就需要继续重试下一行 if (
not replace_succ
): # 有的时候,一个pdf的line高度从API里会计算的有问题,因此在行内span级别会替换不成功,这就需要继续重试下一行
continue continue
else: else:
break break
...@@ -380,9 +434,9 @@ def replace_eq_blk(eqinfo, text_block): ...@@ -380,9 +434,9 @@ def replace_eq_blk(eqinfo, text_block):
def replace_inline_equations(inline_equation_bboxes, raw_text_blocks): def replace_inline_equations(inline_equation_bboxes, raw_text_blocks):
"""替换行内公式""" """替换行内公式"""
for eqinfo in inline_equation_bboxes: for eqinfo in inline_equation_bboxes:
eqbox = eqinfo['bbox'] eqbox = eqinfo["bbox"]
for blk in raw_text_blocks: for blk in raw_text_blocks:
if _is_xin(eqbox, blk['bbox']): if _is_xin(eqbox, blk["bbox"]):
if not replace_eq_blk(eqinfo, blk): if not replace_eq_blk(eqinfo, blk):
logger.error(f"行内公式没有替换成功:{eqinfo} ") logger.error(f"行内公式没有替换成功:{eqinfo} ")
else: else:
...@@ -390,22 +444,29 @@ def replace_inline_equations(inline_equation_bboxes, raw_text_blocks): ...@@ -390,22 +444,29 @@ def replace_inline_equations(inline_equation_bboxes, raw_text_blocks):
return raw_text_blocks return raw_text_blocks
def remove_chars_in_text_blocks(text_blocks): def remove_chars_in_text_blocks(text_blocks):
"""删除text_blocks里的char""" """删除text_blocks里的char"""
for blk in text_blocks: for blk in text_blocks:
for line in blk['lines']: for line in blk["lines"]:
for span in line['spans']: for span in line["spans"]:
_ = span.pop("chars", "no such key") _ = span.pop("chars", "no such key")
return text_blocks return text_blocks
def replace_equations_in_textblock(raw_text_blocks, inline_equation_bboxes, interline_equation_bboxes): def replace_equations_in_textblock(
raw_text_blocks, inline_equation_bboxes, interline_equation_bboxes
):
""" """
替换行间和和行内公式为latex 替换行间和和行内公式为latex
""" """
raw_text_blocks = remove_text_block_in_interline_equation_bbox(interline_equation_bboxes, raw_text_blocks) # 消除重叠:第一步,在公式内部的 raw_text_blocks = remove_text_block_in_interline_equation_bbox(
raw_text_blocks = remove_text_block_overlap_interline_equation_bbox(interline_equation_bboxes, raw_text_blocks) # 消重,第二步,和公式覆盖的 interline_equation_bboxes, raw_text_blocks
) # 消除重叠:第一步,在公式内部的
raw_text_blocks = remove_text_block_overlap_interline_equation_bbox(
interline_equation_bboxes, raw_text_blocks
) # 消重,第二步,和公式覆盖的
insert_interline_equations_textblock(interline_equation_bboxes, raw_text_blocks) insert_interline_equations_textblock(interline_equation_bboxes, raw_text_blocks)
raw_text_blocks = replace_inline_equations(inline_equation_bboxes, raw_text_blocks) raw_text_blocks = replace_inline_equations(inline_equation_bboxes, raw_text_blocks)
...@@ -414,34 +475,38 @@ def replace_equations_in_textblock(raw_text_blocks, inline_equation_bboxes, inte ...@@ -414,34 +475,38 @@ def replace_equations_in_textblock(raw_text_blocks, inline_equation_bboxes, inte
def draw_block_on_pdf_with_txt_replace_eq_bbox(json_path, pdf_path): def draw_block_on_pdf_with_txt_replace_eq_bbox(json_path, pdf_path):
""" """ """
"""
new_pdf = f"{Path(pdf_path).parent}/{Path(pdf_path).stem}.step3-消除行内公式text_block.pdf" new_pdf = f"{Path(pdf_path).parent}/{Path(pdf_path).stem}.step3-消除行内公式text_block.pdf"
with open(json_path, "r", encoding='utf-8') as f: with open(json_path, "r", encoding="utf-8") as f:
obj = json.loads(f.read()) obj = json.loads(f.read())
if os.path.exists(new_pdf): if os.path.exists(new_pdf):
os.remove(new_pdf) os.remove(new_pdf)
new_doc = fitz.open('') new_doc = fitz.open("")
doc = fitz.open(pdf_path) doc = fitz.open(pdf_path)
new_doc = fitz.open(pdf_path) new_doc = fitz.open(pdf_path)
for i in range(len(new_doc)): for i in range(len(new_doc)):
page = new_doc[i] page = new_doc[i]
inline_equation_bboxes = obj[f"page_{i}"]['inline_equations'] inline_equation_bboxes = obj[f"page_{i}"]["inline_equations"]
interline_equation_bboxes = obj[f"page_{i}"]['interline_equations'] interline_equation_bboxes = obj[f"page_{i}"]["interline_equations"]
raw_text_blocks = obj[f'page_{i}']['preproc_blocks'] raw_text_blocks = obj[f"page_{i}"]["preproc_blocks"]
raw_text_blocks = remove_text_block_in_interline_equation_bbox(interline_equation_bboxes, raw_text_blocks) # 消除重叠:第一步,在公式内部的 raw_text_blocks = remove_text_block_in_interline_equation_bbox(
raw_text_blocks = remove_text_block_overlap_interline_equation_bbox(interline_equation_bboxes, raw_text_blocks) # 消重,第二步,和公式覆盖的 interline_equation_bboxes, raw_text_blocks
) # 消除重叠:第一步,在公式内部的
raw_text_blocks = remove_text_block_overlap_interline_equation_bbox(
interline_equation_bboxes, raw_text_blocks
) # 消重,第二步,和公式覆盖的
insert_interline_equations_textblock(interline_equation_bboxes, raw_text_blocks) insert_interline_equations_textblock(interline_equation_bboxes, raw_text_blocks)
raw_text_blocks = replace_inline_equations(inline_equation_bboxes, raw_text_blocks) raw_text_blocks = replace_inline_equations(
inline_equation_bboxes, raw_text_blocks
)
# 为了检验公式是否重复,把每一行里,含有公式的span背景改成黄色的 # 为了检验公式是否重复,把每一行里,含有公式的span背景改成黄色的
color_map = [fitz.pdfcolor['blue'],fitz.pdfcolor['green']] color_map = [fitz.pdfcolor["blue"], fitz.pdfcolor["green"]]
j = 0 j = 0
for blk in raw_text_blocks: for blk in raw_text_blocks:
for i,line in enumerate(blk['lines']): for i, line in enumerate(blk["lines"]):
# line_box = line['bbox'] # line_box = line['bbox']
# shape = page.new_shape() # shape = page.new_shape()
...@@ -450,20 +515,20 @@ def draw_block_on_pdf_with_txt_replace_eq_bbox(json_path, pdf_path): ...@@ -450,20 +515,20 @@ def draw_block_on_pdf_with_txt_replace_eq_bbox(json_path, pdf_path):
# shape.commit() # shape.commit()
# j = j+1 # j = j+1
for i, span in enumerate(line['spans']): for i, span in enumerate(line["spans"]):
shape_page = page.new_shape() shape_page = page.new_shape()
span_type = span.get('_type') span_type = span.get("_type")
color = fitz.pdfcolor['blue'] color = fitz.pdfcolor["blue"]
if span_type=='first': if span_type == "first":
color = fitz.pdfcolor['blue'] color = fitz.pdfcolor["blue"]
elif span_type=='tail': elif span_type == "tail":
color = fitz.pdfcolor['green'] color = fitz.pdfcolor["green"]
elif span_type==TYPE_INLINE_EQUATION: elif span_type == TYPE_INLINE_EQUATION:
color = fitz.pdfcolor['black'] color = fitz.pdfcolor["black"]
else: else:
color = None color = None
b = span['bbox'] b = span["bbox"]
shape_page.draw_rect(b) shape_page.draw_rect(b)
shape_page.finish(color=None, fill=color, fill_opacity=0.3) shape_page.finish(color=None, fill=color, fill_opacity=0.3)
...@@ -471,13 +536,13 @@ def draw_block_on_pdf_with_txt_replace_eq_bbox(json_path, pdf_path): ...@@ -471,13 +536,13 @@ def draw_block_on_pdf_with_txt_replace_eq_bbox(json_path, pdf_path):
new_doc.save(new_pdf) new_doc.save(new_pdf)
logger.info(f"save ok {new_pdf}") logger.info(f"save ok {new_pdf}")
final_json = json.dumps(obj, ensure_ascii=False,indent=2) final_json = json.dumps(obj, ensure_ascii=False, indent=2)
with open("equations_test/final_json.json", "w") as f: with open("equations_test/final_json.json", "w") as f:
f.write(final_json) f.write(final_json)
return new_pdf return new_pdf
if __name__=="__main__": if __name__ == "__main__":
# draw_block_on_pdf_with_txt_replace_eq_bbox(new_json_path, equation_color_pdf) # draw_block_on_pdf_with_txt_replace_eq_bbox(new_json_path, equation_color_pdf)
pass pass
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.libs.pdf_image_tools import cut_image
def cut_image_and_table(spans, page, page_id, pdf_bytes_md5, imageWriter):
def return_path(type):
return join_path(pdf_bytes_md5, type)
for span in spans:
span_type = span['type']
if span_type == ContentType.Image:
span['image_path'] = cut_image(span['bbox'], page_id, page, return_path=return_path('images'), imageWriter=imageWriter)
elif span_type == ContentType.Table:
span['image_path'] = cut_image(span['bbox'], page_id, page, return_path=return_path('tables'), imageWriter=imageWriter)
return spans
from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \
calculate_iou
from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.ocr_content_type import BlockType
def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_blocks, text_blocks,
title_blocks, interline_equation_blocks, page_w, page_h):
all_bboxes = []
for image in img_blocks:
x0, y0, x1, y1 = image['bbox']
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Image, None, None, None, None])
for table in table_blocks:
x0, y0, x1, y1 = table['bbox']
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Table, None, None, None, None])
for text in text_blocks:
x0, y0, x1, y1 = text['bbox']
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Text, None, None, None, None])
for title in title_blocks:
x0, y0, x1, y1 = title['bbox']
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Title, None, None, None, None])
for interline_equation in interline_equation_blocks:
x0, y0, x1, y1 = interline_equation['bbox']
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None])
'''block嵌套问题解决'''
'''文本框与标题框重叠,优先信任文本框'''
all_bboxes = fix_text_overlap_title_blocks(all_bboxes)
'''任何框体与舍弃框重叠,优先信任舍弃框'''
all_bboxes = remove_need_drop_blocks(all_bboxes, discarded_blocks)
'''经过以上处理后,还存在大框套小框的情况,则删除小框'''
all_bboxes = remove_overlaps_min_blocks(all_bboxes)
'''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
for discarded in discarded_blocks:
x0, y0, x1, y1 = discarded['bbox']
if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None])
return all_bboxes
def fix_text_overlap_title_blocks(all_bboxes):
# 先提取所有text和title block
text_blocks = []
for block in all_bboxes:
if block[7] == BlockType.Text:
text_blocks.append(block)
title_blocks = []
for block in all_bboxes:
if block[7] == BlockType.Title:
title_blocks.append(block)
for text_block in text_blocks:
for title_block in title_blocks:
text_block_bbox = text_block[0], text_block[1], text_block[2], text_block[3]
title_block_bbox = title_block[0], title_block[1], title_block[2], title_block[3]
if calculate_iou(text_block_bbox, title_block_bbox) > 0.8:
all_bboxes.remove(title_block)
return all_bboxes
def remove_need_drop_blocks(all_bboxes, discarded_blocks):
for block in all_bboxes.copy():
for discarded_block in discarded_blocks:
block_bbox = block[0], block[1], block[2], block[3]
if calculate_overlap_area_in_bbox1_area_ratio(block_bbox, discarded_block['bbox']) > 0.6:
all_bboxes.remove(block)
return all_bboxes
def remove_overlaps_min_blocks(all_bboxes):
# 删除重叠blocks中较小的那些
for block1 in all_bboxes.copy():
for block2 in all_bboxes.copy():
if block1 != block2:
block1_bbox = [block1[0], block1[1], block1[2], block1[3]]
block2_bbox = [block2[0], block2[1], block2[2], block2[3]]
overlap_box = get_minbox_if_overlap_by_ratio(block1_bbox, block2_bbox, 0.8)
if overlap_box is not None:
bbox_to_remove = next(
(block for block in all_bboxes if [block[0], block[1], block[2], block[3]] == overlap_box),
None)
if bbox_to_remove is not None:
all_bboxes.remove(bbox_to_remove)
return all_bboxes
...@@ -3,7 +3,9 @@ from loguru import logger ...@@ -3,7 +3,9 @@ from loguru import logger
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold, get_minbox_if_overlap_by_ratio, \ from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold, get_minbox_if_overlap_by_ratio, \
calculate_overlap_area_in_bbox1_area_ratio calculate_overlap_area_in_bbox1_area_ratio
from magic_pdf.libs.drop_tag import DropTag from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.ocr_content_type import ContentType from magic_pdf.libs.ocr_content_type import ContentType, BlockType
from magic_pdf.pre_proc.ocr_span_list_modify import modify_y_axis, modify_inline_equation
from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox
# 将每一个line中的span从左到右排序 # 将每一个line中的span从左到右排序
...@@ -24,6 +26,7 @@ def line_sort_spans_by_left_to_right(lines): ...@@ -24,6 +26,7 @@ def line_sort_spans_by_left_to_right(lines):
}) })
return line_objects return line_objects
def merge_spans_to_line(spans): def merge_spans_to_line(spans):
if len(spans) == 0: if len(spans) == 0:
return [] return []
...@@ -37,7 +40,8 @@ def merge_spans_to_line(spans): ...@@ -37,7 +40,8 @@ def merge_spans_to_line(spans):
# 如果当前的span类型为"interline_equation" 或者 当前行中已经有"interline_equation" # 如果当前的span类型为"interline_equation" 或者 当前行中已经有"interline_equation"
# image和table类型,同上 # image和table类型,同上
if span['type'] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table] or any( if span['type'] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table] or any(
s['type'] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table] for s in current_line): s['type'] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table] for s in
current_line):
# 则开始新行 # 则开始新行
lines.append(current_line) lines.append(current_line)
current_line = [span] current_line = [span]
...@@ -57,6 +61,7 @@ def merge_spans_to_line(spans): ...@@ -57,6 +61,7 @@ def merge_spans_to_line(spans):
return lines return lines
def merge_spans_to_line_by_layout(spans, layout_bboxes): def merge_spans_to_line_by_layout(spans, layout_bboxes):
lines = [] lines = []
new_spans = [] new_spans = []
...@@ -103,7 +108,205 @@ def merge_lines_to_block(lines): ...@@ -103,7 +108,205 @@ def merge_lines_to_block(lines):
return blocks return blocks
def sort_blocks_by_layout(all_bboxes, layout_bboxes):
new_blocks = []
sort_blocks = []
for item in layout_bboxes:
layout_bbox = item['layout_bbox']
# 遍历blocks,将每个blocks放入对应的layout中
layout_blocks = []
for block in all_bboxes:
# 如果是footnote则跳过
if block[7] == BlockType.Footnote:
continue
block_bbox = [block[0], block[1], block[2], block[3]]
if calculate_overlap_area_in_bbox1_area_ratio(block_bbox, layout_bbox) > 0.8:
layout_blocks.append(block)
# 如果layout_blocks不为空,则放入new_blocks中
if len(layout_blocks) > 0:
new_blocks.append(layout_blocks)
# 从spans删除已经放入layout_sapns中的span
for layout_block in layout_blocks:
all_bboxes.remove(layout_block)
# 如果new_blocks不为空,则对new_blocks中每个block进行排序
if len(new_blocks) > 0:
for bboxes_in_layout_block in new_blocks:
bboxes_in_layout_block.sort(key=lambda x: x[1]) # 一个layout内部的box,按照y0自上而下排序
sort_blocks.extend(bboxes_in_layout_block)
# sort_blocks中已经包含了当前页面所有最终留下的block,且已经排好了顺序
return sort_blocks
def fill_spans_in_blocks(blocks, spans):
'''
将allspans中的span按位置关系,放入blocks中
'''
block_with_spans = []
for block in blocks:
block_type = block[7]
block_bbox = block[0:4]
block_dict = {
'type': block_type,
'bbox': block_bbox,
}
block_spans = []
for span in spans:
span_bbox = span['bbox']
if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.7:
block_spans.append(span)
'''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)'''
displayed_list = []
text_inline_lines = []
modify_y_axis(block_spans, displayed_list, text_inline_lines)
'''模型识别错误的行间公式, type类型转换成行内公式'''
block_spans = modify_inline_equation(block_spans, displayed_list, text_inline_lines)
'''bbox去除粘连'''
block_spans = remove_overlap_between_bbox(block_spans)
block_dict['spans'] = block_spans
block_with_spans.append(block_dict)
# 从spans删除已经放入block_spans中的span
if len(block_spans) > 0:
for span in block_spans:
spans.remove(span)
return block_with_spans
def fix_block_spans(block_with_spans, img_blocks, table_blocks):
'''
1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
需要将caption和footnote的text_span放入相应img_block和table_block内的
caption_block和footnote_block中
2、同时需要删除block中的spans字段
'''
fix_blocks = []
for block in block_with_spans:
block_type = block['type']
if block_type == BlockType.Image:
block = fix_image_block(block, img_blocks)
elif block_type == BlockType.Table:
block = fix_table_block(block, table_blocks)
elif block_type in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
block = fix_text_block(block)
else:
continue
fix_blocks.append(block)
return fix_blocks
def merge_spans_to_block(spans: list, block_bbox: list, block_type: str):
block_spans = []
# 如果有img_caption,则将img_block中的text_spans放入img_caption_block中
for span in spans:
if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block_bbox) > 0.8:
block_spans.append(span)
block_lines = merge_spans_to_line(block_spans)
# 对line中的span进行排序
sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
block = {
'bbox': block_bbox,
'type': block_type,
'lines': sort_block_lines
}
return block, block_spans
def make_body_block(span: dict, block_bbox: list, block_type: str):
# 创建body_block
body_line = {
'bbox': block_bbox,
'spans': [span],
}
body_block = {
'bbox': block_bbox,
'type': block_type,
'lines': [body_line]
}
return body_block
def fix_image_block(block, img_blocks):
block['blocks'] = []
# 遍历img_blocks,找到与当前block匹配的img_block
for img_block in img_blocks:
if img_block['bbox'] == block['bbox']:
# 创建img_body_block
for span in block['spans']:
if span['type'] == ContentType.Image and span['bbox'] == img_block['img_body_bbox']:
# 创建img_body_block
img_body_block = make_body_block(span, img_block['img_body_bbox'], BlockType.ImageBody)
block['blocks'].append(img_body_block)
# 从spans中移除img_body_block中已经放入的span
block['spans'].remove(span)
break
# 根据list长度,判断img_block中是否有img_caption
if img_block['img_caption_bbox'] is not None:
img_caption_block, img_caption_spans = merge_spans_to_block(
block['spans'], img_block['img_caption_bbox'], BlockType.ImageCaption
)
block['blocks'].append(img_caption_block)
break
del block['spans']
return block
def fix_table_block(block, table_blocks):
block['blocks'] = []
# 遍历table_blocks,找到与当前block匹配的table_block
for table_block in table_blocks:
if table_block['bbox'] == block['bbox']:
# 创建table_body_block
for span in block['spans']:
if span['type'] == ContentType.Table and span['bbox'] == table_block['table_body_bbox']:
# 创建table_body_block
table_body_block = make_body_block(span, table_block['table_body_bbox'], BlockType.TableBody)
block['blocks'].append(table_body_block)
# 从spans中移除img_body_block中已经放入的span
block['spans'].remove(span)
break
# 根据list长度,判断table_block中是否有caption
if table_block['table_caption_bbox'] is not None:
table_caption_block, table_caption_spans = merge_spans_to_block(
block['spans'], table_block['table_caption_bbox'], BlockType.TableCaption
)
block['blocks'].append(table_caption_block)
# 如果table_caption_block_spans不为空
if len(table_caption_spans) > 0:
# 一些span已经放入了caption_block中,需要从block['spans']中删除
for span in table_caption_spans:
block['spans'].remove(span)
# 根据list长度,判断table_block中是否有table_note
if table_block['table_footnote_bbox'] is not None:
table_footnote_block, table_footnote_spans = merge_spans_to_block(
block['spans'], table_block['table_footnote_bbox'], BlockType.TableFootnote
)
block['blocks'].append(table_footnote_block)
break
del block['spans']
return block
def fix_text_block(block):
block_lines = merge_spans_to_line(block['spans'])
sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
block['lines'] = sort_block_lines
del block['spans']
return block
...@@ -3,7 +3,7 @@ from loguru import logger ...@@ -3,7 +3,7 @@ from loguru import logger
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, get_minbox_if_overlap_by_ratio, \ from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, get_minbox_if_overlap_by_ratio, \
__is_overlaps_y_exceeds_threshold __is_overlaps_y_exceeds_threshold
from magic_pdf.libs.drop_tag import DropTag from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.ocr_content_type import ContentType from magic_pdf.libs.ocr_content_type import ContentType, BlockType
def remove_overlaps_min_spans(spans): def remove_overlaps_min_spans(spans):
...@@ -50,7 +50,8 @@ def remove_spans_by_bboxes_dict(spans, need_remove_spans_bboxes_dict): ...@@ -50,7 +50,8 @@ def remove_spans_by_bboxes_dict(spans, need_remove_spans_bboxes_dict):
need_remove_spans.append(span) need_remove_spans.append(span)
break break
# 当drop_tag为DropTag.FOOTNOTE时, 判断span是否在removed_bboxes中任意一个的下方,如果是,则删除该span # 当drop_tag为DropTag.FOOTNOTE时, 判断span是否在removed_bboxes中任意一个的下方,如果是,则删除该span
elif drop_tag == DropTag.FOOTNOTE and (span['bbox'][1]+span['bbox'][3])/2 > removed_bbox[3] and removed_bbox[0] < (span['bbox'][0]+span['bbox'][2])/2 < removed_bbox[2]: elif drop_tag == DropTag.FOOTNOTE and (span['bbox'][1] + span['bbox'][3]) / 2 > removed_bbox[3] and \
removed_bbox[0] < (span['bbox'][0] + span['bbox'][2]) / 2 < removed_bbox[2]:
need_remove_spans.append(span) need_remove_spans.append(span)
break break
...@@ -162,9 +163,10 @@ def modify_inline_equation(spans: list, displayed_list: list, text_inline_lines: ...@@ -162,9 +163,10 @@ def modify_inline_equation(spans: list, displayed_list: list, text_inline_lines:
text_line = text_inline_lines[j] text_line = text_inline_lines[j]
y0, y1 = text_line[1] y0, y1 = text_line[1]
if ( if (
span_y0 < y0 and span_y > y0 or span_y0 < y1 and span_y > y1 or span_y0 < y0 and span_y > y1) and __is_overlaps_y_exceeds_threshold( span_y0 < y0 < span_y or span_y0 < y1 < span_y or span_y0 < y0 and span_y > y1
span['bbox'], (0, y0, 0, y1)): ) and __is_overlaps_y_exceeds_threshold(
span['bbox'], (0, y0, 0, y1)
):
# 调整公式类型 # 调整公式类型
if span["type"] == ContentType.InterlineEquation: if span["type"] == ContentType.InterlineEquation:
# 最后一行是行间公式 # 最后一行是行间公式
...@@ -181,7 +183,7 @@ def modify_inline_equation(spans: list, displayed_list: list, text_inline_lines: ...@@ -181,7 +183,7 @@ def modify_inline_equation(spans: list, displayed_list: list, text_inline_lines:
span["bbox"][1] = y0 span["bbox"][1] = y0
span["bbox"][3] = y1 span["bbox"][3] = y1
break break
elif span_y < y0 or span_y0 < y0 and span_y > y0 and not __is_overlaps_y_exceeds_threshold(span['bbox'], elif span_y < y0 or span_y0 < y0 < span_y and not __is_overlaps_y_exceeds_threshold(span['bbox'],
(0, y0, 0, y1)): (0, y0, 0, y1)):
break break
else: else:
...@@ -211,3 +213,19 @@ def get_qa_need_list(blocks): ...@@ -211,3 +213,19 @@ def get_qa_need_list(blocks):
else: else:
continue continue
return images, tables, interline_equations, inline_equations return images, tables, interline_equations, inline_equations
def get_qa_need_list_v2(blocks):
# 创建 images, tables, interline_equations, inline_equations 的副本
images = []
tables = []
interline_equations = []
for block in blocks:
if block["type"] == BlockType.Image:
images.append(block)
elif block["type"] == BlockType.Table:
tables.append(block)
elif block["type"] == BlockType.InterlineEquation:
interline_equations.append(block)
return images, tables, interline_equations
...@@ -68,7 +68,7 @@ def pdf_filter(page:fitz.Page, text_blocks, table_bboxes, image_bboxes) -> tuple ...@@ -68,7 +68,7 @@ def pdf_filter(page:fitz.Page, text_blocks, table_bboxes, image_bboxes) -> tuple
""" """
if __is_contain_color_background_rect(page, text_blocks, image_bboxes): if __is_contain_color_background_rect(page, text_blocks, image_bboxes):
return False, {"need_drop": True, "drop_reason": DropReason.COLOR_BACKGROUND_TEXT_BOX} return False, {"_need_drop": True, "_drop_reason": DropReason.COLOR_BACKGROUND_TEXT_BOX}
return True, None return True, None
\ No newline at end of file
...@@ -5,7 +5,7 @@ def _remove_overlap_between_bbox(spans): ...@@ -5,7 +5,7 @@ def _remove_overlap_between_bbox(spans):
res = [] res = []
for v in spans: for v in spans:
for i in range(len(res)): for i in range(len(res)):
if _is_in(res[i]["bbox"], v["bbox"]): if _is_in(res[i]["bbox"], v["bbox"]) or _is_in(v["bbox"], res[i]["bbox"]):
continue continue
if _is_in_or_part_overlap(res[i]["bbox"], v["bbox"]): if _is_in_or_part_overlap(res[i]["bbox"], v["bbox"]):
ix0, iy0, ix1, iy1 = res[i]["bbox"] ix0, iy0, ix1, iy1 = res[i]["bbox"]
...@@ -17,21 +17,21 @@ def _remove_overlap_between_bbox(spans): ...@@ -17,21 +17,21 @@ def _remove_overlap_between_bbox(spans):
if diff_y > diff_x: if diff_y > diff_x:
if x1 >= ix1: if x1 >= ix1:
mid = (x0 + ix1) // 2 mid = (x0 + ix1) // 2
ix1 = min(mid, ix1) ix1 = min(mid - 0.25, ix1)
x0 = max(mid + 1, x0) x0 = max(mid + 0.25, x0)
else: else:
mid = (ix0 + x1) // 2 mid = (ix0 + x1) // 2
ix0 = max(mid + 1, ix0) ix0 = max(mid + 0.25, ix0)
x1 = min(mid, x1) x1 = min(mid -0.25, x1)
else: else:
if y1 >= iy1: if y1 >= iy1:
mid = (y0 + iy1) // 2 mid = (y0 + iy1) // 2
y0 = max(mid + 1, y0) y0 = max(mid + 0.25, y0)
iy1 = min(iy1, mid) iy1 = min(iy1, mid-0.25)
else: else:
mid = (iy0 + y1) // 2 mid = (iy0 + y1) // 2
y1 = min(y1, mid) y1 = min(y1, mid-0.25)
iy0 = max(mid + 1, iy0) iy0 = max(mid + 0.25, iy0)
res[i]["bbox"] = [ix0, iy0, ix1, iy1] res[i]["bbox"] = [ix0, iy0, ix1, iy1]
v["bbox"] = [x0, y0, x1, y1] v["bbox"] = [x0, y0, x1, y1]
......
import os import os
from magic_pdf.io.AbsReaderWriter import AbsReaderWriter from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from loguru import logger from loguru import logger
MODE_TXT = "text" MODE_TXT = "text"
MODE_BIN = "binary" MODE_BIN = "binary"
class DiskReaderWriter(AbsReaderWriter): class DiskReaderWriter(AbsReaderWriter):
def __init__(self, parent_path, encoding='utf-8'): def __init__(self, parent_path, encoding="utf-8"):
self.path = parent_path self.path = parent_path
self.encoding = encoding self.encoding = encoding
...@@ -20,10 +22,10 @@ class DiskReaderWriter(AbsReaderWriter): ...@@ -20,10 +22,10 @@ class DiskReaderWriter(AbsReaderWriter):
logger.error(f"文件 {abspath} 不存在") logger.error(f"文件 {abspath} 不存在")
raise Exception(f"文件 {abspath} 不存在") raise Exception(f"文件 {abspath} 不存在")
if mode == MODE_TXT: if mode == MODE_TXT:
with open(abspath, 'r', encoding = self.encoding) as f: with open(abspath, "r", encoding=self.encoding) as f:
return f.read() return f.read()
elif mode == MODE_BIN: elif mode == MODE_BIN:
with open(abspath, 'rb') as f: with open(abspath, "rb") as f:
return f.read() return f.read()
else: else:
raise ValueError("Invalid mode. Use 'text' or 'binary'.") raise ValueError("Invalid mode. Use 'text' or 'binary'.")
...@@ -33,32 +35,32 @@ class DiskReaderWriter(AbsReaderWriter): ...@@ -33,32 +35,32 @@ class DiskReaderWriter(AbsReaderWriter):
abspath = path abspath = path
else: else:
abspath = os.path.join(self.path, path) abspath = os.path.join(self.path, path)
directory_path = os.path.dirname(abspath)
if not os.path.exists(directory_path):
os.makedirs(directory_path)
if mode == MODE_TXT: if mode == MODE_TXT:
with open(abspath, 'w', encoding=self.encoding) as f: with open(abspath, "w", encoding=self.encoding) as f:
f.write(content) f.write(content)
logger.info(f"内容已成功写入 {abspath}")
elif mode == MODE_BIN: elif mode == MODE_BIN:
with open(abspath, 'wb') as f: with open(abspath, "wb") as f:
f.write(content) f.write(content)
logger.info(f"内容已成功写入 {abspath}")
else: else:
raise ValueError("Invalid mode. Use 'text' or 'binary'.") raise ValueError("Invalid mode. Use 'text' or 'binary'.")
def read_jsonl(self, path: str, byte_start=0, byte_end=None, encoding='utf-8'): def read_jsonl(self, path: str, byte_start=0, byte_end=None, encoding="utf-8"):
return self.read(path) return self.read(path)
# 使用示例 # 使用示例
if __name__ == "__main__": if __name__ == "__main__":
file_path = "io/example.txt" file_path = "io/test/example.txt"
drw = DiskReaderWriter("D:\projects\papayfork\Magic-PDF\magic_pdf") drw = DiskReaderWriter("D:\projects\papayfork\Magic-PDF\magic_pdf")
# 写入内容到文件 # 写入内容到文件
drw.write(b"Hello, World!", path="io/example.txt", mode="binary") drw.write(b"Hello, World!", path="io/test/example.txt", mode="binary")
# 从文件读取内容 # 从文件读取内容
content = drw.read(path=file_path) content = drw.read(path=file_path)
if content: if content:
logger.info(f"从 {file_path} 读取的内容: {content}") logger.info(f"从 {file_path} 读取的内容: {content}")
from magic_pdf.io.AbsReaderWriter import AbsReaderWriter from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.libs.commons import parse_aws_param, parse_bucket_key from magic_pdf.libs.commons import parse_aws_param, parse_bucket_key
import boto3 import boto3
from loguru import logger from loguru import logger
...@@ -11,7 +11,7 @@ MODE_BIN = "binary" ...@@ -11,7 +11,7 @@ MODE_BIN = "binary"
class S3ReaderWriter(AbsReaderWriter): class S3ReaderWriter(AbsReaderWriter):
def __init__(self, ak: str, sk: str, endpoint_url: str, addressing_style: str, parent_path: str): def __init__(self, ak: str, sk: str, endpoint_url: str, addressing_style: str = 'auto', parent_path: str = ''):
self.client = self._get_client(ak, sk, endpoint_url, addressing_style) self.client = self._get_client(ak, sk, endpoint_url, addressing_style)
self.path = parent_path self.path = parent_path
......
from loguru import logger
from magic_pdf.libs.drop_reason import DropReason
from loguru import logger
from magic_pdf.libs.drop_reason import DropReason
def get_data_source(jso: dict):
data_source = jso.get("data_source")
if data_source is None:
data_source = jso.get("file_source")
return data_source
def get_data_type(jso: dict):
data_type = jso.get("data_type")
if data_type is None:
data_type = jso.get("file_type")
return data_type
def get_bookid(jso: dict):
book_id = jso.get("bookid")
if book_id is None:
book_id = jso.get("original_file_id")
return book_id
def exception_handler(jso: dict, e):
logger.exception(e)
jso["need_drop"] = True
jso["drop_reason"] = DropReason.Exception
jso["exception"] = f"ERROR: {e}"
return jso
def get_bookname(jso: dict):
data_source = get_data_source(jso)
file_id = jso.get("file_id")
book_name = f"{data_source}/{file_id}"
return book_name
from loguru import logger
""" from magic_pdf.libs.drop_reason import DropReason
用户输入:
model数组,每个元素代表一个页面
pdf在s3的路径
截图保存的s3位置
然后:
1)根据s3路径,调用spark集群的api,拿到ak,sk,endpoint,构造出s3PDFReader
2)根据用户输入的s3地址,调用spark集群的api,拿到ak,sk,endpoint,构造出s3ImageWriter
其余部分至于构造s3cli, 获取ak,sk都在code-clean里写代码完成。不要反向依赖!!! def get_data_source(jso: dict):
data_source = jso.get("data_source")
if data_source is None:
data_source = jso.get("file_source")
return data_source
"""
from loguru import logger
from magic_pdf.io import AbsReaderWriter def get_data_type(jso: dict):
from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr data_type = jso.get("data_type")
from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt if data_type is None:
data_type = jso.get("file_type")
return data_type
def parse_txt_pdf(pdf_bytes:bytes, pdf_models:list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args, **kwargs): def get_bookid(jso: dict):
""" book_id = jso.get("bookid")
解析文本类pdf if book_id is None:
""" book_id = jso.get("original_file_id")
pdf_info_dict = parse_pdf_by_txt( return book_id
pdf_bytes,
pdf_models,
imageWriter,
start_page_id=start_page,
debug_mode=is_debug,
)
pdf_info_dict["parse_type"] = "txt"
return pdf_info_dict def exception_handler(jso: dict, e):
logger.exception(e)
jso["_need_drop"] = True
jso["_drop_reason"] = DropReason.Exception
jso["_exception"] = f"ERROR: {e}"
return jso
def parse_ocr_pdf(pdf_bytes:bytes, pdf_models:list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args, **kwargs): def get_bookname(jso: dict):
""" data_source = get_data_source(jso)
解析ocr类pdf file_id = jso.get("file_id")
""" book_name = f"{data_source}/{file_id}"
pdf_info_dict = parse_pdf_by_ocr( return book_name
pdf_bytes,
pdf_models,
imageWriter,
start_page_id=start_page,
debug_mode=is_debug,
)
pdf_info_dict["parse_type"] = "ocr"
return pdf_info_dict def spark_json_extractor(jso: dict) -> dict:
def parse_union_pdf(pdf_bytes:bytes, pdf_models:list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args, **kwargs):
""" """
ocr和文本混合的pdf,全部解析出来 从json中提取数据,返回一个dict
""" """
def parse_pdf(method):
try:
return method(
pdf_bytes,
pdf_models,
imageWriter,
start_page_id=start_page,
debug_mode=is_debug,
)
except Exception as e:
logger.error(f"{method.__name__} error: {e}")
return None
pdf_info_dict = parse_pdf(parse_pdf_by_txt)
if pdf_info_dict is None or pdf_info_dict.get("need_drop", False):
logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
if pdf_info_dict is None:
raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
else:
pdf_info_dict["parse_type"] = "ocr"
else:
pdf_info_dict["parse_type"] = "txt"
return pdf_info_dict
def spark_json_extractor(jso:dict): return {
pass "_pdf_type": jso["_pdf_type"],
"model_list": jso["doc_layout_result"],
}
"""
用户输入:
model数组,每个元素代表一个页面
pdf在s3的路径
截图保存的s3位置
然后:
1)根据s3路径,调用spark集群的api,拿到ak,sk,endpoint,构造出s3PDFReader
2)根据用户输入的s3地址,调用spark集群的api,拿到ak,sk,endpoint,构造出s3ImageWriter
其余部分至于构造s3cli, 获取ak,sk都在code-clean里写代码完成。不要反向依赖!!!
"""
from loguru import logger
from magic_pdf.rw import AbsReaderWriter
from magic_pdf.pdf_parse_by_ocr_v2 import parse_pdf_by_ocr
from magic_pdf.pdf_parse_by_txt_v2 import parse_pdf_by_txt
PARSE_TYPE_TXT = "txt"
PARSE_TYPE_OCR = "ocr"
def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args,
**kwargs):
"""
解析文本类pdf
"""
pdf_info_dict = parse_pdf_by_txt(
pdf_bytes,
pdf_models,
imageWriter,
start_page_id=start_page,
debug_mode=is_debug,
)
pdf_info_dict["_parse_type"] = PARSE_TYPE_TXT
return pdf_info_dict
def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args,
**kwargs):
"""
解析ocr类pdf
"""
pdf_info_dict = parse_pdf_by_ocr(
pdf_bytes,
pdf_models,
imageWriter,
start_page_id=start_page,
debug_mode=is_debug,
)
pdf_info_dict["_parse_type"] = PARSE_TYPE_OCR
return pdf_info_dict
def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0,
*args, **kwargs):
"""
ocr和文本混合的pdf,全部解析出来
"""
def parse_pdf(method):
try:
return method(
pdf_bytes,
pdf_models,
imageWriter,
start_page_id=start_page,
debug_mode=is_debug,
)
except Exception as e:
logger.error(f"{method.__name__} error: {e}")
return None
pdf_info_dict = parse_pdf(parse_pdf_by_txt)
if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
if pdf_info_dict is None:
raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
else:
pdf_info_dict["_parse_type"] = PARSE_TYPE_OCR
else:
pdf_info_dict["_parse_type"] = PARSE_TYPE_TXT
return pdf_info_dict
...@@ -14,5 +14,7 @@ termcolor>=2.4.0 ...@@ -14,5 +14,7 @@ termcolor>=2.4.0
wordninja>=2.0.0 wordninja>=2.0.0
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl
zh_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.7.0/zh_core_web_sm-3.7.0-py3-none-any.whl zh_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.7.0/zh_core_web_sm-3.7.0-py3-none-any.whl
scikit-learn==1.4.1.post1 scikit-learn>=1.0.2
nltk==3.8.1 nltk==3.8.1
s3pathlib>=2.1.1
# 工具脚本使用说明
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment