Commit 831db2e0 authored by myhloli's avatar myhloli

update:Complete the parsing logic of PEK

parent 1fac6aa7
import os
import time
import cv2
import numpy as np
import yaml
from PIL import Image
from ultralytics import YOLO
from loguru import logger
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
......@@ -9,7 +13,9 @@ import unimernet.tasks as tasks
from unimernet.processors import load_processor
import argparse
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
......@@ -31,6 +37,25 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
return model, vis_processor
class MathDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# if not pil image, then convert to pil image
if isinstance(self.image_paths[idx], str):
raw_image = Image.open(self.image_paths[idx])
else:
raw_image = self.image_paths[idx]
if self.transform:
image = self.transform(raw_image)
return image
class CustomPEKModel:
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
"""
......@@ -82,6 +107,83 @@ class CustomPEKModel:
logger.info('DocAnalysis init done!')
def __call__(self, images):
# layout检测 + 公式检测
doc_layout_result = []
latex_filling_list = []
mf_image_list = []
for idx, img_dict in enumerate(images):
image = img_dict["img"]
img_height, img_width = img_dict["height"], img_dict["width"]
layout_res = self.layout_model(image, ignore_catids=[])
# 公式检测
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
}
layout_res['layout_dets'].append(new_item)
latex_filling_list.append(new_item)
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
mf_image_list.append(bbox_img)
layout_res['page_info'] = dict(
page_no=idx,
height=img_height,
width=img_width
)
doc_layout_result.append(layout_res)
# 公式识别,因为识别速度较慢,为了提速,把单个pdf的所有公式裁剪完,一起批量做识别。
a = time.time()
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=128, num_workers=0)
mfr_res = []
for imgs in dataloader:
imgs = imgs.to(self.device)
output = self.mfr_model.generate({'image': imgs})
mfr_res.extend(output['pred_str'])
for res, latex in zip(latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex)
b = time.time()
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {round(b - a, 2)}")
if self.apply_ocr:
# ocr识别
for idx, img_dict in enumerate(images):
image = img_dict["img"]
pil_img = Image.fromarray(image)
single_page_res = doc_layout_result[idx]['layout_dets']
single_page_mfdetrec_res = []
for res in single_page_res:
if int(res['category_id']) in [13, 14]:
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
single_page_mfdetrec_res.append({
"bbox": [xmin, ymin, xmax, ymax],
})
for res in single_page_res:
if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
crop_box = [xmin, ymin, xmax, ymax]
cropped_img = Image.new('RGB', pil_img.size, 'white')
cropped_img.paste(pil_img.crop(crop_box), crop_box)
cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
if ocr_res:
for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
doc_layout_result[idx]['layout_dets'].append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': round(score, 2),
'text': text,
})
def __call__(self, image):
pass
return doc_layout_result
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment