Unverified Commit 763688c0 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub

Merge pull request #695 from icecraft/fix/caption_match

fix: caption or footnote match algorithm
parents 3458f85a ef45ad08
...@@ -110,6 +110,26 @@ class MagicModel: ...@@ -110,6 +110,26 @@ class MagicModel:
self.__fix_by_remove_high_iou_and_low_confidence() self.__fix_by_remove_high_iou_and_low_confidence()
self.__fix_footnote() self.__fix_footnote()
def _bbox_distance(self, bbox1, bbox2):
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
count = sum([1 if v else 0 for v in flags])
if count > 1:
return float('inf')
if left or right:
l1 = bbox1[3] - bbox1[1]
l2 = bbox2[3] - bbox2[1]
minL, maxL = min(l1, l2), max(l1, l2)
if (maxL - minL) / minL > 0.5:
return float('inf')
if bottom or top:
l1 = bbox1[2] - bbox1[0]
l2 = bbox2[2] - bbox2[0]
minL, maxL = min(l1, l2), max(l1, l2)
if (maxL - minL) / minL > 0.5:
return float('inf')
return bbox_distance(bbox1, bbox2)
def __fix_footnote(self): def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote # 3: figure, 5: table, 7: footnote
for model_page_info in self.__model_list: for model_page_info in self.__model_list:
...@@ -144,7 +164,7 @@ class MagicModel: ...@@ -144,7 +164,7 @@ class MagicModel:
if pos_flag_count > 1: if pos_flag_count > 1:
continue continue
dis_figure_footnote[i] = min( dis_figure_footnote[i] = min(
bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']), self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')), dis_figure_footnote.get(i, float('inf')),
) )
for i in range(len(footnotes)): for i in range(len(footnotes)):
...@@ -163,7 +183,7 @@ class MagicModel: ...@@ -163,7 +183,7 @@ class MagicModel:
continue continue
dis_table_footnote[i] = min( dis_table_footnote[i] = min(
bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']), self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')), dis_table_footnote.get(i, float('inf')),
) )
for i in range(len(footnotes)): for i in range(len(footnotes)):
...@@ -350,7 +370,7 @@ class MagicModel: ...@@ -350,7 +370,7 @@ class MagicModel:
dis[j][i] = dis[i][j] dis[j][i] = dis[i][j]
continue continue
dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox']) dis[i][j] = self._bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
dis[j][i] = dis[i][j] dis[j][i] = dis[i][j]
used = set() used = set()
...@@ -441,7 +461,7 @@ class MagicModel: ...@@ -441,7 +461,7 @@ class MagicModel:
if is_nearest: if is_nearest:
nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k]) nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
n_dis = bbox_distance( n_dis = self._bbox_distance(
all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1] all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
) )
if float_gt(dis[i][j], n_dis): if float_gt(dis[i][j], n_dis):
...@@ -537,7 +557,7 @@ class MagicModel: ...@@ -537,7 +557,7 @@ class MagicModel:
# 计算已经配对的 distance 距离 # 计算已经配对的 distance 距离
for i in subject_object_relation_map.keys(): for i in subject_object_relation_map.keys():
for j in subject_object_relation_map[i]: for j in subject_object_relation_map[i]:
total_subject_object_dis += bbox_distance( total_subject_object_dis += self._bbox_distance(
all_bboxes[i]['bbox'], all_bboxes[j]['bbox'] all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment