Commit faac0a4d authored by icecraft's avatar icecraft

fix: 1. resolve uncorrect pair relation of figure and footnote, 2. resolve...

fix: 1. resolve uncorrect pair relation of figure and footnote, 2. resolve uncorrect pair relation of table and caption #590
parent 0140d7d2
...@@ -426,3 +426,22 @@ def bbox_distance(bbox1, bbox2): ...@@ -426,3 +426,22 @@ def bbox_distance(bbox1, bbox2):
elif top: elif top:
return y2 - y1b return y2 - y1b
return 0.0 return 0.0
def box_area(bbox):
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
def get_overlap_area(bbox1, bbox2):
"""计算box1和box2的重叠面积占bbox1的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
return (x_right - x_left) * (y_bottom - y_top)
import json import json
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance, from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, calculate_iou, bbox_relative_pos, box_area, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio) calculate_overlap_area_in_bbox1_area_ratio,
get_overlap_area)
from magic_pdf.libs.commons import fitz, join_path from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt from magic_pdf.libs.local_math import float_gt
...@@ -12,6 +13,7 @@ from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter ...@@ -12,6 +13,7 @@ from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO = 0.6 CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
class MagicModel: class MagicModel:
...@@ -124,49 +126,51 @@ class MagicModel: ...@@ -124,49 +126,51 @@ class MagicModel:
tables.append(obj) tables.append(obj)
if len(footnotes) * len(figures) == 0: if len(footnotes) * len(figures) == 0:
continue continue
dis_figure_footnote = {} dis_figure_footnote = {}
dis_table_footnote = {} dis_table_footnote = {}
for i in range(len(footnotes)): for i in range(len(footnotes)):
for j in range(len(figures)): for j in range(len(figures)):
pos_flag_count = sum( pos_flag_count = sum(
list( list(
map( map(
lambda x: 1 if x else 0, lambda x: 1 if x else 0,
bbox_relative_pos( bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox'] footnotes[i]['bbox'], figures[j]['bbox']
), ),
)
) )
) )
if pos_flag_count > 1: )
continue if pos_flag_count > 1:
dis_figure_footnote[i] = min( continue
bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']), dis_figure_footnote[i] = min(
dis_figure_footnote.get(i, float('inf')), bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
) dis_figure_footnote.get(i, float('inf')),
for i in range(len(footnotes)): )
for j in range(len(tables)): for i in range(len(footnotes)):
pos_flag_count = sum( for j in range(len(tables)):
list( pos_flag_count = sum(
map( list(
lambda x: 1 if x else 0, map(
bbox_relative_pos( lambda x: 1 if x else 0,
footnotes[i]['bbox'], tables[j]['bbox'] bbox_relative_pos(
), footnotes[i]['bbox'], tables[j]['bbox']
) ),
) )
) )
if pos_flag_count > 1: )
continue if pos_flag_count > 1:
continue
dis_table_footnote[i] = min( dis_table_footnote[i] = min(
bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']), bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')), dis_table_footnote.get(i, float('inf')),
) )
for i in range(len(footnotes)): for i in range(len(footnotes)):
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]: if i not in dis_figure_footnote:
footnotes[i]['category_id'] = CategoryId.ImageFootnote continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def __reduct_overlap(self, bboxes): def __reduct_overlap(self, bboxes):
N = len(bboxes) N = len(bboxes)
...@@ -191,6 +195,44 @@ class MagicModel: ...@@ -191,6 +195,44 @@ class MagicModel:
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离 再求出筛选出的 subjects 和 object 的最短距离
""" """
def search_overlap_between_boxes(
subject_idx, object_idx
):
idxes = [subject_idx, object_idx]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
merged_bbox = [
min(x0s),
min(y0s),
max(x1s),
max(y1s),
]
ratio = 0
other_objects = list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id']
not in (object_category_id, subject_category_id),
self.__model_list[page_no]['layout_dets'],
),
)
)
for other_object in other_objects:
ratio = max(
ratio,
get_overlap_area(
merged_bbox, other_object['bbox']
) * 1.0 / box_area(all_bboxes[object_idx]['bbox'])
)
if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
break
return ratio
def may_find_other_nearest_bbox(subject_idx, object_idx): def may_find_other_nearest_bbox(subject_idx, object_idx):
ret = float('inf') ret = float('inf')
...@@ -299,6 +341,15 @@ class MagicModel: ...@@ -299,6 +341,15 @@ class MagicModel:
): ):
continue continue
subject_idx, object_idx = i, j
if all_bboxes[j]['category_id'] == subject_category_id:
subject_idx, object_idx = j, i
if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO:
dis[i][j] = float('inf')
dis[j][i] = dis[i][j]
continue
dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox']) dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
dis[j][i] = dis[i][j] dis[j][i] = dis[i][j]
...@@ -627,13 +678,13 @@ class MagicModel: ...@@ -627,13 +678,13 @@ class MagicModel:
span['type'] = ContentType.Image span['type'] = ContentType.Image
elif category_id == 5: elif category_id == 5:
# 获取table模型结果 # 获取table模型结果
latex = layout_det.get("latex", None) latex = layout_det.get('latex', None)
html = layout_det.get("html", None) html = layout_det.get('html', None)
if latex: if latex:
span["latex"] = latex span['latex'] = latex
elif html: elif html:
span["html"] = html span['html'] = html
span["type"] = ContentType.Table span['type'] = ContentType.Table
elif category_id == 13: elif category_id == 13:
span['content'] = layout_det['latex'] span['content'] = layout_det['latex']
span['type'] = ContentType.InlineEquation span['type'] = ContentType.InlineEquation
......
...@@ -46,7 +46,7 @@ def do_parse( ...@@ -46,7 +46,7 @@ def do_parse(
end_page_id=None, end_page_id=None,
): ):
if debug_able: if debug_able:
logger.warning("debug mode is on") logger.warning('debug mode is on')
f_dump_content_list = True f_dump_content_list = True
f_draw_model_bbox = True f_draw_model_bbox = True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment