Unverified Commit 95e7e3a7 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub

Merge pull request #160 from icecraft/fix/figure_caption_relation

fix: object cluster algorithm
parents 6a9ad924 ddff4b42
......@@ -15,7 +15,8 @@ from magic_pdf.libs.boxbase import (
bbox_relative_pos,
bbox_distance,
_is_part_overlap,
calculate_overlap_area_in_bbox1_area_ratio, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio,
calculate_iou,
)
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
......@@ -78,9 +79,23 @@ class MagicModel:
for layout_det2 in layout_dets:
if layout_det1 == layout_det2:
continue
if layout_det1["category_id"] in [0,1,2,3,4,5,6,7,8,9] and layout_det2["category_id"] in [0,1,2,3,4,5,6,7,8,9]:
if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
if layout_det1['score'] < layout_det2['score']:
if layout_det1["category_id"] in [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
] and layout_det2["category_id"] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if (
calculate_iou(layout_det1["bbox"], layout_det2["bbox"])
> 0.9
):
if layout_det1["score"] < layout_det2["score"]:
layout_det_need_remove = layout_det1
else:
layout_det_need_remove = layout_det2
......@@ -97,11 +112,11 @@ class MagicModel:
def __init__(self, model_list: list, docs: fitz.Document):
self.__model_list = model_list
self.__docs = docs
'''为所有模型数据添加bbox信息(缩放,poly->bbox)'''
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
self.__fix_axis()
'''删除置信度特别低的模型数据(<0.05),提高质量'''
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self.__fix_by_remove_low_confidence()
'''删除高iou(>0.9)数据中置信度较低的那个'''
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
def __reduct_overlap(self, bboxes):
......@@ -125,16 +140,6 @@ class MagicModel:
ret = []
MAX_DIS_OF_POINT = 10**9 + 7
def expand_bbox(bbox1, bbox2):
x0 = min(bbox1[0], bbox2[0])
y0 = min(bbox1[1], bbox2[1])
x1 = max(bbox1[2], bbox2[2])
y1 = max(bbox1[3], bbox2[3])
return [x0, y0, x1, y1]
def get_bbox_area(bbox):
return abs(bbox[2] - bbox[0]) * abs(bbox[3] - bbox[1])
# subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
# 再求出筛选出的 subjects 和 object 的最短距离!
def may_find_other_nearest_bbox(subject_idx, object_idx):
......@@ -177,6 +182,13 @@ class MagicModel:
return ret
def expand_bbbox(idxes):
x0s = [all_bboxes[idx]["bbox"][0] for idx in idxes]
y0s = [all_bboxes[idx]["bbox"][1] for idx in idxes]
x1s = [all_bboxes[idx]["bbox"][2] for idx in idxes]
y1s = [all_bboxes[idx]["bbox"][3] for idx in idxes]
return min(x0s), min(y0s), max(x1s), max(y1s)
subjects = self.__reduct_overlap(
list(
map(
......@@ -268,7 +280,9 @@ class MagicModel:
or dis[i][j] == MAX_DIS_OF_POINT
):
continue
left, right, _, _ = bbox_relative_pos(all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性
left, right, _, _ = bbox_relative_pos(
all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]
) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性
if left or right:
one_way_dis = all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]
else:
......@@ -322,6 +336,10 @@ class MagicModel:
break
if is_nearest:
nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
n_dis = bbox_distance(all_bboxes[i]["bbox"], [nx0, ny0, nx1, ny1])
if float_gt(dis[i][j], n_dis):
continue
tmp.append(k)
seen.add(k)
......@@ -331,20 +349,7 @@ class MagicModel:
# 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
# 先扩一下 bbox,
x0s = [all_bboxes[idx]["bbox"][0] for idx in seen] + [
all_bboxes[i]["bbox"][0]
]
y0s = [all_bboxes[idx]["bbox"][1] for idx in seen] + [
all_bboxes[i]["bbox"][1]
]
x1s = [all_bboxes[idx]["bbox"][2] for idx in seen] + [
all_bboxes[i]["bbox"][2]
]
y1s = [all_bboxes[idx]["bbox"][3] for idx in seen] + [
all_bboxes[i]["bbox"][3]
]
ox0, oy0, ox1, oy1 = min(x0s), min(y0s), max(x1s), max(y1s)
ox0, oy0, ox1, oy1 = expand_bbbox(list(seen) + [i])
ix0, iy0, ix1, iy1 = all_bboxes[i]["bbox"]
# 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
......@@ -455,8 +460,10 @@ class MagicModel:
with_caption_subject.add(j)
return ret, total_subject_object_dis
def get_imgs(self, page_no: int): # @许瑞
records, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
def get_imgs(self, page_no: int):
figure_captions, _ = self.__tie_up_category_by_distance(
page_no, 3, 4
)
return [
{
"bbox": record["all"],
......@@ -464,7 +471,7 @@ class MagicModel:
"img_caption_bbox": record.get("object_body", None),
"score": record["score"],
}
for record in records
for record in figure_captions
]
def get_tables(
......@@ -535,6 +542,7 @@ class MagicModel:
if not any(span == existing_span for existing_span in new_spans):
new_spans.append(span)
return new_spans
all_spans = []
model_page_info = self.__model_list[page_no]
layout_dets = model_page_info["layout_dets"]
......@@ -548,10 +556,7 @@ class MagicModel:
for layout_det in layout_dets:
category_id = layout_det["category_id"]
if category_id in allow_category_id_list:
span = {
"bbox": layout_det["bbox"],
"score": layout_det["score"]
}
span = {"bbox": layout_det["bbox"], "score": layout_det["score"]}
if category_id == 3:
span["type"] = ContentType.Image
elif category_id == 5:
......@@ -604,7 +609,6 @@ class MagicModel:
return self.__model_list[page_no]
if __name__ == "__main__":
drw = DiskReaderWriter(r"D:/project/20231108code-clean")
if 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment