Unverified Commit 2502db13 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub

Merge pull request #374 from myhloli/master

fix&refactor(pdf-extract-kit):  table recognition and ocr
parents ad5596fc 334ccac2
...@@ -27,7 +27,7 @@ except ImportError as e: ...@@ -27,7 +27,7 @@ except ImportError as e:
logger.exception(e) logger.exception(e)
logger.error( logger.error(
'Required dependency not installed, please install by \n' 'Required dependency not installed, please install by \n'
'"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"') '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
exit(1) exit(1)
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
...@@ -188,13 +188,9 @@ class CustomPEKModel: ...@@ -188,13 +188,9 @@ class CustomPEKModel:
mfr_cost = round(time.time() - mfr_start, 2) mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}") logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
# ocr识别 # Select regions for OCR / formula regions / table regions
if self.apply_ocr:
ocr_start = time.time()
pil_img = Image.fromarray(image)
# 筛选出需要OCR的区域和公式区域
ocr_res_list = [] ocr_res_list = []
table_res_list = []
single_page_mfdetrec_res = [] single_page_mfdetrec_res = []
for res in layout_res: for res in layout_res:
if int(res['category_id']) in [13, 14]: if int(res['category_id']) in [13, 14]:
...@@ -204,34 +200,44 @@ class CustomPEKModel: ...@@ -204,34 +200,44 @@ class CustomPEKModel:
}) })
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]: elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
ocr_res_list.append(res) ocr_res_list.append(res)
elif int(res['category_id']) in [5]:
table_res_list.append(res)
# Unified crop img logic
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
# Create a white background with an additional width and height of 50
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
# Crop image
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
cropped_img = input_pil_img.crop(crop_box)
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
return return_image, return_list
pil_img = Image.fromarray(image)
# 对每一个需OCR处理的区域进行处理 # ocr识别
if self.apply_ocr:
ocr_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list: for res in ocr_res_list:
xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# Adjust the coordinates of the formula area
paste_x = 50
paste_y = 50
# 创建一个宽高各多50的白色背景
new_width = xmax - xmin + paste_x * 2
new_height = ymax - ymin + paste_y * 2
new_image = Image.new('RGB', (new_width, new_height), 'white')
# 裁剪图像
crop_box = (xmin, ymin, xmax, ymax)
cropped_img = pil_img.crop(crop_box)
new_image.paste(cropped_img, (paste_x, paste_y))
# 调整公式区域坐标
adjusted_mfdetrec_res = [] adjusted_mfdetrec_res = []
for mf_res in single_page_mfdetrec_res: for mf_res in single_page_mfdetrec_res:
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"] mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
# 将公式区域坐标调整为相对于裁剪区域的坐标 # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
x0 = mf_xmin - xmin + paste_x x0 = mf_xmin - xmin + paste_x
y0 = mf_ymin - ymin + paste_y y0 = mf_ymin - ymin + paste_y
x1 = mf_xmax - xmin + paste_x x1 = mf_xmax - xmin + paste_x
y1 = mf_ymax - ymin + paste_y y1 = mf_ymax - ymin + paste_y
# 过滤在图外的公式块 # Filter formula blocks outside the graph
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]): if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
continue continue
else: else:
...@@ -239,17 +245,17 @@ class CustomPEKModel: ...@@ -239,17 +245,17 @@ class CustomPEKModel:
"bbox": [x0, y0, x1, y1], "bbox": [x0, y0, x1, y1],
}) })
# OCR识别 # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
# 整合结果 # Integration results
if ocr_res: if ocr_res:
for box_ocr_res in ocr_res: for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0] p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1] text, score = box_ocr_res[1]
# 将坐标转换回原图坐标系 # Convert the coordinates back to the original coordinate system
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin] p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin] p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin] p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
...@@ -267,35 +273,24 @@ class CustomPEKModel: ...@@ -267,35 +273,24 @@ class CustomPEKModel:
# 表格识别 table recognition # 表格识别 table recognition
if self.apply_table: if self.apply_table:
pil_img = Image.fromarray(image) table_start = time.time()
for layout in layout_res: for res in table_res_list:
if layout.get("category_id", -1) == 5: new_image, _ = crop_img(res, pil_img)
poly = layout["poly"] single_table_start_time = time.time()
xmin, ymin = int(poly[0]), int(poly[1])
xmax, ymax = int(poly[4]), int(poly[5])
paste_x = 50
paste_y = 50
# 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
new_width = xmax - xmin + paste_x * 2
new_height = ymax - ymin + paste_y * 2
new_image = Image.new('RGB', (new_width, new_height), 'white')
# 裁剪图像 crop image
crop_box = (xmin, ymin, xmax, ymax)
cropped_img = pil_img.crop(crop_box)
new_image.paste(cropped_img, (paste_x, paste_y))
start_time = time.time()
logger.info("------------------table recognition processing begins-----------------") logger.info("------------------table recognition processing begins-----------------")
with torch.no_grad():
latex_code = self.table_model.image2latex(new_image)[0] latex_code = self.table_model.image2latex(new_image)[0]
end_time = time.time() run_time = time.time() - single_table_start_time
run_time = end_time - start_time
logger.info(f"------------table recognition processing ends within {run_time}s-----") logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time: if run_time > self.table_max_time:
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------") logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
# 判断是否返回正常 # 判断是否返回正常
if latex_code and latex_code.strip().endswith('end{tabular}'): expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
layout["latex"] = latex_code if latex_code and expected_ending:
res["latex"] = latex_code
else: else:
logger.warning(f"------------table recognition processing fails----------") logger.warning(f"------------table recognition processing fails----------")
table_cost = round(time.time() - table_start, 2)
logger.info(f"table cost: {table_cost}")
return layout_res return layout_res
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment