Unverified Commit 03469909 authored by icecraft's avatar icecraft Committed by GitHub

Feat/support footnote in figure (#532)

* feat: support figure footnote

* feat: using the relative position to combine footnote, table, image

* feat: add the readme of projects

* fix: code spell in unittest

---------
Co-authored-by: 's avataricecraft <xurui1@pjlab.org.cn>
parent 4331b837
......@@ -3,6 +3,7 @@ repos:
rev: 5.0.4
hooks:
- id: flake8
args: ["--max-line-length=120", "--ignore=E131,E125,W503,W504,E203"]
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
......@@ -11,6 +12,7 @@ repos:
rev: v0.32.0
hooks:
- id: yapf
args: ["--style={based_on_style: google, column_limit: 120, indent_width: 4}"]
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
hooks:
......@@ -41,4 +43,4 @@ repos:
rev: v1.3.1
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
args: ["--in-place", "--wrap-descriptions", "119"]
import re
import wordninja
from loguru import logger
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import ContentType, BlockType
import wordninja
import re
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
def __is_hyphen_at_line_end(line):
......@@ -37,8 +38,9 @@ def split_long_words(text):
def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
markdown = []
for page_info in pdf_info_list:
paras_of_layout = page_info.get("para_blocks")
page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "mm", img_buket_path)
paras_of_layout = page_info.get('para_blocks')
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
markdown.extend(page_markdown)
return '\n\n'.join(markdown)
......@@ -46,29 +48,34 @@ def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
def ocr_mk_nlp_markdown_with_para(pdf_info_dict: list):
markdown = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get("para_blocks")
page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "nlp")
paras_of_layout = page_info.get('para_blocks')
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'nlp')
markdown.extend(page_markdown)
return '\n\n'.join(markdown)
def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list, img_buket_path):
def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
img_buket_path):
markdown_with_para_and_pagination = []
page_no = 0
for page_info in pdf_info_dict:
paras_of_layout = page_info.get("para_blocks")
paras_of_layout = page_info.get('para_blocks')
if not paras_of_layout:
continue
page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "mm", img_buket_path)
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
markdown_with_para_and_pagination.append({
'page_no': page_no,
'md_content': '\n\n'.join(page_markdown)
'page_no':
page_no,
'md_content':
'\n\n'.join(page_markdown)
})
page_no += 1
return markdown_with_para_and_pagination
def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=""):
def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=''):
page_markdown = []
for paras in paras_of_layout:
for para in paras:
......@@ -81,8 +88,9 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=""):
if span_type == ContentType.Text:
content = span['content']
language = detect_lang(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(split_long_words(content))
if (language == 'en'): # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
......@@ -106,7 +114,9 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=""):
return page_markdown
def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
mode,
img_buket_path=''):
page_markdown = []
for para_block in paras_of_layout:
para_text = ''
......@@ -114,7 +124,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
if para_type == BlockType.Text:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Title:
para_text = f"# {merge_para_with_text(para_block)}"
para_text = f'# {merge_para_with_text(para_block)}'
elif para_type == BlockType.InterlineEquation:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Image:
......@@ -130,11 +140,13 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageCaption:
para_text += merge_para_with_text(block)
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageFootnote:
para_text += merge_para_with_text(block)
elif para_type == BlockType.Table:
if mode == 'nlp':
continue
elif mode == 'mm':
table_caption = ''
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TableCaption:
para_text += merge_para_with_text(block)
......@@ -163,6 +175,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
def merge_para_with_text(para_block):
def detect_language(text):
en_pattern = r'[a-zA-Z]+'
en_matches = re.findall(en_pattern, text)
......@@ -171,19 +184,19 @@ def merge_para_with_text(para_block):
if en_length / len(text) >= 0.5:
return 'en'
else:
return "unknown"
return 'unknown'
else:
return "empty"
return 'empty'
para_text = ''
for line in para_block['lines']:
line_text = ""
line_lang = ""
line_text = ''
line_lang = ''
for span in line['spans']:
span_type = span['type']
if span_type == ContentType.Text:
line_text += span['content'].strip()
if line_text != "":
if line_text != '':
line_lang = detect_lang(line_text)
for span in line['spans']:
span_type = span['type']
......@@ -193,7 +206,8 @@ def merge_para_with_text(para_block):
# language = detect_lang(content)
language = detect_language(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(split_long_words(content))
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
......@@ -227,12 +241,13 @@ def para_to_standard_format(para, img_buket_path):
for span in line['spans']:
language = ''
span_type = span.get('type')
content = ""
content = ''
if span_type == ContentType.Text:
content = span['content']
language = detect_lang(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(split_long_words(content))
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
......@@ -245,7 +260,7 @@ def para_to_standard_format(para, img_buket_path):
para_content = {
'type': 'text',
'text': para_text,
'inline_equation_num': inline_equation_num
'inline_equation_num': inline_equation_num,
}
return para_content
......@@ -256,37 +271,35 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
'page_idx': page_idx
'page_idx': page_idx,
}
elif para_type == BlockType.Title:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
'text_level': 1,
'page_idx': page_idx
'page_idx': page_idx,
}
elif para_type == BlockType.InterlineEquation:
para_content = {
'type': 'equation',
'text': merge_para_with_text(para_block),
'text_format': "latex",
'page_idx': page_idx
'text_format': 'latex',
'page_idx': page_idx,
}
elif para_type == BlockType.Image:
para_content = {
'type': 'image',
'page_idx': page_idx
}
para_content = {'type': 'image', 'page_idx': page_idx}
for block in para_block['blocks']:
if block['type'] == BlockType.ImageBody:
para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
para_content['img_path'] = join_path(
img_buket_path,
block['lines'][0]['spans'][0]['image_path'])
if block['type'] == BlockType.ImageCaption:
para_content['img_caption'] = merge_para_with_text(block)
if block['type'] == BlockType.ImageFootnote:
para_content['img_footnote'] = merge_para_with_text(block)
elif para_type == BlockType.Table:
para_content = {
'type': 'table',
'page_idx': page_idx
}
para_content = {'type': 'table', 'page_idx': page_idx}
for block in para_block['blocks']:
if block['type'] == BlockType.TableBody:
if block["lines"][0]["spans"][0].get('latex', ''):
......@@ -305,17 +318,18 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
def make_standard_format_with_para(pdf_info_dict: list, img_buket_path: str):
content_list = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get("para_blocks")
paras_of_layout = page_info.get('para_blocks')
if not paras_of_layout:
continue
for para_block in paras_of_layout:
para_content = para_to_standard_format_v2(para_block, img_buket_path)
para_content = para_to_standard_format_v2(para_block,
img_buket_path)
content_list.append(para_content)
return content_list
def line_to_standard_format(line, img_buket_path):
line_text = ""
line_text = ''
inline_equation_num = 0
for span in line['spans']:
if not span.get('content'):
......@@ -325,13 +339,15 @@ def line_to_standard_format(line, img_buket_path):
if span['type'] == ContentType.Image:
content = {
'type': 'image',
'img_path': join_path(img_buket_path, span['image_path'])
'img_path': join_path(img_buket_path,
span['image_path']),
}
return content
elif span['type'] == ContentType.Table:
content = {
'type': 'table',
'img_path': join_path(img_buket_path, span['image_path'])
'img_path': join_path(img_buket_path,
span['image_path']),
}
return content
else:
......@@ -339,36 +355,33 @@ def line_to_standard_format(line, img_buket_path):
interline_equation = span['content']
content = {
'type': 'equation',
'latex': f"$$\n{interline_equation}\n$$"
'latex': f'$$\n{interline_equation}\n$$'
}
return content
elif span['type'] == ContentType.InlineEquation:
inline_equation = span['content']
line_text += f"${inline_equation}$"
line_text += f'${inline_equation}$'
inline_equation_num += 1
elif span['type'] == ContentType.Text:
text_content = ocr_escape_special_markdown_char(span['content']) # 转义特殊符号
text_content = ocr_escape_special_markdown_char(
span['content']) # 转义特殊符号
line_text += text_content
content = {
'type': 'text',
'text': line_text,
'inline_equation_num': inline_equation_num
'inline_equation_num': inline_equation_num,
}
return content
def ocr_mk_mm_standard_format(pdf_info_dict: list):
"""
content_list
type string image/text/table/equation(行间的单独拿出来,行内的和text合并)
latex string latex文本字段。
text string 纯文本格式的文本数据。
md string markdown格式的文本数据。
img_path string s3://full/path/to/img.jpg
"""
"""content_list type string
image/text/table/equation(行间的单独拿出来,行内的和text合并) latex string
latex文本字段。 text string 纯文本格式的文本数据。 md string
markdown格式的文本数据。 img_path string s3://full/path/to/img.jpg."""
content_list = []
for page_info in pdf_info_dict:
blocks = page_info.get("preproc_blocks")
blocks = page_info.get('preproc_blocks')
if not blocks:
continue
for block in blocks:
......@@ -378,34 +391,42 @@ def ocr_mk_mm_standard_format(pdf_info_dict: list):
return content_list
def union_make(pdf_info_dict: list, make_mode: str, drop_mode: str, img_buket_path: str = ""):
def union_make(pdf_info_dict: list,
make_mode: str,
drop_mode: str,
img_buket_path: str = ''):
output_content = []
for page_info in pdf_info_dict:
if page_info.get("need_drop", False):
drop_reason = page_info.get("drop_reason")
if page_info.get('need_drop', False):
drop_reason = page_info.get('drop_reason')
if drop_mode == DropMode.NONE:
pass
elif drop_mode == DropMode.WHOLE_PDF:
raise Exception(f"drop_mode is {DropMode.WHOLE_PDF} , drop_reason is {drop_reason}")
raise Exception((f'drop_mode is {DropMode.WHOLE_PDF} ,'
f'drop_reason is {drop_reason}'))
elif drop_mode == DropMode.SINGLE_PAGE:
logger.warning(f"drop_mode is {DropMode.SINGLE_PAGE} , drop_reason is {drop_reason}")
logger.warning((f'drop_mode is {DropMode.SINGLE_PAGE} ,'
f'drop_reason is {drop_reason}'))
continue
else:
raise Exception(f"drop_mode can not be null")
raise Exception('drop_mode can not be null')
paras_of_layout = page_info.get("para_blocks")
page_idx = page_info.get("page_idx")
paras_of_layout = page_info.get('para_blocks')
page_idx = page_info.get('page_idx')
if not paras_of_layout:
continue
if make_mode == MakeMode.MM_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "mm", img_buket_path)
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
output_content.extend(page_markdown)
elif make_mode == MakeMode.NLP_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "nlp")
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'nlp')
output_content.extend(page_markdown)
elif make_mode == MakeMode.STANDARD_FORMAT:
for para_block in paras_of_layout:
para_content = para_to_standard_format_v2(para_block, img_buket_path, page_idx)
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx)
output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content)
......
"""
对pdf上的box进行layout识别,并对内部组成的box进行排序
"""
"""对pdf上的box进行layout识别,并对内部组成的box进行排序."""
from loguru import logger
from magic_pdf.layout.bbox_sort import CONTENT_IDX, CONTENT_TYPE_IDX, X0_EXT_IDX, X0_IDX, X1_EXT_IDX, X1_IDX, Y0_EXT_IDX, Y0_IDX, Y1_EXT_IDX, Y1_IDX, paper_bbox_sort
from magic_pdf.layout.layout_det_utils import find_all_left_bbox_direct, find_all_right_bbox_direct, find_bottom_bbox_direct_from_left_edge, find_bottom_bbox_direct_from_right_edge, find_top_bbox_direct_from_left_edge, find_top_bbox_direct_from_right_edge, find_all_top_bbox_direct, find_all_bottom_bbox_direct, get_left_edge_bboxes, get_right_edge_bboxes
from magic_pdf.libs.boxbase import get_bbox_in_boundry
from magic_pdf.layout.bbox_sort import (CONTENT_IDX, CONTENT_TYPE_IDX,
X0_EXT_IDX, X0_IDX, X1_EXT_IDX, X1_IDX,
Y0_EXT_IDX, Y0_IDX, Y1_EXT_IDX, Y1_IDX,
paper_bbox_sort)
from magic_pdf.layout.layout_det_utils import (
find_all_bottom_bbox_direct, find_all_left_bbox_direct,
find_all_right_bbox_direct, find_all_top_bbox_direct,
find_bottom_bbox_direct_from_left_edge,
find_bottom_bbox_direct_from_right_edge,
find_top_bbox_direct_from_left_edge, find_top_bbox_direct_from_right_edge,
get_left_edge_bboxes, get_right_edge_bboxes)
from magic_pdf.libs.boxbase import get_bbox_in_boundary
LAYOUT_V = 'V'
LAYOUT_H = 'H'
LAYOUT_UNPROC = 'U'
LAYOUT_BAD = 'B'
LAYOUT_V = "V"
LAYOUT_H = "H"
LAYOUT_UNPROC = "U"
LAYOUT_BAD = "B"
def _is_single_line_text(bbox):
"""
检查bbox里面的文字是否只有一行
"""
"""检查bbox里面的文字是否只有一行."""
return True # TODO
box_type = bbox[CONTENT_TYPE_IDX]
if box_type != 'text':
return False
paras = bbox[CONTENT_IDX]["paras"]
text_content = ""
paras = bbox[CONTENT_IDX]['paras']
text_content = ''
for para_id, para in paras.items(): # 拼装内部的段落文本
is_title = para['is_title']
if is_title!=0:
if is_title != 0:
text_content += f"## {para['text']}"
else:
text_content += para["text"]
text_content += "\n\n"
text_content += para['text']
text_content += '\n\n'
return bbox[CONTENT_TYPE_IDX] == 'text' and len(text_content.split("\n\n")) <= 1
return bbox[CONTENT_TYPE_IDX] == 'text' and len(text_content.split('\n\n')) <= 1
def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
def _horizontal_split(bboxes: list, boundary: tuple, avg_font_size=20) -> list:
"""
对bboxes进行水平切割
方法是:找到左侧和右侧都没有被直接遮挡的box,然后进行扩展,之后进行切割
......@@ -43,14 +49,14 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
"""
sorted_layout_blocks = [] # 这是要最终返回的值
bound_x0, bound_y0, bound_x1, bound_y1 = boundry
all_bboxes = get_bbox_in_boundry(bboxes, boundry)
#all_bboxes = paper_bbox_sort(all_bboxes, abs(bound_x1-bound_x0), abs(bound_y1-bound_x0)) # 大致拍下序, 这个是基于直接遮挡的。
bound_x0, bound_y0, bound_x1, bound_y1 = boundary
all_bboxes = get_bbox_in_boundary(bboxes, boundary)
# all_bboxes = paper_bbox_sort(all_bboxes, abs(bound_x1-bound_x0), abs(bound_y1-bound_x0)) # 大致拍下序, 这个是基于直接遮挡的。
"""
首先在水平方向上扩展独占一行的bbox
"""
last_h_split_line_y1 = bound_y0 #记录下上次的水平分割线
last_h_split_line_y1 = bound_y0 # 记录下上次的水平分割线
for i, bbox in enumerate(all_bboxes):
left_nearest_bbox = find_all_left_bbox_direct(bbox, all_bboxes) # 非扩展线
right_nearest_bbox = find_all_right_bbox_direct(bbox, all_bboxes)
......@@ -62,16 +68,20 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
3. TODO 加强条件:这个bbox上方和下方是同一列column,那么就不能算作独占一行
"""
# 先检查这个bbox里是否只包含一行文字
is_single_line = _is_single_line_text(bbox)
# is_single_line = _is_single_line_text(bbox)
"""
这里有个点需要注意,当页面内容不是居中的时候,第一次调用传递的是page的boundry,这个时候mid_x就不是中心线了.
所以这里计算出最紧致的boundry,然后再计算mid_x
这里有个点需要注意,当页面内容不是居中的时候,第一次调用传递的是page的boundary,这个时候mid_x就不是中心线了.
所以这里计算出最紧致的boundary,然后再计算mid_x
"""
boundry_real_x0, boundry_real_x1 = min([bbox[X0_IDX] for bbox in all_bboxes]), max([bbox[X1_IDX] for bbox in all_bboxes])
mid_x = (boundry_real_x0+boundry_real_x1)/2
boundary_real_x0, boundary_real_x1 = min(
[bbox[X0_IDX] for bbox in all_bboxes]
), max([bbox[X1_IDX] for bbox in all_bboxes])
mid_x = (boundary_real_x0 + boundary_real_x1) / 2
# 检查这个box是否内容在中心线有交
# 必须跨过去2个字符的宽度
is_cross_boundry_mid_line = min(mid_x-bbox[X0_IDX], bbox[X1_IDX]-mid_x) > avg_font_size*2
is_cross_boundary_mid_line = (
min(mid_x - bbox[X0_IDX], bbox[X1_IDX] - mid_x) > avg_font_size * 2
)
"""
检查条件2
"""
......@@ -84,11 +94,11 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
"""
以迭代的方式向上找,查找范围是[bound_x0, last_h_sp, bound_x1, bbox[Y0_IDX]]
"""
#先确定上方的y0, y0
# 先确定上方的y0, y0
b_y0, b_y1 = last_h_split_line_y1, bbox[Y0_IDX]
#然后从box开始逐个向上找到所有与box在x上有交集的box
# 然后从box开始逐个向上找到所有与box在x上有交集的box
box_to_check = [bound_x0, b_y0, bound_x1, b_y1]
bbox_in_bound_check = get_bbox_in_boundry(all_bboxes, box_to_check)
bbox_in_bound_check = get_bbox_in_boundary(all_bboxes, box_to_check)
bboxes_on_top = []
virtual_box = bbox
......@@ -96,33 +106,61 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
b_on_top = find_all_top_bbox_direct(virtual_box, bbox_in_bound_check)
if b_on_top is not None:
bboxes_on_top.append(b_on_top)
virtual_box = [min([virtual_box[X0_IDX], b_on_top[X0_IDX]]), min(virtual_box[Y0_IDX], b_on_top[Y0_IDX]), max([virtual_box[X1_IDX], b_on_top[X1_IDX]]), b_y1]
virtual_box = [
min([virtual_box[X0_IDX], b_on_top[X0_IDX]]),
min(virtual_box[Y0_IDX], b_on_top[Y0_IDX]),
max([virtual_box[X1_IDX], b_on_top[X1_IDX]]),
b_y1,
]
else:
break
# 随后确定这些box的最小x0, 最大x1
if len(bboxes_on_top)>0 and len(bboxes_on_top) != len(bbox_in_bound_check):# virtual_box可能会膨胀到占满整个区域,这实际上就不能属于一个col了。
if len(bboxes_on_top) > 0 and len(bboxes_on_top) != len(
bbox_in_bound_check
): # virtual_box可能会膨胀到占满整个区域,这实际上就不能属于一个col了。
min_x0, max_x1 = virtual_box[X0_IDX], virtual_box[X1_IDX]
# 然后采用一种比较粗糙的方法,看min_x0,max_x1是否与位于[bound_x0, last_h_sp, bound_x1, bbox[Y0_IDX]]之间的box有相交
if not any([b[X0_IDX] <= min_x0-1 <= b[X1_IDX] or b[X0_IDX] <= max_x1+1 <= b[X1_IDX] for b in bbox_in_bound_check]):
if not any(
[
b[X0_IDX] <= min_x0 - 1 <= b[X1_IDX]
or b[X0_IDX] <= max_x1 + 1 <= b[X1_IDX]
for b in bbox_in_bound_check
]
):
# 其上,下都不能被扩展成行,暂时只检查一下上方 TODO
top_nearest_bbox = find_all_top_bbox_direct(bbox, bboxes)
bottom_nearest_bbox = find_all_bottom_bbox_direct(bbox, bboxes)
if not any([
top_nearest_bbox is not None and (find_all_left_bbox_direct(top_nearest_bbox, bboxes) is None and find_all_right_bbox_direct(top_nearest_bbox, bboxes) is None),
bottom_nearest_bbox is not None and (find_all_left_bbox_direct(bottom_nearest_bbox, bboxes) is None and find_all_right_bbox_direct(bottom_nearest_bbox, bboxes) is None),
top_nearest_bbox is None or bottom_nearest_bbox is None
]):
if not any(
[
top_nearest_bbox is not None
and (
find_all_left_bbox_direct(top_nearest_bbox, bboxes)
is None
and find_all_right_bbox_direct(top_nearest_bbox, bboxes)
is None
),
bottom_nearest_bbox is not None
and (
find_all_left_bbox_direct(bottom_nearest_bbox, bboxes)
is None
and find_all_right_bbox_direct(
bottom_nearest_bbox, bboxes
)
is None
),
top_nearest_bbox is None or bottom_nearest_bbox is None,
]
):
is_belong_to_col = True
# 检查是否能被下方col吸收 TODO
"""
这里为什么没有is_cross_boundry_mid_line的条件呢?
这里为什么没有is_cross_boundary_mid_line的条件呢?
确实有些杂志左右两栏宽度不是对称的。
"""
if not is_belong_to_col or is_cross_boundry_mid_line:
if not is_belong_to_col or is_cross_boundary_mid_line:
bbox[X0_EXT_IDX] = bound_x0
bbox[Y0_EXT_IDX] = bbox[Y0_IDX]
bbox[X1_EXT_IDX] = bound_x1
......@@ -142,13 +180,12 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
if bbox[X0_EXT_IDX] == bound_x0 and bbox[X1_EXT_IDX] == bound_x1:
h_bbox_group.append(bbox)
else:
if len(h_bbox_group)>0:
if len(h_bbox_group) > 0:
h_bboxes.append(h_bbox_group)
h_bbox_group = []
# 最后一个group
if len(h_bbox_group)>0:
if len(h_bbox_group) > 0:
h_bboxes.append(h_bbox_group)
"""
现在h_bboxes里面是所有的group了,每个group都是一个list
对h_bboxes里的每个group进行计算放回到sorted_layouts里
......@@ -157,9 +194,13 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
for gp in h_bboxes:
gp.sort(key=lambda x: x[Y0_IDX])
# 然后计算这个group的layout_bbox,也就是最小的x0,y0, 最大的x1,y1
x0, y0, x1, y1 = gp[0][X0_EXT_IDX], gp[0][Y0_EXT_IDX], gp[-1][X1_EXT_IDX], gp[-1][Y1_EXT_IDX]
x0, y0, x1, y1 = (
gp[0][X0_EXT_IDX],
gp[0][Y0_EXT_IDX],
gp[-1][X1_EXT_IDX],
gp[-1][Y1_EXT_IDX],
)
h_layouts.append([x0, y0, x1, y1, LAYOUT_H]) # 水平的布局
"""
接下来利用这些连续的水平bbox的layout_bbox的y0, y1,从水平上切分开其余的为几个部分
"""
......@@ -172,37 +213,48 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
unsplited_bboxes = []
for i in range(0, len(h_split_lines), 2):
start_y0, start_y1 = h_split_lines[i:i+2]
start_y0, start_y1 = h_split_lines[i : i + 2]
# 然后找出[start_y0, start_y1]之间的其他bbox,这些组成一个未分割板块
bboxes_in_block = [bbox for bbox in all_bboxes if bbox[Y0_IDX]>=start_y0 and bbox[Y1_IDX]<=start_y1]
bboxes_in_block = [
bbox
for bbox in all_bboxes
if bbox[Y0_IDX] >= start_y0 and bbox[Y1_IDX] <= start_y1
]
unsplited_bboxes.append(bboxes_in_block)
# 接着把未处理的加入到h_layouts里
for bboxes_in_block in unsplited_bboxes:
if len(bboxes_in_block) == 0:
continue
x0, y0, x1, y1 = bound_x0, min([bbox[Y0_IDX] for bbox in bboxes_in_block]), bound_x1, max([bbox[Y1_IDX] for bbox in bboxes_in_block])
x0, y0, x1, y1 = (
bound_x0,
min([bbox[Y0_IDX] for bbox in bboxes_in_block]),
bound_x1,
max([bbox[Y1_IDX] for bbox in bboxes_in_block]),
)
h_layouts.append([x0, y0, x1, y1, LAYOUT_UNPROC])
h_layouts.sort(key=lambda x: x[1]) # 按照y0排序, 也就是从上到下的顺序
"""
转换成如下格式返回
"""
for layout in h_layouts:
sorted_layout_blocks.append({
"layout_bbox": layout[:4],
"layout_label":layout[4],
"sub_layout":[],
})
sorted_layout_blocks.append(
{
'layout_bbox': layout[:4],
'layout_label': layout[4],
'sub_layout': [],
}
)
return sorted_layout_blocks
###############################################################################################
#
# 垂直方向的处理
#
#
###############################################################################################
def _vertical_align_split_v1(bboxes:list, boundry:tuple)-> list:
def _vertical_align_split_v1(bboxes: list, boundary: tuple) -> list:
"""
计算垂直方向上的对齐, 并分割bboxes成layout。负责对一列多行的进行列维度分割。
如果不能完全分割,剩余部分作为layout_lable为u的layout返回
......@@ -215,84 +267,123 @@ def _vertical_align_split_v1(bboxes:list, boundry:tuple)-> list:
此函数会将:以上布局将会切分出来2列
"""
sorted_layout_blocks = [] # 这是要最终返回的值
new_boundry = [boundry[0], boundry[1], boundry[2], boundry[3]]
new_boundary = [boundary[0], boundary[1], boundary[2], boundary[3]]
v_blocks = []
"""
先从左到右切分
"""
while True:
all_bboxes = get_bbox_in_boundry(bboxes, new_boundry)
all_bboxes = get_bbox_in_boundary(bboxes, new_boundary)
left_edge_bboxes = get_left_edge_bboxes(all_bboxes)
if len(left_edge_bboxes) == 0:
break
right_split_line_x1 = max([bbox[X1_IDX] for bbox in left_edge_bboxes])+1
right_split_line_x1 = max([bbox[X1_IDX] for bbox in left_edge_bboxes]) + 1
# 然后检查这条线能不与其他bbox的左边界相交或者重合
if any([bbox[X0_IDX] <= right_split_line_x1 <= bbox[X1_IDX] for bbox in all_bboxes]):
if any(
[bbox[X0_IDX] <= right_split_line_x1 <= bbox[X1_IDX] for bbox in all_bboxes]
):
# 垂直切分线与某些box发生相交,说明无法完全垂直方向切分。
break
else: # 说明成功分割出一列
# 找到左侧边界最靠左的bbox作为layout的x0
layout_x0 = min([bbox[X0_IDX] for bbox in left_edge_bboxes]) # 这里主要是为了画出来有一定间距
v_blocks.append([layout_x0, new_boundry[1], right_split_line_x1, new_boundry[3], LAYOUT_V])
new_boundry[0] = right_split_line_x1 # 更新边界
layout_x0 = min(
[bbox[X0_IDX] for bbox in left_edge_bboxes]
) # 这里主要是为了画出来有一定间距
v_blocks.append(
[
layout_x0,
new_boundary[1],
right_split_line_x1,
new_boundary[3],
LAYOUT_V,
]
)
new_boundary[0] = right_split_line_x1 # 更新边界
"""
再从右到左切, 此时如果还是无法完全切分,那么剩余部分作为layout_lable为u的layout返回
"""
unsplited_block = []
while True:
all_bboxes = get_bbox_in_boundry(bboxes, new_boundry)
all_bboxes = get_bbox_in_boundary(bboxes, new_boundary)
right_edge_bboxes = get_right_edge_bboxes(all_bboxes)
if len(right_edge_bboxes) == 0:
break
left_split_line_x0 = min([bbox[X0_IDX] for bbox in right_edge_bboxes])-1
left_split_line_x0 = min([bbox[X0_IDX] for bbox in right_edge_bboxes]) - 1
# 然后检查这条线能不与其他bbox的左边界相交或者重合
if any([bbox[X0_IDX] <= left_split_line_x0 <= bbox[X1_IDX] for bbox in all_bboxes]):
if any(
[bbox[X0_IDX] <= left_split_line_x0 <= bbox[X1_IDX] for bbox in all_bboxes]
):
# 这里是余下的
unsplited_block.append([new_boundry[0], new_boundry[1], new_boundry[2], new_boundry[3], LAYOUT_UNPROC])
unsplited_block.append(
[
new_boundary[0],
new_boundary[1],
new_boundary[2],
new_boundary[3],
LAYOUT_UNPROC,
]
)
break
else:
# 找到右侧边界最靠右的bbox作为layout的x1
layout_x1 = max([bbox[X1_IDX] for bbox in right_edge_bboxes])
v_blocks.append([left_split_line_x0, new_boundry[1], layout_x1, new_boundry[3], LAYOUT_V])
new_boundry[2] = left_split_line_x0 # 更新右边界
v_blocks.append(
[
left_split_line_x0,
new_boundary[1],
layout_x1,
new_boundary[3],
LAYOUT_V,
]
)
new_boundary[2] = left_split_line_x0 # 更新右边界
"""
最后拼装成layout格式返回
"""
for block in v_blocks:
sorted_layout_blocks.append({
"layout_bbox": block[:4],
"layout_label":block[4],
"sub_layout":[],
})
sorted_layout_blocks.append(
{
'layout_bbox': block[:4],
'layout_label': block[4],
'sub_layout': [],
}
)
for block in unsplited_block:
sorted_layout_blocks.append({
"layout_bbox": block[:4],
"layout_label":block[4],
"sub_layout":[],
})
sorted_layout_blocks.append(
{
'layout_bbox': block[:4],
'layout_label': block[4],
'sub_layout': [],
}
)
# 按照x0排序
sorted_layout_blocks.sort(key=lambda x: x['layout_bbox'][0])
return sorted_layout_blocks
def _vertical_align_split_v2(bboxes:list, boundry:tuple)-> list:
"""
改进的 _vertical_align_split算法,原算法会因为第二列的box由于左侧没有遮挡被认为是左侧的一部分,导致整个layout多列被识别为一列。
利用从左上角的box开始向下看的方法,不断扩展w_x0, w_x1,直到不能继续向下扩展,或者到达边界下边界
"""
def _vertical_align_split_v2(bboxes: list, boundary: tuple) -> list:
"""改进的
_vertical_align_split算法,原算法会因为第二列的box由于左侧没有遮挡被认为是左侧的一部分,导致整个layout多列被识别为一列
利用从左上角的box开始向下看的方法,不断扩展w_x0, w_x1,直到不能继续向下扩展,或者到达边界下边界。"""
sorted_layout_blocks = [] # 这是要最终返回的值
new_boundry = [boundry[0], boundry[1], boundry[2], boundry[3]]
new_boundary = [boundary[0], boundary[1], boundary[2], boundary[3]]
bad_boxes = [] # 被割中的box
v_blocks = []
while True:
all_bboxes = get_bbox_in_boundry(bboxes, new_boundry)
all_bboxes = get_bbox_in_boundary(bboxes, new_boundary)
if len(all_bboxes) == 0:
break
left_top_box = min(all_bboxes, key=lambda x: (x[X0_IDX],x[Y0_IDX]))# 这里应该加强,检查一下必须是在第一列的 TODO
start_box = [left_top_box[X0_IDX], left_top_box[Y0_IDX], left_top_box[X1_IDX], left_top_box[Y1_IDX]]
left_top_box = min(
all_bboxes, key=lambda x: (x[X0_IDX], x[Y0_IDX])
) # 这里应该加强,检查一下必须是在第一列的 TODO
start_box = [
left_top_box[X0_IDX],
left_top_box[Y0_IDX],
left_top_box[X1_IDX],
left_top_box[Y1_IDX],
]
w_x0, w_x1 = left_top_box[X0_IDX], left_top_box[X1_IDX]
"""
然后沿着这个box线向下找最近的那个box, 然后扩展w_x0, w_x1
......@@ -303,96 +394,138 @@ def _vertical_align_split_v2(bboxes:list, boundry:tuple)-> list:
"""
while left_top_box is not None: # 向下去找
virtual_box = [w_x0, left_top_box[Y0_IDX], w_x1, left_top_box[Y1_IDX]]
left_top_box = find_bottom_bbox_direct_from_left_edge(virtual_box, all_bboxes)
left_top_box = find_bottom_bbox_direct_from_left_edge(
virtual_box, all_bboxes
)
if left_top_box:
w_x0, w_x1 = min(virtual_box[X0_IDX], left_top_box[X0_IDX]), max([virtual_box[X1_IDX], left_top_box[X1_IDX]])
w_x0, w_x1 = min(virtual_box[X0_IDX], left_top_box[X0_IDX]), max(
[virtual_box[X1_IDX], left_top_box[X1_IDX]]
)
# 万一这个初始的box在column中间,那么还要向上看
start_box = [w_x0, start_box[Y0_IDX], w_x1, start_box[Y1_IDX]] # 扩展一下宽度更鲁棒
start_box = [
w_x0,
start_box[Y0_IDX],
w_x1,
start_box[Y1_IDX],
] # 扩展一下宽度更鲁棒
left_top_box = find_top_bbox_direct_from_left_edge(start_box, all_bboxes)
while left_top_box is not None: # 向上去找
virtual_box = [w_x0, left_top_box[Y0_IDX], w_x1, left_top_box[Y1_IDX]]
left_top_box = find_top_bbox_direct_from_left_edge(virtual_box, all_bboxes)
if left_top_box:
w_x0, w_x1 = min(virtual_box[X0_IDX], left_top_box[X0_IDX]), max([virtual_box[X1_IDX], left_top_box[X1_IDX]])
w_x0, w_x1 = min(virtual_box[X0_IDX], left_top_box[X0_IDX]), max(
[virtual_box[X1_IDX], left_top_box[X1_IDX]]
)
# 检查相交
if any([bbox[X0_IDX] <= w_x1+1 <= bbox[X1_IDX] for bbox in all_bboxes]):
if any([bbox[X0_IDX] <= w_x1 + 1 <= bbox[X1_IDX] for bbox in all_bboxes]):
for b in all_bboxes:
if b[X0_IDX] <= w_x1+1 <= b[X1_IDX]:
if b[X0_IDX] <= w_x1 + 1 <= b[X1_IDX]:
bad_boxes.append([b[X0_IDX], b[Y0_IDX], b[X1_IDX], b[Y1_IDX]])
break
else: # 说明成功分割出一列
v_blocks.append([w_x0, new_boundry[1], w_x1, new_boundry[3], LAYOUT_V])
new_boundry[0] = w_x1 # 更新边界
v_blocks.append([w_x0, new_boundary[1], w_x1, new_boundary[3], LAYOUT_V])
new_boundary[0] = w_x1 # 更新边界
"""
接着开始从右上角的box扫描
"""
w_x0 , w_x1 = 0, 0
w_x0, w_x1 = 0, 0
unsplited_block = []
while True:
all_bboxes = get_bbox_in_boundry(bboxes, new_boundry)
all_bboxes = get_bbox_in_boundary(bboxes, new_boundary)
if len(all_bboxes) == 0:
break
# 先找到X1最大的
bbox_list_sorted = sorted(all_bboxes, key=lambda bbox: bbox[X1_IDX], reverse=True)
bbox_list_sorted = sorted(
all_bboxes, key=lambda bbox: bbox[X1_IDX], reverse=True
)
# Then, find the boxes with the smallest Y0 value
bigest_x1 = bbox_list_sorted[0][X1_IDX]
boxes_with_bigest_x1 = [bbox for bbox in bbox_list_sorted if bbox[X1_IDX] == bigest_x1] # 也就是最靠右的那些
right_top_box = min(boxes_with_bigest_x1, key=lambda bbox: bbox[Y0_IDX]) # y0最小的那个
start_box = [right_top_box[X0_IDX], right_top_box[Y0_IDX], right_top_box[X1_IDX], right_top_box[Y1_IDX]]
boxes_with_bigest_x1 = [
bbox for bbox in bbox_list_sorted if bbox[X1_IDX] == bigest_x1
] # 也就是最靠右的那些
right_top_box = min(
boxes_with_bigest_x1, key=lambda bbox: bbox[Y0_IDX]
) # y0最小的那个
start_box = [
right_top_box[X0_IDX],
right_top_box[Y0_IDX],
right_top_box[X1_IDX],
right_top_box[Y1_IDX],
]
w_x0, w_x1 = right_top_box[X0_IDX], right_top_box[X1_IDX]
while right_top_box is not None:
virtual_box = [w_x0, right_top_box[Y0_IDX], w_x1, right_top_box[Y1_IDX]]
right_top_box = find_bottom_bbox_direct_from_right_edge(virtual_box, all_bboxes)
right_top_box = find_bottom_bbox_direct_from_right_edge(
virtual_box, all_bboxes
)
if right_top_box:
w_x0, w_x1 = min([w_x0, right_top_box[X0_IDX]]), max([w_x1, right_top_box[X1_IDX]])
w_x0, w_x1 = min([w_x0, right_top_box[X0_IDX]]), max(
[w_x1, right_top_box[X1_IDX]]
)
# 在向上扫描
start_box = [w_x0, start_box[Y0_IDX], w_x1, start_box[Y1_IDX]] # 扩展一下宽度更鲁棒
start_box = [
w_x0,
start_box[Y0_IDX],
w_x1,
start_box[Y1_IDX],
] # 扩展一下宽度更鲁棒
right_top_box = find_top_bbox_direct_from_right_edge(start_box, all_bboxes)
while right_top_box is not None:
virtual_box = [w_x0, right_top_box[Y0_IDX], w_x1, right_top_box[Y1_IDX]]
right_top_box = find_top_bbox_direct_from_right_edge(virtual_box, all_bboxes)
right_top_box = find_top_bbox_direct_from_right_edge(
virtual_box, all_bboxes
)
if right_top_box:
w_x0, w_x1 = min([w_x0, right_top_box[X0_IDX]]), max([w_x1, right_top_box[X1_IDX]])
w_x0, w_x1 = min([w_x0, right_top_box[X0_IDX]]), max(
[w_x1, right_top_box[X1_IDX]]
)
# 检查是否与其他box相交, 垂直切分线与某些box发生相交,说明无法完全垂直方向切分。
if any([bbox[X0_IDX] <= w_x0-1 <= bbox[X1_IDX] for bbox in all_bboxes]):
unsplited_block.append([new_boundry[0], new_boundry[1], new_boundry[2], new_boundry[3], LAYOUT_UNPROC])
if any([bbox[X0_IDX] <= w_x0 - 1 <= bbox[X1_IDX] for bbox in all_bboxes]):
unsplited_block.append(
[
new_boundary[0],
new_boundary[1],
new_boundary[2],
new_boundary[3],
LAYOUT_UNPROC,
]
)
for b in all_bboxes:
if b[X0_IDX] <= w_x0-1 <= b[X1_IDX]:
if b[X0_IDX] <= w_x0 - 1 <= b[X1_IDX]:
bad_boxes.append([b[X0_IDX], b[Y0_IDX], b[X1_IDX], b[Y1_IDX]])
break
else: # 说明成功分割出一列
v_blocks.append([w_x0, new_boundry[1], w_x1, new_boundry[3], LAYOUT_V])
new_boundry[2] = w_x0
v_blocks.append([w_x0, new_boundary[1], w_x1, new_boundary[3], LAYOUT_V])
new_boundary[2] = w_x0
"""转换数据结构"""
for block in v_blocks:
sorted_layout_blocks.append({
"layout_bbox": block[:4],
"layout_label":block[4],
"sub_layout":[],
})
sorted_layout_blocks.append(
{
'layout_bbox': block[:4],
'layout_label': block[4],
'sub_layout': [],
}
)
for block in unsplited_block:
sorted_layout_blocks.append({
"layout_bbox": block[:4],
"layout_label":block[4],
"sub_layout":[],
"bad_boxes": bad_boxes # 记录下来,这个box是被割中的
})
sorted_layout_blocks.append(
{
'layout_bbox': block[:4],
'layout_label': block[4],
'sub_layout': [],
'bad_boxes': bad_boxes, # 记录下来,这个box是被割中的
}
)
# 按照x0排序
sorted_layout_blocks.sort(key=lambda x: x['layout_bbox'][0])
return sorted_layout_blocks
def _try_horizontal_mult_column_split(bboxes:list, boundry:tuple)-> list:
def _try_horizontal_mult_column_split(bboxes: list, boundary: tuple) -> list:
"""
尝试水平切分,如果切分不动,那就当一个BAD_LAYOUT返回
------------------
......@@ -406,9 +539,7 @@ def _try_horizontal_mult_column_split(bboxes:list, boundry:tuple)-> list:
pass
def _vertical_split(bboxes:list, boundry:tuple)-> list:
def _vertical_split(bboxes: list, boundary: tuple) -> list:
"""
从垂直方向进行切割,分block
这个版本里,如果垂直切分不动,那就当一个BAD_LAYOUT返回
......@@ -425,8 +556,8 @@ def _vertical_split(bboxes:list, boundry:tuple)-> list:
"""
sorted_layout_blocks = [] # 这是要最终返回的值
bound_x0, bound_y0, bound_x1, bound_y1 = boundry
all_bboxes = get_bbox_in_boundry(bboxes, boundry)
bound_x0, bound_y0, bound_x1, bound_y1 = boundary
all_bboxes = get_bbox_in_boundary(bboxes, boundary)
"""
all_bboxes = fix_vertical_bbox_pos(all_bboxes) # 垂直方向解覆盖
all_bboxes = fix_hor_bbox_pos(all_bboxes) # 水平解覆盖
......@@ -436,7 +567,7 @@ def _vertical_split(bboxes:list, boundry:tuple)-> list:
如果遇到互相重叠的bbox, 那么会把面积较小的box进行压缩,从而避免重叠。对布局切分来说带来正反馈。
"""
#all_bboxes = paper_bbox_sort(all_bboxes, abs(bound_x1-bound_x0), abs(bound_y1-bound_x0)) # 大致拍下序, 这个是基于直接遮挡的。
# all_bboxes = paper_bbox_sort(all_bboxes, abs(bound_x1-bound_x0), abs(bound_y1-bound_x0)) # 大致拍下序, 这个是基于直接遮挡的。
"""
首先在垂直方向上扩展独占一行的bbox
......@@ -444,12 +575,21 @@ def _vertical_split(bboxes:list, boundry:tuple)-> list:
for bbox in all_bboxes:
top_nearest_bbox = find_all_top_bbox_direct(bbox, all_bboxes) # 非扩展线
bottom_nearest_bbox = find_all_bottom_bbox_direct(bbox, all_bboxes)
if top_nearest_bbox is None and bottom_nearest_bbox is None and not any([b[X0_IDX]<bbox[X1_IDX]<b[X1_IDX] or b[X0_IDX]<bbox[X0_IDX]<b[X1_IDX] for b in all_bboxes]): # 独占一列, 且不和其他重叠
if (
top_nearest_bbox is None
and bottom_nearest_bbox is None
and not any(
[
b[X0_IDX] < bbox[X1_IDX] < b[X1_IDX]
or b[X0_IDX] < bbox[X0_IDX] < b[X1_IDX]
for b in all_bboxes
]
)
): # 独占一列, 且不和其他重叠
bbox[X0_EXT_IDX] = bbox[X0_IDX]
bbox[Y0_EXT_IDX] = bound_y0
bbox[X1_EXT_IDX] = bbox[X1_IDX]
bbox[Y1_EXT_IDX] = bound_y1
"""
此时独占一列的被成功扩展到指定的边界上,这个时候利用边界条件合并连续的bbox,成为一个group
然后合并所有连续垂直方向的bbox.
......@@ -460,18 +600,21 @@ def _vertical_split(bboxes:list, boundry:tuple)-> list:
for box in all_bboxes:
if box[Y0_EXT_IDX] == bound_y0 and box[Y1_EXT_IDX] == bound_y1:
v_bboxes.append(box)
"""
现在v_bboxes里面是所有的group了,每个group都是一个list
对v_bboxes里的每个group进行计算放回到sorted_layouts里
"""
v_layouts = []
for vbox in v_bboxes:
#gp.sort(key=lambda x: x[X0_IDX])
# gp.sort(key=lambda x: x[X0_IDX])
# 然后计算这个group的layout_bbox,也就是最小的x0,y0, 最大的x1,y1
x0, y0, x1, y1 = vbox[X0_EXT_IDX], vbox[Y0_EXT_IDX], vbox[X1_EXT_IDX], vbox[Y1_EXT_IDX]
x0, y0, x1, y1 = (
vbox[X0_EXT_IDX],
vbox[Y0_EXT_IDX],
vbox[X1_EXT_IDX],
vbox[Y1_EXT_IDX],
)
v_layouts.append([x0, y0, x1, y1, LAYOUT_V]) # 垂直的布局
"""
接下来利用这些连续的垂直bbox的layout_bbox的x0, x1,从垂直上切分开其余的为几个部分
"""
......@@ -484,26 +627,38 @@ def _vertical_split(bboxes:list, boundry:tuple)-> list:
unsplited_bboxes = []
for i in range(0, len(v_split_lines), 2):
start_x0, start_x1 = v_split_lines[i:i+2]
start_x0, start_x1 = v_split_lines[i : i + 2]
# 然后找出[start_x0, start_x1]之间的其他bbox,这些组成一个未分割板块
bboxes_in_block = [bbox for bbox in all_bboxes if bbox[X0_IDX]>=start_x0 and bbox[X1_IDX]<=start_x1]
bboxes_in_block = [
bbox
for bbox in all_bboxes
if bbox[X0_IDX] >= start_x0 and bbox[X1_IDX] <= start_x1
]
unsplited_bboxes.append(bboxes_in_block)
# 接着把未处理的加入到v_layouts里
for bboxes_in_block in unsplited_bboxes:
if len(bboxes_in_block) == 0:
continue
x0, y0, x1, y1 = min([bbox[X0_IDX] for bbox in bboxes_in_block]), bound_y0, max([bbox[X1_IDX] for bbox in bboxes_in_block]), bound_y1
v_layouts.append([x0, y0, x1, y1, LAYOUT_UNPROC]) # 说明这篇区域未能够分析出可靠的版面
x0, y0, x1, y1 = (
min([bbox[X0_IDX] for bbox in bboxes_in_block]),
bound_y0,
max([bbox[X1_IDX] for bbox in bboxes_in_block]),
bound_y1,
)
v_layouts.append(
[x0, y0, x1, y1, LAYOUT_UNPROC]
) # 说明这篇区域未能够分析出可靠的版面
v_layouts.sort(key=lambda x: x[0]) # 按照x0排序, 也就是从左到右的顺序
for layout in v_layouts:
sorted_layout_blocks.append({
"layout_bbox": layout[:4],
"layout_label":layout[4],
"sub_layout":[],
})
sorted_layout_blocks.append(
{
'layout_bbox': layout[:4],
'layout_label': layout[4],
'sub_layout': [],
}
)
"""
至此,垂直方向切成了2种类型,其一是独占一列的,其二是未处理的。
下面对这些未处理的进行垂直方向切分,这个切分要切出来类似“吕”这种类型的垂直方向的布局
......@@ -513,24 +668,32 @@ def _vertical_split(bboxes:list, boundry:tuple)-> list:
x0, y0, x1, y1 = layout['layout_bbox']
v_split_layouts = _vertical_align_split_v2(bboxes, [x0, y0, x1, y1])
sorted_layout_blocks[i] = {
"layout_bbox": [x0, y0, x1, y1],
"layout_label": LAYOUT_H,
"sub_layout": v_split_layouts
'layout_bbox': [x0, y0, x1, y1],
'layout_label': LAYOUT_H,
'sub_layout': v_split_layouts,
}
layout['layout_label'] = LAYOUT_H # 被垂线切分成了水平布局
return sorted_layout_blocks
def split_layout(bboxes:list, boundry:tuple, page_num:int)-> list:
def split_layout(bboxes: list, boundary: tuple, page_num: int) -> list:
"""
把bboxes切割成layout
return:
[
{
"layout_bbox": [x0, y0, x1, y1],
"layout_bbox": [x0,y0,x1,y1],
"layout_label":"u|v|h|b", 未处理|垂直|水平|BAD_LAYOUT
"sub_layout": [] #每个元素都是[x0, y0, x1, y1, block_content, idx_x, idx_y, content_type, ext_x0, ext_y0, ext_x1, ext_y1], 并且顺序就是阅读顺序
"sub_layout":[] #每个元素都是[
x0,y0,
x1,y1,
block_content,
idx_x,idx_y,
content_type,
ext_x0,ext_y0,
ext_x1,ext_y1
], 并且顺序就是阅读顺序
}
]
example:
......@@ -567,27 +730,27 @@ def split_layout(bboxes:list, boundry:tuple, page_num:int)-> list:
"""
sorted_layouts = [] # 最终返回的结果
boundry_x0, boundry_y0, boundry_x1, boundry_y1 = boundry
if len(bboxes) <=1:
boundary_x0, boundary_y0, boundary_x1, boundary_y1 = boundary
if len(bboxes) <= 1:
return [
{
"layout_bbox": [boundry_x0, boundry_y0, boundry_x1, boundry_y1],
"layout_label": LAYOUT_V,
"sub_layout":[]
'layout_bbox': [boundary_x0, boundary_y0, boundary_x1, boundary_y1],
'layout_label': LAYOUT_V,
'sub_layout': [],
}
]
"""
接下来按照先水平后垂直的顺序进行切分
"""
bboxes = paper_bbox_sort(bboxes, boundry_x1-boundry_x0, boundry_y1-boundry_y0)
sorted_layouts = _horizontal_split(bboxes, boundry) # 通过水平分割出来的layout
bboxes = paper_bbox_sort(
bboxes, boundary_x1 - boundary_x0, boundary_y1 - boundary_y0
)
sorted_layouts = _horizontal_split(bboxes, boundary) # 通过水平分割出来的layout
for i, layout in enumerate(sorted_layouts):
x0, y0, x1, y1 = layout['layout_bbox']
layout_type = layout['layout_label']
if layout_type == LAYOUT_UNPROC: # 说明是非独占单行的,这些需要垂直切分
v_split_layouts = _vertical_split(bboxes, [x0, y0, x1, y1])
"""
最后这里有个逻辑问题:如果这个函数只分离出来了一个column layout,那么这个layout分割肯定超出了算法能力范围。因为我们假定的是传进来的
box已经把行全部剥离了,所以这里必须十多个列才可以。如果只剥离出来一个layout,并且是多个box,那么就说明这个layout是无法分割的,标记为LAYOUT_UNPROC
......@@ -596,18 +759,16 @@ def split_layout(bboxes:list, boundry:tuple, page_num:int)-> list:
if len(v_split_layouts) == 1:
if len(v_split_layouts[0]['sub_layout']) == 0:
layout_label = LAYOUT_UNPROC
#logger.warning(f"WARNING: pageno={page_num}, 无法分割的layout: ", v_split_layouts)
# logger.warning(f"WARNING: pageno={page_num}, 无法分割的layout: ", v_split_layouts)
"""
组合起来最终的layout
"""
sorted_layouts[i] = {
"layout_bbox": [x0, y0, x1, y1],
"layout_label": layout_label,
"sub_layout": v_split_layouts
'layout_bbox': [x0, y0, x1, y1],
'layout_label': layout_label,
'sub_layout': v_split_layouts,
}
layout['layout_label'] = LAYOUT_H
"""
水平和垂直方向都切分完毕了。此时还有一些未处理的,这些未处理的可能是因为水平和垂直方向都无法切分。
这些最后调用_try_horizontal_mult_block_split做一次水平多个block的联合切分,如果也不能切分最终就当做BAD_LAYOUT返回
......@@ -617,7 +778,7 @@ def split_layout(bboxes:list, boundry:tuple, page_num:int)-> list:
return sorted_layouts
def get_bboxes_layout(all_boxes:list, boundry:tuple, page_id:int):
def get_bboxes_layout(all_boxes: list, boundary: tuple, page_id: int):
"""
对利用layout排序之后的box,进行排序
return:
......@@ -628,10 +789,10 @@ def get_bboxes_layout(all_boxes:list, boundry:tuple, page_id:int):
},
]
"""
def _preorder_traversal(layout):
"""
对sorted_layouts的叶子节点,也就是len(sub_layout)==0的节点进行排序。排序按照前序遍历的顺序,也就是从上到下,从左到右的顺序
"""
"""对sorted_layouts的叶子节点,也就是len(sub_layout)==0的节点进行排序。排序按照前序遍历的顺序,也就是从上到
下,从左到右的顺序."""
sorted_layout_blocks = []
for layout in layout:
sub_layout = layout['sub_layout']
......@@ -641,71 +802,89 @@ def get_bboxes_layout(all_boxes:list, boundry:tuple, page_id:int):
s = _preorder_traversal(sub_layout)
sorted_layout_blocks.extend(s)
return sorted_layout_blocks
# -------------------------------------------------------------------------------------------------------------------------
sorted_layouts = split_layout(all_boxes, boundry, page_id)# 先切分成layout,得到一个Tree
sorted_layouts = split_layout(
all_boxes, boundary, page_id
) # 先切分成layout,得到一个Tree
total_sorted_layout_blocks = _preorder_traversal(sorted_layouts)
return total_sorted_layout_blocks, sorted_layouts
def get_columns_cnt_of_layout(layout_tree):
"""
获取一个layout的宽度
"""
"""获取一个layout的宽度."""
max_width_list = [0] # 初始化一个元素,防止max,min函数报错
for items in layout_tree: # 针对每一层(横切)计算列数,横着的算一列
layout_type = items['layout_label']
sub_layouts = items['sub_layout']
if len(sub_layouts)==0:
if len(sub_layouts) == 0:
max_width_list.append(1)
else:
if layout_type == LAYOUT_H:
max_width_list.append(1)
else:
width = 0
for l in sub_layouts:
if len(l['sub_layout']) == 0:
for sub_layout in sub_layouts:
if len(sub_layout['sub_layout']) == 0:
width += 1
else:
for lay in l['sub_layout']:
for lay in sub_layout['sub_layout']:
width += get_columns_cnt_of_layout([lay])
max_width_list.append(width)
return max(max_width_list)
def sort_with_layout(bboxes: list, page_width, page_height) -> (list, list):
"""输入是一个bbox的list.
def sort_with_layout(bboxes:list, page_width, page_height) -> (list,list):
"""
输入是一个bbox的list.
获取到输入之后,先进行layout切分,然后对这些bbox进行排序。返回排序后的bboxes
"""
new_bboxes = []
for box in bboxes:
# new_bboxes.append([box[0], box[1], box[2], box[3], None, None, None, 'text', None, None, None, None])
new_bboxes.append([box[0], box[1], box[2], box[3], None, None, None, 'text', None, None, None, None, box[4]])
new_bboxes.append(
[
box[0],
box[1],
box[2],
box[3],
None,
None,
None,
'text',
None,
None,
None,
None,
box[4],
]
)
layout_bboxes, _ = get_bboxes_layout(new_bboxes, [0, 0, page_width, page_height], 0)
if any([lay['layout_label']==LAYOUT_UNPROC for lay in layout_bboxes]):
logger.warning(f"drop this pdf, reason: 复杂版面")
return None,None
layout_bboxes, _ = get_bboxes_layout(
new_bboxes, tuple([0, 0, page_width, page_height]), 0
)
if any([lay['layout_label'] == LAYOUT_UNPROC for lay in layout_bboxes]):
logger.warning('drop this pdf, reason: 复杂版面')
return None, None
sorted_bboxes = []
# 利用layout bbox每次框定一些box,然后排序
for layout in layout_bboxes:
lbox = layout['layout_bbox']
bbox_in_layout = get_bbox_in_boundry(new_bboxes, lbox)
sorted_bbox = paper_bbox_sort(bbox_in_layout, lbox[2]-lbox[0], lbox[3]-lbox[1])
bbox_in_layout = get_bbox_in_boundary(new_bboxes, lbox)
sorted_bbox = paper_bbox_sort(
bbox_in_layout, lbox[2] - lbox[0], lbox[3] - lbox[1]
)
sorted_bboxes.extend(sorted_bbox)
return sorted_bboxes, layout_bboxes
def sort_text_block(text_block, layout_bboxes):
"""
对一页的text_block进行排序
"""
"""对一页的text_block进行排序."""
sorted_text_bbox = []
all_text_bbox = []
# 做一个box=>text的映射
......@@ -722,10 +901,20 @@ def sort_text_block(text_block, layout_bboxes):
# 按照layout_bboxes的顺序,对text_block进行排序
for layout in layout_bboxes:
layout_box = layout['layout_bbox']
text_bbox_in_layout = get_bbox_in_boundry(all_text_bbox, [layout_box[0]-1, layout_box[1]-1, layout_box[2]+1, layout_box[3]+1])
#sorted_bbox = paper_bbox_sort(text_bbox_in_layout, layout_box[2]-layout_box[0], layout_box[3]-layout_box[1])
text_bbox_in_layout.sort(key = lambda x: x[1]) # 一个layout内部的box,按照y0自上而下排序
#sorted_bbox = [[b] for b in text_blocks_to_sort]
text_bbox_in_layout = get_bbox_in_boundary(
all_text_bbox,
[
layout_box[0] - 1,
layout_box[1] - 1,
layout_box[2] + 1,
layout_box[3] + 1,
],
)
# sorted_bbox = paper_bbox_sort(text_bbox_in_layout, layout_box[2]-layout_box[0], layout_box[3]-layout_box[1])
text_bbox_in_layout.sort(
key=lambda x: x[1]
) # 一个layout内部的box,按照y0自上而下排序
# sorted_bbox = [[b] for b in text_blocks_to_sort]
for sb in text_bbox_in_layout:
sorted_text_bbox.append(box_to_text[(sb[0], sb[1], sb[2], sb[3])])
......
from loguru import logger
import math
def _is_in_or_part_overlap(box1, box2) -> bool:
"""
两个bbox是否有部分重叠或者包含
"""
"""两个bbox是否有部分重叠或者包含."""
if box1 is None or box2 is None:
return False
......@@ -18,11 +14,11 @@ def _is_in_or_part_overlap(box1, box2) -> bool:
y1_1 < y0_2 or # box1在box2的上边
y0_1 > y1_2) # box1在box2的下边
def _is_in_or_part_overlap_with_area_ratio(box1, box2, area_ratio_threshold=0.6):
"""
判断box1是否在box2里面,或者box1和box2有部分重叠,且重叠面积占box1的比例超过area_ratio_threshold
"""
def _is_in_or_part_overlap_with_area_ratio(box1,
box2,
area_ratio_threshold=0.6):
"""判断box1是否在box2里面,或者box1和box2有部分重叠,且重叠面积占box1的比例超过area_ratio_threshold."""
if box1 is None or box2 is None:
return False
......@@ -46,9 +42,7 @@ def _is_in_or_part_overlap_with_area_ratio(box1, box2, area_ratio_threshold=0.6)
def _is_in(box1, box2) -> bool:
"""
box1是否完全在box2里面
"""
"""box1是否完全在box2里面."""
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
......@@ -57,49 +51,48 @@ def _is_in(box1, box2) -> bool:
x1_1 <= x1_2 and # box1的右边界不在box2的右边外
y1_1 <= y1_2) # box1的下边界不在box2的下边外
def _is_part_overlap(box1, box2) -> bool:
"""
两个bbox是否有部分重叠,但不完全包含
"""
"""两个bbox是否有部分重叠,但不完全包含."""
if box1 is None or box2 is None:
return False
return _is_in_or_part_overlap(box1, box2) and not _is_in(box1, box2)
def _left_intersect(left_box, right_box):
"检查两个box的左边界是否有交集,也就是left_box的右边界是否在right_box的左边界内"
"""检查两个box的左边界是否有交集,也就是left_box的右边界是否在right_box的左边界内."""
if left_box is None or right_box is None:
return False
x0_1, y0_1, x1_1, y1_1 = left_box
x0_2, y0_2, x1_2, y1_2 = right_box
return x1_1>x0_2 and x0_1<x0_2 and (y0_1<=y0_2<=y1_1 or y0_1<=y1_2<=y1_1)
return x1_1 > x0_2 and x0_1 < x0_2 and (y0_1 <= y0_2 <= y1_1
or y0_1 <= y1_2 <= y1_1)
def _right_intersect(left_box, right_box):
"""
检查box是否在右侧边界有交集,也就是left_box的左边界是否在right_box的右边界内
"""
"""检查box是否在右侧边界有交集,也就是left_box的左边界是否在right_box的右边界内."""
if left_box is None or right_box is None:
return False
x0_1, y0_1, x1_1, y1_1 = left_box
x0_2, y0_2, x1_2, y1_2 = right_box
return x0_1<x1_2 and x1_1>x1_2 and (y0_1<=y0_2<=y1_1 or y0_1<=y1_2<=y1_1)
return x0_1 < x1_2 and x1_1 > x1_2 and (y0_1 <= y0_2 <= y1_1
or y0_1 <= y1_2 <= y1_1)
def _is_vertical_full_overlap(box1, box2, x_torlence=2):
"""
x方向上:要么box1包含box2, 要么box2包含box1。不能部分包含
y方向上:box1和box2有重叠
"""
"""x方向上:要么box1包含box2, 要么box2包含box1。不能部分包含 y方向上:box1和box2有重叠."""
# 解析box的坐标
x11, y11, x12, y12 = box1 # 左上角和右下角的坐标 (x1, y1, x2, y2)
x21, y21, x22, y22 = box2
# 在x轴方向上,box1是否包含box2 或 box2包含box1
contains_in_x = (x11-x_torlence <= x21 and x12+x_torlence >= x22) or (x21-x_torlence <= x11 and x22+x_torlence >= x12)
contains_in_x = (x11 - x_torlence <= x21 and x12 + x_torlence >= x22) or (
x21 - x_torlence <= x11 and x22 + x_torlence >= x12)
# 在y轴方向上,box1和box2是否有重叠
overlap_in_y = not (y12 < y21 or y11 > y22)
......@@ -108,26 +101,31 @@ def _is_vertical_full_overlap(box1, box2, x_torlence=2):
def _is_bottom_full_overlap(box1, box2, y_tolerance=2):
"""
检查box1下方和box2的上方有轻微的重叠,轻微程度收到y_tolerance的限制
这个函数和_is_vertical-full_overlap的区别是,这个函数允许box1和box2在x方向上有轻微的重叠,允许一定的模糊度
"""
"""检查box1下方和box2的上方有轻微的重叠,轻微程度收到y_tolerance的限制 这个函数和_is_vertical-
full_overlap的区别是,这个函数允许box1和box2在x方向上有轻微的重叠,允许一定的模糊度."""
if box1 is None or box2 is None:
return False
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
tolerance_margin = 2
is_xdir_full_overlap = ((x0_1-tolerance_margin<=x0_2<=x1_1+tolerance_margin and x0_1-tolerance_margin<=x1_2<=x1_1+tolerance_margin) or (x0_2-tolerance_margin<=x0_1<=x1_2+tolerance_margin and x0_2-tolerance_margin<=x1_1<=x1_2+tolerance_margin))
is_xdir_full_overlap = (
(x0_1 - tolerance_margin <= x0_2 <= x1_1 + tolerance_margin
and x0_1 - tolerance_margin <= x1_2 <= x1_1 + tolerance_margin)
or (x0_2 - tolerance_margin <= x0_1 <= x1_2 + tolerance_margin
and x0_2 - tolerance_margin <= x1_1 <= x1_2 + tolerance_margin))
return y0_2<y1_1 and 0<(y1_1-y0_2)<y_tolerance and is_xdir_full_overlap
return y0_2 < y1_1 and 0 < (y1_1 -
y0_2) < y_tolerance and is_xdir_full_overlap
def _is_left_overlap(
box1,
box2,
):
"""检查box1的左侧是否和box2有重叠 在Y方向上可以是部分重叠或者是完全重叠。不分box1和box2的上下关系,也就是无论box1在box2下
方还是box2在box1下方,都可以检测到重叠。 X方向上."""
def _is_left_overlap(box1, box2,):
"""
检查box1的左侧是否和box2有重叠
在Y方向上可以是部分重叠或者是完全重叠。不分box1和box2的上下关系,也就是无论box1在box2下方还是box2在box1下方,都可以检测到重叠。
X方向上
"""
def __overlap_y(Ay1, Ay2, By1, By2):
return max(0, min(Ay2, By2) - max(Ay1, By1))
......@@ -138,31 +136,31 @@ def _is_left_overlap(box1, box2,):
x0_2, y0_2, x1_2, y1_2 = box2
y_overlap_len = __overlap_y(y0_1, y1_1, y0_2, y1_2)
ratio_1 = 1.0 * y_overlap_len / (y1_1 - y0_1) if y1_1-y0_1!=0 else 0
ratio_2 = 1.0 * y_overlap_len / (y1_2 - y0_2) if y1_2-y0_2!=0 else 0
ratio_1 = 1.0 * y_overlap_len / (y1_1 - y0_1) if y1_1 - y0_1 != 0 else 0
ratio_2 = 1.0 * y_overlap_len / (y1_2 - y0_2) if y1_2 - y0_2 != 0 else 0
vertical_overlap_cond = ratio_1 >= 0.5 or ratio_2 >= 0.5
#vertical_overlap_cond = y0_1<=y0_2<=y1_1 or y0_1<=y1_2<=y1_1 or y0_2<=y0_1<=y1_2 or y0_2<=y1_1<=y1_2
return x0_1<=x0_2<=x1_1 and vertical_overlap_cond
# vertical_overlap_cond = y0_1<=y0_2<=y1_1 or y0_1<=y1_2<=y1_1 or y0_2<=y0_1<=y1_2 or y0_2<=y1_1<=y1_2
return x0_1 <= x0_2 <= x1_1 and vertical_overlap_cond
def __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8):
def __is_overlaps_y_exceeds_threshold(bbox1,
bbox2,
overlap_ratio_threshold=0.8):
"""检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%"""
_, y0_1, _, y1_1 = bbox1
_, y0_2, _, y1_2 = bbox2
overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
height1, height2 = y1_1 - y0_1, y1_2 - y0_2
max_height = max(height1, height2)
# max_height = max(height1, height2)
min_height = min(height1, height2)
return (overlap / min_height) > overlap_ratio_threshold
def calculate_iou(bbox1, bbox2):
"""
计算两个边界框的交并比(IOU)。
"""计算两个边界框的交并比(IOU)。
Args:
bbox1 (list[float]): 第一个边界框的坐标,格式为 [x1, y1, x2, y2],其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
......@@ -170,7 +168,6 @@ def calculate_iou(bbox1, bbox2):
Returns:
float: 两个边界框的交并比(IOU),取值范围为 [0, 1]。
"""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
......@@ -190,14 +187,13 @@ def calculate_iou(bbox1, bbox2):
# Compute the intersection over union by taking the intersection area
# and dividing it by the sum of both areas minus the intersection area
iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
iou = intersection_area / float(bbox1_area + bbox2_area -
intersection_area)
return iou
def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
"""
计算box1和box2的重叠面积占最小面积的box的比例
"""
"""计算box1和box2的重叠面积占最小面积的box的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
......@@ -209,16 +205,16 @@ def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
# The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
min_box_area = min([(bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]), (bbox2[3]-bbox2[1])*(bbox2[2]-bbox2[0])])
if min_box_area==0:
min_box_area = min([(bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]),
(bbox2[3] - bbox2[1]) * (bbox2[2] - bbox2[0])])
if min_box_area == 0:
return 0
else:
return intersection_area / min_box_area
def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
"""
计算box1和box2的重叠面积占bbox1的比例
"""
"""计算box1和box2的重叠面积占bbox1的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
......@@ -230,7 +226,7 @@ def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
# The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
bbox1_area = (bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1])
bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
if bbox1_area == 0:
return 0
else:
......@@ -238,11 +234,8 @@ def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
"""
通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
如果比例大于ratio,则返回小的那个bbox,
否则返回None
"""
"""通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
如果比例大于ratio,则返回小的那个bbox, 否则返回None."""
x1_min, y1_min, x1_max, y1_max = bbox1
x2_min, y2_min, x2_max, y2_max = bbox2
area1 = (x1_max - x1_min) * (y1_max - y1_min)
......@@ -256,33 +249,47 @@ def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
else:
return None
def get_bbox_in_boundry(bboxes:list, boundry:tuple)-> list:
x0, y0, x1, y1 = boundry
new_boxes = [box for box in bboxes if box[0] >= x0 and box[1] >= y0 and box[2] <= x1 and box[3] <= y1]
def get_bbox_in_boundary(bboxes: list, boundary: tuple) -> list:
x0, y0, x1, y1 = boundary
new_boxes = [
box for box in bboxes
if box[0] >= x0 and box[1] >= y0 and box[2] <= x1 and box[3] <= y1
]
return new_boxes
def is_vbox_on_side(bbox, width, height, side_threshold=0.2):
"""
判断一个bbox是否在pdf页面的边缘
"""
"""判断一个bbox是否在pdf页面的边缘."""
x0, x1 = bbox[0], bbox[2]
if x1<=width*side_threshold or x0>=width*(1-side_threshold):
if x1 <= width * side_threshold or x0 >= width * (1 - side_threshold):
return True
return False
def find_top_nearest_text_bbox(pymu_blocks, obj_bbox):
tolerance_margin = 4
top_boxes = [box for box in pymu_blocks if obj_bbox[1]-box['bbox'][3] >=-tolerance_margin and not _is_in(box['bbox'], obj_bbox)]
top_boxes = [
box for box in pymu_blocks
if obj_bbox[1] - box['bbox'][3] >= -tolerance_margin
and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
top_boxes = [box for box in top_boxes if any([obj_bbox[0]-tolerance_margin <=box['bbox'][0]<=obj_bbox[2]+tolerance_margin,
obj_bbox[0]-tolerance_margin <=box['bbox'][2]<=obj_bbox[2]+tolerance_margin,
box['bbox'][0]-tolerance_margin <=obj_bbox[0]<=box['bbox'][2]+tolerance_margin,
box['bbox'][0]-tolerance_margin <=obj_bbox[2]<=box['bbox'][2]+tolerance_margin
])]
top_boxes = [
box for box in top_boxes if any([
obj_bbox[0] - tolerance_margin <= box['bbox'][0] <= obj_bbox[2] +
tolerance_margin, obj_bbox[0] -
tolerance_margin <= box['bbox'][2] <= obj_bbox[2] +
tolerance_margin, box['bbox'][0] -
tolerance_margin <= obj_bbox[0] <= box['bbox'][2] +
tolerance_margin, box['bbox'][0] -
tolerance_margin <= obj_bbox[2] <= box['bbox'][2] +
tolerance_margin
])
]
# 然后找到y1最大的那个
if len(top_boxes)>0:
if len(top_boxes) > 0:
top_boxes.sort(key=lambda x: x['bbox'][3], reverse=True)
return top_boxes[0]
else:
......@@ -290,35 +297,46 @@ def find_top_nearest_text_bbox(pymu_blocks, obj_bbox):
def find_bottom_nearest_text_bbox(pymu_blocks, obj_bbox):
bottom_boxes = [box for box in pymu_blocks if box['bbox'][1] - obj_bbox[3]>=-2 and not _is_in(box['bbox'], obj_bbox)]
bottom_boxes = [
box for box in pymu_blocks if box['bbox'][1] -
obj_bbox[3] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
bottom_boxes = [box for box in bottom_boxes if any([obj_bbox[0]-2 <=box['bbox'][0]<=obj_bbox[2]+2,
obj_bbox[0]-2 <=box['bbox'][2]<=obj_bbox[2]+2,
box['bbox'][0]-2 <=obj_bbox[0]<=box['bbox'][2]+2,
box['bbox'][0]-2 <=obj_bbox[2]<=box['bbox'][2]+2
])]
bottom_boxes = [
box for box in bottom_boxes if any([
obj_bbox[0] - 2 <= box['bbox'][0] <= obj_bbox[2] + 2, obj_bbox[0] -
2 <= box['bbox'][2] <= obj_bbox[2] + 2, box['bbox'][0] -
2 <= obj_bbox[0] <= box['bbox'][2] + 2, box['bbox'][0] -
2 <= obj_bbox[2] <= box['bbox'][2] + 2
])
]
# 然后找到y0最小的那个
if len(bottom_boxes)>0:
if len(bottom_boxes) > 0:
bottom_boxes.sort(key=lambda x: x['bbox'][1], reverse=False)
return bottom_boxes[0]
else:
return None
def find_left_nearest_text_bbox(pymu_blocks, obj_bbox):
"""
寻找左侧最近的文本block
"""
left_boxes = [box for box in pymu_blocks if obj_bbox[0]-box['bbox'][2]>=-2 and not _is_in(box['bbox'], obj_bbox)]
"""寻找左侧最近的文本block."""
left_boxes = [
box for box in pymu_blocks if obj_bbox[0] -
box['bbox'][2] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
left_boxes = [box for box in left_boxes if any([obj_bbox[1]-2 <=box['bbox'][1]<=obj_bbox[3]+2,
obj_bbox[1]-2 <=box['bbox'][3]<=obj_bbox[3]+2,
box['bbox'][1]-2 <=obj_bbox[1]<=box['bbox'][3]+2,
box['bbox'][1]-2 <=obj_bbox[3]<=box['bbox'][3]+2
])]
left_boxes = [
box for box in left_boxes if any([
obj_bbox[1] - 2 <= box['bbox'][1] <= obj_bbox[3] + 2, obj_bbox[1] -
2 <= box['bbox'][3] <= obj_bbox[3] + 2, box['bbox'][1] -
2 <= obj_bbox[1] <= box['bbox'][3] + 2, box['bbox'][1] -
2 <= obj_bbox[3] <= box['bbox'][3] + 2
])
]
# 然后找到x1最大的那个
if len(left_boxes)>0:
if len(left_boxes) > 0:
left_boxes.sort(key=lambda x: x['bbox'][2], reverse=True)
return left_boxes[0]
else:
......@@ -326,19 +344,23 @@ def find_left_nearest_text_bbox(pymu_blocks, obj_bbox):
def find_right_nearest_text_bbox(pymu_blocks, obj_bbox):
"""
寻找右侧最近的文本block
"""
right_boxes = [box for box in pymu_blocks if box['bbox'][0]-obj_bbox[2]>=-2 and not _is_in(box['bbox'], obj_bbox)]
"""寻找右侧最近的文本block."""
right_boxes = [
box for box in pymu_blocks if box['bbox'][0] -
obj_bbox[2] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
right_boxes = [box for box in right_boxes if any([obj_bbox[1]-2 <=box['bbox'][1]<=obj_bbox[3]+2,
obj_bbox[1]-2 <=box['bbox'][3]<=obj_bbox[3]+2,
box['bbox'][1]-2 <=obj_bbox[1]<=box['bbox'][3]+2,
box['bbox'][1]-2 <=obj_bbox[3]<=box['bbox'][3]+2
])]
right_boxes = [
box for box in right_boxes if any([
obj_bbox[1] - 2 <= box['bbox'][1] <= obj_bbox[3] + 2, obj_bbox[1] -
2 <= box['bbox'][3] <= obj_bbox[3] + 2, box['bbox'][1] -
2 <= obj_bbox[1] <= box['bbox'][3] + 2, box['bbox'][1] -
2 <= obj_bbox[3] <= box['bbox'][3] + 2
])
]
# 然后找到x0最小的那个
if len(right_boxes)>0:
if len(right_boxes) > 0:
right_boxes.sort(key=lambda x: x['bbox'][0], reverse=False)
return right_boxes[0]
else:
......@@ -346,8 +368,7 @@ def find_right_nearest_text_bbox(pymu_blocks, obj_bbox):
def bbox_relative_pos(bbox1, bbox2):
"""
判断两个矩形框的相对位置关系
"""判断两个矩形框的相对位置关系.
Args:
bbox1: 一个四元组,表示第一个矩形框的左上角和右下角的坐标,格式为(x1, y1, x1b, y1b)
......@@ -357,7 +378,6 @@ def bbox_relative_pos(bbox1, bbox2):
一个四元组,表示矩形框1相对于矩形框2的位置关系,格式为(left, right, bottom, top)
其中,left表示矩形框1是否在矩形框2的左侧,right表示矩形框1是否在矩形框2的右侧,
bottom表示矩形框1是否在矩形框2的下方,top表示矩形框1是否在矩形框2的上方
"""
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
......@@ -368,9 +388,9 @@ def bbox_relative_pos(bbox1, bbox2):
top = y1b < y2
return left, right, bottom, top
def bbox_distance(bbox1, bbox2):
"""
计算两个矩形框的距离。
"""计算两个矩形框的距离。
Args:
bbox1 (tuple): 第一个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
......@@ -378,10 +398,11 @@ def bbox_distance(bbox1, bbox2):
Returns:
float: 矩形框之间的距离。
"""
def dist(point1, point2):
return math.sqrt((point1[0]-point2[0])**2 + (point1[1]-point2[1])**2)
return math.sqrt((point1[0] - point2[0])**2 +
(point1[1] - point2[1])**2)
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
......@@ -404,5 +425,4 @@ def bbox_distance(bbox1, bbox2):
return y1 - y2b
elif top:
return y2 - y1b
else: # rectangles intersect
return 0
\ No newline at end of file
return 0.0
......@@ -71,6 +71,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
tables_list, tables_body_list = [], []
tables_caption_list, tables_footnote_list = [], []
imgs_list, imgs_body_list, imgs_caption_list = [], [], []
imgs_footnote_list = []
titles_list = []
texts_list = []
interequations_list = []
......@@ -78,7 +79,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
page_layout_list = []
page_dropped_list = []
tables, tables_body, tables_caption, tables_footnote = [], [], [], []
imgs, imgs_body, imgs_caption = [], [], []
imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
titles = []
texts = []
interequations = []
......@@ -108,6 +109,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
imgs_body.append(bbox)
elif nested_block['type'] == BlockType.ImageCaption:
imgs_caption.append(bbox)
elif nested_block['type'] == BlockType.ImageFootnote:
imgs_footnote.append(bbox)
elif block['type'] == BlockType.Title:
titles.append(bbox)
elif block['type'] == BlockType.Text:
......@@ -121,6 +124,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
imgs_list.append(imgs)
imgs_body_list.append(imgs_body)
imgs_caption_list.append(imgs_caption)
imgs_footnote_list.append(imgs_footnote)
titles_list.append(titles)
texts_list.append(texts)
interequations_list.append(interequations)
......@@ -142,6 +146,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255],
True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
True),
draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_without_number(i, interequations_list, page, [0, 255, 0],
......@@ -241,7 +247,7 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
dropped_bbox_list = []
tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
imgs_body_list, imgs_caption_list = [], []
imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
titles_list = []
texts_list = []
interequations_list = []
......@@ -250,7 +256,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
for i in range(len(model_list)):
page_dropped_list = []
tables_body, tables_caption, tables_footnote = [], [], []
imgs_body, imgs_caption = [], []
imgs_body, imgs_caption, imgs_footnote = [], [], []
titles = []
texts = []
interequations = []
......@@ -277,6 +283,8 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
interequations.append(bbox)
elif layout_det['category_id'] == CategoryId.Abandon:
page_dropped_list.append(bbox)
elif layout_det['category_id'] == CategoryId.ImageFootnote:
imgs_footnote.append(bbox)
tables_body_list.append(tables_body)
tables_caption_list.append(tables_caption)
......@@ -287,6 +295,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
texts_list.append(texts)
interequations_list.append(interequations)
dropped_bbox_list.append(page_dropped_list)
imgs_footnote_list.append(imgs_footnote)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158],
......@@ -299,6 +308,8 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255],
True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
True)
draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
......
class ContentType:
Image = "image"
Table = "table"
Text = "text"
InlineEquation = "inline_equation"
InterlineEquation = "interline_equation"
Image = 'image'
Table = 'table'
Text = 'text'
InlineEquation = 'inline_equation'
InterlineEquation = 'interline_equation'
class BlockType:
Image = "image"
ImageBody = "image_body"
ImageCaption = "image_caption"
Table = "table"
TableBody = "table_body"
TableCaption = "table_caption"
TableFootnote = "table_footnote"
Text = "text"
Title = "title"
InterlineEquation = "interline_equation"
Footnote = "footnote"
Discarded = "discarded"
Image = 'image'
ImageBody = 'image_body'
ImageCaption = 'image_caption'
ImageFootnote = 'image_footnote'
Table = 'table'
TableBody = 'table_body'
TableCaption = 'table_caption'
TableFootnote = 'table_footnote'
Text = 'text'
Title = 'title'
InterlineEquation = 'interline_equation'
Footnote = 'footnote'
Discarded = 'discarded'
class CategoryId:
......@@ -33,3 +35,4 @@ class CategoryId:
InlineEquation = 13
InterlineEquation_YOLO = 14
OcrText = 15
ImageFootnote = 101
import json
import math
from magic_pdf.libs.commons import fitz
from loguru import logger
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio)
from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
from magic_pdf.libs.local_math import float_gt
from magic_pdf.libs.boxbase import (
_is_in,
bbox_relative_pos,
bbox_distance,
_is_part_overlap,
calculate_overlap_area_in_bbox1_area_ratio,
calculate_iou,
)
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO = 0.6
class MagicModel:
"""
每个函数没有得到元素的时候返回空list
"""
"""每个函数没有得到元素的时候返回空list."""
def __fix_axis(self):
for model_page_info in self.__model_list:
need_remove_list = []
page_no = model_page_info["page_info"]["page_no"]
page_no = model_page_info['page_info']['page_no']
horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
model_page_info, self.__docs[page_no]
)
layout_dets = model_page_info["layout_dets"]
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det.get("bbox") is not None:
if layout_det.get('bbox') is not None:
# 兼容直接输出bbox的模型数据,如paddle
x0, y0, x1, y1 = layout_det["bbox"]
x0, y0, x1, y1 = layout_det['bbox']
else:
# 兼容直接输出poly的模型数据,如xxx
x0, y0, _, _, x1, y1, _, _ = layout_det["poly"]
x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
bbox = [
int(x0 / horizontal_scale_ratio),
......@@ -52,7 +40,7 @@ class MagicModel:
int(x1 / horizontal_scale_ratio),
int(y1 / vertical_scale_ratio),
]
layout_det["bbox"] = bbox
layout_det['bbox'] = bbox
# 删除高度或者宽度小于等于0的spans
if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
need_remove_list.append(layout_det)
......@@ -62,9 +50,9 @@ class MagicModel:
def __fix_by_remove_low_confidence(self):
for model_page_info in self.__model_list:
need_remove_list = []
layout_dets = model_page_info["layout_dets"]
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det["score"] <= 0.05:
if layout_det['score'] <= 0.05:
need_remove_list.append(layout_det)
else:
continue
......@@ -74,12 +62,12 @@ class MagicModel:
def __fix_by_remove_high_iou_and_low_confidence(self):
for model_page_info in self.__model_list:
need_remove_list = []
layout_dets = model_page_info["layout_dets"]
layout_dets = model_page_info['layout_dets']
for layout_det1 in layout_dets:
for layout_det2 in layout_dets:
if layout_det1 == layout_det2:
continue
if layout_det1["category_id"] in [
if layout_det1['category_id'] in [
0,
1,
2,
......@@ -90,12 +78,12 @@ class MagicModel:
7,
8,
9,
] and layout_det2["category_id"] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if (
calculate_iou(layout_det1["bbox"], layout_det2["bbox"])
calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
> 0.9
):
if layout_det1["score"] < layout_det2["score"]:
if layout_det1['score'] < layout_det2['score']:
layout_det_need_remove = layout_det1
else:
layout_det_need_remove = layout_det2
......@@ -118,6 +106,67 @@ class MagicModel:
self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
self.__fix_footnote()
def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
for model_page_info in self.__model_list:
footnotes = []
figures = []
tables = []
for obj in model_page_info['layout_dets']:
if obj['category_id'] == 7:
footnotes.append(obj)
elif obj['category_id'] == 3:
figures.append(obj)
elif obj['category_id'] == 5:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_table_footnote[i] = min(
bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def __reduct_overlap(self, bboxes):
N = len(bboxes)
......@@ -126,76 +175,77 @@ class MagicModel:
for j in range(N):
if i == j:
continue
if _is_in(bboxes[i]["bbox"], bboxes[j]["bbox"]):
if _is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance(
self, page_no, subject_category_id, object_category_id
):
"""
假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object 只能属于一个 subject
"""
"""假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object
只能属于一个 subject."""
ret = []
MAX_DIS_OF_POINT = 10**9 + 7
"""
subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
# subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
# 再求出筛选出的 subjects 和 object 的最短距离!
def may_find_other_nearest_bbox(subject_idx, object_idx):
ret = float("inf")
ret = float('inf')
x0 = min(
all_bboxes[subject_idx]["bbox"][0], all_bboxes[object_idx]["bbox"][0]
all_bboxes[subject_idx]['bbox'][0], all_bboxes[object_idx]['bbox'][0]
)
y0 = min(
all_bboxes[subject_idx]["bbox"][1], all_bboxes[object_idx]["bbox"][1]
all_bboxes[subject_idx]['bbox'][1], all_bboxes[object_idx]['bbox'][1]
)
x1 = max(
all_bboxes[subject_idx]["bbox"][2], all_bboxes[object_idx]["bbox"][2]
all_bboxes[subject_idx]['bbox'][2], all_bboxes[object_idx]['bbox'][2]
)
y1 = max(
all_bboxes[subject_idx]["bbox"][3], all_bboxes[object_idx]["bbox"][3]
all_bboxes[subject_idx]['bbox'][3], all_bboxes[object_idx]['bbox'][3]
)
object_area = abs(
all_bboxes[object_idx]["bbox"][2] - all_bboxes[object_idx]["bbox"][0]
all_bboxes[object_idx]['bbox'][2] - all_bboxes[object_idx]['bbox'][0]
) * abs(
all_bboxes[object_idx]["bbox"][3] - all_bboxes[object_idx]["bbox"][1]
all_bboxes[object_idx]['bbox'][3] - all_bboxes[object_idx]['bbox'][1]
)
for i in range(len(all_bboxes)):
if (
i == subject_idx
or all_bboxes[i]["category_id"] != subject_category_id
or all_bboxes[i]['category_id'] != subject_category_id
):
continue
if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]["bbox"]) or _is_in(
all_bboxes[i]["bbox"], [x0, y0, x1, y1]
if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]['bbox']) or _is_in(
all_bboxes[i]['bbox'], [x0, y0, x1, y1]
):
i_area = abs(
all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]
) * abs(all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1])
all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
) * abs(all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1])
if i_area >= object_area:
ret = min(float("inf"), dis[i][object_idx])
ret = min(float('inf'), dis[i][object_idx])
return ret
def expand_bbbox(idxes):
x0s = [all_bboxes[idx]["bbox"][0] for idx in idxes]
y0s = [all_bboxes[idx]["bbox"][1] for idx in idxes]
x1s = [all_bboxes[idx]["bbox"][2] for idx in idxes]
y1s = [all_bboxes[idx]["bbox"][3] for idx in idxes]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
return min(x0s), min(y0s), max(x1s), max(y1s)
subjects = self.__reduct_overlap(
list(
map(
lambda x: {"bbox": x["bbox"], "score": x["score"]},
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x["category_id"] == subject_category_id,
self.__model_list[page_no]["layout_dets"],
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
......@@ -204,10 +254,10 @@ class MagicModel:
objects = self.__reduct_overlap(
list(
map(
lambda x: {"bbox": x["bbox"], "score": x["score"]},
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x["category_id"] == object_category_id,
self.__model_list[page_no]["layout_dets"],
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
......@@ -215,7 +265,7 @@ class MagicModel:
subject_object_relation_map = {}
subjects.sort(
key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2
key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2
) # get the distance !
all_bboxes = []
......@@ -223,18 +273,18 @@ class MagicModel:
for v in subjects:
all_bboxes.append(
{
"category_id": subject_category_id,
"bbox": v["bbox"],
"score": v["score"],
'category_id': subject_category_id,
'bbox': v['bbox'],
'score': v['score'],
}
)
for v in objects:
all_bboxes.append(
{
"category_id": object_category_id,
"bbox": v["bbox"],
"score": v["score"],
'category_id': object_category_id,
'bbox': v['bbox'],
'score': v['score'],
}
)
......@@ -244,18 +294,18 @@ class MagicModel:
for i in range(N):
for j in range(i):
if (
all_bboxes[i]["category_id"] == subject_category_id
and all_bboxes[j]["category_id"] == subject_category_id
all_bboxes[i]['category_id'] == subject_category_id
and all_bboxes[j]['category_id'] == subject_category_id
):
continue
dis[i][j] = bbox_distance(all_bboxes[i]["bbox"], all_bboxes[j]["bbox"])
dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
dis[j][i] = dis[i][j]
used = set()
for i in range(N):
# 求第 i 个 subject 所关联的 object
if all_bboxes[i]["category_id"] != subject_category_id:
if all_bboxes[i]['category_id'] != subject_category_id:
continue
seen = set()
candidates = []
......@@ -267,7 +317,7 @@ class MagicModel:
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
),
)
)
......@@ -275,25 +325,28 @@ class MagicModel:
if pos_flag_count > 1:
continue
if (
all_bboxes[j]["category_id"] != object_category_id
all_bboxes[j]['category_id'] != object_category_id
or j in used
or dis[i][j] == MAX_DIS_OF_POINT
):
continue
left, right, _, _ = bbox_relative_pos(
all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性
if left or right:
one_way_dis = all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]
one_way_dis = all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
else:
one_way_dis = all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1]
one_way_dis = all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1]
if dis[i][j] > one_way_dis:
continue
arr.append((dis[i][j], j))
arr.sort(key=lambda x: x[0])
if len(arr) > 0:
# bug: 离该subject 最近的 object 可能跨越了其它的 subject 。比如 [this subect] [some sbuject] [the nearest objec of subject]
"""
bug: 离该subject 最近的 object 可能跨越了其它的 subject。
比如 [this subect] [some sbuject] [the nearest object of subject]
"""
if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
candidates.append(arr[0][1])
......@@ -308,7 +361,7 @@ class MagicModel:
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
all_bboxes[j]["bbox"], all_bboxes[k]["bbox"]
all_bboxes[j]['bbox'], all_bboxes[k]['bbox']
),
)
)
......@@ -318,7 +371,7 @@ class MagicModel:
continue
if (
all_bboxes[k]["category_id"] != object_category_id
all_bboxes[k]['category_id'] != object_category_id
or k in used
or k in seen
or dis[j][k] == MAX_DIS_OF_POINT
......@@ -327,17 +380,19 @@ class MagicModel:
continue
is_nearest = True
for l in range(i + 1, N):
if l in (j, k) or l in used or l in seen:
for ni in range(i + 1, N):
if ni in (j, k) or ni in used or ni in seen:
continue
if not float_gt(dis[l][k], dis[j][k]):
if not float_gt(dis[ni][k], dis[j][k]):
is_nearest = False
break
if is_nearest:
nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
n_dis = bbox_distance(all_bboxes[i]["bbox"], [nx0, ny0, nx1, ny1])
n_dis = bbox_distance(
all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
)
if float_gt(dis[i][j], n_dis):
continue
tmp.append(k)
......@@ -350,7 +405,7 @@ class MagicModel:
# 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
# 先扩一下 bbox,
ox0, oy0, ox1, oy1 = expand_bbbox(list(seen) + [i])
ix0, iy0, ix1, iy1 = all_bboxes[i]["bbox"]
ix0, iy0, ix1, iy1 = all_bboxes[i]['bbox']
# 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
caption_poses = [
......@@ -366,17 +421,17 @@ class MagicModel:
for idx in seen:
if (
calculate_overlap_area_in_bbox1_area_ratio(
all_bboxes[idx]["bbox"], bbox
all_bboxes[idx]['bbox'], bbox
)
> CAPATION_OVERLAP_AREA_RATIO
):
embed_arr.append(idx)
if len(embed_arr) > 0:
embed_x0 = min([all_bboxes[idx]["bbox"][0] for idx in embed_arr])
embed_y0 = min([all_bboxes[idx]["bbox"][1] for idx in embed_arr])
embed_x1 = max([all_bboxes[idx]["bbox"][2] for idx in embed_arr])
embed_y1 = max([all_bboxes[idx]["bbox"][3] for idx in embed_arr])
embed_x0 = min([all_bboxes[idx]['bbox'][0] for idx in embed_arr])
embed_y0 = min([all_bboxes[idx]['bbox'][1] for idx in embed_arr])
embed_x1 = max([all_bboxes[idx]['bbox'][2] for idx in embed_arr])
embed_y1 = max([all_bboxes[idx]['bbox'][3] for idx in embed_arr])
caption_areas.append(
int(abs(embed_x1 - embed_x0) * abs(embed_y1 - embed_y0))
)
......@@ -391,7 +446,7 @@ class MagicModel:
for j in seen:
if (
calculate_overlap_area_in_bbox1_area_ratio(
all_bboxes[j]["bbox"], caption_bbox
all_bboxes[j]['bbox'], caption_bbox
)
> CAPATION_OVERLAP_AREA_RATIO
):
......@@ -400,30 +455,30 @@ class MagicModel:
for i in sorted(subject_object_relation_map.keys()):
result = {
"subject_body": all_bboxes[i]["bbox"],
"all": all_bboxes[i]["bbox"],
"score": all_bboxes[i]["score"],
'subject_body': all_bboxes[i]['bbox'],
'all': all_bboxes[i]['bbox'],
'score': all_bboxes[i]['score'],
}
if len(subject_object_relation_map[i]) > 0:
x0 = min(
[all_bboxes[j]["bbox"][0] for j in subject_object_relation_map[i]]
[all_bboxes[j]['bbox'][0] for j in subject_object_relation_map[i]]
)
y0 = min(
[all_bboxes[j]["bbox"][1] for j in subject_object_relation_map[i]]
[all_bboxes[j]['bbox'][1] for j in subject_object_relation_map[i]]
)
x1 = max(
[all_bboxes[j]["bbox"][2] for j in subject_object_relation_map[i]]
[all_bboxes[j]['bbox'][2] for j in subject_object_relation_map[i]]
)
y1 = max(
[all_bboxes[j]["bbox"][3] for j in subject_object_relation_map[i]]
)
result["object_body"] = [x0, y0, x1, y1]
result["all"] = [
min(x0, all_bboxes[i]["bbox"][0]),
min(y0, all_bboxes[i]["bbox"][1]),
max(x1, all_bboxes[i]["bbox"][2]),
max(y1, all_bboxes[i]["bbox"][3]),
[all_bboxes[j]['bbox'][3] for j in subject_object_relation_map[i]]
)
result['object_body'] = [x0, y0, x1, y1]
result['all'] = [
min(x0, all_bboxes[i]['bbox'][0]),
min(y0, all_bboxes[i]['bbox'][1]),
max(x1, all_bboxes[i]['bbox'][2]),
max(y1, all_bboxes[i]['bbox'][3]),
]
ret.append(result)
......@@ -432,7 +487,7 @@ class MagicModel:
for i in subject_object_relation_map.keys():
for j in subject_object_relation_map[i]:
total_subject_object_dis += bbox_distance(
all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
)
# 计算未匹配的 subject 和 object 的距离(非精确版)
......@@ -444,12 +499,12 @@ class MagicModel:
]
)
for i in range(N):
if all_bboxes[i]["category_id"] != object_category_id or i in used:
if all_bboxes[i]['category_id'] != object_category_id or i in used:
continue
candidates = []
for j in range(N):
if (
all_bboxes[j]["category_id"] != subject_category_id
all_bboxes[j]['category_id'] != subject_category_id
or j in with_caption_subject
):
continue
......@@ -461,18 +516,28 @@ class MagicModel:
return ret, total_subject_object_dis
def get_imgs(self, page_no: int):
figure_captions, _ = self.__tie_up_category_by_distance(
page_no, 3, 4
with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
with_footnotes, _ = self.__tie_up_category_by_distance(
page_no, 3, CategoryId.ImageFootnote
)
return [
{
"bbox": record["all"],
"img_body_bbox": record["subject_body"],
"img_caption_bbox": record.get("object_body", None),
"score": record["score"],
ret = []
N, M = len(with_captions), len(with_footnotes)
assert N == M
for i in range(N):
record = {
'score': with_captions[i]['score'],
'img_caption_bbox': with_captions[i].get('object_body', None),
'img_body_bbox': with_captions[i]['subject_body'],
'img_footnote_bbox': with_footnotes[i].get('object_body', None),
}
for record in figure_captions
]
x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
record['bbox'] = [x0, y0, x1, y1]
ret.append(record)
return ret
def get_tables(
self, page_no: int
......@@ -484,26 +549,26 @@ class MagicModel:
assert N == M
for i in range(N):
record = {
"score": with_captions[i]["score"],
"table_caption_bbox": with_captions[i].get("object_body", None),
"table_body_bbox": with_captions[i]["subject_body"],
"table_footnote_bbox": with_footnotes[i].get("object_body", None),
'score': with_captions[i]['score'],
'table_caption_bbox': with_captions[i].get('object_body', None),
'table_body_bbox': with_captions[i]['subject_body'],
'table_footnote_bbox': with_footnotes[i].get('object_body', None),
}
x0 = min(with_captions[i]["all"][0], with_footnotes[i]["all"][0])
y0 = min(with_captions[i]["all"][1], with_footnotes[i]["all"][1])
x1 = max(with_captions[i]["all"][2], with_footnotes[i]["all"][2])
y1 = max(with_captions[i]["all"][3], with_footnotes[i]["all"][3])
record["bbox"] = [x0, y0, x1, y1]
x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
record['bbox'] = [x0, y0, x1, y1]
ret.append(record)
return ret
def get_equations(self, page_no: int) -> list: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.EMBEDDING.value, page_no, ["latex"]
ModelBlockTypeEnum.EMBEDDING.value, page_no, ['latex']
)
interline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATED.value, page_no, ["latex"]
ModelBlockTypeEnum.ISOLATED.value, page_no, ['latex']
)
interline_equations_blocks = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no
......@@ -525,17 +590,18 @@ class MagicModel:
def get_ocr_text(self, page_no: int) -> list: # paddle 搞的,有字也有坐标
text_spans = []
model_page_info = self.__model_list[page_no]
layout_dets = model_page_info["layout_dets"]
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det["category_id"] == "15":
if layout_det['category_id'] == '15':
span = {
"bbox": layout_det["bbox"],
"content": layout_det["text"],
'bbox': layout_det['bbox'],
'content': layout_det['text'],
}
text_spans.append(span)
return text_spans
def get_all_spans(self, page_no: int) -> list:
def remove_duplicate_spans(spans):
new_spans = []
for span in spans:
......@@ -545,7 +611,7 @@ class MagicModel:
all_spans = []
model_page_info = self.__model_list[page_no]
layout_dets = model_page_info["layout_dets"]
layout_dets = model_page_info['layout_dets']
allow_category_id_list = [3, 5, 13, 14, 15]
"""当成span拼接的"""
# 3: 'image', # 图片
......@@ -554,11 +620,11 @@ class MagicModel:
# 14: 'interline_equation', # 行间公式
# 15: 'text', # ocr识别文本
for layout_det in layout_dets:
category_id = layout_det["category_id"]
category_id = layout_det['category_id']
if category_id in allow_category_id_list:
span = {"bbox": layout_det["bbox"], "score": layout_det["score"]}
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == 3:
span["type"] = ContentType.Image
span['type'] = ContentType.Image
elif category_id == 5:
# 获取table模型结果
latex = layout_det.get("latex", None)
......@@ -569,14 +635,14 @@ class MagicModel:
span["html"] = html
span["type"] = ContentType.Table
elif category_id == 13:
span["content"] = layout_det["latex"]
span["type"] = ContentType.InlineEquation
span['content'] = layout_det['latex']
span['type'] = ContentType.InlineEquation
elif category_id == 14:
span["content"] = layout_det["latex"]
span["type"] = ContentType.InterlineEquation
span['content'] = layout_det['latex']
span['type'] = ContentType.InterlineEquation
elif category_id == 15:
span["content"] = layout_det["text"]
span["type"] = ContentType.Text
span['content'] = layout_det['text']
span['type'] = ContentType.Text
all_spans.append(span)
return remove_duplicate_spans(all_spans)
......@@ -593,19 +659,19 @@ class MagicModel:
) -> list:
blocks = []
for page_dict in self.__model_list:
layout_dets = page_dict.get("layout_dets", [])
page_info = page_dict.get("page_info", {})
page_number = page_info.get("page_no", -1)
layout_dets = page_dict.get('layout_dets', [])
page_info = page_dict.get('page_info', {})
page_number = page_info.get('page_no', -1)
if page_no != page_number:
continue
for item in layout_dets:
category_id = item.get("category_id", -1)
bbox = item.get("bbox", None)
category_id = item.get('category_id', -1)
bbox = item.get('bbox', None)
if category_id == type:
block = {
"bbox": bbox,
"score": item.get("score"),
'bbox': bbox,
'score': item.get('score'),
}
for col in extra_col:
block[col] = item.get(col, None)
......@@ -616,28 +682,28 @@ class MagicModel:
return self.__model_list[page_no]
if __name__ == "__main__":
drw = DiskReaderWriter(r"D:/project/20231108code-clean")
if __name__ == '__main__':
drw = DiskReaderWriter(r'D:/project/20231108code-clean')
if 0:
pdf_file_path = r"linshixuqiu\19983-00.pdf"
model_file_path = r"linshixuqiu\19983-00_new.json"
pdf_file_path = r'linshixuqiu\19983-00.pdf'
model_file_path = r'linshixuqiu\19983-00_new.json'
pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
model_list = json.loads(model_json_txt)
write_path = r"D:\project\20231108code-clean\linshixuqiu\19983-00"
img_bucket_path = "imgs"
write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path = 'imgs'
img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
pdf_docs = fitz.open("pdf", pdf_bytes)
pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
if 1:
model_list = json.loads(
drw.read("/opt/data/pdf/20240418/j.chroma.2009.03.042.json")
drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json')
)
pdf_bytes = drw.read(
"/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf", AbsReaderWriter.MODE_BIN
'/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf', AbsReaderWriter.MODE_BIN
)
pdf_docs = fitz.open("pdf", pdf_bytes)
pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
for i in range(7):
print(magic_model.get_imgs(i))
from loguru import logger
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold, get_minbox_if_overlap_by_ratio, \
calculate_overlap_area_in_bbox1_area_ratio, _is_in_or_part_overlap_with_area_ratio
from magic_pdf.libs.boxbase import (__is_overlaps_y_exceeds_threshold,
_is_in_or_part_overlap_with_area_ratio,
calculate_overlap_area_in_bbox1_area_ratio)
from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.ocr_content_type import ContentType, BlockType
from magic_pdf.pre_proc.ocr_span_list_modify import modify_y_axis, modify_inline_equation
from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_span
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
# 将每一个line中的span从左到右排序
......@@ -21,8 +18,8 @@ def line_sort_spans_by_left_to_right(lines):
max(span['bbox'][3] for span in line), # y1
]
line_objects.append({
"bbox": line_bbox,
"spans": line,
'bbox': line_bbox,
'spans': line,
})
return line_objects
......@@ -39,16 +36,21 @@ def merge_spans_to_line(spans):
for span in spans[1:]:
# 如果当前的span类型为"interline_equation" 或者 当前行中已经有"interline_equation"
# image和table类型,同上
if span['type'] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table] or any(
s['type'] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table] for s in
current_line):
if span['type'] in [
ContentType.InterlineEquation, ContentType.Image,
ContentType.Table
] or any(s['type'] in [
ContentType.InterlineEquation, ContentType.Image,
ContentType.Table
] for s in current_line):
# 则开始新行
lines.append(current_line)
current_line = [span]
continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox']):
if __is_overlaps_y_exceeds_threshold(span['bbox'],
current_line[-1]['bbox']):
current_line.append(span)
else:
# 否则,开始新行
......@@ -71,7 +73,8 @@ def merge_spans_to_line_by_layout(spans, layout_bboxes):
# 遍历spans,将每个span放入对应的layout中
layout_sapns = []
for span in spans:
if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], layout_bbox) > 0.6:
if calculate_overlap_area_in_bbox1_area_ratio(
span['bbox'], layout_bbox) > 0.6:
layout_sapns.append(span)
# 如果layout_sapns不为空,则放入new_spans中
if len(layout_sapns) > 0:
......@@ -99,12 +102,10 @@ def merge_lines_to_block(lines):
# 目前不做block拼接,先做个结构,每个block中只有一个line,block的bbox就是line的bbox
blocks = []
for line in lines:
blocks.append(
{
"bbox": line["bbox"],
"lines": [line],
}
)
blocks.append({
'bbox': line['bbox'],
'lines': [line],
})
return blocks
......@@ -121,7 +122,8 @@ def sort_blocks_by_layout(all_bboxes, layout_bboxes):
if block[7] == BlockType.Footnote:
continue
block_bbox = block[:4]
if calculate_overlap_area_in_bbox1_area_ratio(block_bbox, layout_bbox) > 0.8:
if calculate_overlap_area_in_bbox1_area_ratio(
block_bbox, layout_bbox) > 0.8:
layout_blocks.append(block)
# 如果layout_blocks不为空,则放入new_blocks中
......@@ -134,7 +136,8 @@ def sort_blocks_by_layout(all_bboxes, layout_bboxes):
# 如果new_blocks不为空,则对new_blocks中每个block进行排序
if len(new_blocks) > 0:
for bboxes_in_layout_block in new_blocks:
bboxes_in_layout_block.sort(key=lambda x: x[1]) # 一个layout内部的box,按照y0自上而下排序
bboxes_in_layout_block.sort(
key=lambda x: x[1]) # 一个layout内部的box,按照y0自上而下排序
sort_blocks.extend(bboxes_in_layout_block)
# sort_blocks中已经包含了当前页面所有最终留下的block,且已经排好了顺序
......@@ -142,9 +145,7 @@ def sort_blocks_by_layout(all_bboxes, layout_bboxes):
def fill_spans_in_blocks(blocks, spans, radio):
'''
将allspans中的span按位置关系,放入blocks中
'''
"""将allspans中的span按位置关系,放入blocks中."""
block_with_spans = []
for block in blocks:
block_type = block[7]
......@@ -156,17 +157,15 @@ def fill_spans_in_blocks(blocks, spans, radio):
block_spans = []
for span in spans:
span_bbox = span['bbox']
if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > radio:
if calculate_overlap_area_in_bbox1_area_ratio(
span_bbox, block_bbox) > radio:
block_spans.append(span)
'''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)'''
# displayed_list = []
# text_inline_lines = []
# modify_y_axis(block_spans, displayed_list, text_inline_lines)
'''模型识别错误的行间公式, type类型转换成行内公式'''
# block_spans = modify_inline_equation(block_spans, displayed_list, text_inline_lines)
'''bbox去除粘连''' # 去粘连会影响span的bbox,导致后续fill的时候出错
# block_spans = remove_overlap_between_bbox_for_span(block_spans)
......@@ -182,12 +181,9 @@ def fill_spans_in_blocks(blocks, spans, radio):
def fix_block_spans(block_with_spans, img_blocks, table_blocks):
'''
1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
"""1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
需要将caption和footnote的text_span放入相应img_block和table_block内的
caption_block和footnote_block中
2、同时需要删除block中的spans字段
'''
caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
fix_blocks = []
for block in block_with_spans:
block_type = block['type']
......@@ -218,16 +214,13 @@ def merge_spans_to_block(spans: list, block_bbox: list, block_type: str):
block_spans = []
# 如果有img_caption,则将img_block中的text_spans放入img_caption_block中
for span in spans:
if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block_bbox) > 0.6:
if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'],
block_bbox) > 0.6:
block_spans.append(span)
block_lines = merge_spans_to_line(block_spans)
# 对line中的span进行排序
sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
block = {
'bbox': block_bbox,
'type': block_type,
'lines': sort_block_lines
}
block = {'bbox': block_bbox, 'type': block_type, 'lines': sort_block_lines}
return block, block_spans
......@@ -237,11 +230,7 @@ def make_body_block(span: dict, block_bbox: list, block_type: str):
'bbox': block_bbox,
'spans': [span],
}
body_block = {
'bbox': block_bbox,
'type': block_type,
'lines': [body_line]
}
body_block = {'bbox': block_bbox, 'type': block_type, 'lines': [body_line]}
return body_block
......@@ -249,13 +238,16 @@ def fix_image_block(block, img_blocks):
block['blocks'] = []
# 遍历img_blocks,找到与当前block匹配的img_block
for img_block in img_blocks:
if _is_in_or_part_overlap_with_area_ratio(block['bbox'], img_block['bbox'], 0.95):
if _is_in_or_part_overlap_with_area_ratio(block['bbox'],
img_block['bbox'], 0.95):
# 创建img_body_block
for span in block['spans']:
if span['type'] == ContentType.Image and img_block['img_body_bbox'] == span['bbox']:
if span['type'] == ContentType.Image and img_block[
'img_body_bbox'] == span['bbox']:
# 创建img_body_block
img_body_block = make_body_block(span, img_block['img_body_bbox'], BlockType.ImageBody)
img_body_block = make_body_block(
span, img_block['img_body_bbox'], BlockType.ImageBody)
block['blocks'].append(img_body_block)
# 从spans中移除img_body_block中已经放入的span
......@@ -265,10 +257,15 @@ def fix_image_block(block, img_blocks):
# 根据list长度,判断img_block中是否有img_caption
if img_block['img_caption_bbox'] is not None:
img_caption_block, img_caption_spans = merge_spans_to_block(
block['spans'], img_block['img_caption_bbox'], BlockType.ImageCaption
)
block['spans'], img_block['img_caption_bbox'],
BlockType.ImageCaption)
block['blocks'].append(img_caption_block)
if img_block['img_footnote_bbox'] is not None:
img_footnote_block, img_footnote_spans = merge_spans_to_block(
block['spans'], img_block['img_footnote_bbox'],
BlockType.ImageFootnote)
block['blocks'].append(img_footnote_block)
break
del block['spans']
return block
......@@ -278,13 +275,17 @@ def fix_table_block(block, table_blocks):
block['blocks'] = []
# 遍历table_blocks,找到与当前block匹配的table_block
for table_block in table_blocks:
if _is_in_or_part_overlap_with_area_ratio(block['bbox'], table_block['bbox'], 0.95):
if _is_in_or_part_overlap_with_area_ratio(block['bbox'],
table_block['bbox'], 0.95):
# 创建table_body_block
for span in block['spans']:
if span['type'] == ContentType.Table and table_block['table_body_bbox'] == span['bbox']:
if span['type'] == ContentType.Table and table_block[
'table_body_bbox'] == span['bbox']:
# 创建table_body_block
table_body_block = make_body_block(span, table_block['table_body_bbox'], BlockType.TableBody)
table_body_block = make_body_block(
span, table_block['table_body_bbox'],
BlockType.TableBody)
block['blocks'].append(table_body_block)
# 从spans中移除img_body_block中已经放入的span
......@@ -294,8 +295,8 @@ def fix_table_block(block, table_blocks):
# 根据list长度,判断table_block中是否有caption
if table_block['table_caption_bbox'] is not None:
table_caption_block, table_caption_spans = merge_spans_to_block(
block['spans'], table_block['table_caption_bbox'], BlockType.TableCaption
)
block['spans'], table_block['table_caption_bbox'],
BlockType.TableCaption)
block['blocks'].append(table_caption_block)
# 如果table_caption_block_spans不为空
......@@ -307,8 +308,8 @@ def fix_table_block(block, table_blocks):
# 根据list长度,判断table_block中是否有table_note
if table_block['table_footnote_bbox'] is not None:
table_footnote_block, table_footnote_spans = merge_spans_to_block(
block['spans'], table_block['table_footnote_bbox'], BlockType.TableFootnote
)
block['spans'], table_block['table_footnote_bbox'],
BlockType.TableFootnote)
block['blocks'].append(table_footnote_block)
break
......
# 欢迎来到 MinerU 项目列表
## 项目列表
- [llama_index_rag](./llama_index_rag/README.md): 基于 llama_index 构建轻量级 RAG 系统
import pytest
import os
from magic_pdf.libs.boxbase import _is_in_or_part_overlap, _is_in_or_part_overlap_with_area_ratio, _is_in, \
_is_part_overlap, _left_intersect, _right_intersect, _is_vertical_full_overlap, _is_bottom_full_overlap, \
_is_left_overlap, __is_overlaps_y_exceeds_threshold, calculate_iou, calculate_overlap_area_2_minbox_area_ratio, \
calculate_overlap_area_in_bbox1_area_ratio, get_minbox_if_overlap_by_ratio, get_bbox_in_boundry, \
find_top_nearest_text_bbox, find_bottom_nearest_text_bbox, find_left_nearest_text_bbox, \
find_right_nearest_text_bbox, bbox_relative_pos, bbox_distance
from magic_pdf.libs.commons import mymax, join_path, get_top_percent_list
import pytest
from magic_pdf.libs.boxbase import (__is_overlaps_y_exceeds_threshold,
_is_bottom_full_overlap, _is_in,
_is_in_or_part_overlap,
_is_in_or_part_overlap_with_area_ratio,
_is_left_overlap, _is_part_overlap,
_is_vertical_full_overlap, _left_intersect,
_right_intersect, bbox_distance,
bbox_relative_pos, calculate_iou,
calculate_overlap_area_2_minbox_area_ratio,
calculate_overlap_area_in_bbox1_area_ratio,
find_bottom_nearest_text_bbox,
find_left_nearest_text_bbox,
find_right_nearest_text_bbox,
find_top_nearest_text_bbox,
get_bbox_in_boundary,
get_minbox_if_overlap_by_ratio)
from magic_pdf.libs.commons import get_top_percent_list, join_path, mymax
from magic_pdf.libs.config_reader import get_s3_config
from magic_pdf.libs.path_utils import parse_s3path
# 输入一个列表,如果列表空返回0,否则返回最大元素
@pytest.mark.parametrize("list_input, target_num",
@pytest.mark.parametrize('list_input, target_num',
[
([0, 0, 0, 0], 0),
([0], 0),
......@@ -29,7 +41,7 @@ def test_list_max(list_input: list, target_num) -> None:
# 连接多个参数生成路径信息,使用"/"作为连接符,生成的结果需要是一个合法路径
@pytest.mark.parametrize("path_input, target_path", [
@pytest.mark.parametrize('path_input, target_path', [
(['https:', '', 'www.baidu.com'], 'https://www.baidu.com'),
(['https:', 'www.baidu.com'], 'https:/www.baidu.com'),
(['D:', 'file', 'pythonProject', 'demo' + '.py'], 'D:/file/pythonProject/demo.py'),
......@@ -42,7 +54,7 @@ def test_join_path(path_input: list, target_path: str) -> None:
# 获取列表中前百分之多少的元素
@pytest.mark.parametrize("num_list, percent, target_num_list", [
@pytest.mark.parametrize('num_list, percent, target_num_list', [
([], 0.75, []),
([-5, -10, 9, 3, 7, -7, 0, 23, -1, -11], 0.8, [23, 9, 7, 3, 0, -1, -5, -7]),
([-5, -10, 9, 3, 7, -7, 0, 23, -1, -11], 0, []),
......@@ -57,9 +69,9 @@ def test_get_top_percent_list(num_list: list, percent: float, target_num_list: l
# 输入一个s3路径,返回bucket名字和其余部分(key)
@pytest.mark.parametrize("s3_path, target_data", [
("s3://bucket/path/to/my/file.txt", "bucket"),
("s3a://bucket1/path/to/my/file2.txt", "bucket1"),
@pytest.mark.parametrize('s3_path, target_data', [
('s3://bucket/path/to/my/file.txt', 'bucket'),
('s3a://bucket1/path/to/my/file2.txt', 'bucket1'),
# ("/path/to/my/file1.txt", "path"),
# ("bucket/path/to/my/file2.txt", "bucket"),
])
......@@ -76,7 +88,7 @@ def test_parse_s3path(s3_path: str, target_data: str):
# 2个box是否处于包含或者部分重合关系。
# 如果某边界重合算重合。
# 部分边界重合,其他在内部也算包含
@pytest.mark.parametrize("box1, box2, target_bool", [
@pytest.mark.parametrize('box1, box2, target_bool', [
((120, 133, 223, 248), (128, 168, 269, 295), True),
((137, 53, 245, 157), (134, 11, 200, 147), True), # 部分重合
((137, 56, 211, 116), (140, 66, 202, 199), True), # 部分重合
......@@ -101,7 +113,7 @@ def test_is_in_or_part_overlap(box1: tuple, box2: tuple, target_bool: bool) -> N
# 如果box1在box2内部,返回True
# 如果是部分重合的,则重合面积占box1的比例大于阈值时候返回True
@pytest.mark.parametrize("box1, box2, target_bool", [
@pytest.mark.parametrize('box1, box2, target_bool', [
((35, 28, 108, 90), (47, 60, 83, 96), False), # 包含 box1 up box2, box2 多半,box1少半
((65, 151, 92, 177), (49, 99, 105, 198), True), # 包含 box1 in box2
((80, 62, 112, 84), (74, 40, 144, 111), True), # 包含 box1 in box2
......@@ -119,7 +131,7 @@ def test_is_in_or_part_overlap_with_area_ratio(box1: tuple, box2: tuple, target_
# box1在box2内部或者box2在box1内部返回True。如果部分边界重合也算作包含。
@pytest.mark.parametrize("box1, box2, target_bool", [
@pytest.mark.parametrize('box1, box2, target_bool', [
# ((), (), "Error"), # Error
((65, 151, 92, 177), (49, 99, 105, 198), True), # 包含 box1 in box2
((80, 62, 112, 84), (74, 40, 144, 111), True), # 包含 box1 in box2
......@@ -141,7 +153,7 @@ def test_is_in(box1: tuple, box2: tuple, target_bool: bool) -> None:
# 仅仅是部分包含关系,返回True,如果是完全包含关系则返回False
@pytest.mark.parametrize("box1, box2, target_bool", [
@pytest.mark.parametrize('box1, box2, target_bool', [
((65, 151, 92, 177), (49, 99, 105, 198), False), # 包含 box1 in box2
((80, 62, 112, 84), (74, 40, 144, 111), False), # 包含 box1 in box2
# ((76, 140, 154, 277), (121, 326, 192, 384), False), # 分离 Error
......@@ -161,7 +173,7 @@ def test_is_part_overlap(box1: tuple, box2: tuple, target_bool: bool) -> None:
# left_box右侧是否和right_box左侧有部分重叠
@pytest.mark.parametrize("box1, box2, target_bool", [
@pytest.mark.parametrize('box1, box2, target_bool', [
(None, None, False),
((88, 81, 222, 173), (60, 221, 123, 358), False), # 分离
((121, 149, 184, 289), (172, 130, 230, 268), True), # box1 left bottom box2 相交
......@@ -179,7 +191,7 @@ def test_left_intersect(box1: tuple, box2: tuple, target_bool: bool) -> None:
# left_box左侧是否和right_box右侧部分重叠
@pytest.mark.parametrize("box1, box2, target_bool", [
@pytest.mark.parametrize('box1, box2, target_bool', [
(None, None, False),
((88, 81, 222, 173), (60, 221, 123, 358), False), # 分离
((121, 149, 184, 289), (172, 130, 230, 268), False), # box1 left bottom box2 相交
......@@ -198,7 +210,7 @@ def test_right_intersect(box1: tuple, box2: tuple, target_bool: bool) -> None:
# x方向上:要么box1包含box2, 要么box2包含box1。不能部分包含
# y方向上:box1和box2有重叠
@pytest.mark.parametrize("box1, box2, target_bool", [
@pytest.mark.parametrize('box1, box2, target_bool', [
# (None, None, False), # Error
((35, 28, 108, 90), (47, 60, 83, 96), True), # box1 top box2, x:box2 in box1, y:有重叠
((35, 28, 98, 90), (27, 60, 103, 96), True), # box1 top box2, x:box1 in box2, y:有重叠
......@@ -214,7 +226,7 @@ def test_is_vertical_full_overlap(box1: tuple, box2: tuple, target_bool: bool) -
# 检查box1下方和box2的上方有轻微的重叠,轻微程度收到y_tolerance的限制
@pytest.mark.parametrize("box1, box2, target_bool", [
@pytest.mark.parametrize('box1, box2, target_bool', [
(None, None, False),
((35, 28, 108, 90), (47, 89, 83, 116), True), # box1 top box2, y:有重叠
((35, 28, 108, 90), (47, 60, 83, 96), False), # box1 top box2, y:有重叠且过多
......@@ -228,7 +240,7 @@ def test_is_bottom_full_overlap(box1: tuple, box2: tuple, target_bool: bool) ->
# 检查box1的左侧是否和box2有重叠
@pytest.mark.parametrize("box1, box2, target_bool", [
@pytest.mark.parametrize('box1, box2, target_bool', [
(None, None, False),
((88, 81, 222, 173), (60, 221, 123, 358), False), # 分离
# ((121, 149, 184, 289), (172, 130, 230, 268), False), # box1 left bottom box2 相交 Error
......@@ -245,7 +257,7 @@ def test_is_left_overlap(box1: tuple, box2: tuple, target_bool: bool) -> None:
# 查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过阈值
@pytest.mark.parametrize("box1, box2, target_bool", [
@pytest.mark.parametrize('box1, box2, target_bool', [
# (None, None, "Error"), # Error
((51, 69, 192, 147), (75, 48, 132, 187), True), # y: box1 in box2
((51, 39, 192, 197), (75, 48, 132, 187), True), # y: box2 in box1
......@@ -260,7 +272,7 @@ def test_is_overlaps_y_exceeds_threshold(box1: tuple, box2: tuple, target_bool:
# Determine the coordinates of the intersection rectangle
@pytest.mark.parametrize("box1, box2, target_num", [
@pytest.mark.parametrize('box1, box2, target_num', [
# (None, None, "Error"), # Error
((88, 81, 222, 173), (60, 221, 123, 358), 0.0), # 分离
((76, 140, 154, 277), (121, 326, 192, 384), 0.0), # 分离
......@@ -276,7 +288,7 @@ def test_calculate_iou(box1: tuple, box2: tuple, target_num: float) -> None:
# 计算box1和box2的重叠面积占最小面积的box的比例
@pytest.mark.parametrize("box1, box2, target_num", [
@pytest.mark.parametrize('box1, box2, target_num', [
# (None, None, "Error"), # Error
((142, 109, 238, 164), (134, 211, 224, 270), 0.0), # 分离
((88, 81, 222, 173), (60, 221, 123, 358), 0.0), # 分离
......@@ -295,7 +307,7 @@ def test_calculate_overlap_area_2_minbox_area_ratio(box1: tuple, box2: tuple, ta
# 计算box1和box2的重叠面积占bbox1的比例
@pytest.mark.parametrize("box1, box2, target_num", [
@pytest.mark.parametrize('box1, box2, target_num', [
# (None, None, "Error"), # Error
((142, 109, 238, 164), (134, 211, 224, 270), 0.0), # 分离
((88, 81, 222, 173), (60, 221, 123, 358), 0.0), # 分离
......@@ -315,7 +327,7 @@ def test_calculate_overlap_area_in_bbox1_area_ratio(box1: tuple, box2: tuple, ta
# 计算两个bbox重叠的面积占最小面积的box的比例,如果比例大于ratio,则返回小的那个bbox,否则返回None
@pytest.mark.parametrize("box1, box2, ratio, target_box", [
@pytest.mark.parametrize('box1, box2, ratio, target_box', [
# (None, None, 0.8, "Error"), # Error
((142, 109, 238, 164), (134, 211, 224, 270), 0.0, None), # 分离
((109, 126, 204, 245), (110, 127, 232, 206), 0.5, (110, 127, 232, 206)),
......@@ -331,7 +343,7 @@ def test_get_minbox_if_overlap_by_ratio(box1: tuple, box2: tuple, ratio: float,
# 根据boundry获取在这个范围内的所有的box的列表,完全包含关系
@pytest.mark.parametrize("boxes, boundry, target_boxs", [
@pytest.mark.parametrize('boxes, boundary, target_boxs', [
# ([], (), "Error"), # Error
([], (110, 340, 209, 387), []),
([(142, 109, 238, 164)], (134, 211, 224, 270), []), # 分离
......@@ -347,15 +359,15 @@ def test_get_minbox_if_overlap_by_ratio(box1: tuple, box2: tuple, ratio: float,
(137, 29, 287, 87)], (30, 20, 200, 320),
[(81, 280, 123, 315), (46, 99, 133, 148), (33, 156, 97, 211)]),
])
def test_get_bbox_in_boundry(boxes: list, boundry: tuple, target_boxs: list) -> None:
assert target_boxs == get_bbox_in_boundry(boxes, boundry)
def test_get_bbox_in_boundary(boxes: list, boundary: tuple, target_boxs: list) -> None:
assert target_boxs == get_bbox_in_boundary(boxes, boundary)
# 寻找上方距离最近的box,margin 4个单位, x方向有重合,y方向最近的
@pytest.mark.parametrize("pymu_blocks, obj_box, target_boxs", [
([{"bbox": (81, 280, 123, 315)}, {"bbox": (282, 203, 342, 247)}, {"bbox": (183, 100, 300, 155)},
{"bbox": (46, 99, 133, 148)}, {"bbox": (33, 156, 97, 211)},
{"bbox": (137, 29, 287, 87)}], (81, 280, 123, 315), {"bbox": (33, 156, 97, 211)}),
@pytest.mark.parametrize('pymu_blocks, obj_box, target_boxs', [
([{'bbox': (81, 280, 123, 315)}, {'bbox': (282, 203, 342, 247)}, {'bbox': (183, 100, 300, 155)},
{'bbox': (46, 99, 133, 148)}, {'bbox': (33, 156, 97, 211)},
{'bbox': (137, 29, 287, 87)}], (81, 280, 123, 315), {'bbox': (33, 156, 97, 211)}),
# ([{"bbox": (168, 120, 263, 159)},
# {"bbox": (231, 61, 279, 159)},
# {"bbox": (35, 85, 136, 110)},
......@@ -363,46 +375,46 @@ def test_get_bbox_in_boundry(boxes: list, boundry: tuple, target_boxs: list) ->
# {"bbox": (144, 264, 188, 323)},
# {"bbox": (62, 37, 126, 64)}], (228, 193, 347, 225),
# [{"bbox": (168, 120, 263, 159)}, {"bbox": (231, 61, 279, 159)}]), # y:方向最近的有两个,x: 两个均有重合 Error
([{"bbox": (35, 85, 136, 159)},
{"bbox": (168, 120, 263, 159)},
{"bbox": (231, 61, 279, 118)},
{"bbox": (228, 193, 347, 225)},
{"bbox": (144, 264, 188, 323)},
{"bbox": (62, 37, 126, 64)}], (228, 193, 347, 225),
{"bbox": (168, 120, 263, 159)},), # y:方向最近的有两个,x:只有一个有重合
([{"bbox": (239, 115, 379, 167)},
{"bbox": (33, 237, 104, 262)},
{"bbox": (124, 288, 168, 325)},
{"bbox": (242, 291, 379, 340)},
{"bbox": (55, 117, 121, 154)},
{"bbox": (266, 183, 384, 217)}, ], (124, 288, 168, 325), {'bbox': (55, 117, 121, 154)}),
([{"bbox": (239, 115, 379, 167)},
{"bbox": (33, 237, 104, 262)},
{"bbox": (124, 288, 168, 325)},
{"bbox": (242, 291, 379, 340)},
{"bbox": (55, 117, 119, 154)},
{"bbox": (266, 183, 384, 217)}, ], (124, 288, 168, 325), None), # x没有重合
([{"bbox": (80, 90, 249, 200)},
{"bbox": (183, 100, 240, 155)}, ], (183, 100, 240, 155), None), # 包含
([{'bbox': (35, 85, 136, 159)},
{'bbox': (168, 120, 263, 159)},
{'bbox': (231, 61, 279, 118)},
{'bbox': (228, 193, 347, 225)},
{'bbox': (144, 264, 188, 323)},
{'bbox': (62, 37, 126, 64)}], (228, 193, 347, 225),
{'bbox': (168, 120, 263, 159)},), # y:方向最近的有两个,x:只有一个有重合
([{'bbox': (239, 115, 379, 167)},
{'bbox': (33, 237, 104, 262)},
{'bbox': (124, 288, 168, 325)},
{'bbox': (242, 291, 379, 340)},
{'bbox': (55, 117, 121, 154)},
{'bbox': (266, 183, 384, 217)}, ], (124, 288, 168, 325), {'bbox': (55, 117, 121, 154)}),
([{'bbox': (239, 115, 379, 167)},
{'bbox': (33, 237, 104, 262)},
{'bbox': (124, 288, 168, 325)},
{'bbox': (242, 291, 379, 340)},
{'bbox': (55, 117, 119, 154)},
{'bbox': (266, 183, 384, 217)}, ], (124, 288, 168, 325), None), # x没有重合
([{'bbox': (80, 90, 249, 200)},
{'bbox': (183, 100, 240, 155)}, ], (183, 100, 240, 155), None), # 包含
])
def test_find_top_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_boxs: dict) -> None:
assert target_boxs == find_top_nearest_text_bbox(pymu_blocks, obj_box)
# 寻找下方距离自己最近的box, x方向有重合,y方向最近的
@pytest.mark.parametrize("pymu_blocks, obj_box, target_boxs", [
([{"bbox": (165, 96, 300, 114)},
{"bbox": (11, 157, 139, 201)},
{"bbox": (124, 208, 265, 262)},
{"bbox": (124, 283, 248, 306)},
{"bbox": (39, 267, 84, 301)},
{"bbox": (36, 89, 114, 145)}, ], (165, 96, 300, 114), {"bbox": (124, 208, 265, 262)}),
([{"bbox": (187, 37, 303, 49)},
{"bbox": (2, 227, 90, 283)},
{"bbox": (158, 174, 200, 212)},
{"bbox": (259, 174, 324, 228)},
{"bbox": (205, 61, 316, 97)},
{"bbox": (295, 248, 374, 287)}, ], (205, 61, 316, 97), {"bbox": (259, 174, 324, 228)}), # y有两个最近的, x只有一个重合
@pytest.mark.parametrize('pymu_blocks, obj_box, target_boxs', [
([{'bbox': (165, 96, 300, 114)},
{'bbox': (11, 157, 139, 201)},
{'bbox': (124, 208, 265, 262)},
{'bbox': (124, 283, 248, 306)},
{'bbox': (39, 267, 84, 301)},
{'bbox': (36, 89, 114, 145)}, ], (165, 96, 300, 114), {'bbox': (124, 208, 265, 262)}),
([{'bbox': (187, 37, 303, 49)},
{'bbox': (2, 227, 90, 283)},
{'bbox': (158, 174, 200, 212)},
{'bbox': (259, 174, 324, 228)},
{'bbox': (205, 61, 316, 97)},
{'bbox': (295, 248, 374, 287)}, ], (205, 61, 316, 97), {'bbox': (259, 174, 324, 228)}), # y有两个最近的, x只有一个重合
# ([{"bbox": (187, 37, 303, 49)},
# {"bbox": (2, 227, 90, 283)},
# {"bbox": (259, 174, 324, 228)},
......@@ -410,31 +422,31 @@ def test_find_top_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_bo
# {"bbox": (295, 248, 374, 287)},
# {"bbox": (158, 174, 209, 212)}, ], (205, 61, 316, 97),
# [{"bbox": (259, 174, 324, 228)}, {"bbox": (158, 174, 209, 212)}]), # x有重合,y有两个最近的 Error
([{"bbox": (287, 132, 398, 191)},
{"bbox": (44, 141, 163, 188)},
{"bbox": (132, 191, 240, 241)},
{"bbox": (81, 25, 142, 67)},
{"bbox": (74, 297, 116, 314)},
{"bbox": (77, 84, 224, 107)}, ], (287, 132, 398, 191), None), # x没有重合
([{"bbox": (80, 90, 249, 200)},
{"bbox": (183, 100, 240, 155)}, ], (183, 100, 240, 155), None), # 包含
([{'bbox': (287, 132, 398, 191)},
{'bbox': (44, 141, 163, 188)},
{'bbox': (132, 191, 240, 241)},
{'bbox': (81, 25, 142, 67)},
{'bbox': (74, 297, 116, 314)},
{'bbox': (77, 84, 224, 107)}, ], (287, 132, 398, 191), None), # x没有重合
([{'bbox': (80, 90, 249, 200)},
{'bbox': (183, 100, 240, 155)}, ], (183, 100, 240, 155), None), # 包含
])
def test_find_bottom_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_boxs: dict) -> None:
assert target_boxs == find_bottom_nearest_text_bbox(pymu_blocks, obj_box)
# 寻找左侧距离自己最近的box, y方向有重叠,x方向最近
@pytest.mark.parametrize("pymu_blocks, obj_box, target_boxs", [
([{"bbox": (80, 90, 249, 200)}, {"bbox": (183, 100, 240, 155)}], (183, 100, 240, 155), None), # 包含
([{"bbox": (28, 90, 77, 126)}, {"bbox": (35, 84, 84, 120)}], (35, 84, 84, 120), None), # y:重叠,x:重叠大于2
([{"bbox": (28, 90, 77, 126)}, {"bbox": (75, 84, 134, 120)}], (75, 84, 134, 120), {"bbox": (28, 90, 77, 126)}),
@pytest.mark.parametrize('pymu_blocks, obj_box, target_boxs', [
([{'bbox': (80, 90, 249, 200)}, {'bbox': (183, 100, 240, 155)}], (183, 100, 240, 155), None), # 包含
([{'bbox': (28, 90, 77, 126)}, {'bbox': (35, 84, 84, 120)}], (35, 84, 84, 120), None), # y:重叠,x:重叠大于2
([{'bbox': (28, 90, 77, 126)}, {'bbox': (75, 84, 134, 120)}], (75, 84, 134, 120), {'bbox': (28, 90, 77, 126)}),
# y:重叠,x:重叠小于等于2
([{"bbox": (239, 115, 379, 167)},
{"bbox": (33, 237, 104, 262)},
{"bbox": (124, 288, 168, 325)},
{"bbox": (242, 291, 379, 340)},
{"bbox": (55, 113, 161, 154)},
{"bbox": (266, 123, 384, 217)}], (266, 123, 384, 217), {"bbox": (55, 113, 161, 154)}), # y重叠,x left
([{'bbox': (239, 115, 379, 167)},
{'bbox': (33, 237, 104, 262)},
{'bbox': (124, 288, 168, 325)},
{'bbox': (242, 291, 379, 340)},
{'bbox': (55, 113, 161, 154)},
{'bbox': (266, 123, 384, 217)}], (266, 123, 384, 217), {'bbox': (55, 113, 161, 154)}), # y重叠,x left
# ([{"bbox": (136, 219, 268, 240)},
# {"bbox": (169, 115, 268, 181)},
# {"bbox": (33, 237, 104, 262)},
......@@ -448,17 +460,17 @@ def test_find_left_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_b
# 寻找右侧距离自己最近的box, y方向有重叠,x方向最近
@pytest.mark.parametrize("pymu_blocks, obj_box, target_boxs", [
([{"bbox": (80, 90, 249, 200)}, {"bbox": (183, 100, 240, 155)}], (183, 100, 240, 155), None), # 包含
([{"bbox": (28, 90, 77, 126)}, {"bbox": (35, 84, 84, 120)}], (28, 90, 77, 126), None), # y:重叠,x:重叠大于2
([{"bbox": (28, 90, 77, 126)}, {"bbox": (75, 84, 134, 120)}], (28, 90, 77, 126), {"bbox": (75, 84, 134, 120)}),
@pytest.mark.parametrize('pymu_blocks, obj_box, target_boxs', [
([{'bbox': (80, 90, 249, 200)}, {'bbox': (183, 100, 240, 155)}], (183, 100, 240, 155), None), # 包含
([{'bbox': (28, 90, 77, 126)}, {'bbox': (35, 84, 84, 120)}], (28, 90, 77, 126), None), # y:重叠,x:重叠大于2
([{'bbox': (28, 90, 77, 126)}, {'bbox': (75, 84, 134, 120)}], (28, 90, 77, 126), {'bbox': (75, 84, 134, 120)}),
# y:重叠,x:重叠小于等于2
([{"bbox": (239, 115, 379, 167)},
{"bbox": (33, 237, 104, 262)},
{"bbox": (124, 288, 168, 325)},
{"bbox": (242, 291, 379, 340)},
{"bbox": (55, 113, 161, 154)},
{"bbox": (266, 123, 384, 217)}], (55, 113, 161, 154), {"bbox": (239, 115, 379, 167)}), # y重叠,x right
([{'bbox': (239, 115, 379, 167)},
{'bbox': (33, 237, 104, 262)},
{'bbox': (124, 288, 168, 325)},
{'bbox': (242, 291, 379, 340)},
{'bbox': (55, 113, 161, 154)},
{'bbox': (266, 123, 384, 217)}], (55, 113, 161, 154), {'bbox': (239, 115, 379, 167)}), # y重叠,x right
# ([{"bbox": (169, 115, 298, 181)},
# {"bbox": (169, 219, 268, 240)},
# {"bbox": (33, 177, 104, 262)},
......@@ -472,7 +484,7 @@ def test_find_right_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_
# 判断两个矩形框的相对位置关系 (left, right, bottom, top)
@pytest.mark.parametrize("box1, box2, target_box", [
@pytest.mark.parametrize('box1, box2, target_box', [
# (None, None, "Error"), # Error
((80, 90, 249, 200), (183, 100, 240, 155), (False, False, False, False)), # 包含
# ((124, 81, 222, 173), (60, 221, 123, 358), (False, True, False, True)), # 分离,右上 Error
......@@ -494,7 +506,7 @@ def test_bbox_relative_pos(box1: tuple, box2: tuple, target_box: tuple) -> None:
"""
@pytest.mark.parametrize("box1, box2, target_num", [
@pytest.mark.parametrize('box1, box2, target_num', [
# (None, None, "Error"), # Error
((80, 90, 249, 200), (183, 100, 240, 155), 0.0), # 包含
((142, 109, 238, 164), (134, 211, 224, 270), 47.0), # 分离,上
......@@ -514,7 +526,8 @@ def test_bbox_relative_pos(box1: tuple, box2: tuple, target_box: tuple) -> None:
def test_bbox_distance(box1: tuple, box2: tuple, target_num: float) -> None:
assert target_num - bbox_distance(box1, box2) < 1
@pytest.mark.skip(reason="skip")
@pytest.mark.skip(reason='skip')
# 根据bucket_name获取s3配置ak,sk,endpoint
def test_get_s3_config() -> None:
bucket_name = os.getenv('bucket_name')
......@@ -522,7 +535,6 @@ def test_get_s3_config() -> None:
assert convert_string_to_list(target_data) == list(get_s3_config(bucket_name))
def convert_string_to_list(s):
cleaned_s = s.strip("'")
items = cleaned_s.split(',')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment