Commit e7ce3051 authored by myhloli's avatar myhloli

fix(magic_pdf): optimize formula area selection for OCR

parent 5f992de4
......@@ -168,33 +168,74 @@ class CustomPEKModel:
if self.apply_ocr:
ocr_start = time.time()
pil_img = Image.fromarray(image)
# 筛选出需要OCR的区域和公式区域
ocr_res_list = []
single_page_mfdetrec_res = []
for res in layout_res:
if int(res['category_id']) in [13, 14]:
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
single_page_mfdetrec_res.append({
"bbox": [xmin, ymin, xmax, ymax],
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
int(res['poly'][4]), int(res['poly'][5])],
})
for res in layout_res:
if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
crop_box = (xmin, ymin, xmax, ymax)
cropped_img = Image.new('RGB', pil_img.size, 'white')
cropped_img.paste(pil_img.crop(crop_box), crop_box)
cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
if ocr_res:
for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
layout_res.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': round(score, 2),
'text': text,
})
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
ocr_res_list.append(res)
# 对每一个需OCR处理的区域进行处理
for res in ocr_res_list:
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
paste_x = 50
paste_y = 50
# 创建一个宽高各多50的白色背景
new_width = xmax - xmin + paste_x*2
new_height = ymax - ymin + paste_y*2
new_image = Image.new('RGB', (new_width, new_height), 'white')
# 裁剪图像
crop_box = (xmin, ymin, xmax, ymax)
cropped_img = pil_img.crop(crop_box)
new_image.paste(cropped_img, (paste_x, paste_y))
# 调整公式区域坐标
adjusted_mfdetrec_res = []
for mf_res in single_page_mfdetrec_res:
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
# 将公式区域坐标调整为相对于裁剪区域的坐标
x0 = mf_xmin - xmin + paste_x
y0 = mf_ymin - ymin + paste_y
x1 = mf_xmax - xmin + paste_x
y1 = mf_ymax - ymin + paste_y
if any([x0 < 0, y0 < 0, x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height, x1 > new_width, y1 > new_height]):
continue
else:
adjusted_mfdetrec_res.append({
"bbox": [x0, y0, x1, y1],
})
# OCR识别
ocr_res = self.ocr_model.ocr(np.array(new_image), mfd_res=adjusted_mfdetrec_res)[0]
# 整合结果
if ocr_res:
for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
# 将坐标转换回原图坐标系
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
layout_res.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': round(score, 2),
'text': text,
})
ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr cost: {ocr_cost}")
......
......@@ -10,12 +10,17 @@ from paddleocr import PaddleOCR
from paddleocr.ppocr.utils.logging import get_logger
from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img
from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
logger = get_logger()
def img_decode(content: bytes):
np_arr = np.frombuffer(content, dtype=np.uint8)
return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
def check_img(img):
if isinstance(img, bytes):
img = img_decode(img)
......@@ -51,6 +56,7 @@ def check_img(img):
return img
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
......@@ -75,49 +81,87 @@ def sorted_boxes(dt_boxes):
return _boxes
def formula_in_text(mf_bbox, text_bbox):
x1, y1, x2, y2 = mf_bbox
x3, y3 = text_bbox[0]
x4, y4 = text_bbox[2]
left_box, right_box = None, None
same_line = abs((y1+y2)/2 - (y3+y4)/2) / abs(y4-y3) < 0.2
if not same_line:
return False, left_box, right_box
else:
drop_origin = False
left_x = x1 - 1
right_x = x2 + 1
if x3 < x1 and x2 < x4:
drop_origin = True
left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32')
right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32')
if x3 < x1 and x1 <= x4 <= x2:
drop_origin = True
left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32')
if x1 <= x3 <= x2 and x2 < x4:
drop_origin = True
right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32')
if x1 <= x3 < x4 <= x2:
drop_origin = True
return drop_origin, left_box, right_box
def update_det_boxes(dt_boxes, mfdetrec_res):
new_dt_boxes = dt_boxes
for mf_box in mfdetrec_res:
flag, left_box, right_box = False, None, None
for idx, text_box in enumerate(new_dt_boxes):
ret, left_box, right_box = formula_in_text(mf_box['bbox'], text_box)
if ret:
new_dt_boxes.pop(idx)
if left_box is not None:
new_dt_boxes.append(left_box)
if right_box is not None:
new_dt_boxes.append(right_box)
break
def bbox_to_points(bbox):
""" 将bbox格式转换为四个顶点的数组 """
x0, y0, x1, y1 = bbox
return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
def points_to_bbox(points):
""" 将四个顶点的数组转换为bbox格式 """
x0, y0 = points[0]
x1, _ = points[1]
_, y1 = points[2]
return [x0, y0, x1, y1]
def merge_intervals(intervals):
# Sort the intervals based on the start value
intervals.sort(key=lambda x: x[0])
merged = []
for interval in intervals:
# If the list of merged intervals is empty or if the current
# interval does not overlap with the previous, simply append it.
if not merged or merged[-1][1] < interval[0]:
merged.append(interval)
else:
# Otherwise, there is overlap, so we merge the current and previous intervals.
merged[-1][1] = max(merged[-1][1], interval[1])
return merged
def remove_intervals(original, masks):
# Merge all mask intervals
merged_masks = merge_intervals(masks)
result = []
original_start, original_end = original
for mask in merged_masks:
mask_start, mask_end = mask
# If the mask starts after the original range, ignore it
if mask_start > original_end:
continue
# If the mask ends before the original range starts, ignore it
if mask_end < original_start:
continue
# Remove the masked part from the original range
if original_start < mask_start:
result.append([original_start, mask_start - 1])
original_start = max(mask_end + 1, original_start)
# Add the remaining part of the original range, if any
if original_start <= original_end:
result.append([original_start, original_end])
return result
def update_det_boxes(dt_boxes, mfd_res):
new_dt_boxes = []
for text_box in dt_boxes:
text_bbox = points_to_bbox(text_box)
masks_list = []
for mf_box in mfd_res:
mf_bbox = mf_box['bbox']
if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
masks_list.append([mf_bbox[0], mf_bbox[2]])
text_x_range = [text_bbox[0], text_bbox[2]]
text_remove_mask_range = remove_intervals(text_x_range, masks_list)
temp_dt_box = []
for text_remove_mask in text_remove_mask_range:
temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
if len(temp_dt_box) > 0:
new_dt_boxes.extend(temp_dt_box)
return new_dt_boxes
class ModifiedPaddleOCR(PaddleOCR):
def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
"""
......@@ -197,7 +241,7 @@ class ModifiedPaddleOCR(PaddleOCR):
if not rec:
return cls_res
return ocr_res
def __call__(self, img, cls=True, mfd_res=None):
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
......@@ -226,7 +270,7 @@ class ModifiedPaddleOCR(PaddleOCR):
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
aft = time.time()
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
len(dt_boxes), aft-bef))
len(dt_boxes), aft - bef))
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
......@@ -257,4 +301,60 @@ class ModifiedPaddleOCR(PaddleOCR):
filter_rec_res.append(rec_result)
end = time.time()
time_dict['all'] = end - start
return filter_boxes, filter_rec_res, time_dict
\ No newline at end of file
return filter_boxes, filter_rec_res, time_dict
if __name__ == '__main__':
def merge_intervals(intervals):
# Sort the intervals based on the start value
intervals.sort(key=lambda x: x[0])
merged = []
for interval in intervals:
# If the list of merged intervals is empty or if the current
# interval does not overlap with the previous, simply append it.
if not merged or merged[-1][1] < interval[0]:
merged.append(interval)
else:
# Otherwise, there is overlap, so we merge the current and previous intervals.
merged[-1][1] = max(merged[-1][1], interval[1])
return merged
def remove_intervals(original, masks):
# Merge all mask intervals
merged_masks = merge_intervals(masks)
result = []
original_start, original_end = original
for mask in merged_masks:
mask_start, mask_end = mask
# If the mask starts after the original range, ignore it
if mask_start > original_end:
continue
# If the mask ends before the original range starts, ignore it
if mask_end < original_start:
continue
# Remove the masked part from the original range
if original_start < mask_start:
result.append([original_start, mask_start - 1])
original_start = max(mask_end + 1, original_start)
# Add the remaining part of the original range, if any
if original_start <= original_end:
result.append([original_start, original_end])
return result
# Test the function
original_range = [1, 100]
masks = [[0, 15], [25, 40], [55, 80]]
result = remove_intervals(original_range, masks)
print(result) # Expected output: [[1, 4], [21, 59], [81, 100]]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment