Commit 4c39bcd3 authored by 赵小蒙's avatar 赵小蒙

feat(magic-pdf): add conditional application of formula detection and recognition

parent bba53839
...@@ -141,34 +141,35 @@ class CustomPEKModel: ...@@ -141,34 +141,35 @@ class CustomPEKModel:
layout_cost = round(time.time() - layout_start, 2) layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection cost: {layout_cost}") logger.info(f"layout detection cost: {layout_cost}")
# 公式检测 if self.apply_formula:
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0] # 公式检测
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()): mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
new_item = { xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
'category_id': 13 + int(cla.item()), new_item = {
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], 'category_id': 13 + int(cla.item()),
'score': round(float(conf.item()), 2), 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'latex': '', 'score': round(float(conf.item()), 2),
} 'latex': '',
layout_res.append(new_item) }
latex_filling_list.append(new_item) layout_res.append(new_item)
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax]) latex_filling_list.append(new_item)
mf_image_list.append(bbox_img) bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
mf_image_list.append(bbox_img)
# 公式识别
mfr_start = time.time() # 公式识别
dataset = MathDataset(mf_image_list, transform=self.mfr_transform) mfr_start = time.time()
dataloader = DataLoader(dataset, batch_size=64, num_workers=0) dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
mfr_res = [] dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
for mf_img in dataloader: mfr_res = []
mf_img = mf_img.to(self.device) for mf_img in dataloader:
output = self.mfr_model.generate({'image': mf_img}) mf_img = mf_img.to(self.device)
mfr_res.extend(output['pred_str']) output = self.mfr_model.generate({'image': mf_img})
for res, latex in zip(latex_filling_list, mfr_res): mfr_res.extend(output['pred_str'])
res['latex'] = latex_rm_whitespace(latex) for res, latex in zip(latex_filling_list, mfr_res):
mfr_cost = round(time.time() - mfr_start, 2) res['latex'] = latex_rm_whitespace(latex)
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}") mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
# ocr识别 # ocr识别
if self.apply_ocr: if self.apply_ocr:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment