Commit e5907296 authored by 赵小蒙's avatar 赵小蒙

fix span overlap by confidence,remove duplicate spans

parent 34651637
...@@ -161,6 +161,17 @@ def __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8) ...@@ -161,6 +161,17 @@ def __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8)
def calculate_iou(bbox1, bbox2): def calculate_iou(bbox1, bbox2):
"""
计算两个边界框的交并比(IOU)。
Args:
bbox1 (list[float]): 第一个边界框的坐标,格式为 [x1, y1, x2, y2],其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
bbox2 (list[float]): 第二个边界框的坐标,格式与 `bbox1` 相同。
Returns:
float: 两个边界框的交并比(IOU),取值范围为 [0, 1]。
"""
# Determine the coordinates of the intersection rectangle # Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0]) x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1]) y_top = max(bbox1[1], bbox2[1])
......
...@@ -448,6 +448,12 @@ class MagicModel: ...@@ -448,6 +448,12 @@ class MagicModel:
return text_spans return text_spans
def get_all_spans(self, page_no: int) -> list: def get_all_spans(self, page_no: int) -> list:
def remove_duplicate_spans(spans):
new_spans = []
for span in spans:
if not any(span == existing_span for existing_span in new_spans):
new_spans.append(span)
return new_spans
all_spans = [] all_spans = []
model_page_info = self.__model_list[page_no] model_page_info = self.__model_list[page_no]
layout_dets = model_page_info["layout_dets"] layout_dets = model_page_info["layout_dets"]
...@@ -461,7 +467,10 @@ class MagicModel: ...@@ -461,7 +467,10 @@ class MagicModel:
for layout_det in layout_dets: for layout_det in layout_dets:
category_id = layout_det["category_id"] category_id = layout_det["category_id"]
if category_id in allow_category_id_list: if category_id in allow_category_id_list:
span = {"bbox": layout_det["bbox"]} span = {
"bbox": layout_det["bbox"],
"score": layout_det["score"]
}
if category_id == 3: if category_id == 3:
span["type"] = ContentType.Image span["type"] = ContentType.Image
elif category_id == 5: elif category_id == 5:
...@@ -476,7 +485,7 @@ class MagicModel: ...@@ -476,7 +485,7 @@ class MagicModel:
span["content"] = layout_det["text"] span["content"] = layout_det["text"]
span["type"] = ContentType.Text span["type"] = ContentType.Text
all_spans.append(span) all_spans.append(span)
return all_spans return remove_duplicate_spans(all_spans)
def get_page_size(self, page_no: int): # 获取页面宽高 def get_page_size(self, page_no: int): # 获取页面宽高
# 获取当前页的page对象 # 获取当前页的page对象
......
...@@ -19,7 +19,8 @@ from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, re ...@@ -19,7 +19,8 @@ from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, re
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split
from magic_pdf.pre_proc.ocr_dict_merge import sort_blocks_by_layout, fill_spans_in_blocks, fix_block_spans, \ from magic_pdf.pre_proc.ocr_dict_merge import sort_blocks_by_layout, fill_spans_in_blocks, fix_block_spans, \
fix_discarded_block fix_discarded_block
from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2 from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \
remove_overlaps_low_confidence_spans
from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap
...@@ -117,6 +118,8 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, ...@@ -117,6 +118,8 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
else: else:
raise Exception("parse_mode must be txt or ocr") raise Exception("parse_mode must be txt or ocr")
'''删除重叠spans中置信度较低的那些'''
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
'''删除重叠spans中较小的那些''' '''删除重叠spans中较小的那些'''
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans) spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
'''对image和table截图''' '''对image和table截图'''
......
from loguru import logger from loguru import logger
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, get_minbox_if_overlap_by_ratio, \ from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, get_minbox_if_overlap_by_ratio, \
__is_overlaps_y_exceeds_threshold __is_overlaps_y_exceeds_threshold, calculate_iou
from magic_pdf.libs.drop_tag import DropTag from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.ocr_content_type import ContentType, BlockType from magic_pdf.libs.ocr_content_type import ContentType, BlockType
def remove_overlaps_low_confidence_spans(spans):
dropped_spans = []
# 删除重叠spans中置信度低的的那些
for span1 in spans:
for span2 in spans:
if span1 != span2:
if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
if span1['score'] < span2['score']:
span_need_remove = span1
else:
span_need_remove = span2
if span_need_remove is not None and span_need_remove not in dropped_spans:
dropped_spans.append(span_need_remove)
if len(dropped_spans) > 0:
for span_need_remove in dropped_spans:
spans.remove(span_need_remove)
span_need_remove['tag'] = DropTag.SPAN_OVERLAP
return spans, dropped_spans
def remove_overlaps_min_spans(spans): def remove_overlaps_min_spans(spans):
dropped_spans = [] dropped_spans = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment