Unverified Commit a1a1d177 authored by myhloli's avatar myhloli Committed by GitHub

Merge pull request #78 from icecraft/debug/txt_pipeline

parents 683fa633 6a3d1f2d
...@@ -10,9 +10,16 @@ from magic_pdf.libs.ocr_content_type import ContentType ...@@ -10,9 +10,16 @@ from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
from magic_pdf.libs.math import float_gt from magic_pdf.libs.math import float_gt
from magic_pdf.libs.boxbase import _is_in, bbox_relative_pos, bbox_distance from magic_pdf.libs.boxbase import (
_is_in,
bbox_relative_pos,
bbox_distance,
_is_part_overlap,
calculate_overlap_area_in_bbox1_area_ratio,
)
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
CAPATION_OVERLAP_AREA_RATIO = 0.6
class MagicModel: class MagicModel:
""" """
...@@ -74,13 +81,13 @@ class MagicModel: ...@@ -74,13 +81,13 @@ class MagicModel:
return [bboxes[i] for i in range(N) if keep[i]] return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance( def __tie_up_category_by_distance(
self, page_no, subject_category_id, object_category_id self, page_no, subject_category_id, object_category_id
): ):
""" """
假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object 只能属于一个 subject 假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object 只能属于一个 subject
""" """
ret = [] ret = []
MAX_DIS_OF_POINT = 10 ** 9 + 7 MAX_DIS_OF_POINT = 10**9 + 7
subjects = self.__reduct_overlap( subjects = self.__reduct_overlap(
list( list(
...@@ -123,8 +130,8 @@ class MagicModel: ...@@ -123,8 +130,8 @@ class MagicModel:
for i in range(N): for i in range(N):
for j in range(i): for j in range(i):
if ( if (
all_bboxes[i]["category_id"] == subject_category_id all_bboxes[i]["category_id"] == subject_category_id
and all_bboxes[j]["category_id"] == subject_category_id and all_bboxes[j]["category_id"] == subject_category_id
): ):
continue continue
...@@ -154,9 +161,9 @@ class MagicModel: ...@@ -154,9 +161,9 @@ class MagicModel:
if pos_flag_count > 1: if pos_flag_count > 1:
continue continue
if ( if (
all_bboxes[j]["category_id"] != object_category_id all_bboxes[j]["category_id"] != object_category_id
or j in used or j in used
or dis[i][j] == MAX_DIS_OF_POINT or dis[i][j] == MAX_DIS_OF_POINT
): ):
continue continue
arr.append((dis[i][j], j)) arr.append((dis[i][j], j))
...@@ -185,10 +192,10 @@ class MagicModel: ...@@ -185,10 +192,10 @@ class MagicModel:
continue continue
if ( if (
all_bboxes[k]["category_id"] != object_category_id all_bboxes[k]["category_id"] != object_category_id
or k in used or k in used
or k in seen or k in seen
or dis[j][k] == MAX_DIS_OF_POINT or dis[j][k] == MAX_DIS_OF_POINT
): ):
continue continue
is_nearest = True is_nearest = True
...@@ -238,7 +245,7 @@ class MagicModel: ...@@ -238,7 +245,7 @@ class MagicModel:
for bbox in caption_poses: for bbox in caption_poses:
embed_arr = [] embed_arr = []
for idx in seen: for idx in seen:
if _is_in(all_bboxes[idx]["bbox"], bbox): if calculate_overlap_area_in_bbox1_area_ratio(all_bboxes[idx]["bbox"], bbox) > CAPATION_OVERLAP_AREA_RATIO:
embed_arr.append(idx) embed_arr.append(idx)
if len(embed_arr) > 0: if len(embed_arr) > 0:
...@@ -258,7 +265,7 @@ class MagicModel: ...@@ -258,7 +265,7 @@ class MagicModel:
caption_bbox = caption_poses[max_area_idx] caption_bbox = caption_poses[max_area_idx]
for j in seen: for j in seen:
if _is_in(all_bboxes[j]["bbox"], caption_bbox): if calculate_overlap_area_in_bbox1_area_ratio(all_bboxes[j]["bbox"], caption_bbox) > CAPATION_OVERLAP_AREA_RATIO:
used.add(j) used.add(j)
subject_object_relation_map[i].append(j) subject_object_relation_map[i].append(j)
...@@ -312,8 +319,8 @@ class MagicModel: ...@@ -312,8 +319,8 @@ class MagicModel:
candidates = [] candidates = []
for j in range(N): for j in range(N):
if ( if (
all_bboxes[j]["category_id"] != subject_category_id all_bboxes[j]["category_id"] != subject_category_id
or j in with_caption_subject or j in with_caption_subject
): ):
continue continue
candidates.append((dis[i][j], j)) candidates.append((dis[i][j], j))
...@@ -335,7 +342,7 @@ class MagicModel: ...@@ -335,7 +342,7 @@ class MagicModel:
] ]
def get_tables( def get_tables(
self, page_no: int self, page_no: int
) -> list: # 3个坐标, caption, table主体,table-note ) -> list: # 3个坐标, caption, table主体,table-note
with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6) with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6)
with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7) with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7)
...@@ -358,9 +365,15 @@ class MagicModel: ...@@ -358,9 +365,15 @@ class MagicModel:
return ret return ret
def get_equations(self, page_no: int) -> list: # 有坐标,也有字 def get_equations(self, page_no: int) -> list: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type(ModelBlockTypeEnum.EMBEDDING.value, page_no, ["latex"]) inline_equations = self.__get_blocks_by_type(
interline_equations = self.__get_blocks_by_type(ModelBlockTypeEnum.ISOLATED.value, page_no, ["latex"]) ModelBlockTypeEnum.EMBEDDING.value, page_no, ["latex"]
interline_equations_blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no) )
interline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATED.value, page_no, ["latex"]
)
interline_equations_blocks = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no
)
return inline_equations, interline_equations, interline_equations_blocks return inline_equations, interline_equations, interline_equations_blocks
def get_discarded(self, page_no: int) -> list: # 自研模型,只有坐标 def get_discarded(self, page_no: int) -> list: # 自研模型,只有坐标
...@@ -382,7 +395,7 @@ class MagicModel: ...@@ -382,7 +395,7 @@ class MagicModel:
for layout_det in layout_dets: for layout_det in layout_dets:
if layout_det["category_id"] == "15": if layout_det["category_id"] == "15":
span = { span = {
"bbox": layout_det['bbox'], "bbox": layout_det["bbox"],
"content": layout_det["text"], "content": layout_det["text"],
} }
text_spans.append(span) text_spans.append(span)
...@@ -402,9 +415,7 @@ class MagicModel: ...@@ -402,9 +415,7 @@ class MagicModel:
for layout_det in layout_dets: for layout_det in layout_dets:
category_id = layout_det["category_id"] category_id = layout_det["category_id"]
if category_id in allow_category_id_list: if category_id in allow_category_id_list:
span = { span = {"bbox": layout_det["bbox"]}
"bbox": layout_det['bbox']
}
if category_id == 3: if category_id == 3:
span["type"] = ContentType.Image span["type"] = ContentType.Image
elif category_id == 5: elif category_id == 5:
...@@ -429,7 +440,9 @@ class MagicModel: ...@@ -429,7 +440,9 @@ class MagicModel:
page_h = page.rect.height page_h = page.rect.height
return page_w, page_h return page_w, page_h
def __get_blocks_by_type(self, type: int, page_no: int, extra_col: list[str] = []) -> list: def __get_blocks_by_type(
self, type: int, page_no: int, extra_col: list[str] = []
) -> list:
blocks = [] blocks = []
for page_dict in self.__model_list: for page_dict in self.__model_list:
layout_dets = page_dict.get("layout_dets", []) layout_dets = page_dict.get("layout_dets", [])
...@@ -442,9 +455,7 @@ class MagicModel: ...@@ -442,9 +455,7 @@ class MagicModel:
bbox = item.get("bbox", None) bbox = item.get("bbox", None)
if category_id == type: if category_id == type:
block = { block = {"bbox": bbox}
"bbox": bbox
}
for col in extra_col: for col in extra_col:
block[col] = item.get(col, None) block[col] = item.get(col, None)
blocks.append(block) blocks.append(block)
......
...@@ -3,9 +3,21 @@ from magic_pdf.libs.boxbase import _is_in_or_part_overlap, _is_in ...@@ -3,9 +3,21 @@ from magic_pdf.libs.boxbase import _is_in_or_part_overlap, _is_in
def _remove_overlap_between_bbox(spans): def _remove_overlap_between_bbox(spans):
res = [] res = []
for v in spans:
keeps = [True] * len(spans)
for i in range(len(spans)):
for j in range(len(spans)):
if i == j:
continue
if _is_in(spans[i]["bbox"], spans[j]["bbox"]):
keeps[i] = False
for idx, v in enumerate(spans):
if not keeps[idx]:
continue
for i in range(len(res)): for i in range(len(res)):
if _is_in(res[i]["bbox"], v["bbox"]) or _is_in(v["bbox"], res[i]["bbox"]): if _is_in(v["bbox"], res[i]["bbox"]):
continue continue
if _is_in_or_part_overlap(res[i]["bbox"], v["bbox"]): if _is_in_or_part_overlap(res[i]["bbox"], v["bbox"]):
ix0, iy0, ix1, iy1 = res[i]["bbox"] ix0, iy0, ix1, iy1 = res[i]["bbox"]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment