Commit 2b8db660 authored by myhloli's avatar myhloli

update:Modify the PEK module to parse page by page.

parent 40802b79
import time
import fitz import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.model.model_list import MODEL, MODEL_TYPE from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config import magic_pdf.model as model_config
...@@ -44,9 +46,9 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list: ...@@ -44,9 +46,9 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
return images return images
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, model=MODEL.PEK, def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, model=MODEL.PEK):
model_type=MODEL_TYPE.MULTI_PAGE):
if model_config.__use_inside_model__: if model_config.__use_inside_model__:
model_init_start = time.time()
if model == MODEL.Paddle: if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log) custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
...@@ -56,6 +58,8 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod ...@@ -56,6 +58,8 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod
else: else:
logger.error("Not allow model_name!") logger.error("Not allow model_name!")
exit(1) exit(1)
model_init_cost = time.time() - model_init_start
logger.info(f"model init cost: {model_init_cost}")
else: else:
logger.error("use_inside_model is False, not allow to use inside model") logger.error("use_inside_model is False, not allow to use inside model")
exit(1) exit(1)
...@@ -63,16 +67,16 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod ...@@ -63,16 +67,16 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod
images = load_images_from_pdf(pdf_bytes) images = load_images_from_pdf(pdf_bytes)
model_json = [] model_json = []
if model_type == MODEL_TYPE.SINGLE_PAGE: doc_analyze_start = time.time()
for index, img_dict in enumerate(images): for index, img_dict in enumerate(images):
img = img_dict["img"] img = img_dict["img"]
page_width = img_dict["width"] page_width = img_dict["width"]
page_height = img_dict["height"] page_height = img_dict["height"]
result = custom_model(img) result = custom_model(img)
page_info = {"page_no": index, "height": page_height, "width": page_width} page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info} page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict) model_json.append(page_dict)
elif model_type == MODEL_TYPE.MULTI_PAGE: doc_analyze_cost = time.time() - doc_analyze_start
model_json = custom_model(images) logger.info(f"doc analyze cost: {doc_analyze_cost}")
return model_json return model_json
class MODEL: class MODEL:
Paddle = "pp_structure_v2" Paddle = "pp_structure_v2"
PEK = "pdf_extract_kit" PEK = "pdf_extract_kit"
class MODEL_TYPE:
# 单页解析
SINGLE_PAGE = 1
# 多页解析
MULTI_PAGE = 2
...@@ -107,83 +107,78 @@ class CustomPEKModel: ...@@ -107,83 +107,78 @@ class CustomPEKModel:
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
def __call__(self, images): def __call__(self, image):
# layout检测 + 公式检测
doc_layout_result = []
latex_filling_list = [] latex_filling_list = []
mf_image_list = [] mf_image_list = []
for idx, img_dict in enumerate(images):
image = img_dict["img"]
img_height, img_width = img_dict["height"], img_dict["width"]
layout_res = self.layout_model(image, ignore_catids=[])
# 公式检测
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
}
layout_res['layout_dets'].append(new_item)
latex_filling_list.append(new_item)
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
mf_image_list.append(bbox_img)
layout_res['page_info'] = dict(
page_no=idx,
height=img_height,
width=img_width
)
doc_layout_result.append(layout_res)
# 公式识别,因为识别速度较慢,为了提速,把单个pdf的所有公式裁剪完,一起批量做识别。 # layout检测
a = time.time() layout_start = time.time()
layout_res = self.layout_model(image, ignore_catids=[])
layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection cost: {layout_cost}")
# 公式检测
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
}
layout_res.append(new_item)
latex_filling_list.append(new_item)
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
mf_image_list.append(bbox_img)
# 公式识别
mfr_start = time.time()
dataset = MathDataset(mf_image_list, transform=self.mfr_transform) dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=128, num_workers=0) dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
mfr_res = [] mfr_res = []
for imgs in dataloader: for mf_img in dataloader:
imgs = imgs.to(self.device) mf_img = mf_img.to(self.device)
output = self.mfr_model.generate({'image': imgs}) output = self.mfr_model.generate({'image': mf_img})
mfr_res.extend(output['pred_str']) mfr_res.extend(output['pred_str'])
for res, latex in zip(latex_filling_list, mfr_res): for res, latex in zip(latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex) res['latex'] = latex_rm_whitespace(latex)
b = time.time() mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {round(b - a, 2)}") logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
# ocr识别 # ocr识别
if self.apply_ocr: if self.apply_ocr:
for idx, img_dict in enumerate(images): ocr_start = time.time()
image = img_dict["img"] pil_img = Image.fromarray(image)
pil_img = Image.fromarray(image) single_page_mfdetrec_res = []
single_page_res = doc_layout_result[idx]['layout_dets'] for res in layout_res:
single_page_mfdetrec_res = [] if int(res['category_id']) in [13, 14]:
for res in single_page_res: xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
if int(res['category_id']) in [13, 14]: xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) single_page_mfdetrec_res.append({
xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) "bbox": [xmin, ymin, xmax, ymax],
single_page_mfdetrec_res.append({ })
"bbox": [xmin, ymin, xmax, ymax], for res in layout_res:
}) if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
for res in single_page_res: xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别 xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) crop_box = (xmin, ymin, xmax, ymax)
xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) cropped_img = Image.new('RGB', pil_img.size, 'white')
crop_box = [xmin, ymin, xmax, ymax] cropped_img.paste(pil_img.crop(crop_box), crop_box)
cropped_img = Image.new('RGB', pil_img.size, 'white') cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
cropped_img.paste(pil_img.crop(crop_box), crop_box) ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR) if ocr_res:
ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0] for box_ocr_res in ocr_res:
if ocr_res: p1, p2, p3, p4 = box_ocr_res[0]
for box_ocr_res in ocr_res: text, score = box_ocr_res[1]
p1, p2, p3, p4 = box_ocr_res[0] layout_res.append({
text, score = box_ocr_res[1] 'category_id': 15,
doc_layout_result[idx]['layout_dets'].append({ 'poly': p1 + p2 + p3 + p4,
'category_id': 15, 'score': round(score, 2),
'poly': p1 + p2 + p3 + p4, 'text': text,
'score': round(score, 2), })
'text': text, ocr_cost = round(time.time() - ocr_start, 2)
}) logger.info(f"ocr cost: {ocr_cost}")
return doc_layout_result return layout_res
...@@ -8,6 +8,7 @@ from detectron2.data import MetadataCatalog, DatasetCatalog ...@@ -8,6 +8,7 @@ from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, DefaultPredictor from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, DefaultPredictor
def add_vit_config(cfg): def add_vit_config(cfg):
""" """
Add config for VIT. Add config for VIT.
...@@ -72,14 +73,14 @@ def setup(args): ...@@ -72,14 +73,14 @@ def setup(args):
cfg.merge_from_list(args.opts) cfg.merge_from_list(args.opts)
cfg.freeze() cfg.freeze()
default_setup(cfg, args) default_setup(cfg, args)
register_coco_instances( register_coco_instances(
"scihub_train", "scihub_train",
{}, {},
cfg.SCIHUB_DATA_DIR_TRAIN + ".json", cfg.SCIHUB_DATA_DIR_TRAIN + ".json",
cfg.SCIHUB_DATA_DIR_TRAIN cfg.SCIHUB_DATA_DIR_TRAIN
) )
return cfg return cfg
...@@ -94,10 +95,11 @@ class DotDict(dict): ...@@ -94,10 +95,11 @@ class DotDict(dict):
if isinstance(value, dict): if isinstance(value, dict):
value = DotDict(value) value = DotDict(value)
return value return value
def __setattr__(self, key, value): def __setattr__(self, key, value):
self[key] = value self[key] = value
class Layoutlmv3_Predictor(object): class Layoutlmv3_Predictor(object):
def __init__(self, weights, config_file): def __init__(self, weights, config_file):
layout_args = { layout_args = {
...@@ -113,14 +115,16 @@ class Layoutlmv3_Predictor(object): ...@@ -113,14 +115,16 @@ class Layoutlmv3_Predictor(object):
layout_args = DotDict(layout_args) layout_args = DotDict(layout_args)
cfg = setup(layout_args) cfg = setup(layout_args)
self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption", "table_footnote", "isolate_formula", "formula_caption"] self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption",
"table_footnote", "isolate_formula", "formula_caption"]
MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping
self.predictor = DefaultPredictor(cfg) self.predictor = DefaultPredictor(cfg)
def __call__(self, image, ignore_catids=[]): def __call__(self, image, ignore_catids=[]):
page_layout_result = { # page_layout_result = {
"layout_dets": [] # "layout_dets": []
} # }
layout_dets = []
outputs = self.predictor(image) outputs = self.predictor(image)
boxes = outputs["instances"].to("cpu")._fields["pred_boxes"].tensor.tolist() boxes = outputs["instances"].to("cpu")._fields["pred_boxes"].tensor.tolist()
labels = outputs["instances"].to("cpu")._fields["pred_classes"].tolist() labels = outputs["instances"].to("cpu")._fields["pred_classes"].tolist()
...@@ -128,7 +132,7 @@ class Layoutlmv3_Predictor(object): ...@@ -128,7 +132,7 @@ class Layoutlmv3_Predictor(object):
for bbox_idx in range(len(boxes)): for bbox_idx in range(len(boxes)):
if labels[bbox_idx] in ignore_catids: if labels[bbox_idx] in ignore_catids:
continue continue
page_layout_result["layout_dets"].append({ layout_dets.append({
"category_id": labels[bbox_idx], "category_id": labels[bbox_idx],
"poly": [ "poly": [
boxes[bbox_idx][0], boxes[bbox_idx][1], boxes[bbox_idx][0], boxes[bbox_idx][1],
...@@ -138,4 +142,4 @@ class Layoutlmv3_Predictor(object): ...@@ -138,4 +142,4 @@ class Layoutlmv3_Predictor(object):
], ],
"score": scores[bbox_idx] "score": scores[bbox_idx]
}) })
return page_layout_result return layout_dets
\ No newline at end of file
...@@ -136,9 +136,10 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -136,9 +136,10 @@ class ModifiedPaddleOCR(PaddleOCR):
logger.error('When input a list of images, det must be false') logger.error('When input a list of images, det must be false')
exit(0) exit(0)
if cls == True and self.use_angle_cls == False: if cls == True and self.use_angle_cls == False:
logger.warning( pass
'Since the angle classifier is not initialized, it will not be used during the forward process' # logger.warning(
) # 'Since the angle classifier is not initialized, it will not be used during the forward process'
# )
img = check_img(img) img = check_img(img)
# for infer pdf file # for infer pdf file
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment