Commit 0e2d0b8b authored by 赵小蒙's avatar 赵小蒙

parse_pdf_by_ocr 和 cut_image 重构,使用抽象类进行写出操作

parent 00f16239
......@@ -120,29 +120,9 @@ def read_file(pdf_path: str, s3_profile):
return f.read()
def get_docx_model_output(pdf_model_output, pdf_model_s3_profile, page_id):
if isinstance(pdf_model_output, str):
model_output_json_path = join_path(pdf_model_output, f"page_{page_id + 1}.json") # 模型输出的页面编号从1开始的
if os.path.exists(model_output_json_path):
json_from_docx = read_file(model_output_json_path, pdf_model_s3_profile)
model_output_json = json.loads(json_from_docx)
else:
try:
model_output_json_path = join_path(pdf_model_output, "model.json")
with open(model_output_json_path, "r", encoding="utf-8") as f:
model_output_json = json.load(f)
model_output_json = model_output_json["doc_layout_result"][page_id]
except:
s3_model_output_json_path = join_path(pdf_model_output, f"page_{page_id + 1}.json")
s3_model_output_json_path = join_path(pdf_model_output, f"{page_id}.json")
#s3_model_output_json_path = join_path(pdf_model_output, f"page_{page_id }.json")
# logger.warning(f"model_output_json_path: {model_output_json_path} not found. try to load from s3: {s3_model_output_json_path}")
s = read_file(s3_model_output_json_path, pdf_model_s3_profile)
return json.loads(s)
elif isinstance(pdf_model_output, list):
model_output_json = pdf_model_output[page_id]
def get_docx_model_output(pdf_model_output, page_id):
model_output_json = pdf_model_output[page_id]
return model_output_json
......
......@@ -10,35 +10,19 @@ from magic_pdf.libs.commons import parse_bucket_key, join_path
from magic_pdf.libs.hash_utils import compute_sha256
def cut_image(bbox: Tuple, page_num: int, page: fitz.Page, save_parent_path: str, s3_return_path=None, img_s3_client=None, upload_switch=True):
def cut_image(bbox: Tuple, page_num: int, page: fitz.Page, return_path, imageWriter, upload_switch=True):
"""
从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径
save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。
"""
# 拼接文件名
filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}.jpg"
filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}"
# 老版本返回不带bucket的路径
s3_img_path = join_path(s3_return_path, filename) if s3_return_path is not None else None
# 新版本生成s3的平铺路径
s3_img_hash256_path = f"{compute_sha256(s3_img_path)}.jpg"
# 打印图片文件名
# print(f"Saved {image_save_path}")
#检查坐标
# x_check = int(bbox[2]) - int(bbox[0])
# y_check = int(bbox[3]) - int(bbox[1])
# if x_check <= 0 or y_check <= 0:
#
# if image_save_path.startswith("s3://"):
# logger.exception(f"传入图片坐标有误,x1<x0或y1<y0,{s3_img_path}")
# return s3_img_path
# else:
# logger.exception(f"传入图片坐标有误,x1<x0或y1<y0,{image_save_path}")
# return image_save_path
img_path = join_path(return_path, filename) if return_path is not None else None
# 新版本生成平铺路径
img_hash256_path = f"{compute_sha256(img_path)}.jpg"
# 将坐标转换为fitz.Rect对象
rect = fitz.Rect(*bbox)
......@@ -47,41 +31,11 @@ def cut_image(bbox: Tuple, page_num: int, page: fitz.Page, save_parent_path: str
# 截取图片
pix = page.get_pixmap(clip=rect, matrix=zoom)
if save_parent_path.startswith("s3://"):
if not upload_switch:
pass
else:
"""图片保存到s3"""
# 从save_parent_path获取bucket_name
bucket_name, bucket_key = parse_bucket_key(save_parent_path)
# 平铺路径赋值给bucket_key
bucket_key = s3_img_hash256_path
# 将字节流上传到s3
byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
file_obj = io.BytesIO(byte_data)
if img_s3_client is not None:
img_s3_client.upload_fileobj(file_obj, bucket_name, bucket_key)
# 每个图片上传任务都创建一个新的client
# img_s3_client_once = get_s3_client(image_save_path)
# img_s3_client_once.upload_fileobj(file_obj, bucket_name, bucket_key)
else:
logger.exception("must input img_s3_client")
# return s3_img_path # 早期版本要求返回不带bucket的路径
s3_image_save_path = f"s3://{bucket_name}/{s3_img_hash256_path}" # 新版本返回平铺的s3路径
return s3_image_save_path
else:
# 保存图片到本地
# 先检查一下image_save_path的父目录是否存在,如果不存在,就创建
local_image_save_path = join_path(save_parent_path, filename)
parent_dir = os.path.dirname(local_image_save_path)
if not os.path.exists(parent_dir):
os.makedirs(parent_dir)
pix.save(local_image_save_path, jpg_quality=95)
# 为了直接能在markdown里看,这里把地址改为相对于mardown的地址
pth = Path(local_image_save_path)
local_image_save_path = f"{pth.parent.name}/{pth.name}"
return local_image_save_path
byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
imageWriter.write(data=byte_data, path=img_hash256_path, mode="binary")
return img_hash256_path
def save_images_by_bboxes(book_name: str, page_num: int, page: fitz.Page, save_path: str,
......
import json
import os
import time
from loguru import logger
from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_text_bbox
from magic_pdf.libs.commons import (
read_file,
join_path,
fitz,
get_img_s3_client,
get_delta_time,
get_docx_model_output,
)
......@@ -17,7 +9,6 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.libs.safe_filename import sanitize_filename
from magic_pdf.para.para_split import para_split
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component
from magic_pdf.pre_proc.detect_footer_by_model import parse_footers
......@@ -39,38 +30,16 @@ from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox
def parse_pdf_by_ocr(
pdf_bytes,
pdf_model_output,
save_path,
book_name="",
pdf_model_profile=None,
image_s3_config=None,
imageWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
):
pdf_bytes_md5 = compute_md5(pdf_bytes)
save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
md_bookname_save_path = ""
if debug_mode:
book_name = sanitize_filename(book_name)
save_path = join_path(save_tmp_path, "md")
pdf_local_path = join_path(save_tmp_path, "download-pdfs", book_name)
if not os.path.exists(os.path.dirname(pdf_local_path)):
# 如果目录不存在,创建它
os.makedirs(os.path.dirname(pdf_local_path))
md_bookname_save_path = join_path(save_tmp_path, "md", book_name)
if not os.path.exists(md_bookname_save_path):
# 如果目录不存在,创建它
os.makedirs(md_bookname_save_path)
with open(pdf_local_path + ".pdf", "wb") as pdf_file:
pdf_file.write(pdf_bytes)
pdf_docs = fitz.open("pdf", pdf_bytes)
# 初始化空的pdf_info_dict
pdf_info_dict = {}
img_s3_client = get_img_s3_client(save_path, image_s3_config)
start_time = time.time()
......@@ -92,16 +61,14 @@ def parse_pdf_by_ocr(
# 获取当前页的模型数据
ocr_page_info = get_docx_model_output(
pdf_model_output, pdf_model_profile, page_id
pdf_model_output, page_id
)
"""从json中获取每页的页码、页眉、页脚的bbox"""
page_no_bboxes = parse_pageNos(page_id, page, ocr_page_info)
header_bboxes = parse_headers(page_id, page, ocr_page_info)
footer_bboxes = parse_footers(page_id, page, ocr_page_info)
footnote_bboxes = parse_footnotes_by_model(
page_id, page, ocr_page_info, md_bookname_save_path, debug_mode=debug_mode
)
footnote_bboxes = parse_footnotes_by_model(page_id, page, ocr_page_info, debug_mode=debug_mode)
# 构建需要remove的bbox字典
need_remove_spans_bboxes_dict = {
......@@ -180,9 +147,7 @@ def parse_pdf_by_ocr(
spans, dropped_spans_by_removed_bboxes = remove_spans_by_bboxes_dict(spans, need_remove_spans_bboxes_dict)
'''对image和table截图'''
if book_name == "":
book_name = pdf_bytes_md5
spans = cut_image_and_table(spans, page, page_id, book_name, save_path, img_s3_client)
spans = cut_image_and_table(spans, page, page_id, pdf_bytes_md5, imageWriter)
'''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)'''
displayed_list = []
......@@ -245,16 +210,4 @@ def parse_pdf_by_ocr(
"""分段"""
para_split(pdf_info_dict, debug_mode=debug_mode)
'''在测试时,保存调试信息'''
if debug_mode:
params_file_save_path = join_path(
save_tmp_path, "md", book_name, "preproc_out.json"
)
with open(params_file_save_path, "w", encoding="utf-8") as f:
json.dump(pdf_info_dict, f, ensure_ascii=False, indent=4)
# drow_bbox
draw_layout_bbox(pdf_info_dict, pdf_bytes, md_bookname_save_path)
draw_text_bbox(pdf_info_dict, pdf_bytes, md_bookname_save_path)
return pdf_info_dict
......@@ -73,13 +73,8 @@ paraMergeException_msg = ParaMergeException().message
def parse_pdf_by_txt(
pdf_bytes,
pdf_model_output,
save_path,
book_name,
pdf_model_profile=None,
image_s3_config=None,
start_page_id=0,
end_page_id=None,
junk_img_bojids=[],
debug_mode=False,
):
......@@ -128,22 +123,29 @@ def parse_pdf_by_txt(
# 对单页面非重复id的img数量做统计,如果当前页超过1500则直接return need_drop
"""
page_imgs = page.get_images()
img_counts = 0
for img in page_imgs:
img_bojid = img[0]
if img_bojid in junk_img_bojids: # 判断这个图片在不在junklist中
continue # 如果在junklist就不用管了,跳过
else:
recs = page.get_image_rects(img, transform=True)
if recs: # 如果这张图在当前页面有展示
img_counts += 1
if img_counts >= 1500: # 如果去除了junkimg的影响,单页img仍然超过1500的话,就排除当前pdf
logger.warning(
f"page_id: {page_id}, img_counts: {img_counts}, drop this pdf: {book_name}, drop_reason: {DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}"
)
# 去除对junkimg的依赖,简化逻辑
if len(page_imgs) > 1500: # 如果当前页超过1500张图片,直接跳过
logger.warning(f"page_id: {page_id}, img_counts: {len(page_imgs)}, drop this pdf: {book_name}")
result = {"need_drop": True, "drop_reason": DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}
if not debug_mode:
return result
# img_counts = 0
# for img in page_imgs:
# img_bojid = img[0]
# if img_bojid in junk_img_bojids: # 判断这个图片在不在junklist中
# continue # 如果在junklist就不用管了,跳过
# else:
# recs = page.get_image_rects(img, transform=True)
# if recs: # 如果这张图在当前页面有展示
# img_counts += 1
# if img_counts >= 1500: # 如果去除了junkimg的影响,单页img仍然超过1500的话,就排除当前pdf
# logger.warning(
# f"page_id: {page_id}, img_counts: {img_counts}, drop this pdf: {book_name}, drop_reason: {DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}"
# )
# result = {"need_drop": True, "drop_reason": DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}
# if not debug_mode:
# return result
"""
==================================================================================================================================
......@@ -154,10 +156,10 @@ def parse_pdf_by_txt(
"dict",
flags=fitz.TEXTFLAGS_TEXT,
)["blocks"]
model_output_json = get_docx_model_output(pdf_model_output, pdf_model_profile, page_id)
model_output_json = get_docx_model_output(pdf_model_output, page_id)
# 解析图片
image_bboxes = parse_images(page_id, page, model_output_json, junk_img_bojids)
image_bboxes = parse_images(page_id, page, model_output_json)
image_bboxes = fix_image_vertical(image_bboxes, text_raw_blocks) # 修正图片的位置
image_bboxes = fix_seperated_image(image_bboxes) # 合并有边重合的图片
image_bboxes = include_img_title(text_raw_blocks, image_bboxes) # 向图片上方和下方寻找title,使用规则进行匹配,暂时只支持英文规则
......
......@@ -112,7 +112,6 @@ def parse_pdf_for_train(
pdf_model_output,
save_path,
book_name,
pdf_model_profile=None,
image_s3_config=None,
start_page_id=0,
end_page_id=None,
......@@ -200,7 +199,7 @@ def parse_pdf_for_train(
flags=fitz.TEXTFLAGS_TEXT,
)["blocks"]
model_output_json = get_docx_model_output(
pdf_model_output, pdf_model_profile, page_id
pdf_model_output, page_id
)
# 解析图片
......
......@@ -3,7 +3,7 @@ from magic_pdf.libs.commons import fitz # pyMuPDF库
from magic_pdf.libs.coordinate_transform import get_scale_ratio
def parse_footnotes_by_model(page_ID: int, page: fitz.Page, json_from_DocXchain_obj: dict, md_bookname_save_path, debug_mode=False):
def parse_footnotes_by_model(page_ID: int, page: fitz.Page, json_from_DocXchain_obj: dict, md_bookname_save_path=None, debug_mode=False):
"""
:param page_ID: int类型,当前page在当前pdf文档中是第page_D页。
:param page :fitz读取的当前页的内容
......
......@@ -3,21 +3,16 @@ from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.libs.pdf_image_tools import cut_image
def cut_image_and_table(spans, page, page_id, book_name, save_path, img_s3_client):
def cut_image_and_table(spans, page, page_id, pdf_bytes_md5, imageWriter):
"""spark环境book_name为pdf_bytes_md5,本地环境会传正常bookname"""
def s3_return_path(type):
return join_path(book_name, type)
def img_save_path(type):
return join_path(save_path, s3_return_path(type))
def return_path(type):
return join_path(pdf_bytes_md5, type)
for span in spans:
span_type = span['type']
if span_type == ContentType.Image:
span['image_path'] = cut_image(span['bbox'], page_id, page, img_save_path('images'), s3_return_path=s3_return_path('images'), img_s3_client=img_s3_client)
span['image_path'] = cut_image(span['bbox'], page_id, page, return_path=return_path('images'), imageWriter=imageWriter)
elif span_type == ContentType.Table:
span['image_path'] = cut_image(span['bbox'], page_id, page, img_save_path('tables'), s3_return_path=s3_return_path('tables'), img_s3_client=img_s3_client)
span['image_path'] = cut_image(span['bbox'], page_id, page, return_path=return_path('tables'), imageWriter=imageWriter)
return spans
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment