Unverified Commit 2502db13 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub

Merge pull request #374 from myhloli/master

fix&refactor(pdf-extract-kit):  table recognition and ocr
parents ad5596fc 334ccac2
......@@ -27,7 +27,7 @@ except ImportError as e:
logger.exception(e)
logger.error(
'Required dependency not installed, please install by \n'
'"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
'"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
exit(1)
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
......@@ -188,50 +188,56 @@ class CustomPEKModel:
mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
# Select regions for OCR / formula regions / table regions
ocr_res_list = []
table_res_list = []
single_page_mfdetrec_res = []
for res in layout_res:
if int(res['category_id']) in [13, 14]:
single_page_mfdetrec_res.append({
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
int(res['poly'][4]), int(res['poly'][5])],
})
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
ocr_res_list.append(res)
elif int(res['category_id']) in [5]:
table_res_list.append(res)
# Unified crop img logic
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
# Create a white background with an additional width and height of 50
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
# Crop image
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
cropped_img = input_pil_img.crop(crop_box)
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
return return_image, return_list
pil_img = Image.fromarray(image)
# ocr识别
if self.apply_ocr:
ocr_start = time.time()
pil_img = Image.fromarray(image)
# 筛选出需要OCR的区域和公式区域
ocr_res_list = []
single_page_mfdetrec_res = []
for res in layout_res:
if int(res['category_id']) in [13, 14]:
single_page_mfdetrec_res.append({
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
int(res['poly'][4]), int(res['poly'][5])],
})
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
ocr_res_list.append(res)
# 对每一个需OCR处理的区域进行处理
# Process each area that requires OCR processing
for res in ocr_res_list:
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
paste_x = 50
paste_y = 50
# 创建一个宽高各多50的白色背景
new_width = xmax - xmin + paste_x * 2
new_height = ymax - ymin + paste_y * 2
new_image = Image.new('RGB', (new_width, new_height), 'white')
# 裁剪图像
crop_box = (xmin, ymin, xmax, ymax)
cropped_img = pil_img.crop(crop_box)
new_image.paste(cropped_img, (paste_x, paste_y))
# 调整公式区域坐标
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# Adjust the coordinates of the formula area
adjusted_mfdetrec_res = []
for mf_res in single_page_mfdetrec_res:
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
# 将公式区域坐标调整为相对于裁剪区域的坐标
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
x0 = mf_xmin - xmin + paste_x
y0 = mf_ymin - ymin + paste_y
x1 = mf_xmax - xmin + paste_x
y1 = mf_ymax - ymin + paste_y
# 过滤在图外的公式块
# Filter formula blocks outside the graph
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
continue
else:
......@@ -239,17 +245,17 @@ class CustomPEKModel:
"bbox": [x0, y0, x1, y1],
})
# OCR识别
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
# 整合结果
# Integration results
if ocr_res:
for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
# 将坐标转换回原图坐标系
# Convert the coordinates back to the original coordinate system
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
......@@ -267,35 +273,24 @@ class CustomPEKModel:
# 表格识别 table recognition
if self.apply_table:
pil_img = Image.fromarray(image)
for layout in layout_res:
if layout.get("category_id", -1) == 5:
poly = layout["poly"]
xmin, ymin = int(poly[0]), int(poly[1])
xmax, ymax = int(poly[4]), int(poly[5])
paste_x = 50
paste_y = 50
# 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
new_width = xmax - xmin + paste_x * 2
new_height = ymax - ymin + paste_y * 2
new_image = Image.new('RGB', (new_width, new_height), 'white')
# 裁剪图像 crop image
crop_box = (xmin, ymin, xmax, ymax)
cropped_img = pil_img.crop(crop_box)
new_image.paste(cropped_img, (paste_x, paste_y))
start_time = time.time()
logger.info("------------------table recognition processing begins-----------------")
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, pil_img)
single_table_start_time = time.time()
logger.info("------------------table recognition processing begins-----------------")
with torch.no_grad():
latex_code = self.table_model.image2latex(new_image)[0]
end_time = time.time()
run_time = end_time - start_time
logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time:
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
# 判断是否返回正常
if latex_code and latex_code.strip().endswith('end{tabular}'):
layout["latex"] = latex_code
else:
logger.warning(f"------------table recognition processing fails----------")
run_time = time.time() - single_table_start_time
logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time:
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
# 判断是否返回正常
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
if latex_code and expected_ending:
res["latex"] = latex_code
else:
logger.warning(f"------------table recognition processing fails----------")
table_cost = round(time.time() - table_start, 2)
logger.info(f"table cost: {table_cost}")
return layout_res
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment