Commit 86d7cff1 authored by 许瑞's avatar 许瑞

fix: table caption relation

parent 3a0a08e4
...@@ -21,6 +21,7 @@ from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum ...@@ -21,6 +21,7 @@ from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
CAPATION_OVERLAP_AREA_RATIO = 0.6 CAPATION_OVERLAP_AREA_RATIO = 0.6
class MagicModel: class MagicModel:
""" """
每个函数没有得到元素的时候返回空list 每个函数没有得到元素的时候返回空list
...@@ -89,23 +90,37 @@ class MagicModel: ...@@ -89,23 +90,37 @@ class MagicModel:
ret = [] ret = []
MAX_DIS_OF_POINT = 10**9 + 7 MAX_DIS_OF_POINT = 10**9 + 7
def expand_bbox(bbox1, bbox2):
x0 = min(bbox1[0], bbox2[0])
y0 = min(bbox1[1], bbox2[1])
x1 = max(bbox1[2], bbox2[2])
y1 = max(bbox1[3], bbox2[3])
return [x0, y0, x1, y1]
def get_bbox_area(bbox):
return abs(bbox[2] - bbox[0]) * abs(bbox[3] - bbox[1])
# subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。 # subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
# 再求出筛选出的 subjects 和 object 的最短距离! # 再求出筛选出的 subjects 和 object 的最短距离!
def may_find_other_nearest_bbox(subject_idx, object_idx): def may_find_other_nearest_bbox(subject_idx, object_idx):
ret = float("inf") ret = float("inf")
x0 = min(all_bboxes[subject_idx]["bbox"][0], all_bboxes[object_idx]["bbox"][0]) x0, y0, x1, y1 = expand_bbox(
y0 = min(all_bboxes[subject_idx]["bbox"][1], all_bboxes[object_idx]["bbox"][1]) all_bboxes[subject_idx]["bbox"], all_bboxes[object_idx]["bbox"]
x1 = max(all_bboxes[subject_idx]["bbox"][2], all_bboxes[object_idx]["bbox"][2]) )
y1 = max(all_bboxes[subject_idx]["bbox"][3], all_bboxes[object_idx]["bbox"][3])
object_area = abs(all_bboxes[object_idx]["bbox"][2] - all_bboxes[object_idx]["bbox"][0]) * abs(all_bboxes[object_idx]["bbox"][3] - all_bboxes[object_idx]["bbox"][1]) object_area = get_bbox_area(all_bboxes[object_idx]["bbox"])
for i in range(len(all_bboxes)): for i in range(len(all_bboxes)):
if i == subject_idx or all_bboxes[i]["category_id"] != subject_category_id: if (
i == subject_idx
or all_bboxes[i]["category_id"] != subject_category_id
):
continue continue
if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]["bbox"]) or _is_in(all_bboxes[i]["bbox"], [x0, y0, x1, y1]): if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]["bbox"]) or _is_in(
i_area = abs(all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]) * abs(all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1]) all_bboxes[i]["bbox"], [x0, y0, x1, y1]
):
i_area = get_bbox_area(all_bboxes[i]["bbox"])
if i_area >= object_area: if i_area >= object_area:
ret = min(float("inf"), dis[i][object_idx]) ret = min(ret, dis[i][object_idx])
return ret return ret
subjects = self.__reduct_overlap( subjects = self.__reduct_overlap(
...@@ -190,7 +205,7 @@ class MagicModel: ...@@ -190,7 +205,7 @@ class MagicModel:
arr.sort(key=lambda x: x[0]) arr.sort(key=lambda x: x[0])
if len(arr) > 0: if len(arr) > 0:
# bug: 离该subject 最近的 object 可能跨越了其它的 subject 。比如 [this subect] [some sbuject] [the nearest objec of subject] # bug: 离该subject 最近的 object 可能跨越了其它的 subject 。比如 [this subect] [some sbuject] [the nearest objec of subject]
if may_find_other_nearest_bbox(i, j) >= arr[0][0]: if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
candidates.append(arr[0][1]) candidates.append(arr[0][1])
seen.add(arr[0][1]) seen.add(arr[0][1])
...@@ -266,7 +281,12 @@ class MagicModel: ...@@ -266,7 +281,12 @@ class MagicModel:
for bbox in caption_poses: for bbox in caption_poses:
embed_arr = [] embed_arr = []
for idx in seen: for idx in seen:
if calculate_overlap_area_in_bbox1_area_ratio(all_bboxes[idx]["bbox"], bbox) > CAPATION_OVERLAP_AREA_RATIO: if (
calculate_overlap_area_in_bbox1_area_ratio(
all_bboxes[idx]["bbox"], bbox
)
> CAPATION_OVERLAP_AREA_RATIO
):
embed_arr.append(idx) embed_arr.append(idx)
if len(embed_arr) > 0: if len(embed_arr) > 0:
...@@ -286,7 +306,12 @@ class MagicModel: ...@@ -286,7 +306,12 @@ class MagicModel:
caption_bbox = caption_poses[max_area_idx] caption_bbox = caption_poses[max_area_idx]
for j in seen: for j in seen:
if calculate_overlap_area_in_bbox1_area_ratio(all_bboxes[j]["bbox"], caption_bbox) > CAPATION_OVERLAP_AREA_RATIO: if (
calculate_overlap_area_in_bbox1_area_ratio(
all_bboxes[j]["bbox"], caption_bbox
)
> CAPATION_OVERLAP_AREA_RATIO
):
used.add(j) used.add(j)
subject_object_relation_map[i].append(j) subject_object_relation_map[i].append(j)
...@@ -482,6 +507,9 @@ class MagicModel: ...@@ -482,6 +507,9 @@ class MagicModel:
blocks.append(block) blocks.append(block)
return blocks return blocks
def get_model_list(self, page_no):
return self.__model_list[page_no]
if __name__ == "__main__": if __name__ == "__main__":
drw = DiskReaderWriter(r"D:/project/20231108code-clean") drw = DiskReaderWriter(r"D:/project/20231108code-clean")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment