Commit 2b8db660 authored by myhloli's avatar myhloli

update:Modify the PEK module to parse page by page.

parent 40802b79
import time
import fitz
import numpy as np
from loguru import logger
from magic_pdf.model.model_list import MODEL, MODEL_TYPE
from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config
......@@ -44,9 +46,9 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
return images
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, model=MODEL.PEK,
model_type=MODEL_TYPE.MULTI_PAGE):
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, model=MODEL.PEK):
if model_config.__use_inside_model__:
model_init_start = time.time()
if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
......@@ -56,6 +58,8 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod
else:
logger.error("Not allow model_name!")
exit(1)
model_init_cost = time.time() - model_init_start
logger.info(f"model init cost: {model_init_cost}")
else:
logger.error("use_inside_model is False, not allow to use inside model")
exit(1)
......@@ -63,7 +67,7 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod
images = load_images_from_pdf(pdf_bytes)
model_json = []
if model_type == MODEL_TYPE.SINGLE_PAGE:
doc_analyze_start = time.time()
for index, img_dict in enumerate(images):
img = img_dict["img"]
page_width = img_dict["width"]
......@@ -72,7 +76,7 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict)
elif model_type == MODEL_TYPE.MULTI_PAGE:
model_json = custom_model(images)
doc_analyze_cost = time.time() - doc_analyze_start
logger.info(f"doc analyze cost: {doc_analyze_cost}")
return model_json
class MODEL:
Paddle = "pp_structure_v2"
PEK = "pdf_extract_kit"
class MODEL_TYPE:
# 单页解析
SINGLE_PAGE = 1
# 多页解析
MULTI_PAGE = 2
......@@ -107,15 +107,17 @@ class CustomPEKModel:
logger.info('DocAnalysis init done!')
def __call__(self, images):
# layout检测 + 公式检测
doc_layout_result = []
def __call__(self, image):
latex_filling_list = []
mf_image_list = []
for idx, img_dict in enumerate(images):
image = img_dict["img"]
img_height, img_width = img_dict["height"], img_dict["width"]
# layout检测
layout_start = time.time()
layout_res = self.layout_model(image, ignore_catids=[])
layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection cost: {layout_cost}")
# 公式检测
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
......@@ -126,51 +128,42 @@ class CustomPEKModel:
'score': round(float(conf.item()), 2),
'latex': '',
}
layout_res['layout_dets'].append(new_item)
layout_res.append(new_item)
latex_filling_list.append(new_item)
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
mf_image_list.append(bbox_img)
layout_res['page_info'] = dict(
page_no=idx,
height=img_height,
width=img_width
)
doc_layout_result.append(layout_res)
# 公式识别,因为识别速度较慢,为了提速,把单个pdf的所有公式裁剪完,一起批量做识别。
a = time.time()
# 公式识别
mfr_start = time.time()
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=128, num_workers=0)
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
mfr_res = []
for imgs in dataloader:
imgs = imgs.to(self.device)
output = self.mfr_model.generate({'image': imgs})
for mf_img in dataloader:
mf_img = mf_img.to(self.device)
output = self.mfr_model.generate({'image': mf_img})
mfr_res.extend(output['pred_str'])
for res, latex in zip(latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex)
b = time.time()
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {round(b - a, 2)}")
mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
# ocr识别
if self.apply_ocr:
for idx, img_dict in enumerate(images):
image = img_dict["img"]
ocr_start = time.time()
pil_img = Image.fromarray(image)
single_page_res = doc_layout_result[idx]['layout_dets']
single_page_mfdetrec_res = []
for res in single_page_res:
for res in layout_res:
if int(res['category_id']) in [13, 14]:
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
single_page_mfdetrec_res.append({
"bbox": [xmin, ymin, xmax, ymax],
})
for res in single_page_res:
for res in layout_res:
if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
crop_box = [xmin, ymin, xmax, ymax]
crop_box = (xmin, ymin, xmax, ymax)
cropped_img = Image.new('RGB', pil_img.size, 'white')
cropped_img.paste(pil_img.crop(crop_box), crop_box)
cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
......@@ -179,11 +172,13 @@ class CustomPEKModel:
for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
doc_layout_result[idx]['layout_dets'].append({
layout_res.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': round(score, 2),
'text': text,
})
ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr cost: {ocr_cost}")
return doc_layout_result
return layout_res
......@@ -8,6 +8,7 @@ from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, DefaultPredictor
def add_vit_config(cfg):
"""
Add config for VIT.
......@@ -98,6 +99,7 @@ class DotDict(dict):
def __setattr__(self, key, value):
self[key] = value
class Layoutlmv3_Predictor(object):
def __init__(self, weights, config_file):
layout_args = {
......@@ -113,14 +115,16 @@ class Layoutlmv3_Predictor(object):
layout_args = DotDict(layout_args)
cfg = setup(layout_args)
self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption", "table_footnote", "isolate_formula", "formula_caption"]
self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption",
"table_footnote", "isolate_formula", "formula_caption"]
MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping
self.predictor = DefaultPredictor(cfg)
def __call__(self, image, ignore_catids=[]):
page_layout_result = {
"layout_dets": []
}
# page_layout_result = {
# "layout_dets": []
# }
layout_dets = []
outputs = self.predictor(image)
boxes = outputs["instances"].to("cpu")._fields["pred_boxes"].tensor.tolist()
labels = outputs["instances"].to("cpu")._fields["pred_classes"].tolist()
......@@ -128,7 +132,7 @@ class Layoutlmv3_Predictor(object):
for bbox_idx in range(len(boxes)):
if labels[bbox_idx] in ignore_catids:
continue
page_layout_result["layout_dets"].append({
layout_dets.append({
"category_id": labels[bbox_idx],
"poly": [
boxes[bbox_idx][0], boxes[bbox_idx][1],
......@@ -138,4 +142,4 @@ class Layoutlmv3_Predictor(object):
],
"score": scores[bbox_idx]
})
return page_layout_result
\ No newline at end of file
return layout_dets
......@@ -136,9 +136,10 @@ class ModifiedPaddleOCR(PaddleOCR):
logger.error('When input a list of images, det must be false')
exit(0)
if cls == True and self.use_angle_cls == False:
logger.warning(
'Since the angle classifier is not initialized, it will not be used during the forward process'
)
pass
# logger.warning(
# 'Since the angle classifier is not initialized, it will not be used during the forward process'
# )
img = check_img(img)
# for infer pdf file
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment