Unverified Commit 03469909 authored by icecraft's avatar icecraft Committed by GitHub

Feat/support footnote in figure (#532)

* feat: support figure footnote

* feat: using the relative position to combine footnote, table, image

* feat: add the readme of projects

* fix: code spell in unittest

---------
Co-authored-by: 's avataricecraft <xurui1@pjlab.org.cn>
parent 4331b837
...@@ -3,6 +3,7 @@ repos: ...@@ -3,6 +3,7 @@ repos:
rev: 5.0.4 rev: 5.0.4
hooks: hooks:
- id: flake8 - id: flake8
args: ["--max-line-length=120", "--ignore=E131,E125,W503,W504,E203"]
- repo: https://github.com/PyCQA/isort - repo: https://github.com/PyCQA/isort
rev: 5.11.5 rev: 5.11.5
hooks: hooks:
...@@ -11,6 +12,7 @@ repos: ...@@ -11,6 +12,7 @@ repos:
rev: v0.32.0 rev: v0.32.0
hooks: hooks:
- id: yapf - id: yapf
args: ["--style={based_on_style: google, column_limit: 120, indent_width: 4}"]
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.2.1 rev: v2.2.1
hooks: hooks:
...@@ -41,4 +43,4 @@ repos: ...@@ -41,4 +43,4 @@ repos:
rev: v1.3.1 rev: v1.3.1
hooks: hooks:
- id: docformatter - id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"] args: ["--in-place", "--wrap-descriptions", "119"]
import re
import wordninja
from loguru import logger from loguru import logger
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.libs.commons import join_path from magic_pdf.libs.commons import join_path
from magic_pdf.libs.language import detect_lang from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import ContentType, BlockType from magic_pdf.libs.ocr_content_type import BlockType, ContentType
import wordninja
import re
def __is_hyphen_at_line_end(line): def __is_hyphen_at_line_end(line):
...@@ -37,8 +38,9 @@ def split_long_words(text): ...@@ -37,8 +38,9 @@ def split_long_words(text):
def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path): def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
markdown = [] markdown = []
for page_info in pdf_info_list: for page_info in pdf_info_list:
paras_of_layout = page_info.get("para_blocks") paras_of_layout = page_info.get('para_blocks')
page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "mm", img_buket_path) page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
markdown.extend(page_markdown) markdown.extend(page_markdown)
return '\n\n'.join(markdown) return '\n\n'.join(markdown)
...@@ -46,29 +48,34 @@ def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path): ...@@ -46,29 +48,34 @@ def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
def ocr_mk_nlp_markdown_with_para(pdf_info_dict: list): def ocr_mk_nlp_markdown_with_para(pdf_info_dict: list):
markdown = [] markdown = []
for page_info in pdf_info_dict: for page_info in pdf_info_dict:
paras_of_layout = page_info.get("para_blocks") paras_of_layout = page_info.get('para_blocks')
page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "nlp") page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'nlp')
markdown.extend(page_markdown) markdown.extend(page_markdown)
return '\n\n'.join(markdown) return '\n\n'.join(markdown)
def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list, img_buket_path): def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
img_buket_path):
markdown_with_para_and_pagination = [] markdown_with_para_and_pagination = []
page_no = 0 page_no = 0
for page_info in pdf_info_dict: for page_info in pdf_info_dict:
paras_of_layout = page_info.get("para_blocks") paras_of_layout = page_info.get('para_blocks')
if not paras_of_layout: if not paras_of_layout:
continue continue
page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "mm", img_buket_path) page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
markdown_with_para_and_pagination.append({ markdown_with_para_and_pagination.append({
'page_no': page_no, 'page_no':
'md_content': '\n\n'.join(page_markdown) page_no,
'md_content':
'\n\n'.join(page_markdown)
}) })
page_no += 1 page_no += 1
return markdown_with_para_and_pagination return markdown_with_para_and_pagination
def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=""): def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=''):
page_markdown = [] page_markdown = []
for paras in paras_of_layout: for paras in paras_of_layout:
for para in paras: for para in paras:
...@@ -81,8 +88,9 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=""): ...@@ -81,8 +88,9 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=""):
if span_type == ContentType.Text: if span_type == ContentType.Text:
content = span['content'] content = span['content']
language = detect_lang(content) language = detect_lang(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本 if (language == 'en'): # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(split_long_words(content)) content = ocr_escape_special_markdown_char(
split_long_words(content))
else: else:
content = ocr_escape_special_markdown_char(content) content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation: elif span_type == ContentType.InlineEquation:
...@@ -106,7 +114,9 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=""): ...@@ -106,7 +114,9 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=""):
return page_markdown return page_markdown
def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""): def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
mode,
img_buket_path=''):
page_markdown = [] page_markdown = []
for para_block in paras_of_layout: for para_block in paras_of_layout:
para_text = '' para_text = ''
...@@ -114,7 +124,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""): ...@@ -114,7 +124,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
if para_type == BlockType.Text: if para_type == BlockType.Text:
para_text = merge_para_with_text(para_block) para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Title: elif para_type == BlockType.Title:
para_text = f"# {merge_para_with_text(para_block)}" para_text = f'# {merge_para_with_text(para_block)}'
elif para_type == BlockType.InterlineEquation: elif para_type == BlockType.InterlineEquation:
para_text = merge_para_with_text(para_block) para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Image: elif para_type == BlockType.Image:
...@@ -130,11 +140,13 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""): ...@@ -130,11 +140,13 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
for block in para_block['blocks']: # 2nd.拼image_caption for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageCaption: if block['type'] == BlockType.ImageCaption:
para_text += merge_para_with_text(block) para_text += merge_para_with_text(block)
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageFootnote:
para_text += merge_para_with_text(block)
elif para_type == BlockType.Table: elif para_type == BlockType.Table:
if mode == 'nlp': if mode == 'nlp':
continue continue
elif mode == 'mm': elif mode == 'mm':
table_caption = ''
for block in para_block['blocks']: # 1st.拼table_caption for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TableCaption: if block['type'] == BlockType.TableCaption:
para_text += merge_para_with_text(block) para_text += merge_para_with_text(block)
...@@ -163,6 +175,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""): ...@@ -163,6 +175,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
def merge_para_with_text(para_block): def merge_para_with_text(para_block):
def detect_language(text): def detect_language(text):
en_pattern = r'[a-zA-Z]+' en_pattern = r'[a-zA-Z]+'
en_matches = re.findall(en_pattern, text) en_matches = re.findall(en_pattern, text)
...@@ -171,19 +184,19 @@ def merge_para_with_text(para_block): ...@@ -171,19 +184,19 @@ def merge_para_with_text(para_block):
if en_length / len(text) >= 0.5: if en_length / len(text) >= 0.5:
return 'en' return 'en'
else: else:
return "unknown" return 'unknown'
else: else:
return "empty" return 'empty'
para_text = '' para_text = ''
for line in para_block['lines']: for line in para_block['lines']:
line_text = "" line_text = ''
line_lang = "" line_lang = ''
for span in line['spans']: for span in line['spans']:
span_type = span['type'] span_type = span['type']
if span_type == ContentType.Text: if span_type == ContentType.Text:
line_text += span['content'].strip() line_text += span['content'].strip()
if line_text != "": if line_text != '':
line_lang = detect_lang(line_text) line_lang = detect_lang(line_text)
for span in line['spans']: for span in line['spans']:
span_type = span['type'] span_type = span['type']
...@@ -193,7 +206,8 @@ def merge_para_with_text(para_block): ...@@ -193,7 +206,8 @@ def merge_para_with_text(para_block):
# language = detect_lang(content) # language = detect_lang(content)
language = detect_language(content) language = detect_language(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本 if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(split_long_words(content)) content = ocr_escape_special_markdown_char(
split_long_words(content))
else: else:
content = ocr_escape_special_markdown_char(content) content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation: elif span_type == ContentType.InlineEquation:
...@@ -227,12 +241,13 @@ def para_to_standard_format(para, img_buket_path): ...@@ -227,12 +241,13 @@ def para_to_standard_format(para, img_buket_path):
for span in line['spans']: for span in line['spans']:
language = '' language = ''
span_type = span.get('type') span_type = span.get('type')
content = "" content = ''
if span_type == ContentType.Text: if span_type == ContentType.Text:
content = span['content'] content = span['content']
language = detect_lang(content) language = detect_lang(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本 if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(split_long_words(content)) content = ocr_escape_special_markdown_char(
split_long_words(content))
else: else:
content = ocr_escape_special_markdown_char(content) content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation: elif span_type == ContentType.InlineEquation:
...@@ -245,7 +260,7 @@ def para_to_standard_format(para, img_buket_path): ...@@ -245,7 +260,7 @@ def para_to_standard_format(para, img_buket_path):
para_content = { para_content = {
'type': 'text', 'type': 'text',
'text': para_text, 'text': para_text,
'inline_equation_num': inline_equation_num 'inline_equation_num': inline_equation_num,
} }
return para_content return para_content
...@@ -256,37 +271,35 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx): ...@@ -256,37 +271,35 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
para_content = { para_content = {
'type': 'text', 'type': 'text',
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
'page_idx': page_idx 'page_idx': page_idx,
} }
elif para_type == BlockType.Title: elif para_type == BlockType.Title:
para_content = { para_content = {
'type': 'text', 'type': 'text',
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
'text_level': 1, 'text_level': 1,
'page_idx': page_idx 'page_idx': page_idx,
} }
elif para_type == BlockType.InterlineEquation: elif para_type == BlockType.InterlineEquation:
para_content = { para_content = {
'type': 'equation', 'type': 'equation',
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
'text_format': "latex", 'text_format': 'latex',
'page_idx': page_idx 'page_idx': page_idx,
} }
elif para_type == BlockType.Image: elif para_type == BlockType.Image:
para_content = { para_content = {'type': 'image', 'page_idx': page_idx}
'type': 'image',
'page_idx': page_idx
}
for block in para_block['blocks']: for block in para_block['blocks']:
if block['type'] == BlockType.ImageBody: if block['type'] == BlockType.ImageBody:
para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path']) para_content['img_path'] = join_path(
img_buket_path,
block['lines'][0]['spans'][0]['image_path'])
if block['type'] == BlockType.ImageCaption: if block['type'] == BlockType.ImageCaption:
para_content['img_caption'] = merge_para_with_text(block) para_content['img_caption'] = merge_para_with_text(block)
if block['type'] == BlockType.ImageFootnote:
para_content['img_footnote'] = merge_para_with_text(block)
elif para_type == BlockType.Table: elif para_type == BlockType.Table:
para_content = { para_content = {'type': 'table', 'page_idx': page_idx}
'type': 'table',
'page_idx': page_idx
}
for block in para_block['blocks']: for block in para_block['blocks']:
if block['type'] == BlockType.TableBody: if block['type'] == BlockType.TableBody:
if block["lines"][0]["spans"][0].get('latex', ''): if block["lines"][0]["spans"][0].get('latex', ''):
...@@ -305,17 +318,18 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx): ...@@ -305,17 +318,18 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
def make_standard_format_with_para(pdf_info_dict: list, img_buket_path: str): def make_standard_format_with_para(pdf_info_dict: list, img_buket_path: str):
content_list = [] content_list = []
for page_info in pdf_info_dict: for page_info in pdf_info_dict:
paras_of_layout = page_info.get("para_blocks") paras_of_layout = page_info.get('para_blocks')
if not paras_of_layout: if not paras_of_layout:
continue continue
for para_block in paras_of_layout: for para_block in paras_of_layout:
para_content = para_to_standard_format_v2(para_block, img_buket_path) para_content = para_to_standard_format_v2(para_block,
img_buket_path)
content_list.append(para_content) content_list.append(para_content)
return content_list return content_list
def line_to_standard_format(line, img_buket_path): def line_to_standard_format(line, img_buket_path):
line_text = "" line_text = ''
inline_equation_num = 0 inline_equation_num = 0
for span in line['spans']: for span in line['spans']:
if not span.get('content'): if not span.get('content'):
...@@ -325,13 +339,15 @@ def line_to_standard_format(line, img_buket_path): ...@@ -325,13 +339,15 @@ def line_to_standard_format(line, img_buket_path):
if span['type'] == ContentType.Image: if span['type'] == ContentType.Image:
content = { content = {
'type': 'image', 'type': 'image',
'img_path': join_path(img_buket_path, span['image_path']) 'img_path': join_path(img_buket_path,
span['image_path']),
} }
return content return content
elif span['type'] == ContentType.Table: elif span['type'] == ContentType.Table:
content = { content = {
'type': 'table', 'type': 'table',
'img_path': join_path(img_buket_path, span['image_path']) 'img_path': join_path(img_buket_path,
span['image_path']),
} }
return content return content
else: else:
...@@ -339,36 +355,33 @@ def line_to_standard_format(line, img_buket_path): ...@@ -339,36 +355,33 @@ def line_to_standard_format(line, img_buket_path):
interline_equation = span['content'] interline_equation = span['content']
content = { content = {
'type': 'equation', 'type': 'equation',
'latex': f"$$\n{interline_equation}\n$$" 'latex': f'$$\n{interline_equation}\n$$'
} }
return content return content
elif span['type'] == ContentType.InlineEquation: elif span['type'] == ContentType.InlineEquation:
inline_equation = span['content'] inline_equation = span['content']
line_text += f"${inline_equation}$" line_text += f'${inline_equation}$'
inline_equation_num += 1 inline_equation_num += 1
elif span['type'] == ContentType.Text: elif span['type'] == ContentType.Text:
text_content = ocr_escape_special_markdown_char(span['content']) # 转义特殊符号 text_content = ocr_escape_special_markdown_char(
span['content']) # 转义特殊符号
line_text += text_content line_text += text_content
content = { content = {
'type': 'text', 'type': 'text',
'text': line_text, 'text': line_text,
'inline_equation_num': inline_equation_num 'inline_equation_num': inline_equation_num,
} }
return content return content
def ocr_mk_mm_standard_format(pdf_info_dict: list): def ocr_mk_mm_standard_format(pdf_info_dict: list):
""" """content_list type string
content_list image/text/table/equation(行间的单独拿出来,行内的和text合并) latex string
type string image/text/table/equation(行间的单独拿出来,行内的和text合并) latex文本字段。 text string 纯文本格式的文本数据。 md string
latex string latex文本字段。 markdown格式的文本数据。 img_path string s3://full/path/to/img.jpg."""
text string 纯文本格式的文本数据。
md string markdown格式的文本数据。
img_path string s3://full/path/to/img.jpg
"""
content_list = [] content_list = []
for page_info in pdf_info_dict: for page_info in pdf_info_dict:
blocks = page_info.get("preproc_blocks") blocks = page_info.get('preproc_blocks')
if not blocks: if not blocks:
continue continue
for block in blocks: for block in blocks:
...@@ -378,34 +391,42 @@ def ocr_mk_mm_standard_format(pdf_info_dict: list): ...@@ -378,34 +391,42 @@ def ocr_mk_mm_standard_format(pdf_info_dict: list):
return content_list return content_list
def union_make(pdf_info_dict: list, make_mode: str, drop_mode: str, img_buket_path: str = ""): def union_make(pdf_info_dict: list,
make_mode: str,
drop_mode: str,
img_buket_path: str = ''):
output_content = [] output_content = []
for page_info in pdf_info_dict: for page_info in pdf_info_dict:
if page_info.get("need_drop", False): if page_info.get('need_drop', False):
drop_reason = page_info.get("drop_reason") drop_reason = page_info.get('drop_reason')
if drop_mode == DropMode.NONE: if drop_mode == DropMode.NONE:
pass pass
elif drop_mode == DropMode.WHOLE_PDF: elif drop_mode == DropMode.WHOLE_PDF:
raise Exception(f"drop_mode is {DropMode.WHOLE_PDF} , drop_reason is {drop_reason}") raise Exception((f'drop_mode is {DropMode.WHOLE_PDF} ,'
f'drop_reason is {drop_reason}'))
elif drop_mode == DropMode.SINGLE_PAGE: elif drop_mode == DropMode.SINGLE_PAGE:
logger.warning(f"drop_mode is {DropMode.SINGLE_PAGE} , drop_reason is {drop_reason}") logger.warning((f'drop_mode is {DropMode.SINGLE_PAGE} ,'
f'drop_reason is {drop_reason}'))
continue continue
else: else:
raise Exception(f"drop_mode can not be null") raise Exception('drop_mode can not be null')
paras_of_layout = page_info.get("para_blocks") paras_of_layout = page_info.get('para_blocks')
page_idx = page_info.get("page_idx") page_idx = page_info.get('page_idx')
if not paras_of_layout: if not paras_of_layout:
continue continue
if make_mode == MakeMode.MM_MD: if make_mode == MakeMode.MM_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "mm", img_buket_path) page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
output_content.extend(page_markdown) output_content.extend(page_markdown)
elif make_mode == MakeMode.NLP_MD: elif make_mode == MakeMode.NLP_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "nlp") page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'nlp')
output_content.extend(page_markdown) output_content.extend(page_markdown)
elif make_mode == MakeMode.STANDARD_FORMAT: elif make_mode == MakeMode.STANDARD_FORMAT:
for para_block in paras_of_layout: for para_block in paras_of_layout:
para_content = para_to_standard_format_v2(para_block, img_buket_path, page_idx) para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx)
output_content.append(para_content) output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]: if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content) return '\n\n'.join(output_content)
......
""" """对pdf上的box进行layout识别,并对内部组成的box进行排序."""
对pdf上的box进行layout识别,并对内部组成的box进行排序
"""
from loguru import logger from loguru import logger
from magic_pdf.layout.bbox_sort import CONTENT_IDX, CONTENT_TYPE_IDX, X0_EXT_IDX, X0_IDX, X1_EXT_IDX, X1_IDX, Y0_EXT_IDX, Y0_IDX, Y1_EXT_IDX, Y1_IDX, paper_bbox_sort
from magic_pdf.layout.layout_det_utils import find_all_left_bbox_direct, find_all_right_bbox_direct, find_bottom_bbox_direct_from_left_edge, find_bottom_bbox_direct_from_right_edge, find_top_bbox_direct_from_left_edge, find_top_bbox_direct_from_right_edge, find_all_top_bbox_direct, find_all_bottom_bbox_direct, get_left_edge_bboxes, get_right_edge_bboxes
from magic_pdf.libs.boxbase import get_bbox_in_boundry
from magic_pdf.layout.bbox_sort import (CONTENT_IDX, CONTENT_TYPE_IDX,
X0_EXT_IDX, X0_IDX, X1_EXT_IDX, X1_IDX,
Y0_EXT_IDX, Y0_IDX, Y1_EXT_IDX, Y1_IDX,
paper_bbox_sort)
from magic_pdf.layout.layout_det_utils import (
find_all_bottom_bbox_direct, find_all_left_bbox_direct,
find_all_right_bbox_direct, find_all_top_bbox_direct,
find_bottom_bbox_direct_from_left_edge,
find_bottom_bbox_direct_from_right_edge,
find_top_bbox_direct_from_left_edge, find_top_bbox_direct_from_right_edge,
get_left_edge_bboxes, get_right_edge_bboxes)
from magic_pdf.libs.boxbase import get_bbox_in_boundary
LAYOUT_V = 'V'
LAYOUT_H = 'H'
LAYOUT_UNPROC = 'U'
LAYOUT_BAD = 'B'
LAYOUT_V = "V"
LAYOUT_H = "H"
LAYOUT_UNPROC = "U"
LAYOUT_BAD = "B"
def _is_single_line_text(bbox): def _is_single_line_text(bbox):
""" """检查bbox里面的文字是否只有一行."""
检查bbox里面的文字是否只有一行 return True # TODO
"""
return True # TODO
box_type = bbox[CONTENT_TYPE_IDX] box_type = bbox[CONTENT_TYPE_IDX]
if box_type != 'text': if box_type != 'text':
return False return False
paras = bbox[CONTENT_IDX]["paras"] paras = bbox[CONTENT_IDX]['paras']
text_content = "" text_content = ''
for para_id, para in paras.items(): # 拼装内部的段落文本 for para_id, para in paras.items(): # 拼装内部的段落文本
is_title = para['is_title'] is_title = para['is_title']
if is_title!=0: if is_title != 0:
text_content += f"## {para['text']}" text_content += f"## {para['text']}"
else: else:
text_content += para["text"] text_content += para['text']
text_content += "\n\n" text_content += '\n\n'
return bbox[CONTENT_TYPE_IDX] == 'text' and len(text_content.split("\n\n")) <= 1
return bbox[CONTENT_TYPE_IDX] == 'text' and len(text_content.split('\n\n')) <= 1
def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
def _horizontal_split(bboxes: list, boundary: tuple, avg_font_size=20) -> list:
""" """
对bboxes进行水平切割 对bboxes进行水平切割
方法是:找到左侧和右侧都没有被直接遮挡的box,然后进行扩展,之后进行切割 方法是:找到左侧和右侧都没有被直接遮挡的box,然后进行扩展,之后进行切割
return: return:
返回几个大的Layout区域 [[x0, y0, x1, y1, "h|u|v"], ], h代表水平,u代表未探测的,v代表垂直布局 返回几个大的Layout区域 [[x0, y0, x1, y1, "h|u|v"], ], h代表水平,u代表未探测的,v代表垂直布局
""" """
sorted_layout_blocks = [] # 这是要最终返回的值 sorted_layout_blocks = [] # 这是要最终返回的值
bound_x0, bound_y0, bound_x1, bound_y1 = boundry bound_x0, bound_y0, bound_x1, bound_y1 = boundary
all_bboxes = get_bbox_in_boundry(bboxes, boundry) all_bboxes = get_bbox_in_boundary(bboxes, boundary)
#all_bboxes = paper_bbox_sort(all_bboxes, abs(bound_x1-bound_x0), abs(bound_y1-bound_x0)) # 大致拍下序, 这个是基于直接遮挡的。 # all_bboxes = paper_bbox_sort(all_bboxes, abs(bound_x1-bound_x0), abs(bound_y1-bound_x0)) # 大致拍下序, 这个是基于直接遮挡的。
""" """
首先在水平方向上扩展独占一行的bbox 首先在水平方向上扩展独占一行的bbox
""" """
last_h_split_line_y1 = bound_y0 #记录下上次的水平分割线 last_h_split_line_y1 = bound_y0 # 记录下上次的水平分割线
for i, bbox in enumerate(all_bboxes): for i, bbox in enumerate(all_bboxes):
left_nearest_bbox = find_all_left_bbox_direct(bbox, all_bboxes) # 非扩展线 left_nearest_bbox = find_all_left_bbox_direct(bbox, all_bboxes) # 非扩展线
right_nearest_bbox = find_all_right_bbox_direct(bbox, all_bboxes) right_nearest_bbox = find_all_right_bbox_direct(bbox, all_bboxes)
if left_nearest_bbox is None and right_nearest_bbox is None: # 独占一行 if left_nearest_bbox is None and right_nearest_bbox is None: # 独占一行
""" """
然而,如果只是孤立的一行文字,那么就还要满足以下几个条件才可以: 然而,如果只是孤立的一行文字,那么就还要满足以下几个条件才可以:
1. bbox和中心线相交。或者 1. bbox和中心线相交。或者
...@@ -62,16 +68,20 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list: ...@@ -62,16 +68,20 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
3. TODO 加强条件:这个bbox上方和下方是同一列column,那么就不能算作独占一行 3. TODO 加强条件:这个bbox上方和下方是同一列column,那么就不能算作独占一行
""" """
# 先检查这个bbox里是否只包含一行文字 # 先检查这个bbox里是否只包含一行文字
is_single_line = _is_single_line_text(bbox) # is_single_line = _is_single_line_text(bbox)
""" """
这里有个点需要注意,当页面内容不是居中的时候,第一次调用传递的是page的boundry,这个时候mid_x就不是中心线了. 这里有个点需要注意,当页面内容不是居中的时候,第一次调用传递的是page的boundary,这个时候mid_x就不是中心线了.
所以这里计算出最紧致的boundry,然后再计算mid_x 所以这里计算出最紧致的boundary,然后再计算mid_x
""" """
boundry_real_x0, boundry_real_x1 = min([bbox[X0_IDX] for bbox in all_bboxes]), max([bbox[X1_IDX] for bbox in all_bboxes]) boundary_real_x0, boundary_real_x1 = min(
mid_x = (boundry_real_x0+boundry_real_x1)/2 [bbox[X0_IDX] for bbox in all_bboxes]
), max([bbox[X1_IDX] for bbox in all_bboxes])
mid_x = (boundary_real_x0 + boundary_real_x1) / 2
# 检查这个box是否内容在中心线有交 # 检查这个box是否内容在中心线有交
# 必须跨过去2个字符的宽度 # 必须跨过去2个字符的宽度
is_cross_boundry_mid_line = min(mid_x-bbox[X0_IDX], bbox[X1_IDX]-mid_x) > avg_font_size*2 is_cross_boundary_mid_line = (
min(mid_x - bbox[X0_IDX], bbox[X1_IDX] - mid_x) > avg_font_size * 2
)
""" """
检查条件2 检查条件2
""" """
...@@ -84,50 +94,78 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list: ...@@ -84,50 +94,78 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
""" """
以迭代的方式向上找,查找范围是[bound_x0, last_h_sp, bound_x1, bbox[Y0_IDX]] 以迭代的方式向上找,查找范围是[bound_x0, last_h_sp, bound_x1, bbox[Y0_IDX]]
""" """
#先确定上方的y0, y0 # 先确定上方的y0, y0
b_y0, b_y1 = last_h_split_line_y1, bbox[Y0_IDX] b_y0, b_y1 = last_h_split_line_y1, bbox[Y0_IDX]
#然后从box开始逐个向上找到所有与box在x上有交集的box # 然后从box开始逐个向上找到所有与box在x上有交集的box
box_to_check = [bound_x0, b_y0, bound_x1, b_y1] box_to_check = [bound_x0, b_y0, bound_x1, b_y1]
bbox_in_bound_check = get_bbox_in_boundry(all_bboxes, box_to_check) bbox_in_bound_check = get_bbox_in_boundary(all_bboxes, box_to_check)
bboxes_on_top = [] bboxes_on_top = []
virtual_box = bbox virtual_box = bbox
while True: while True:
b_on_top = find_all_top_bbox_direct(virtual_box, bbox_in_bound_check) b_on_top = find_all_top_bbox_direct(virtual_box, bbox_in_bound_check)
if b_on_top is not None: if b_on_top is not None:
bboxes_on_top.append(b_on_top) bboxes_on_top.append(b_on_top)
virtual_box = [min([virtual_box[X0_IDX], b_on_top[X0_IDX]]), min(virtual_box[Y0_IDX], b_on_top[Y0_IDX]), max([virtual_box[X1_IDX], b_on_top[X1_IDX]]), b_y1] virtual_box = [
min([virtual_box[X0_IDX], b_on_top[X0_IDX]]),
min(virtual_box[Y0_IDX], b_on_top[Y0_IDX]),
max([virtual_box[X1_IDX], b_on_top[X1_IDX]]),
b_y1,
]
else: else:
break break
# 随后确定这些box的最小x0, 最大x1 # 随后确定这些box的最小x0, 最大x1
if len(bboxes_on_top)>0 and len(bboxes_on_top) != len(bbox_in_bound_check):# virtual_box可能会膨胀到占满整个区域,这实际上就不能属于一个col了。 if len(bboxes_on_top) > 0 and len(bboxes_on_top) != len(
bbox_in_bound_check
): # virtual_box可能会膨胀到占满整个区域,这实际上就不能属于一个col了。
min_x0, max_x1 = virtual_box[X0_IDX], virtual_box[X1_IDX] min_x0, max_x1 = virtual_box[X0_IDX], virtual_box[X1_IDX]
# 然后采用一种比较粗糙的方法,看min_x0,max_x1是否与位于[bound_x0, last_h_sp, bound_x1, bbox[Y0_IDX]]之间的box有相交 # 然后采用一种比较粗糙的方法,看min_x0,max_x1是否与位于[bound_x0, last_h_sp, bound_x1, bbox[Y0_IDX]]之间的box有相交
if not any([b[X0_IDX] <= min_x0-1 <= b[X1_IDX] or b[X0_IDX] <= max_x1+1 <= b[X1_IDX] for b in bbox_in_bound_check]): if not any(
[
b[X0_IDX] <= min_x0 - 1 <= b[X1_IDX]
or b[X0_IDX] <= max_x1 + 1 <= b[X1_IDX]
for b in bbox_in_bound_check
]
):
# 其上,下都不能被扩展成行,暂时只检查一下上方 TODO # 其上,下都不能被扩展成行,暂时只检查一下上方 TODO
top_nearest_bbox = find_all_top_bbox_direct(bbox, bboxes) top_nearest_bbox = find_all_top_bbox_direct(bbox, bboxes)
bottom_nearest_bbox = find_all_bottom_bbox_direct(bbox, bboxes) bottom_nearest_bbox = find_all_bottom_bbox_direct(bbox, bboxes)
if not any([ if not any(
top_nearest_bbox is not None and (find_all_left_bbox_direct(top_nearest_bbox, bboxes) is None and find_all_right_bbox_direct(top_nearest_bbox, bboxes) is None), [
bottom_nearest_bbox is not None and (find_all_left_bbox_direct(bottom_nearest_bbox, bboxes) is None and find_all_right_bbox_direct(bottom_nearest_bbox, bboxes) is None), top_nearest_bbox is not None
top_nearest_bbox is None or bottom_nearest_bbox is None and (
]): find_all_left_bbox_direct(top_nearest_bbox, bboxes)
is_belong_to_col = True is None
and find_all_right_bbox_direct(top_nearest_bbox, bboxes)
is None
),
bottom_nearest_bbox is not None
and (
find_all_left_bbox_direct(bottom_nearest_bbox, bboxes)
is None
and find_all_right_bbox_direct(
bottom_nearest_bbox, bboxes
)
is None
),
top_nearest_bbox is None or bottom_nearest_bbox is None,
]
):
is_belong_to_col = True
# 检查是否能被下方col吸收 TODO # 检查是否能被下方col吸收 TODO
""" """
这里为什么没有is_cross_boundry_mid_line的条件呢? 这里为什么没有is_cross_boundary_mid_line的条件呢?
确实有些杂志左右两栏宽度不是对称的。 确实有些杂志左右两栏宽度不是对称的。
""" """
if not is_belong_to_col or is_cross_boundry_mid_line: if not is_belong_to_col or is_cross_boundary_mid_line:
bbox[X0_EXT_IDX] = bound_x0 bbox[X0_EXT_IDX] = bound_x0
bbox[Y0_EXT_IDX] = bbox[Y0_IDX] bbox[Y0_EXT_IDX] = bbox[Y0_IDX]
bbox[X1_EXT_IDX] = bound_x1 bbox[X1_EXT_IDX] = bound_x1
bbox[Y1_EXT_IDX] = bbox[Y1_IDX] bbox[Y1_EXT_IDX] = bbox[Y1_IDX]
last_h_split_line_y1 = bbox[Y1_IDX] # 更新这条线 last_h_split_line_y1 = bbox[Y1_IDX] # 更新这条线
else: else:
continue continue
""" """
...@@ -142,13 +180,12 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list: ...@@ -142,13 +180,12 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
if bbox[X0_EXT_IDX] == bound_x0 and bbox[X1_EXT_IDX] == bound_x1: if bbox[X0_EXT_IDX] == bound_x0 and bbox[X1_EXT_IDX] == bound_x1:
h_bbox_group.append(bbox) h_bbox_group.append(bbox)
else: else:
if len(h_bbox_group)>0: if len(h_bbox_group) > 0:
h_bboxes.append(h_bbox_group) h_bboxes.append(h_bbox_group)
h_bbox_group = [] h_bbox_group = []
# 最后一个group # 最后一个group
if len(h_bbox_group)>0: if len(h_bbox_group) > 0:
h_bboxes.append(h_bbox_group) h_bboxes.append(h_bbox_group)
""" """
现在h_bboxes里面是所有的group了,每个group都是一个list 现在h_bboxes里面是所有的group了,每个group都是一个list
对h_bboxes里的每个group进行计算放回到sorted_layouts里 对h_bboxes里的每个group进行计算放回到sorted_layouts里
...@@ -157,52 +194,67 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list: ...@@ -157,52 +194,67 @@ def _horizontal_split(bboxes:list, boundry:tuple, avg_font_size=20)-> list:
for gp in h_bboxes: for gp in h_bboxes:
gp.sort(key=lambda x: x[Y0_IDX]) gp.sort(key=lambda x: x[Y0_IDX])
# 然后计算这个group的layout_bbox,也就是最小的x0,y0, 最大的x1,y1 # 然后计算这个group的layout_bbox,也就是最小的x0,y0, 最大的x1,y1
x0, y0, x1, y1 = gp[0][X0_EXT_IDX], gp[0][Y0_EXT_IDX], gp[-1][X1_EXT_IDX], gp[-1][Y1_EXT_IDX] x0, y0, x1, y1 = (
h_layouts.append([x0, y0, x1, y1, LAYOUT_H]) # 水平的布局 gp[0][X0_EXT_IDX],
gp[0][Y0_EXT_IDX],
gp[-1][X1_EXT_IDX],
gp[-1][Y1_EXT_IDX],
)
h_layouts.append([x0, y0, x1, y1, LAYOUT_H]) # 水平的布局
""" """
接下来利用这些连续的水平bbox的layout_bbox的y0, y1,从水平上切分开其余的为几个部分 接下来利用这些连续的水平bbox的layout_bbox的y0, y1,从水平上切分开其余的为几个部分
""" """
h_split_lines = [bound_y0] h_split_lines = [bound_y0]
for gp in h_bboxes: # gp是一个list[bbox_list] for gp in h_bboxes: # gp是一个list[bbox_list]
y0, y1 = gp[0][1], gp[-1][3] y0, y1 = gp[0][1], gp[-1][3]
h_split_lines.append(y0) h_split_lines.append(y0)
h_split_lines.append(y1) h_split_lines.append(y1)
h_split_lines.append(bound_y1) h_split_lines.append(bound_y1)
unsplited_bboxes = [] unsplited_bboxes = []
for i in range(0, len(h_split_lines), 2): for i in range(0, len(h_split_lines), 2):
start_y0, start_y1 = h_split_lines[i:i+2] start_y0, start_y1 = h_split_lines[i : i + 2]
# 然后找出[start_y0, start_y1]之间的其他bbox,这些组成一个未分割板块 # 然后找出[start_y0, start_y1]之间的其他bbox,这些组成一个未分割板块
bboxes_in_block = [bbox for bbox in all_bboxes if bbox[Y0_IDX]>=start_y0 and bbox[Y1_IDX]<=start_y1] bboxes_in_block = [
bbox
for bbox in all_bboxes
if bbox[Y0_IDX] >= start_y0 and bbox[Y1_IDX] <= start_y1
]
unsplited_bboxes.append(bboxes_in_block) unsplited_bboxes.append(bboxes_in_block)
# 接着把未处理的加入到h_layouts里 # 接着把未处理的加入到h_layouts里
for bboxes_in_block in unsplited_bboxes: for bboxes_in_block in unsplited_bboxes:
if len(bboxes_in_block) == 0: if len(bboxes_in_block) == 0:
continue continue
x0, y0, x1, y1 = bound_x0, min([bbox[Y0_IDX] for bbox in bboxes_in_block]), bound_x1, max([bbox[Y1_IDX] for bbox in bboxes_in_block]) x0, y0, x1, y1 = (
bound_x0,
min([bbox[Y0_IDX] for bbox in bboxes_in_block]),
bound_x1,
max([bbox[Y1_IDX] for bbox in bboxes_in_block]),
)
h_layouts.append([x0, y0, x1, y1, LAYOUT_UNPROC]) h_layouts.append([x0, y0, x1, y1, LAYOUT_UNPROC])
h_layouts.sort(key=lambda x: x[1]) # 按照y0排序, 也就是从上到下的顺序 h_layouts.sort(key=lambda x: x[1]) # 按照y0排序, 也就是从上到下的顺序
""" """
转换成如下格式返回 转换成如下格式返回
""" """
for layout in h_layouts: for layout in h_layouts:
sorted_layout_blocks.append({ sorted_layout_blocks.append(
"layout_bbox": layout[:4], {
"layout_label":layout[4], 'layout_bbox': layout[:4],
"sub_layout":[], 'layout_label': layout[4],
}) 'sub_layout': [],
}
)
return sorted_layout_blocks return sorted_layout_blocks
############################################################################################### ###############################################################################################
# #
# 垂直方向的处理 # 垂直方向的处理
# #
# #
############################################################################################### ###############################################################################################
def _vertical_align_split_v1(bboxes:list, boundry:tuple)-> list: def _vertical_align_split_v1(bboxes: list, boundary: tuple) -> list:
""" """
计算垂直方向上的对齐, 并分割bboxes成layout。负责对一列多行的进行列维度分割。 计算垂直方向上的对齐, 并分割bboxes成layout。负责对一列多行的进行列维度分割。
如果不能完全分割,剩余部分作为layout_lable为u的layout返回 如果不能完全分割,剩余部分作为layout_lable为u的layout返回
...@@ -214,85 +266,124 @@ def _vertical_align_split_v1(bboxes:list, boundry:tuple)-> list: ...@@ -214,85 +266,124 @@ def _vertical_align_split_v1(bboxes:list, boundry:tuple)-> list:
------------------------- -------------------------
此函数会将:以上布局将会切分出来2列 此函数会将:以上布局将会切分出来2列
""" """
sorted_layout_blocks = [] # 这是要最终返回的值 sorted_layout_blocks = [] # 这是要最终返回的值
new_boundry = [boundry[0], boundry[1], boundry[2], boundry[3]] new_boundary = [boundary[0], boundary[1], boundary[2], boundary[3]]
v_blocks = [] v_blocks = []
""" """
先从左到右切分 先从左到右切分
""" """
while True: while True:
all_bboxes = get_bbox_in_boundry(bboxes, new_boundry) all_bboxes = get_bbox_in_boundary(bboxes, new_boundary)
left_edge_bboxes = get_left_edge_bboxes(all_bboxes) left_edge_bboxes = get_left_edge_bboxes(all_bboxes)
if len(left_edge_bboxes) == 0: if len(left_edge_bboxes) == 0:
break break
right_split_line_x1 = max([bbox[X1_IDX] for bbox in left_edge_bboxes])+1 right_split_line_x1 = max([bbox[X1_IDX] for bbox in left_edge_bboxes]) + 1
# 然后检查这条线能不与其他bbox的左边界相交或者重合 # 然后检查这条线能不与其他bbox的左边界相交或者重合
if any([bbox[X0_IDX] <= right_split_line_x1 <= bbox[X1_IDX] for bbox in all_bboxes]): if any(
[bbox[X0_IDX] <= right_split_line_x1 <= bbox[X1_IDX] for bbox in all_bboxes]
):
# 垂直切分线与某些box发生相交,说明无法完全垂直方向切分。 # 垂直切分线与某些box发生相交,说明无法完全垂直方向切分。
break break
else: # 说明成功分割出一列 else: # 说明成功分割出一列
# 找到左侧边界最靠左的bbox作为layout的x0 # 找到左侧边界最靠左的bbox作为layout的x0
layout_x0 = min([bbox[X0_IDX] for bbox in left_edge_bboxes]) # 这里主要是为了画出来有一定间距 layout_x0 = min(
v_blocks.append([layout_x0, new_boundry[1], right_split_line_x1, new_boundry[3], LAYOUT_V]) [bbox[X0_IDX] for bbox in left_edge_bboxes]
new_boundry[0] = right_split_line_x1 # 更新边界 ) # 这里主要是为了画出来有一定间距
v_blocks.append(
[
layout_x0,
new_boundary[1],
right_split_line_x1,
new_boundary[3],
LAYOUT_V,
]
)
new_boundary[0] = right_split_line_x1 # 更新边界
""" """
再从右到左切, 此时如果还是无法完全切分,那么剩余部分作为layout_lable为u的layout返回 再从右到左切, 此时如果还是无法完全切分,那么剩余部分作为layout_lable为u的layout返回
""" """
unsplited_block = [] unsplited_block = []
while True: while True:
all_bboxes = get_bbox_in_boundry(bboxes, new_boundry) all_bboxes = get_bbox_in_boundary(bboxes, new_boundary)
right_edge_bboxes = get_right_edge_bboxes(all_bboxes) right_edge_bboxes = get_right_edge_bboxes(all_bboxes)
if len(right_edge_bboxes) == 0: if len(right_edge_bboxes) == 0:
break break
left_split_line_x0 = min([bbox[X0_IDX] for bbox in right_edge_bboxes])-1 left_split_line_x0 = min([bbox[X0_IDX] for bbox in right_edge_bboxes]) - 1
# 然后检查这条线能不与其他bbox的左边界相交或者重合 # 然后检查这条线能不与其他bbox的左边界相交或者重合
if any([bbox[X0_IDX] <= left_split_line_x0 <= bbox[X1_IDX] for bbox in all_bboxes]): if any(
[bbox[X0_IDX] <= left_split_line_x0 <= bbox[X1_IDX] for bbox in all_bboxes]
):
# 这里是余下的 # 这里是余下的
unsplited_block.append([new_boundry[0], new_boundry[1], new_boundry[2], new_boundry[3], LAYOUT_UNPROC]) unsplited_block.append(
[
new_boundary[0],
new_boundary[1],
new_boundary[2],
new_boundary[3],
LAYOUT_UNPROC,
]
)
break break
else: else:
# 找到右侧边界最靠右的bbox作为layout的x1 # 找到右侧边界最靠右的bbox作为layout的x1
layout_x1 = max([bbox[X1_IDX] for bbox in right_edge_bboxes]) layout_x1 = max([bbox[X1_IDX] for bbox in right_edge_bboxes])
v_blocks.append([left_split_line_x0, new_boundry[1], layout_x1, new_boundry[3], LAYOUT_V]) v_blocks.append(
new_boundry[2] = left_split_line_x0 # 更新右边界 [
left_split_line_x0,
new_boundary[1],
layout_x1,
new_boundary[3],
LAYOUT_V,
]
)
new_boundary[2] = left_split_line_x0 # 更新右边界
""" """
最后拼装成layout格式返回 最后拼装成layout格式返回
""" """
for block in v_blocks: for block in v_blocks:
sorted_layout_blocks.append({ sorted_layout_blocks.append(
"layout_bbox": block[:4], {
"layout_label":block[4], 'layout_bbox': block[:4],
"sub_layout":[], 'layout_label': block[4],
}) 'sub_layout': [],
}
)
for block in unsplited_block: for block in unsplited_block:
sorted_layout_blocks.append({ sorted_layout_blocks.append(
"layout_bbox": block[:4], {
"layout_label":block[4], 'layout_bbox': block[:4],
"sub_layout":[], 'layout_label': block[4],
}) 'sub_layout': [],
}
)
# 按照x0排序 # 按照x0排序
sorted_layout_blocks.sort(key=lambda x: x['layout_bbox'][0]) sorted_layout_blocks.sort(key=lambda x: x['layout_bbox'][0])
return sorted_layout_blocks return sorted_layout_blocks
def _vertical_align_split_v2(bboxes:list, boundry:tuple)-> list:
""" def _vertical_align_split_v2(bboxes: list, boundary: tuple) -> list:
改进的 _vertical_align_split算法,原算法会因为第二列的box由于左侧没有遮挡被认为是左侧的一部分,导致整个layout多列被识别为一列。 """改进的
利用从左上角的box开始向下看的方法,不断扩展w_x0, w_x1,直到不能继续向下扩展,或者到达边界下边界 _vertical_align_split算法,原算法会因为第二列的box由于左侧没有遮挡被认为是左侧的一部分,导致整个layout多列被识别为一列
""" 利用从左上角的box开始向下看的方法,不断扩展w_x0, w_x1,直到不能继续向下扩展,或者到达边界下边界。"""
sorted_layout_blocks = [] # 这是要最终返回的值 sorted_layout_blocks = [] # 这是要最终返回的值
new_boundry = [boundry[0], boundry[1], boundry[2], boundry[3]] new_boundary = [boundary[0], boundary[1], boundary[2], boundary[3]]
bad_boxes = [] # 被割中的box bad_boxes = [] # 被割中的box
v_blocks = [] v_blocks = []
while True: while True:
all_bboxes = get_bbox_in_boundry(bboxes, new_boundry) all_bboxes = get_bbox_in_boundary(bboxes, new_boundary)
if len(all_bboxes) == 0: if len(all_bboxes) == 0:
break break
left_top_box = min(all_bboxes, key=lambda x: (x[X0_IDX],x[Y0_IDX]))# 这里应该加强,检查一下必须是在第一列的 TODO left_top_box = min(
start_box = [left_top_box[X0_IDX], left_top_box[Y0_IDX], left_top_box[X1_IDX], left_top_box[Y1_IDX]] all_bboxes, key=lambda x: (x[X0_IDX], x[Y0_IDX])
) # 这里应该加强,检查一下必须是在第一列的 TODO
start_box = [
left_top_box[X0_IDX],
left_top_box[Y0_IDX],
left_top_box[X1_IDX],
left_top_box[Y1_IDX],
]
w_x0, w_x1 = left_top_box[X0_IDX], left_top_box[X1_IDX] w_x0, w_x1 = left_top_box[X0_IDX], left_top_box[X1_IDX]
""" """
然后沿着这个box线向下找最近的那个box, 然后扩展w_x0, w_x1 然后沿着这个box线向下找最近的那个box, 然后扩展w_x0, w_x1
...@@ -301,98 +392,140 @@ def _vertical_align_split_v2(bboxes:list, boundry:tuple)-> list: ...@@ -301,98 +392,140 @@ def _vertical_align_split_v2(bboxes:list, boundry:tuple)-> list:
1. 达到,那么更新左边界继续分下一个列 1. 达到,那么更新左边界继续分下一个列
2. 没有达到,那么此时开始从右侧切分进入下面的循环里 2. 没有达到,那么此时开始从右侧切分进入下面的循环里
""" """
while left_top_box is not None: # 向下去找 while left_top_box is not None: # 向下去找
virtual_box = [w_x0, left_top_box[Y0_IDX], w_x1, left_top_box[Y1_IDX]] virtual_box = [w_x0, left_top_box[Y0_IDX], w_x1, left_top_box[Y1_IDX]]
left_top_box = find_bottom_bbox_direct_from_left_edge(virtual_box, all_bboxes) left_top_box = find_bottom_bbox_direct_from_left_edge(
virtual_box, all_bboxes
)
if left_top_box: if left_top_box:
w_x0, w_x1 = min(virtual_box[X0_IDX], left_top_box[X0_IDX]), max([virtual_box[X1_IDX], left_top_box[X1_IDX]]) w_x0, w_x1 = min(virtual_box[X0_IDX], left_top_box[X0_IDX]), max(
[virtual_box[X1_IDX], left_top_box[X1_IDX]]
)
# 万一这个初始的box在column中间,那么还要向上看 # 万一这个初始的box在column中间,那么还要向上看
start_box = [w_x0, start_box[Y0_IDX], w_x1, start_box[Y1_IDX]] # 扩展一下宽度更鲁棒 start_box = [
w_x0,
start_box[Y0_IDX],
w_x1,
start_box[Y1_IDX],
] # 扩展一下宽度更鲁棒
left_top_box = find_top_bbox_direct_from_left_edge(start_box, all_bboxes) left_top_box = find_top_bbox_direct_from_left_edge(start_box, all_bboxes)
while left_top_box is not None: # 向上去找 while left_top_box is not None: # 向上去找
virtual_box = [w_x0, left_top_box[Y0_IDX], w_x1, left_top_box[Y1_IDX]] virtual_box = [w_x0, left_top_box[Y0_IDX], w_x1, left_top_box[Y1_IDX]]
left_top_box = find_top_bbox_direct_from_left_edge(virtual_box, all_bboxes) left_top_box = find_top_bbox_direct_from_left_edge(virtual_box, all_bboxes)
if left_top_box: if left_top_box:
w_x0, w_x1 = min(virtual_box[X0_IDX], left_top_box[X0_IDX]), max([virtual_box[X1_IDX], left_top_box[X1_IDX]]) w_x0, w_x1 = min(virtual_box[X0_IDX], left_top_box[X0_IDX]), max(
[virtual_box[X1_IDX], left_top_box[X1_IDX]]
# 检查相交 )
if any([bbox[X0_IDX] <= w_x1+1 <= bbox[X1_IDX] for bbox in all_bboxes]):
# 检查相交
if any([bbox[X0_IDX] <= w_x1 + 1 <= bbox[X1_IDX] for bbox in all_bboxes]):
for b in all_bboxes: for b in all_bboxes:
if b[X0_IDX] <= w_x1+1 <= b[X1_IDX]: if b[X0_IDX] <= w_x1 + 1 <= b[X1_IDX]:
bad_boxes.append([b[X0_IDX], b[Y0_IDX], b[X1_IDX], b[Y1_IDX]]) bad_boxes.append([b[X0_IDX], b[Y0_IDX], b[X1_IDX], b[Y1_IDX]])
break break
else: # 说明成功分割出一列 else: # 说明成功分割出一列
v_blocks.append([w_x0, new_boundry[1], w_x1, new_boundry[3], LAYOUT_V]) v_blocks.append([w_x0, new_boundary[1], w_x1, new_boundary[3], LAYOUT_V])
new_boundry[0] = w_x1 # 更新边界 new_boundary[0] = w_x1 # 更新边界
""" """
接着开始从右上角的box扫描 接着开始从右上角的box扫描
""" """
w_x0 , w_x1 = 0, 0 w_x0, w_x1 = 0, 0
unsplited_block = [] unsplited_block = []
while True: while True:
all_bboxes = get_bbox_in_boundry(bboxes, new_boundry) all_bboxes = get_bbox_in_boundary(bboxes, new_boundary)
if len(all_bboxes) == 0: if len(all_bboxes) == 0:
break break
# 先找到X1最大的 # 先找到X1最大的
bbox_list_sorted = sorted(all_bboxes, key=lambda bbox: bbox[X1_IDX], reverse=True) bbox_list_sorted = sorted(
all_bboxes, key=lambda bbox: bbox[X1_IDX], reverse=True
)
# Then, find the boxes with the smallest Y0 value # Then, find the boxes with the smallest Y0 value
bigest_x1 = bbox_list_sorted[0][X1_IDX] bigest_x1 = bbox_list_sorted[0][X1_IDX]
boxes_with_bigest_x1 = [bbox for bbox in bbox_list_sorted if bbox[X1_IDX] == bigest_x1] # 也就是最靠右的那些 boxes_with_bigest_x1 = [
right_top_box = min(boxes_with_bigest_x1, key=lambda bbox: bbox[Y0_IDX]) # y0最小的那个 bbox for bbox in bbox_list_sorted if bbox[X1_IDX] == bigest_x1
start_box = [right_top_box[X0_IDX], right_top_box[Y0_IDX], right_top_box[X1_IDX], right_top_box[Y1_IDX]] ] # 也就是最靠右的那些
right_top_box = min(
boxes_with_bigest_x1, key=lambda bbox: bbox[Y0_IDX]
) # y0最小的那个
start_box = [
right_top_box[X0_IDX],
right_top_box[Y0_IDX],
right_top_box[X1_IDX],
right_top_box[Y1_IDX],
]
w_x0, w_x1 = right_top_box[X0_IDX], right_top_box[X1_IDX] w_x0, w_x1 = right_top_box[X0_IDX], right_top_box[X1_IDX]
while right_top_box is not None: while right_top_box is not None:
virtual_box = [w_x0, right_top_box[Y0_IDX], w_x1, right_top_box[Y1_IDX]] virtual_box = [w_x0, right_top_box[Y0_IDX], w_x1, right_top_box[Y1_IDX]]
right_top_box = find_bottom_bbox_direct_from_right_edge(virtual_box, all_bboxes) right_top_box = find_bottom_bbox_direct_from_right_edge(
virtual_box, all_bboxes
)
if right_top_box: if right_top_box:
w_x0, w_x1 = min([w_x0, right_top_box[X0_IDX]]), max([w_x1, right_top_box[X1_IDX]]) w_x0, w_x1 = min([w_x0, right_top_box[X0_IDX]]), max(
[w_x1, right_top_box[X1_IDX]]
)
# 在向上扫描 # 在向上扫描
start_box = [w_x0, start_box[Y0_IDX], w_x1, start_box[Y1_IDX]] # 扩展一下宽度更鲁棒 start_box = [
w_x0,
start_box[Y0_IDX],
w_x1,
start_box[Y1_IDX],
] # 扩展一下宽度更鲁棒
right_top_box = find_top_bbox_direct_from_right_edge(start_box, all_bboxes) right_top_box = find_top_bbox_direct_from_right_edge(start_box, all_bboxes)
while right_top_box is not None: while right_top_box is not None:
virtual_box = [w_x0, right_top_box[Y0_IDX], w_x1, right_top_box[Y1_IDX]] virtual_box = [w_x0, right_top_box[Y0_IDX], w_x1, right_top_box[Y1_IDX]]
right_top_box = find_top_bbox_direct_from_right_edge(virtual_box, all_bboxes) right_top_box = find_top_bbox_direct_from_right_edge(
virtual_box, all_bboxes
)
if right_top_box: if right_top_box:
w_x0, w_x1 = min([w_x0, right_top_box[X0_IDX]]), max([w_x1, right_top_box[X1_IDX]]) w_x0, w_x1 = min([w_x0, right_top_box[X0_IDX]]), max(
[w_x1, right_top_box[X1_IDX]]
)
# 检查是否与其他box相交, 垂直切分线与某些box发生相交,说明无法完全垂直方向切分。 # 检查是否与其他box相交, 垂直切分线与某些box发生相交,说明无法完全垂直方向切分。
if any([bbox[X0_IDX] <= w_x0-1 <= bbox[X1_IDX] for bbox in all_bboxes]): if any([bbox[X0_IDX] <= w_x0 - 1 <= bbox[X1_IDX] for bbox in all_bboxes]):
unsplited_block.append([new_boundry[0], new_boundry[1], new_boundry[2], new_boundry[3], LAYOUT_UNPROC]) unsplited_block.append(
[
new_boundary[0],
new_boundary[1],
new_boundary[2],
new_boundary[3],
LAYOUT_UNPROC,
]
)
for b in all_bboxes: for b in all_bboxes:
if b[X0_IDX] <= w_x0-1 <= b[X1_IDX]: if b[X0_IDX] <= w_x0 - 1 <= b[X1_IDX]:
bad_boxes.append([b[X0_IDX], b[Y0_IDX], b[X1_IDX], b[Y1_IDX]]) bad_boxes.append([b[X0_IDX], b[Y0_IDX], b[X1_IDX], b[Y1_IDX]])
break break
else: # 说明成功分割出一列 else: # 说明成功分割出一列
v_blocks.append([w_x0, new_boundry[1], w_x1, new_boundry[3], LAYOUT_V]) v_blocks.append([w_x0, new_boundary[1], w_x1, new_boundary[3], LAYOUT_V])
new_boundry[2] = w_x0 new_boundary[2] = w_x0
"""转换数据结构""" """转换数据结构"""
for block in v_blocks: for block in v_blocks:
sorted_layout_blocks.append({ sorted_layout_blocks.append(
"layout_bbox": block[:4], {
"layout_label":block[4], 'layout_bbox': block[:4],
"sub_layout":[], 'layout_label': block[4],
}) 'sub_layout': [],
}
)
for block in unsplited_block: for block in unsplited_block:
sorted_layout_blocks.append({ sorted_layout_blocks.append(
"layout_bbox": block[:4], {
"layout_label":block[4], 'layout_bbox': block[:4],
"sub_layout":[], 'layout_label': block[4],
"bad_boxes": bad_boxes # 记录下来,这个box是被割中的 'sub_layout': [],
}) 'bad_boxes': bad_boxes, # 记录下来,这个box是被割中的
}
)
# 按照x0排序 # 按照x0排序
sorted_layout_blocks.sort(key=lambda x: x['layout_bbox'][0]) sorted_layout_blocks.sort(key=lambda x: x['layout_bbox'][0])
return sorted_layout_blocks return sorted_layout_blocks
def _try_horizontal_mult_column_split(bboxes:list, boundry:tuple)-> list: def _try_horizontal_mult_column_split(bboxes: list, boundary: tuple) -> list:
""" """
尝试水平切分,如果切分不动,那就当一个BAD_LAYOUT返回 尝试水平切分,如果切分不动,那就当一个BAD_LAYOUT返回
------------------ ------------------
...@@ -406,51 +539,58 @@ def _try_horizontal_mult_column_split(bboxes:list, boundry:tuple)-> list: ...@@ -406,51 +539,58 @@ def _try_horizontal_mult_column_split(bboxes:list, boundry:tuple)-> list:
pass pass
def _vertical_split(bboxes: list, boundary: tuple) -> list:
def _vertical_split(bboxes:list, boundry:tuple)-> list:
""" """
从垂直方向进行切割,分block 从垂直方向进行切割,分block
这个版本里,如果垂直切分不动,那就当一个BAD_LAYOUT返回 这个版本里,如果垂直切分不动,那就当一个BAD_LAYOUT返回
-------------------------- --------------------------
| | | | | |
| | | | | |
| | | |
这种列是此函数要切分的 -> | | 这种列是此函数要切分的 -> | |
| | | |
| | | | | |
| | | | | |
------------------------- -------------------------
""" """
sorted_layout_blocks = [] # 这是要最终返回的值 sorted_layout_blocks = [] # 这是要最终返回的值
bound_x0, bound_y0, bound_x1, bound_y1 = boundry bound_x0, bound_y0, bound_x1, bound_y1 = boundary
all_bboxes = get_bbox_in_boundry(bboxes, boundry) all_bboxes = get_bbox_in_boundary(bboxes, boundary)
""" """
all_bboxes = fix_vertical_bbox_pos(all_bboxes) # 垂直方向解覆盖 all_bboxes = fix_vertical_bbox_pos(all_bboxes) # 垂直方向解覆盖
all_bboxes = fix_hor_bbox_pos(all_bboxes) # 水平解覆盖 all_bboxes = fix_hor_bbox_pos(all_bboxes) # 水平解覆盖
这两行代码目前先不执行,因为公式检测,表格检测还不是很成熟,导致非常多的textblock参与了运算,时间消耗太大。 这两行代码目前先不执行,因为公式检测,表格检测还不是很成熟,导致非常多的textblock参与了运算,时间消耗太大。
这两行代码的作用是: 这两行代码的作用是:
如果遇到互相重叠的bbox, 那么会把面积较小的box进行压缩,从而避免重叠。对布局切分来说带来正反馈。 如果遇到互相重叠的bbox, 那么会把面积较小的box进行压缩,从而避免重叠。对布局切分来说带来正反馈。
""" """
#all_bboxes = paper_bbox_sort(all_bboxes, abs(bound_x1-bound_x0), abs(bound_y1-bound_x0)) # 大致拍下序, 这个是基于直接遮挡的。 # all_bboxes = paper_bbox_sort(all_bboxes, abs(bound_x1-bound_x0), abs(bound_y1-bound_x0)) # 大致拍下序, 这个是基于直接遮挡的。
""" """
首先在垂直方向上扩展独占一行的bbox 首先在垂直方向上扩展独占一行的bbox
""" """
for bbox in all_bboxes: for bbox in all_bboxes:
top_nearest_bbox = find_all_top_bbox_direct(bbox, all_bboxes) # 非扩展线 top_nearest_bbox = find_all_top_bbox_direct(bbox, all_bboxes) # 非扩展线
bottom_nearest_bbox = find_all_bottom_bbox_direct(bbox, all_bboxes) bottom_nearest_bbox = find_all_bottom_bbox_direct(bbox, all_bboxes)
if top_nearest_bbox is None and bottom_nearest_bbox is None and not any([b[X0_IDX]<bbox[X1_IDX]<b[X1_IDX] or b[X0_IDX]<bbox[X0_IDX]<b[X1_IDX] for b in all_bboxes]): # 独占一列, 且不和其他重叠 if (
top_nearest_bbox is None
and bottom_nearest_bbox is None
and not any(
[
b[X0_IDX] < bbox[X1_IDX] < b[X1_IDX]
or b[X0_IDX] < bbox[X0_IDX] < b[X1_IDX]
for b in all_bboxes
]
)
): # 独占一列, 且不和其他重叠
bbox[X0_EXT_IDX] = bbox[X0_IDX] bbox[X0_EXT_IDX] = bbox[X0_IDX]
bbox[Y0_EXT_IDX] = bound_y0 bbox[Y0_EXT_IDX] = bound_y0
bbox[X1_EXT_IDX] = bbox[X1_IDX] bbox[X1_EXT_IDX] = bbox[X1_IDX]
bbox[Y1_EXT_IDX] = bound_y1 bbox[Y1_EXT_IDX] = bound_y1
"""
"""
此时独占一列的被成功扩展到指定的边界上,这个时候利用边界条件合并连续的bbox,成为一个group 此时独占一列的被成功扩展到指定的边界上,这个时候利用边界条件合并连续的bbox,成为一个group
然后合并所有连续垂直方向的bbox. 然后合并所有连续垂直方向的bbox.
""" """
...@@ -458,20 +598,23 @@ def _vertical_split(bboxes:list, boundry:tuple)-> list: ...@@ -458,20 +598,23 @@ def _vertical_split(bboxes:list, boundry:tuple)-> list:
# fix: 这里水平方向的列不要合并成一个行,因为需要保证返回给下游的最小block,总是可以无脑从上到下阅读文字。 # fix: 这里水平方向的列不要合并成一个行,因为需要保证返回给下游的最小block,总是可以无脑从上到下阅读文字。
v_bboxes = [] v_bboxes = []
for box in all_bboxes: for box in all_bboxes:
if box[Y0_EXT_IDX] == bound_y0 and box[Y1_EXT_IDX] == bound_y1: if box[Y0_EXT_IDX] == bound_y0 and box[Y1_EXT_IDX] == bound_y1:
v_bboxes.append(box) v_bboxes.append(box)
""" """
现在v_bboxes里面是所有的group了,每个group都是一个list 现在v_bboxes里面是所有的group了,每个group都是一个list
对v_bboxes里的每个group进行计算放回到sorted_layouts里 对v_bboxes里的每个group进行计算放回到sorted_layouts里
""" """
v_layouts = [] v_layouts = []
for vbox in v_bboxes: for vbox in v_bboxes:
#gp.sort(key=lambda x: x[X0_IDX]) # gp.sort(key=lambda x: x[X0_IDX])
# 然后计算这个group的layout_bbox,也就是最小的x0,y0, 最大的x1,y1 # 然后计算这个group的layout_bbox,也就是最小的x0,y0, 最大的x1,y1
x0, y0, x1, y1 = vbox[X0_EXT_IDX], vbox[Y0_EXT_IDX], vbox[X1_EXT_IDX], vbox[Y1_EXT_IDX] x0, y0, x1, y1 = (
v_layouts.append([x0, y0, x1, y1, LAYOUT_V]) # 垂直的布局 vbox[X0_EXT_IDX],
vbox[Y0_EXT_IDX],
vbox[X1_EXT_IDX],
vbox[Y1_EXT_IDX],
)
v_layouts.append([x0, y0, x1, y1, LAYOUT_V]) # 垂直的布局
""" """
接下来利用这些连续的垂直bbox的layout_bbox的x0, x1,从垂直上切分开其余的为几个部分 接下来利用这些连续的垂直bbox的layout_bbox的x0, x1,从垂直上切分开其余的为几个部分
""" """
...@@ -481,29 +624,41 @@ def _vertical_split(bboxes:list, boundry:tuple)-> list: ...@@ -481,29 +624,41 @@ def _vertical_split(bboxes:list, boundry:tuple)-> list:
v_split_lines.append(x0) v_split_lines.append(x0)
v_split_lines.append(x1) v_split_lines.append(x1)
v_split_lines.append(bound_x1) v_split_lines.append(bound_x1)
unsplited_bboxes = [] unsplited_bboxes = []
for i in range(0, len(v_split_lines), 2): for i in range(0, len(v_split_lines), 2):
start_x0, start_x1 = v_split_lines[i:i+2] start_x0, start_x1 = v_split_lines[i : i + 2]
# 然后找出[start_x0, start_x1]之间的其他bbox,这些组成一个未分割板块 # 然后找出[start_x0, start_x1]之间的其他bbox,这些组成一个未分割板块
bboxes_in_block = [bbox for bbox in all_bboxes if bbox[X0_IDX]>=start_x0 and bbox[X1_IDX]<=start_x1] bboxes_in_block = [
bbox
for bbox in all_bboxes
if bbox[X0_IDX] >= start_x0 and bbox[X1_IDX] <= start_x1
]
unsplited_bboxes.append(bboxes_in_block) unsplited_bboxes.append(bboxes_in_block)
# 接着把未处理的加入到v_layouts里 # 接着把未处理的加入到v_layouts里
for bboxes_in_block in unsplited_bboxes: for bboxes_in_block in unsplited_bboxes:
if len(bboxes_in_block) == 0: if len(bboxes_in_block) == 0:
continue continue
x0, y0, x1, y1 = min([bbox[X0_IDX] for bbox in bboxes_in_block]), bound_y0, max([bbox[X1_IDX] for bbox in bboxes_in_block]), bound_y1 x0, y0, x1, y1 = (
v_layouts.append([x0, y0, x1, y1, LAYOUT_UNPROC]) # 说明这篇区域未能够分析出可靠的版面 min([bbox[X0_IDX] for bbox in bboxes_in_block]),
bound_y0,
v_layouts.sort(key=lambda x: x[0]) # 按照x0排序, 也就是从左到右的顺序 max([bbox[X1_IDX] for bbox in bboxes_in_block]),
bound_y1,
)
v_layouts.append(
[x0, y0, x1, y1, LAYOUT_UNPROC]
) # 说明这篇区域未能够分析出可靠的版面
v_layouts.sort(key=lambda x: x[0]) # 按照x0排序, 也就是从左到右的顺序
for layout in v_layouts: for layout in v_layouts:
sorted_layout_blocks.append({ sorted_layout_blocks.append(
"layout_bbox": layout[:4], {
"layout_label":layout[4], 'layout_bbox': layout[:4],
"sub_layout":[], 'layout_label': layout[4],
}) 'sub_layout': [],
}
)
""" """
至此,垂直方向切成了2种类型,其一是独占一列的,其二是未处理的。 至此,垂直方向切成了2种类型,其一是独占一列的,其二是未处理的。
下面对这些未处理的进行垂直方向切分,这个切分要切出来类似“吕”这种类型的垂直方向的布局 下面对这些未处理的进行垂直方向切分,这个切分要切出来类似“吕”这种类型的垂直方向的布局
...@@ -513,24 +668,32 @@ def _vertical_split(bboxes:list, boundry:tuple)-> list: ...@@ -513,24 +668,32 @@ def _vertical_split(bboxes:list, boundry:tuple)-> list:
x0, y0, x1, y1 = layout['layout_bbox'] x0, y0, x1, y1 = layout['layout_bbox']
v_split_layouts = _vertical_align_split_v2(bboxes, [x0, y0, x1, y1]) v_split_layouts = _vertical_align_split_v2(bboxes, [x0, y0, x1, y1])
sorted_layout_blocks[i] = { sorted_layout_blocks[i] = {
"layout_bbox": [x0, y0, x1, y1], 'layout_bbox': [x0, y0, x1, y1],
"layout_label": LAYOUT_H, 'layout_label': LAYOUT_H,
"sub_layout": v_split_layouts 'sub_layout': v_split_layouts,
} }
layout['layout_label'] = LAYOUT_H # 被垂线切分成了水平布局 layout['layout_label'] = LAYOUT_H # 被垂线切分成了水平布局
return sorted_layout_blocks return sorted_layout_blocks
def split_layout(bboxes:list, boundry:tuple, page_num:int)-> list:
def split_layout(bboxes: list, boundary: tuple, page_num: int) -> list:
""" """
把bboxes切割成layout 把bboxes切割成layout
return: return:
[ [
{ {
"layout_bbox": [x0, y0, x1, y1], "layout_bbox": [x0,y0,x1,y1],
"layout_label":"u|v|h|b", 未处理|垂直|水平|BAD_LAYOUT "layout_label":"u|v|h|b", 未处理|垂直|水平|BAD_LAYOUT
"sub_layout": [] #每个元素都是[x0, y0, x1, y1, block_content, idx_x, idx_y, content_type, ext_x0, ext_y0, ext_x1, ext_y1], 并且顺序就是阅读顺序 "sub_layout":[] #每个元素都是[
x0,y0,
x1,y1,
block_content,
idx_x,idx_y,
content_type,
ext_x0,ext_y0,
ext_x1,ext_y1
], 并且顺序就是阅读顺序
} }
] ]
example: example:
...@@ -539,7 +702,7 @@ def split_layout(bboxes:list, boundry:tuple, page_num:int)-> list: ...@@ -539,7 +702,7 @@ def split_layout(bboxes:list, boundry:tuple, page_num:int)-> list:
"layout_bbox": [0, 0, 100, 100], "layout_bbox": [0, 0, 100, 100],
"layout_label":"u|v|h|b", "layout_label":"u|v|h|b",
"sub_layout":[ "sub_layout":[
] ]
}, },
{ {
...@@ -559,35 +722,35 @@ def split_layout(bboxes:list, boundry:tuple, page_num:int)-> list: ...@@ -559,35 +722,35 @@ def split_layout(bboxes:list, boundry:tuple, page_num:int)-> list:
"layout_bbox": [0, 0, 100, 100], "layout_bbox": [0, 0, 100, 100],
"layout_label":"u|v|h|b", "layout_label":"u|v|h|b",
"sub_layout":[ "sub_layout":[
] ]
} }
} }
] ]
""" """
sorted_layouts = [] # 最终返回的结果 sorted_layouts = [] # 最终返回的结果
boundry_x0, boundry_y0, boundry_x1, boundry_y1 = boundry boundary_x0, boundary_y0, boundary_x1, boundary_y1 = boundary
if len(bboxes) <=1: if len(bboxes) <= 1:
return [ return [
{ {
"layout_bbox": [boundry_x0, boundry_y0, boundry_x1, boundry_y1], 'layout_bbox': [boundary_x0, boundary_y0, boundary_x1, boundary_y1],
"layout_label": LAYOUT_V, 'layout_label': LAYOUT_V,
"sub_layout":[] 'sub_layout': [],
} }
] ]
""" """
接下来按照先水平后垂直的顺序进行切分 接下来按照先水平后垂直的顺序进行切分
""" """
bboxes = paper_bbox_sort(bboxes, boundry_x1-boundry_x0, boundry_y1-boundry_y0) bboxes = paper_bbox_sort(
sorted_layouts = _horizontal_split(bboxes, boundry) # 通过水平分割出来的layout bboxes, boundary_x1 - boundary_x0, boundary_y1 - boundary_y0
)
sorted_layouts = _horizontal_split(bboxes, boundary) # 通过水平分割出来的layout
for i, layout in enumerate(sorted_layouts): for i, layout in enumerate(sorted_layouts):
x0, y0, x1, y1 = layout['layout_bbox'] x0, y0, x1, y1 = layout['layout_bbox']
layout_type = layout['layout_label'] layout_type = layout['layout_label']
if layout_type == LAYOUT_UNPROC: # 说明是非独占单行的,这些需要垂直切分 if layout_type == LAYOUT_UNPROC: # 说明是非独占单行的,这些需要垂直切分
v_split_layouts = _vertical_split(bboxes, [x0, y0, x1, y1]) v_split_layouts = _vertical_split(bboxes, [x0, y0, x1, y1])
""" """
最后这里有个逻辑问题:如果这个函数只分离出来了一个column layout,那么这个layout分割肯定超出了算法能力范围。因为我们假定的是传进来的 最后这里有个逻辑问题:如果这个函数只分离出来了一个column layout,那么这个layout分割肯定超出了算法能力范围。因为我们假定的是传进来的
box已经把行全部剥离了,所以这里必须十多个列才可以。如果只剥离出来一个layout,并且是多个box,那么就说明这个layout是无法分割的,标记为LAYOUT_UNPROC box已经把行全部剥离了,所以这里必须十多个列才可以。如果只剥离出来一个layout,并且是多个box,那么就说明这个layout是无法分割的,标记为LAYOUT_UNPROC
...@@ -596,28 +759,26 @@ def split_layout(bboxes:list, boundry:tuple, page_num:int)-> list: ...@@ -596,28 +759,26 @@ def split_layout(bboxes:list, boundry:tuple, page_num:int)-> list:
if len(v_split_layouts) == 1: if len(v_split_layouts) == 1:
if len(v_split_layouts[0]['sub_layout']) == 0: if len(v_split_layouts[0]['sub_layout']) == 0:
layout_label = LAYOUT_UNPROC layout_label = LAYOUT_UNPROC
#logger.warning(f"WARNING: pageno={page_num}, 无法分割的layout: ", v_split_layouts) # logger.warning(f"WARNING: pageno={page_num}, 无法分割的layout: ", v_split_layouts)
""" """
组合起来最终的layout 组合起来最终的layout
""" """
sorted_layouts[i] = { sorted_layouts[i] = {
"layout_bbox": [x0, y0, x1, y1], 'layout_bbox': [x0, y0, x1, y1],
"layout_label": layout_label, 'layout_label': layout_label,
"sub_layout": v_split_layouts 'sub_layout': v_split_layouts,
} }
layout['layout_label'] = LAYOUT_H layout['layout_label'] = LAYOUT_H
""" """
水平和垂直方向都切分完毕了。此时还有一些未处理的,这些未处理的可能是因为水平和垂直方向都无法切分。 水平和垂直方向都切分完毕了。此时还有一些未处理的,这些未处理的可能是因为水平和垂直方向都无法切分。
这些最后调用_try_horizontal_mult_block_split做一次水平多个block的联合切分,如果也不能切分最终就当做BAD_LAYOUT返回 这些最后调用_try_horizontal_mult_block_split做一次水平多个block的联合切分,如果也不能切分最终就当做BAD_LAYOUT返回
""" """
# TODO # TODO
return sorted_layouts return sorted_layouts
def get_bboxes_layout(all_boxes:list, boundry:tuple, page_id:int): def get_bboxes_layout(all_boxes: list, boundary: tuple, page_id: int):
""" """
对利用layout排序之后的box,进行排序 对利用layout排序之后的box,进行排序
return: return:
...@@ -628,10 +789,10 @@ def get_bboxes_layout(all_boxes:list, boundry:tuple, page_id:int): ...@@ -628,10 +789,10 @@ def get_bboxes_layout(all_boxes:list, boundry:tuple, page_id:int):
}, },
] ]
""" """
def _preorder_traversal(layout): def _preorder_traversal(layout):
""" """对sorted_layouts的叶子节点,也就是len(sub_layout)==0的节点进行排序。排序按照前序遍历的顺序,也就是从上到
对sorted_layouts的叶子节点,也就是len(sub_layout)==0的节点进行排序。排序按照前序遍历的顺序,也就是从上到下,从左到右的顺序 下,从左到右的顺序."""
"""
sorted_layout_blocks = [] sorted_layout_blocks = []
for layout in layout: for layout in layout:
sub_layout = layout['sub_layout'] sub_layout = layout['sub_layout']
...@@ -641,71 +802,89 @@ def get_bboxes_layout(all_boxes:list, boundry:tuple, page_id:int): ...@@ -641,71 +802,89 @@ def get_bboxes_layout(all_boxes:list, boundry:tuple, page_id:int):
s = _preorder_traversal(sub_layout) s = _preorder_traversal(sub_layout)
sorted_layout_blocks.extend(s) sorted_layout_blocks.extend(s)
return sorted_layout_blocks return sorted_layout_blocks
# ------------------------------------------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------------------------------------------
sorted_layouts = split_layout(all_boxes, boundry, page_id)# 先切分成layout,得到一个Tree sorted_layouts = split_layout(
total_sorted_layout_blocks = _preorder_traversal(sorted_layouts) all_boxes, boundary, page_id
) # 先切分成layout,得到一个Tree
total_sorted_layout_blocks = _preorder_traversal(sorted_layouts)
return total_sorted_layout_blocks, sorted_layouts return total_sorted_layout_blocks, sorted_layouts
def get_columns_cnt_of_layout(layout_tree): def get_columns_cnt_of_layout(layout_tree):
""" """获取一个layout的宽度."""
获取一个layout的宽度 max_width_list = [0] # 初始化一个元素,防止max,min函数报错
"""
max_width_list = [0] # 初始化一个元素,防止max,min函数报错 for items in layout_tree: # 针对每一层(横切)计算列数,横着的算一列
for items in layout_tree: # 针对每一层(横切)计算列数,横着的算一列
layout_type = items['layout_label'] layout_type = items['layout_label']
sub_layouts = items['sub_layout'] sub_layouts = items['sub_layout']
if len(sub_layouts)==0: if len(sub_layouts) == 0:
max_width_list.append(1) max_width_list.append(1)
else: else:
if layout_type == LAYOUT_H: if layout_type == LAYOUT_H:
max_width_list.append(1) max_width_list.append(1)
else: else:
width = 0 width = 0
for l in sub_layouts: for sub_layout in sub_layouts:
if len(l['sub_layout']) == 0: if len(sub_layout['sub_layout']) == 0:
width += 1 width += 1
else: else:
for lay in l['sub_layout']: for lay in sub_layout['sub_layout']:
width += get_columns_cnt_of_layout([lay]) width += get_columns_cnt_of_layout([lay])
max_width_list.append(width) max_width_list.append(width)
return max(max_width_list) return max(max_width_list)
def sort_with_layout(bboxes: list, page_width, page_height) -> (list, list):
def sort_with_layout(bboxes:list, page_width, page_height) -> (list,list): """输入是一个bbox的list.
"""
输入是一个bbox的list.
获取到输入之后,先进行layout切分,然后对这些bbox进行排序。返回排序后的bboxes 获取到输入之后,先进行layout切分,然后对这些bbox进行排序。返回排序后的bboxes
""" """
new_bboxes = [] new_bboxes = []
for box in bboxes: for box in bboxes:
# new_bboxes.append([box[0], box[1], box[2], box[3], None, None, None, 'text', None, None, None, None]) # new_bboxes.append([box[0], box[1], box[2], box[3], None, None, None, 'text', None, None, None, None])
new_bboxes.append([box[0], box[1], box[2], box[3], None, None, None, 'text', None, None, None, None, box[4]]) new_bboxes.append(
[
layout_bboxes, _ = get_bboxes_layout(new_bboxes, [0, 0, page_width, page_height], 0) box[0],
if any([lay['layout_label']==LAYOUT_UNPROC for lay in layout_bboxes]): box[1],
logger.warning(f"drop this pdf, reason: 复杂版面") box[2],
return None,None box[3],
None,
sorted_bboxes = [] None,
None,
'text',
None,
None,
None,
None,
box[4],
]
)
layout_bboxes, _ = get_bboxes_layout(
new_bboxes, tuple([0, 0, page_width, page_height]), 0
)
if any([lay['layout_label'] == LAYOUT_UNPROC for lay in layout_bboxes]):
logger.warning('drop this pdf, reason: 复杂版面')
return None, None
sorted_bboxes = []
# 利用layout bbox每次框定一些box,然后排序 # 利用layout bbox每次框定一些box,然后排序
for layout in layout_bboxes: for layout in layout_bboxes:
lbox = layout['layout_bbox'] lbox = layout['layout_bbox']
bbox_in_layout = get_bbox_in_boundry(new_bboxes, lbox) bbox_in_layout = get_bbox_in_boundary(new_bboxes, lbox)
sorted_bbox = paper_bbox_sort(bbox_in_layout, lbox[2]-lbox[0], lbox[3]-lbox[1]) sorted_bbox = paper_bbox_sort(
bbox_in_layout, lbox[2] - lbox[0], lbox[3] - lbox[1]
)
sorted_bboxes.extend(sorted_bbox) sorted_bboxes.extend(sorted_bbox)
return sorted_bboxes, layout_bboxes return sorted_bboxes, layout_bboxes
def sort_text_block(text_block, layout_bboxes): def sort_text_block(text_block, layout_bboxes):
""" """对一页的text_block进行排序."""
对一页的text_block进行排序
"""
sorted_text_bbox = [] sorted_text_bbox = []
all_text_bbox = [] all_text_bbox = []
# 做一个box=>text的映射 # 做一个box=>text的映射
...@@ -714,19 +893,29 @@ def sort_text_block(text_block, layout_bboxes): ...@@ -714,19 +893,29 @@ def sort_text_block(text_block, layout_bboxes):
box = blk['bbox'] box = blk['bbox']
box_to_text[(box[0], box[1], box[2], box[3])] = blk box_to_text[(box[0], box[1], box[2], box[3])] = blk
all_text_bbox.append(box) all_text_bbox.append(box)
# text_blocks_to_sort = [] # text_blocks_to_sort = []
# for box in box_to_text.keys(): # for box in box_to_text.keys():
# text_blocks_to_sort.append([box[0], box[1], box[2], box[3], None, None, None, 'text', None, None, None, None]) # text_blocks_to_sort.append([box[0], box[1], box[2], box[3], None, None, None, 'text', None, None, None, None])
# 按照layout_bboxes的顺序,对text_block进行排序 # 按照layout_bboxes的顺序,对text_block进行排序
for layout in layout_bboxes: for layout in layout_bboxes:
layout_box = layout['layout_bbox'] layout_box = layout['layout_bbox']
text_bbox_in_layout = get_bbox_in_boundry(all_text_bbox, [layout_box[0]-1, layout_box[1]-1, layout_box[2]+1, layout_box[3]+1]) text_bbox_in_layout = get_bbox_in_boundary(
#sorted_bbox = paper_bbox_sort(text_bbox_in_layout, layout_box[2]-layout_box[0], layout_box[3]-layout_box[1]) all_text_bbox,
text_bbox_in_layout.sort(key = lambda x: x[1]) # 一个layout内部的box,按照y0自上而下排序 [
#sorted_bbox = [[b] for b in text_blocks_to_sort] layout_box[0] - 1,
layout_box[1] - 1,
layout_box[2] + 1,
layout_box[3] + 1,
],
)
# sorted_bbox = paper_bbox_sort(text_bbox_in_layout, layout_box[2]-layout_box[0], layout_box[3]-layout_box[1])
text_bbox_in_layout.sort(
key=lambda x: x[1]
) # 一个layout内部的box,按照y0自上而下排序
# sorted_bbox = [[b] for b in text_blocks_to_sort]
for sb in text_bbox_in_layout: for sb in text_bbox_in_layout:
sorted_text_bbox.append(box_to_text[(sb[0], sb[1], sb[2], sb[3])]) sorted_text_bbox.append(box_to_text[(sb[0], sb[1], sb[2], sb[3])])
return sorted_text_bbox return sorted_text_bbox
from loguru import logger
import math import math
def _is_in_or_part_overlap(box1, box2) -> bool: def _is_in_or_part_overlap(box1, box2) -> bool:
""" """两个bbox是否有部分重叠或者包含."""
两个bbox是否有部分重叠或者包含
"""
if box1 is None or box2 is None: if box1 is None or box2 is None:
return False return False
x0_1, y0_1, x1_1, y1_1 = box1 x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2 x0_2, y0_2, x1_2, y1_2 = box2
return not (x1_1 < x0_2 or # box1在box2的左边 return not (x1_1 < x0_2 or # box1在box2的左边
x0_1 > x1_2 or # box1在box2的右边 x0_1 > x1_2 or # box1在box2的右边
y1_1 < y0_2 or # box1在box2的上边 y1_1 < y0_2 or # box1在box2的上边
y0_1 > y1_2) # box1在box2的下边 y0_1 > y1_2) # box1在box2的下边
def _is_in_or_part_overlap_with_area_ratio(box1, box2, area_ratio_threshold=0.6):
""" def _is_in_or_part_overlap_with_area_ratio(box1,
判断box1是否在box2里面,或者box1和box2有部分重叠,且重叠面积占box1的比例超过area_ratio_threshold box2,
area_ratio_threshold=0.6):
""" """判断box1是否在box2里面,或者box1和box2有部分重叠,且重叠面积占box1的比例超过area_ratio_threshold."""
if box1 is None or box2 is None: if box1 is None or box2 is None:
return False return False
x0_1, y0_1, x1_1, y1_1 = box1 x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2 x0_2, y0_2, x1_2, y1_2 = box2
if not _is_in_or_part_overlap(box1, box2): if not _is_in_or_part_overlap(box1, box2):
return False return False
# 计算重叠面积 # 计算重叠面积
x_left = max(x0_1, x0_2) x_left = max(x0_1, x0_2)
y_top = max(y0_1, y0_2) y_top = max(y0_1, y0_2)
x_right = min(x1_1, x1_2) x_right = min(x1_1, x1_2)
y_bottom = min(y1_1, y1_2) y_bottom = min(y1_1, y1_2)
overlap_area = (x_right - x_left) * (y_bottom - y_top) overlap_area = (x_right - x_left) * (y_bottom - y_top)
# 计算box1的面积 # 计算box1的面积
box1_area = (x1_1 - x0_1) * (y1_1 - y0_1) box1_area = (x1_1 - x0_1) * (y1_1 - y0_1)
return overlap_area / box1_area > area_ratio_threshold return overlap_area / box1_area > area_ratio_threshold
def _is_in(box1, box2) -> bool: def _is_in(box1, box2) -> bool:
""" """box1是否完全在box2里面."""
box1是否完全在box2里面
"""
x0_1, y0_1, x1_1, y1_1 = box1 x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2 x0_2, y0_2, x1_2, y1_2 = box2
return (x0_1 >= x0_2 and # box1的左边界不在box2的左边外 return (x0_1 >= x0_2 and # box1的左边界不在box2的左边外
y0_1 >= y0_2 and # box1的上边界不在box2的上边外 y0_1 >= y0_2 and # box1的上边界不在box2的上边外
x1_1 <= x1_2 and # box1的右边界不在box2的右边外 x1_1 <= x1_2 and # box1的右边界不在box2的右边外
y1_1 <= y1_2) # box1的下边界不在box2的下边外 y1_1 <= y1_2) # box1的下边界不在box2的下边外
def _is_part_overlap(box1, box2) -> bool: def _is_part_overlap(box1, box2) -> bool:
""" """两个bbox是否有部分重叠,但不完全包含."""
两个bbox是否有部分重叠,但不完全包含
"""
if box1 is None or box2 is None: if box1 is None or box2 is None:
return False return False
return _is_in_or_part_overlap(box1, box2) and not _is_in(box1, box2) return _is_in_or_part_overlap(box1, box2) and not _is_in(box1, box2)
def _left_intersect(left_box, right_box): def _left_intersect(left_box, right_box):
"检查两个box的左边界是否有交集,也就是left_box的右边界是否在right_box的左边界内" """检查两个box的左边界是否有交集,也就是left_box的右边界是否在right_box的左边界内."""
if left_box is None or right_box is None: if left_box is None or right_box is None:
return False return False
x0_1, y0_1, x1_1, y1_1 = left_box x0_1, y0_1, x1_1, y1_1 = left_box
x0_2, y0_2, x1_2, y1_2 = right_box x0_2, y0_2, x1_2, y1_2 = right_box
return x1_1>x0_2 and x0_1<x0_2 and (y0_1<=y0_2<=y1_1 or y0_1<=y1_2<=y1_1) return x1_1 > x0_2 and x0_1 < x0_2 and (y0_1 <= y0_2 <= y1_1
or y0_1 <= y1_2 <= y1_1)
def _right_intersect(left_box, right_box): def _right_intersect(left_box, right_box):
""" """检查box是否在右侧边界有交集,也就是left_box的左边界是否在right_box的右边界内."""
检查box是否在右侧边界有交集,也就是left_box的左边界是否在right_box的右边界内
"""
if left_box is None or right_box is None: if left_box is None or right_box is None:
return False return False
x0_1, y0_1, x1_1, y1_1 = left_box x0_1, y0_1, x1_1, y1_1 = left_box
x0_2, y0_2, x1_2, y1_2 = right_box x0_2, y0_2, x1_2, y1_2 = right_box
return x0_1<x1_2 and x1_1>x1_2 and (y0_1<=y0_2<=y1_1 or y0_1<=y1_2<=y1_1) return x0_1 < x1_2 and x1_1 > x1_2 and (y0_1 <= y0_2 <= y1_1
or y0_1 <= y1_2 <= y1_1)
def _is_vertical_full_overlap(box1, box2, x_torlence=2): def _is_vertical_full_overlap(box1, box2, x_torlence=2):
""" """x方向上:要么box1包含box2, 要么box2包含box1。不能部分包含 y方向上:box1和box2有重叠."""
x方向上:要么box1包含box2, 要么box2包含box1。不能部分包含
y方向上:box1和box2有重叠
"""
# 解析box的坐标 # 解析box的坐标
x11, y11, x12, y12 = box1 # 左上角和右下角的坐标 (x1, y1, x2, y2) x11, y11, x12, y12 = box1 # 左上角和右下角的坐标 (x1, y1, x2, y2)
x21, y21, x22, y22 = box2 x21, y21, x22, y22 = box2
# 在x轴方向上,box1是否包含box2 或 box2包含box1 # 在x轴方向上,box1是否包含box2 或 box2包含box1
contains_in_x = (x11-x_torlence <= x21 and x12+x_torlence >= x22) or (x21-x_torlence <= x11 and x22+x_torlence >= x12) contains_in_x = (x11 - x_torlence <= x21 and x12 + x_torlence >= x22) or (
x21 - x_torlence <= x11 and x22 + x_torlence >= x12)
# 在y轴方向上,box1和box2是否有重叠 # 在y轴方向上,box1和box2是否有重叠
overlap_in_y = not (y12 < y21 or y11 > y22) overlap_in_y = not (y12 < y21 or y11 > y22)
return contains_in_x and overlap_in_y return contains_in_x and overlap_in_y
def _is_bottom_full_overlap(box1, box2, y_tolerance=2): def _is_bottom_full_overlap(box1, box2, y_tolerance=2):
""" """检查box1下方和box2的上方有轻微的重叠,轻微程度收到y_tolerance的限制 这个函数和_is_vertical-
检查box1下方和box2的上方有轻微的重叠,轻微程度收到y_tolerance的限制 full_overlap的区别是,这个函数允许box1和box2在x方向上有轻微的重叠,允许一定的模糊度."""
这个函数和_is_vertical-full_overlap的区别是,这个函数允许box1和box2在x方向上有轻微的重叠,允许一定的模糊度
"""
if box1 is None or box2 is None: if box1 is None or box2 is None:
return False return False
x0_1, y0_1, x1_1, y1_1 = box1 x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2 x0_2, y0_2, x1_2, y1_2 = box2
tolerance_margin = 2 tolerance_margin = 2
is_xdir_full_overlap = ((x0_1-tolerance_margin<=x0_2<=x1_1+tolerance_margin and x0_1-tolerance_margin<=x1_2<=x1_1+tolerance_margin) or (x0_2-tolerance_margin<=x0_1<=x1_2+tolerance_margin and x0_2-tolerance_margin<=x1_1<=x1_2+tolerance_margin)) is_xdir_full_overlap = (
(x0_1 - tolerance_margin <= x0_2 <= x1_1 + tolerance_margin
return y0_2<y1_1 and 0<(y1_1-y0_2)<y_tolerance and is_xdir_full_overlap and x0_1 - tolerance_margin <= x1_2 <= x1_1 + tolerance_margin)
or (x0_2 - tolerance_margin <= x0_1 <= x1_2 + tolerance_margin
and x0_2 - tolerance_margin <= x1_1 <= x1_2 + tolerance_margin))
return y0_2 < y1_1 and 0 < (y1_1 -
y0_2) < y_tolerance and is_xdir_full_overlap
def _is_left_overlap(
box1,
box2,
):
"""检查box1的左侧是否和box2有重叠 在Y方向上可以是部分重叠或者是完全重叠。不分box1和box2的上下关系,也就是无论box1在box2下
方还是box2在box1下方,都可以检测到重叠。 X方向上."""
def _is_left_overlap(box1, box2,):
"""
检查box1的左侧是否和box2有重叠
在Y方向上可以是部分重叠或者是完全重叠。不分box1和box2的上下关系,也就是无论box1在box2下方还是box2在box1下方,都可以检测到重叠。
X方向上
"""
def __overlap_y(Ay1, Ay2, By1, By2): def __overlap_y(Ay1, Ay2, By1, By2):
return max(0, min(Ay2, By2) - max(Ay1, By1)) return max(0, min(Ay2, By2) - max(Ay1, By1))
if box1 is None or box2 is None: if box1 is None or box2 is None:
return False return False
x0_1, y0_1, x1_1, y1_1 = box1 x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2 x0_2, y0_2, x1_2, y1_2 = box2
y_overlap_len = __overlap_y(y0_1, y1_1, y0_2, y1_2) y_overlap_len = __overlap_y(y0_1, y1_1, y0_2, y1_2)
ratio_1 = 1.0 * y_overlap_len / (y1_1 - y0_1) if y1_1-y0_1!=0 else 0 ratio_1 = 1.0 * y_overlap_len / (y1_1 - y0_1) if y1_1 - y0_1 != 0 else 0
ratio_2 = 1.0 * y_overlap_len / (y1_2 - y0_2) if y1_2-y0_2!=0 else 0 ratio_2 = 1.0 * y_overlap_len / (y1_2 - y0_2) if y1_2 - y0_2 != 0 else 0
vertical_overlap_cond = ratio_1 >= 0.5 or ratio_2 >= 0.5 vertical_overlap_cond = ratio_1 >= 0.5 or ratio_2 >= 0.5
#vertical_overlap_cond = y0_1<=y0_2<=y1_1 or y0_1<=y1_2<=y1_1 or y0_2<=y0_1<=y1_2 or y0_2<=y1_1<=y1_2
return x0_1<=x0_2<=x1_1 and vertical_overlap_cond
# vertical_overlap_cond = y0_1<=y0_2<=y1_1 or y0_1<=y1_2<=y1_1 or y0_2<=y0_1<=y1_2 or y0_2<=y1_1<=y1_2
return x0_1 <= x0_2 <= x1_1 and vertical_overlap_cond
def __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8):
def __is_overlaps_y_exceeds_threshold(bbox1,
bbox2,
overlap_ratio_threshold=0.8):
"""检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%""" """检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%"""
_, y0_1, _, y1_1 = bbox1 _, y0_1, _, y1_1 = bbox1
_, y0_2, _, y1_2 = bbox2 _, y0_2, _, y1_2 = bbox2
overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2)) overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
height1, height2 = y1_1 - y0_1, y1_2 - y0_2 height1, height2 = y1_1 - y0_1, y1_2 - y0_2
max_height = max(height1, height2) # max_height = max(height1, height2)
min_height = min(height1, height2) min_height = min(height1, height2)
return (overlap / min_height) > overlap_ratio_threshold return (overlap / min_height) > overlap_ratio_threshold
def calculate_iou(bbox1, bbox2): def calculate_iou(bbox1, bbox2):
""" """计算两个边界框的交并比(IOU)。
计算两个边界框的交并比(IOU)。
Args: Args:
bbox1 (list[float]): 第一个边界框的坐标,格式为 [x1, y1, x2, y2],其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。 bbox1 (list[float]): 第一个边界框的坐标,格式为 [x1, y1, x2, y2],其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
...@@ -170,7 +168,6 @@ def calculate_iou(bbox1, bbox2): ...@@ -170,7 +168,6 @@ def calculate_iou(bbox1, bbox2):
Returns: Returns:
float: 两个边界框的交并比(IOU),取值范围为 [0, 1]。 float: 两个边界框的交并比(IOU),取值范围为 [0, 1]。
""" """
# Determine the coordinates of the intersection rectangle # Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0]) x_left = max(bbox1[0], bbox2[0])
...@@ -188,16 +185,15 @@ def calculate_iou(bbox1, bbox2): ...@@ -188,16 +185,15 @@ def calculate_iou(bbox1, bbox2):
bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
# Compute the intersection over union by taking the intersection area # Compute the intersection over union by taking the intersection area
# and dividing it by the sum of both areas minus the intersection area # and dividing it by the sum of both areas minus the intersection area
iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area) iou = intersection_area / float(bbox1_area + bbox2_area -
intersection_area)
return iou return iou
def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2): def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
""" """计算box1和box2的重叠面积占最小面积的box的比例."""
计算box1和box2的重叠面积占最小面积的box的比例
"""
# Determine the coordinates of the intersection rectangle # Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0]) x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1]) y_top = max(bbox1[1], bbox2[1])
...@@ -209,16 +205,16 @@ def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2): ...@@ -209,16 +205,16 @@ def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
# The area of overlap area # The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top) intersection_area = (x_right - x_left) * (y_bottom - y_top)
min_box_area = min([(bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]), (bbox2[3]-bbox2[1])*(bbox2[2]-bbox2[0])]) min_box_area = min([(bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]),
if min_box_area==0: (bbox2[3] - bbox2[1]) * (bbox2[2] - bbox2[0])])
if min_box_area == 0:
return 0 return 0
else: else:
return intersection_area / min_box_area return intersection_area / min_box_area
def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2): def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
""" """计算box1和box2的重叠面积占bbox1的比例."""
计算box1和box2的重叠面积占bbox1的比例
"""
# Determine the coordinates of the intersection rectangle # Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0]) x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1]) y_top = max(bbox1[1], bbox2[1])
...@@ -230,7 +226,7 @@ def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2): ...@@ -230,7 +226,7 @@ def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
# The area of overlap area # The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top) intersection_area = (x_right - x_left) * (y_bottom - y_top)
bbox1_area = (bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]) bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
if bbox1_area == 0: if bbox1_area == 0:
return 0 return 0
else: else:
...@@ -238,11 +234,8 @@ def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2): ...@@ -238,11 +234,8 @@ def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio): def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
""" """通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例 如果比例大于ratio,则返回小的那个bbox, 否则返回None."""
如果比例大于ratio,则返回小的那个bbox,
否则返回None
"""
x1_min, y1_min, x1_max, y1_max = bbox1 x1_min, y1_min, x1_max, y1_max = bbox1
x2_min, y2_min, x2_max, y2_max = bbox2 x2_min, y2_min, x2_max, y2_max = bbox2
area1 = (x1_max - x1_min) * (y1_max - y1_min) area1 = (x1_max - x1_min) * (y1_max - y1_min)
...@@ -256,89 +249,118 @@ def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio): ...@@ -256,89 +249,118 @@ def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
else: else:
return None return None
def get_bbox_in_boundry(bboxes:list, boundry:tuple)-> list:
x0, y0, x1, y1 = boundry def get_bbox_in_boundary(bboxes: list, boundary: tuple) -> list:
new_boxes = [box for box in bboxes if box[0] >= x0 and box[1] >= y0 and box[2] <= x1 and box[3] <= y1] x0, y0, x1, y1 = boundary
new_boxes = [
box for box in bboxes
if box[0] >= x0 and box[1] >= y0 and box[2] <= x1 and box[3] <= y1
]
return new_boxes return new_boxes
def is_vbox_on_side(bbox, width, height, side_threshold=0.2): def is_vbox_on_side(bbox, width, height, side_threshold=0.2):
""" """判断一个bbox是否在pdf页面的边缘."""
判断一个bbox是否在pdf页面的边缘
"""
x0, x1 = bbox[0], bbox[2] x0, x1 = bbox[0], bbox[2]
if x1<=width*side_threshold or x0>=width*(1-side_threshold): if x1 <= width * side_threshold or x0 >= width * (1 - side_threshold):
return True return True
return False return False
def find_top_nearest_text_bbox(pymu_blocks, obj_bbox): def find_top_nearest_text_bbox(pymu_blocks, obj_bbox):
tolerance_margin = 4 tolerance_margin = 4
top_boxes = [box for box in pymu_blocks if obj_bbox[1]-box['bbox'][3] >=-tolerance_margin and not _is_in(box['bbox'], obj_bbox)] top_boxes = [
box for box in pymu_blocks
if obj_bbox[1] - box['bbox'][3] >= -tolerance_margin
and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的 # 然后找到X方向上有互相重叠的
top_boxes = [box for box in top_boxes if any([obj_bbox[0]-tolerance_margin <=box['bbox'][0]<=obj_bbox[2]+tolerance_margin, top_boxes = [
obj_bbox[0]-tolerance_margin <=box['bbox'][2]<=obj_bbox[2]+tolerance_margin, box for box in top_boxes if any([
box['bbox'][0]-tolerance_margin <=obj_bbox[0]<=box['bbox'][2]+tolerance_margin, obj_bbox[0] - tolerance_margin <= box['bbox'][0] <= obj_bbox[2] +
box['bbox'][0]-tolerance_margin <=obj_bbox[2]<=box['bbox'][2]+tolerance_margin tolerance_margin, obj_bbox[0] -
])] tolerance_margin <= box['bbox'][2] <= obj_bbox[2] +
tolerance_margin, box['bbox'][0] -
tolerance_margin <= obj_bbox[0] <= box['bbox'][2] +
tolerance_margin, box['bbox'][0] -
tolerance_margin <= obj_bbox[2] <= box['bbox'][2] +
tolerance_margin
])
]
# 然后找到y1最大的那个 # 然后找到y1最大的那个
if len(top_boxes)>0: if len(top_boxes) > 0:
top_boxes.sort(key=lambda x: x['bbox'][3], reverse=True) top_boxes.sort(key=lambda x: x['bbox'][3], reverse=True)
return top_boxes[0] return top_boxes[0]
else: else:
return None return None
def find_bottom_nearest_text_bbox(pymu_blocks, obj_bbox): def find_bottom_nearest_text_bbox(pymu_blocks, obj_bbox):
bottom_boxes = [box for box in pymu_blocks if box['bbox'][1] - obj_bbox[3]>=-2 and not _is_in(box['bbox'], obj_bbox)] bottom_boxes = [
box for box in pymu_blocks if box['bbox'][1] -
obj_bbox[3] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的 # 然后找到X方向上有互相重叠的
bottom_boxes = [box for box in bottom_boxes if any([obj_bbox[0]-2 <=box['bbox'][0]<=obj_bbox[2]+2, bottom_boxes = [
obj_bbox[0]-2 <=box['bbox'][2]<=obj_bbox[2]+2, box for box in bottom_boxes if any([
box['bbox'][0]-2 <=obj_bbox[0]<=box['bbox'][2]+2, obj_bbox[0] - 2 <= box['bbox'][0] <= obj_bbox[2] + 2, obj_bbox[0] -
box['bbox'][0]-2 <=obj_bbox[2]<=box['bbox'][2]+2 2 <= box['bbox'][2] <= obj_bbox[2] + 2, box['bbox'][0] -
])] 2 <= obj_bbox[0] <= box['bbox'][2] + 2, box['bbox'][0] -
2 <= obj_bbox[2] <= box['bbox'][2] + 2
])
]
# 然后找到y0最小的那个 # 然后找到y0最小的那个
if len(bottom_boxes)>0: if len(bottom_boxes) > 0:
bottom_boxes.sort(key=lambda x: x['bbox'][1], reverse=False) bottom_boxes.sort(key=lambda x: x['bbox'][1], reverse=False)
return bottom_boxes[0] return bottom_boxes[0]
else: else:
return None return None
def find_left_nearest_text_bbox(pymu_blocks, obj_bbox): def find_left_nearest_text_bbox(pymu_blocks, obj_bbox):
""" """寻找左侧最近的文本block."""
寻找左侧最近的文本block left_boxes = [
""" box for box in pymu_blocks if obj_bbox[0] -
left_boxes = [box for box in pymu_blocks if obj_bbox[0]-box['bbox'][2]>=-2 and not _is_in(box['bbox'], obj_bbox)] box['bbox'][2] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的 # 然后找到X方向上有互相重叠的
left_boxes = [box for box in left_boxes if any([obj_bbox[1]-2 <=box['bbox'][1]<=obj_bbox[3]+2, left_boxes = [
obj_bbox[1]-2 <=box['bbox'][3]<=obj_bbox[3]+2, box for box in left_boxes if any([
box['bbox'][1]-2 <=obj_bbox[1]<=box['bbox'][3]+2, obj_bbox[1] - 2 <= box['bbox'][1] <= obj_bbox[3] + 2, obj_bbox[1] -
box['bbox'][1]-2 <=obj_bbox[3]<=box['bbox'][3]+2 2 <= box['bbox'][3] <= obj_bbox[3] + 2, box['bbox'][1] -
])] 2 <= obj_bbox[1] <= box['bbox'][3] + 2, box['bbox'][1] -
2 <= obj_bbox[3] <= box['bbox'][3] + 2
])
]
# 然后找到x1最大的那个 # 然后找到x1最大的那个
if len(left_boxes)>0: if len(left_boxes) > 0:
left_boxes.sort(key=lambda x: x['bbox'][2], reverse=True) left_boxes.sort(key=lambda x: x['bbox'][2], reverse=True)
return left_boxes[0] return left_boxes[0]
else: else:
return None return None
def find_right_nearest_text_bbox(pymu_blocks, obj_bbox): def find_right_nearest_text_bbox(pymu_blocks, obj_bbox):
""" """寻找右侧最近的文本block."""
寻找右侧最近的文本block right_boxes = [
""" box for box in pymu_blocks if box['bbox'][0] -
right_boxes = [box for box in pymu_blocks if box['bbox'][0]-obj_bbox[2]>=-2 and not _is_in(box['bbox'], obj_bbox)] obj_bbox[2] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的 # 然后找到X方向上有互相重叠的
right_boxes = [box for box in right_boxes if any([obj_bbox[1]-2 <=box['bbox'][1]<=obj_bbox[3]+2, right_boxes = [
obj_bbox[1]-2 <=box['bbox'][3]<=obj_bbox[3]+2, box for box in right_boxes if any([
box['bbox'][1]-2 <=obj_bbox[1]<=box['bbox'][3]+2, obj_bbox[1] - 2 <= box['bbox'][1] <= obj_bbox[3] + 2, obj_bbox[1] -
box['bbox'][1]-2 <=obj_bbox[3]<=box['bbox'][3]+2 2 <= box['bbox'][3] <= obj_bbox[3] + 2, box['bbox'][1] -
])] 2 <= obj_bbox[1] <= box['bbox'][3] + 2, box['bbox'][1] -
2 <= obj_bbox[3] <= box['bbox'][3] + 2
])
]
# 然后找到x0最小的那个 # 然后找到x0最小的那个
if len(right_boxes)>0: if len(right_boxes) > 0:
right_boxes.sort(key=lambda x: x['bbox'][0], reverse=False) right_boxes.sort(key=lambda x: x['bbox'][0], reverse=False)
return right_boxes[0] return right_boxes[0]
else: else:
...@@ -346,8 +368,7 @@ def find_right_nearest_text_bbox(pymu_blocks, obj_bbox): ...@@ -346,8 +368,7 @@ def find_right_nearest_text_bbox(pymu_blocks, obj_bbox):
def bbox_relative_pos(bbox1, bbox2): def bbox_relative_pos(bbox1, bbox2):
""" """判断两个矩形框的相对位置关系.
判断两个矩形框的相对位置关系
Args: Args:
bbox1: 一个四元组,表示第一个矩形框的左上角和右下角的坐标,格式为(x1, y1, x1b, y1b) bbox1: 一个四元组,表示第一个矩形框的左上角和右下角的坐标,格式为(x1, y1, x1b, y1b)
...@@ -357,20 +378,19 @@ def bbox_relative_pos(bbox1, bbox2): ...@@ -357,20 +378,19 @@ def bbox_relative_pos(bbox1, bbox2):
一个四元组,表示矩形框1相对于矩形框2的位置关系,格式为(left, right, bottom, top) 一个四元组,表示矩形框1相对于矩形框2的位置关系,格式为(left, right, bottom, top)
其中,left表示矩形框1是否在矩形框2的左侧,right表示矩形框1是否在矩形框2的右侧, 其中,left表示矩形框1是否在矩形框2的左侧,right表示矩形框1是否在矩形框2的右侧,
bottom表示矩形框1是否在矩形框2的下方,top表示矩形框1是否在矩形框2的上方 bottom表示矩形框1是否在矩形框2的下方,top表示矩形框1是否在矩形框2的上方
""" """
x1, y1, x1b, y1b = bbox1 x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2 x2, y2, x2b, y2b = bbox2
left = x2b < x1 left = x2b < x1
right = x1b < x2 right = x1b < x2
bottom = y2b < y1 bottom = y2b < y1
top = y1b < y2 top = y1b < y2
return left, right, bottom, top return left, right, bottom, top
def bbox_distance(bbox1, bbox2): def bbox_distance(bbox1, bbox2):
""" """计算两个矩形框的距离。
计算两个矩形框的距离。
Args: Args:
bbox1 (tuple): 第一个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。 bbox1 (tuple): 第一个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
...@@ -378,16 +398,17 @@ def bbox_distance(bbox1, bbox2): ...@@ -378,16 +398,17 @@ def bbox_distance(bbox1, bbox2):
Returns: Returns:
float: 矩形框之间的距离。 float: 矩形框之间的距离。
""" """
def dist(point1, point2): def dist(point1, point2):
return math.sqrt((point1[0]-point2[0])**2 + (point1[1]-point2[1])**2) return math.sqrt((point1[0] - point2[0])**2 +
(point1[1] - point2[1])**2)
x1, y1, x1b, y1b = bbox1 x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2 x2, y2, x2b, y2b = bbox2
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2) left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
if top and left: if top and left:
return dist((x1, y1b), (x2b, y2)) return dist((x1, y1b), (x2b, y2))
elif left and bottom: elif left and bottom:
...@@ -404,5 +425,4 @@ def bbox_distance(bbox1, bbox2): ...@@ -404,5 +425,4 @@ def bbox_distance(bbox1, bbox2):
return y1 - y2b return y1 - y2b
elif top: elif top:
return y2 - y1b return y2 - y1b
else: # rectangles intersect return 0.0
return 0
\ No newline at end of file
...@@ -71,6 +71,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -71,6 +71,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
tables_list, tables_body_list = [], [] tables_list, tables_body_list = [], []
tables_caption_list, tables_footnote_list = [], [] tables_caption_list, tables_footnote_list = [], []
imgs_list, imgs_body_list, imgs_caption_list = [], [], [] imgs_list, imgs_body_list, imgs_caption_list = [], [], []
imgs_footnote_list = []
titles_list = [] titles_list = []
texts_list = [] texts_list = []
interequations_list = [] interequations_list = []
...@@ -78,7 +79,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -78,7 +79,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
page_layout_list = [] page_layout_list = []
page_dropped_list = [] page_dropped_list = []
tables, tables_body, tables_caption, tables_footnote = [], [], [], [] tables, tables_body, tables_caption, tables_footnote = [], [], [], []
imgs, imgs_body, imgs_caption = [], [], [] imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
titles = [] titles = []
texts = [] texts = []
interequations = [] interequations = []
...@@ -108,6 +109,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -108,6 +109,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
imgs_body.append(bbox) imgs_body.append(bbox)
elif nested_block['type'] == BlockType.ImageCaption: elif nested_block['type'] == BlockType.ImageCaption:
imgs_caption.append(bbox) imgs_caption.append(bbox)
elif nested_block['type'] == BlockType.ImageFootnote:
imgs_footnote.append(bbox)
elif block['type'] == BlockType.Title: elif block['type'] == BlockType.Title:
titles.append(bbox) titles.append(bbox)
elif block['type'] == BlockType.Text: elif block['type'] == BlockType.Text:
...@@ -121,6 +124,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -121,6 +124,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
imgs_list.append(imgs) imgs_list.append(imgs)
imgs_body_list.append(imgs_body) imgs_body_list.append(imgs_body)
imgs_caption_list.append(imgs_caption) imgs_caption_list.append(imgs_caption)
imgs_footnote_list.append(imgs_footnote)
titles_list.append(titles) titles_list.append(titles)
texts_list.append(texts) texts_list.append(texts)
interequations_list.append(interequations) interequations_list.append(interequations)
...@@ -142,6 +146,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -142,6 +146,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True) draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255],
True) True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
True),
draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True) draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True) draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_without_number(i, interequations_list, page, [0, 255, 0], draw_bbox_without_number(i, interequations_list, page, [0, 255, 0],
...@@ -241,7 +247,7 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -241,7 +247,7 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
dropped_bbox_list = [] dropped_bbox_list = []
tables_body_list, tables_caption_list, tables_footnote_list = [], [], [] tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
imgs_body_list, imgs_caption_list = [], [] imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
titles_list = [] titles_list = []
texts_list = [] texts_list = []
interequations_list = [] interequations_list = []
...@@ -250,7 +256,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -250,7 +256,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
for i in range(len(model_list)): for i in range(len(model_list)):
page_dropped_list = [] page_dropped_list = []
tables_body, tables_caption, tables_footnote = [], [], [] tables_body, tables_caption, tables_footnote = [], [], []
imgs_body, imgs_caption = [], [] imgs_body, imgs_caption, imgs_footnote = [], [], []
titles = [] titles = []
texts = [] texts = []
interequations = [] interequations = []
...@@ -277,6 +283,8 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -277,6 +283,8 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
interequations.append(bbox) interequations.append(bbox)
elif layout_det['category_id'] == CategoryId.Abandon: elif layout_det['category_id'] == CategoryId.Abandon:
page_dropped_list.append(bbox) page_dropped_list.append(bbox)
elif layout_det['category_id'] == CategoryId.ImageFootnote:
imgs_footnote.append(bbox)
tables_body_list.append(tables_body) tables_body_list.append(tables_body)
tables_caption_list.append(tables_caption) tables_caption_list.append(tables_caption)
...@@ -287,6 +295,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -287,6 +295,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
texts_list.append(texts) texts_list.append(texts)
interequations_list.append(interequations) interequations_list.append(interequations)
dropped_bbox_list.append(page_dropped_list) dropped_bbox_list.append(page_dropped_list)
imgs_footnote_list.append(imgs_footnote)
for i, page in enumerate(pdf_docs): for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158], draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158],
...@@ -299,6 +308,8 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -299,6 +308,8 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True) draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255],
True) True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
True)
draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True) draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True) draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True) draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
......
class ContentType: class ContentType:
Image = "image" Image = 'image'
Table = "table" Table = 'table'
Text = "text" Text = 'text'
InlineEquation = "inline_equation" InlineEquation = 'inline_equation'
InterlineEquation = "interline_equation" InterlineEquation = 'interline_equation'
class BlockType: class BlockType:
Image = "image" Image = 'image'
ImageBody = "image_body" ImageBody = 'image_body'
ImageCaption = "image_caption" ImageCaption = 'image_caption'
Table = "table" ImageFootnote = 'image_footnote'
TableBody = "table_body" Table = 'table'
TableCaption = "table_caption" TableBody = 'table_body'
TableFootnote = "table_footnote" TableCaption = 'table_caption'
Text = "text" TableFootnote = 'table_footnote'
Title = "title" Text = 'text'
InterlineEquation = "interline_equation" Title = 'title'
Footnote = "footnote" InterlineEquation = 'interline_equation'
Discarded = "discarded" Footnote = 'footnote'
Discarded = 'discarded'
class CategoryId: class CategoryId:
...@@ -33,3 +35,4 @@ class CategoryId: ...@@ -33,3 +35,4 @@ class CategoryId:
InlineEquation = 13 InlineEquation = 13
InterlineEquation_YOLO = 14 InterlineEquation_YOLO = 14
OcrText = 15 OcrText = 15
ImageFootnote = 101
import json import json
import math
from magic_pdf.libs.commons import fitz from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
from loguru import logger bbox_relative_pos, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio)
from magic_pdf.libs.commons import join_path from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
from magic_pdf.libs.local_math import float_gt from magic_pdf.libs.local_math import float_gt
from magic_pdf.libs.boxbase import (
_is_in,
bbox_relative_pos,
bbox_distance,
_is_part_overlap,
calculate_overlap_area_in_bbox1_area_ratio,
calculate_iou,
)
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO = 0.6 CAPATION_OVERLAP_AREA_RATIO = 0.6
class MagicModel: class MagicModel:
""" """每个函数没有得到元素的时候返回空list."""
每个函数没有得到元素的时候返回空list
"""
def __fix_axis(self): def __fix_axis(self):
for model_page_info in self.__model_list: for model_page_info in self.__model_list:
need_remove_list = [] need_remove_list = []
page_no = model_page_info["page_info"]["page_no"] page_no = model_page_info['page_info']['page_no']
horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio( horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
model_page_info, self.__docs[page_no] model_page_info, self.__docs[page_no]
) )
layout_dets = model_page_info["layout_dets"] layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets: for layout_det in layout_dets:
if layout_det.get("bbox") is not None: if layout_det.get('bbox') is not None:
# 兼容直接输出bbox的模型数据,如paddle # 兼容直接输出bbox的模型数据,如paddle
x0, y0, x1, y1 = layout_det["bbox"] x0, y0, x1, y1 = layout_det['bbox']
else: else:
# 兼容直接输出poly的模型数据,如xxx # 兼容直接输出poly的模型数据,如xxx
x0, y0, _, _, x1, y1, _, _ = layout_det["poly"] x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
bbox = [ bbox = [
int(x0 / horizontal_scale_ratio), int(x0 / horizontal_scale_ratio),
...@@ -52,7 +40,7 @@ class MagicModel: ...@@ -52,7 +40,7 @@ class MagicModel:
int(x1 / horizontal_scale_ratio), int(x1 / horizontal_scale_ratio),
int(y1 / vertical_scale_ratio), int(y1 / vertical_scale_ratio),
] ]
layout_det["bbox"] = bbox layout_det['bbox'] = bbox
# 删除高度或者宽度小于等于0的spans # 删除高度或者宽度小于等于0的spans
if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0: if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
need_remove_list.append(layout_det) need_remove_list.append(layout_det)
...@@ -62,9 +50,9 @@ class MagicModel: ...@@ -62,9 +50,9 @@ class MagicModel:
def __fix_by_remove_low_confidence(self): def __fix_by_remove_low_confidence(self):
for model_page_info in self.__model_list: for model_page_info in self.__model_list:
need_remove_list = [] need_remove_list = []
layout_dets = model_page_info["layout_dets"] layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets: for layout_det in layout_dets:
if layout_det["score"] <= 0.05: if layout_det['score'] <= 0.05:
need_remove_list.append(layout_det) need_remove_list.append(layout_det)
else: else:
continue continue
...@@ -74,12 +62,12 @@ class MagicModel: ...@@ -74,12 +62,12 @@ class MagicModel:
def __fix_by_remove_high_iou_and_low_confidence(self): def __fix_by_remove_high_iou_and_low_confidence(self):
for model_page_info in self.__model_list: for model_page_info in self.__model_list:
need_remove_list = [] need_remove_list = []
layout_dets = model_page_info["layout_dets"] layout_dets = model_page_info['layout_dets']
for layout_det1 in layout_dets: for layout_det1 in layout_dets:
for layout_det2 in layout_dets: for layout_det2 in layout_dets:
if layout_det1 == layout_det2: if layout_det1 == layout_det2:
continue continue
if layout_det1["category_id"] in [ if layout_det1['category_id'] in [
0, 0,
1, 1,
2, 2,
...@@ -90,12 +78,12 @@ class MagicModel: ...@@ -90,12 +78,12 @@ class MagicModel:
7, 7,
8, 8,
9, 9,
] and layout_det2["category_id"] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]: ] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if ( if (
calculate_iou(layout_det1["bbox"], layout_det2["bbox"]) calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
> 0.9 > 0.9
): ):
if layout_det1["score"] < layout_det2["score"]: if layout_det1['score'] < layout_det2['score']:
layout_det_need_remove = layout_det1 layout_det_need_remove = layout_det1
else: else:
layout_det_need_remove = layout_det2 layout_det_need_remove = layout_det2
...@@ -118,6 +106,67 @@ class MagicModel: ...@@ -118,6 +106,67 @@ class MagicModel:
self.__fix_by_remove_low_confidence() self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个""" """删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence() self.__fix_by_remove_high_iou_and_low_confidence()
self.__fix_footnote()
def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
for model_page_info in self.__model_list:
footnotes = []
figures = []
tables = []
for obj in model_page_info['layout_dets']:
if obj['category_id'] == 7:
footnotes.append(obj)
elif obj['category_id'] == 3:
figures.append(obj)
elif obj['category_id'] == 5:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_table_footnote[i] = min(
bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def __reduct_overlap(self, bboxes): def __reduct_overlap(self, bboxes):
N = len(bboxes) N = len(bboxes)
...@@ -126,76 +175,77 @@ class MagicModel: ...@@ -126,76 +175,77 @@ class MagicModel:
for j in range(N): for j in range(N):
if i == j: if i == j:
continue continue
if _is_in(bboxes[i]["bbox"], bboxes[j]["bbox"]): if _is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
keep[i] = False keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]] return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance( def __tie_up_category_by_distance(
self, page_no, subject_category_id, object_category_id self, page_no, subject_category_id, object_category_id
): ):
""" """假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object
假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object 只能属于一个 subject 只能属于一个 subject."""
"""
ret = [] ret = []
MAX_DIS_OF_POINT = 10**9 + 7 MAX_DIS_OF_POINT = 10**9 + 7
"""
subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
# subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
# 再求出筛选出的 subjects 和 object 的最短距离!
def may_find_other_nearest_bbox(subject_idx, object_idx): def may_find_other_nearest_bbox(subject_idx, object_idx):
ret = float("inf") ret = float('inf')
x0 = min( x0 = min(
all_bboxes[subject_idx]["bbox"][0], all_bboxes[object_idx]["bbox"][0] all_bboxes[subject_idx]['bbox'][0], all_bboxes[object_idx]['bbox'][0]
) )
y0 = min( y0 = min(
all_bboxes[subject_idx]["bbox"][1], all_bboxes[object_idx]["bbox"][1] all_bboxes[subject_idx]['bbox'][1], all_bboxes[object_idx]['bbox'][1]
) )
x1 = max( x1 = max(
all_bboxes[subject_idx]["bbox"][2], all_bboxes[object_idx]["bbox"][2] all_bboxes[subject_idx]['bbox'][2], all_bboxes[object_idx]['bbox'][2]
) )
y1 = max( y1 = max(
all_bboxes[subject_idx]["bbox"][3], all_bboxes[object_idx]["bbox"][3] all_bboxes[subject_idx]['bbox'][3], all_bboxes[object_idx]['bbox'][3]
) )
object_area = abs( object_area = abs(
all_bboxes[object_idx]["bbox"][2] - all_bboxes[object_idx]["bbox"][0] all_bboxes[object_idx]['bbox'][2] - all_bboxes[object_idx]['bbox'][0]
) * abs( ) * abs(
all_bboxes[object_idx]["bbox"][3] - all_bboxes[object_idx]["bbox"][1] all_bboxes[object_idx]['bbox'][3] - all_bboxes[object_idx]['bbox'][1]
) )
for i in range(len(all_bboxes)): for i in range(len(all_bboxes)):
if ( if (
i == subject_idx i == subject_idx
or all_bboxes[i]["category_id"] != subject_category_id or all_bboxes[i]['category_id'] != subject_category_id
): ):
continue continue
if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]["bbox"]) or _is_in( if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]['bbox']) or _is_in(
all_bboxes[i]["bbox"], [x0, y0, x1, y1] all_bboxes[i]['bbox'], [x0, y0, x1, y1]
): ):
i_area = abs( i_area = abs(
all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0] all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
) * abs(all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1]) ) * abs(all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1])
if i_area >= object_area: if i_area >= object_area:
ret = min(float("inf"), dis[i][object_idx]) ret = min(float('inf'), dis[i][object_idx])
return ret return ret
def expand_bbbox(idxes): def expand_bbbox(idxes):
x0s = [all_bboxes[idx]["bbox"][0] for idx in idxes] x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]["bbox"][1] for idx in idxes] y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
x1s = [all_bboxes[idx]["bbox"][2] for idx in idxes] x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
y1s = [all_bboxes[idx]["bbox"][3] for idx in idxes] y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
return min(x0s), min(y0s), max(x1s), max(y1s) return min(x0s), min(y0s), max(x1s), max(y1s)
subjects = self.__reduct_overlap( subjects = self.__reduct_overlap(
list( list(
map( map(
lambda x: {"bbox": x["bbox"], "score": x["score"]}, lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter( filter(
lambda x: x["category_id"] == subject_category_id, lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]["layout_dets"], self.__model_list[page_no]['layout_dets'],
), ),
) )
) )
...@@ -204,10 +254,10 @@ class MagicModel: ...@@ -204,10 +254,10 @@ class MagicModel:
objects = self.__reduct_overlap( objects = self.__reduct_overlap(
list( list(
map( map(
lambda x: {"bbox": x["bbox"], "score": x["score"]}, lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter( filter(
lambda x: x["category_id"] == object_category_id, lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]["layout_dets"], self.__model_list[page_no]['layout_dets'],
), ),
) )
) )
...@@ -215,7 +265,7 @@ class MagicModel: ...@@ -215,7 +265,7 @@ class MagicModel:
subject_object_relation_map = {} subject_object_relation_map = {}
subjects.sort( subjects.sort(
key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2 key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2
) # get the distance ! ) # get the distance !
all_bboxes = [] all_bboxes = []
...@@ -223,18 +273,18 @@ class MagicModel: ...@@ -223,18 +273,18 @@ class MagicModel:
for v in subjects: for v in subjects:
all_bboxes.append( all_bboxes.append(
{ {
"category_id": subject_category_id, 'category_id': subject_category_id,
"bbox": v["bbox"], 'bbox': v['bbox'],
"score": v["score"], 'score': v['score'],
} }
) )
for v in objects: for v in objects:
all_bboxes.append( all_bboxes.append(
{ {
"category_id": object_category_id, 'category_id': object_category_id,
"bbox": v["bbox"], 'bbox': v['bbox'],
"score": v["score"], 'score': v['score'],
} }
) )
...@@ -244,18 +294,18 @@ class MagicModel: ...@@ -244,18 +294,18 @@ class MagicModel:
for i in range(N): for i in range(N):
for j in range(i): for j in range(i):
if ( if (
all_bboxes[i]["category_id"] == subject_category_id all_bboxes[i]['category_id'] == subject_category_id
and all_bboxes[j]["category_id"] == subject_category_id and all_bboxes[j]['category_id'] == subject_category_id
): ):
continue continue
dis[i][j] = bbox_distance(all_bboxes[i]["bbox"], all_bboxes[j]["bbox"]) dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
dis[j][i] = dis[i][j] dis[j][i] = dis[i][j]
used = set() used = set()
for i in range(N): for i in range(N):
# 求第 i 个 subject 所关联的 object # 求第 i 个 subject 所关联的 object
if all_bboxes[i]["category_id"] != subject_category_id: if all_bboxes[i]['category_id'] != subject_category_id:
continue continue
seen = set() seen = set()
candidates = [] candidates = []
...@@ -267,7 +317,7 @@ class MagicModel: ...@@ -267,7 +317,7 @@ class MagicModel:
map( map(
lambda x: 1 if x else 0, lambda x: 1 if x else 0,
bbox_relative_pos( bbox_relative_pos(
all_bboxes[i]["bbox"], all_bboxes[j]["bbox"] all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
), ),
) )
) )
...@@ -275,25 +325,28 @@ class MagicModel: ...@@ -275,25 +325,28 @@ class MagicModel:
if pos_flag_count > 1: if pos_flag_count > 1:
continue continue
if ( if (
all_bboxes[j]["category_id"] != object_category_id all_bboxes[j]['category_id'] != object_category_id
or j in used or j in used
or dis[i][j] == MAX_DIS_OF_POINT or dis[i][j] == MAX_DIS_OF_POINT
): ):
continue continue
left, right, _, _ = bbox_relative_pos( left, right, _, _ = bbox_relative_pos(
all_bboxes[i]["bbox"], all_bboxes[j]["bbox"] all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性 ) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性
if left or right: if left or right:
one_way_dis = all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0] one_way_dis = all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
else: else:
one_way_dis = all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1] one_way_dis = all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1]
if dis[i][j] > one_way_dis: if dis[i][j] > one_way_dis:
continue continue
arr.append((dis[i][j], j)) arr.append((dis[i][j], j))
arr.sort(key=lambda x: x[0]) arr.sort(key=lambda x: x[0])
if len(arr) > 0: if len(arr) > 0:
# bug: 离该subject 最近的 object 可能跨越了其它的 subject 。比如 [this subect] [some sbuject] [the nearest objec of subject] """
bug: 离该subject 最近的 object 可能跨越了其它的 subject。
比如 [this subect] [some sbuject] [the nearest object of subject]
"""
if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]: if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
candidates.append(arr[0][1]) candidates.append(arr[0][1])
...@@ -308,7 +361,7 @@ class MagicModel: ...@@ -308,7 +361,7 @@ class MagicModel:
map( map(
lambda x: 1 if x else 0, lambda x: 1 if x else 0,
bbox_relative_pos( bbox_relative_pos(
all_bboxes[j]["bbox"], all_bboxes[k]["bbox"] all_bboxes[j]['bbox'], all_bboxes[k]['bbox']
), ),
) )
) )
...@@ -318,7 +371,7 @@ class MagicModel: ...@@ -318,7 +371,7 @@ class MagicModel:
continue continue
if ( if (
all_bboxes[k]["category_id"] != object_category_id all_bboxes[k]['category_id'] != object_category_id
or k in used or k in used
or k in seen or k in seen
or dis[j][k] == MAX_DIS_OF_POINT or dis[j][k] == MAX_DIS_OF_POINT
...@@ -327,17 +380,19 @@ class MagicModel: ...@@ -327,17 +380,19 @@ class MagicModel:
continue continue
is_nearest = True is_nearest = True
for l in range(i + 1, N): for ni in range(i + 1, N):
if l in (j, k) or l in used or l in seen: if ni in (j, k) or ni in used or ni in seen:
continue continue
if not float_gt(dis[l][k], dis[j][k]): if not float_gt(dis[ni][k], dis[j][k]):
is_nearest = False is_nearest = False
break break
if is_nearest: if is_nearest:
nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k]) nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
n_dis = bbox_distance(all_bboxes[i]["bbox"], [nx0, ny0, nx1, ny1]) n_dis = bbox_distance(
all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
)
if float_gt(dis[i][j], n_dis): if float_gt(dis[i][j], n_dis):
continue continue
tmp.append(k) tmp.append(k)
...@@ -350,7 +405,7 @@ class MagicModel: ...@@ -350,7 +405,7 @@ class MagicModel:
# 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。 # 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
# 先扩一下 bbox, # 先扩一下 bbox,
ox0, oy0, ox1, oy1 = expand_bbbox(list(seen) + [i]) ox0, oy0, ox1, oy1 = expand_bbbox(list(seen) + [i])
ix0, iy0, ix1, iy1 = all_bboxes[i]["bbox"] ix0, iy0, ix1, iy1 = all_bboxes[i]['bbox']
# 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积 # 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
caption_poses = [ caption_poses = [
...@@ -366,17 +421,17 @@ class MagicModel: ...@@ -366,17 +421,17 @@ class MagicModel:
for idx in seen: for idx in seen:
if ( if (
calculate_overlap_area_in_bbox1_area_ratio( calculate_overlap_area_in_bbox1_area_ratio(
all_bboxes[idx]["bbox"], bbox all_bboxes[idx]['bbox'], bbox
) )
> CAPATION_OVERLAP_AREA_RATIO > CAPATION_OVERLAP_AREA_RATIO
): ):
embed_arr.append(idx) embed_arr.append(idx)
if len(embed_arr) > 0: if len(embed_arr) > 0:
embed_x0 = min([all_bboxes[idx]["bbox"][0] for idx in embed_arr]) embed_x0 = min([all_bboxes[idx]['bbox'][0] for idx in embed_arr])
embed_y0 = min([all_bboxes[idx]["bbox"][1] for idx in embed_arr]) embed_y0 = min([all_bboxes[idx]['bbox'][1] for idx in embed_arr])
embed_x1 = max([all_bboxes[idx]["bbox"][2] for idx in embed_arr]) embed_x1 = max([all_bboxes[idx]['bbox'][2] for idx in embed_arr])
embed_y1 = max([all_bboxes[idx]["bbox"][3] for idx in embed_arr]) embed_y1 = max([all_bboxes[idx]['bbox'][3] for idx in embed_arr])
caption_areas.append( caption_areas.append(
int(abs(embed_x1 - embed_x0) * abs(embed_y1 - embed_y0)) int(abs(embed_x1 - embed_x0) * abs(embed_y1 - embed_y0))
) )
...@@ -391,7 +446,7 @@ class MagicModel: ...@@ -391,7 +446,7 @@ class MagicModel:
for j in seen: for j in seen:
if ( if (
calculate_overlap_area_in_bbox1_area_ratio( calculate_overlap_area_in_bbox1_area_ratio(
all_bboxes[j]["bbox"], caption_bbox all_bboxes[j]['bbox'], caption_bbox
) )
> CAPATION_OVERLAP_AREA_RATIO > CAPATION_OVERLAP_AREA_RATIO
): ):
...@@ -400,30 +455,30 @@ class MagicModel: ...@@ -400,30 +455,30 @@ class MagicModel:
for i in sorted(subject_object_relation_map.keys()): for i in sorted(subject_object_relation_map.keys()):
result = { result = {
"subject_body": all_bboxes[i]["bbox"], 'subject_body': all_bboxes[i]['bbox'],
"all": all_bboxes[i]["bbox"], 'all': all_bboxes[i]['bbox'],
"score": all_bboxes[i]["score"], 'score': all_bboxes[i]['score'],
} }
if len(subject_object_relation_map[i]) > 0: if len(subject_object_relation_map[i]) > 0:
x0 = min( x0 = min(
[all_bboxes[j]["bbox"][0] for j in subject_object_relation_map[i]] [all_bboxes[j]['bbox'][0] for j in subject_object_relation_map[i]]
) )
y0 = min( y0 = min(
[all_bboxes[j]["bbox"][1] for j in subject_object_relation_map[i]] [all_bboxes[j]['bbox'][1] for j in subject_object_relation_map[i]]
) )
x1 = max( x1 = max(
[all_bboxes[j]["bbox"][2] for j in subject_object_relation_map[i]] [all_bboxes[j]['bbox'][2] for j in subject_object_relation_map[i]]
) )
y1 = max( y1 = max(
[all_bboxes[j]["bbox"][3] for j in subject_object_relation_map[i]] [all_bboxes[j]['bbox'][3] for j in subject_object_relation_map[i]]
) )
result["object_body"] = [x0, y0, x1, y1] result['object_body'] = [x0, y0, x1, y1]
result["all"] = [ result['all'] = [
min(x0, all_bboxes[i]["bbox"][0]), min(x0, all_bboxes[i]['bbox'][0]),
min(y0, all_bboxes[i]["bbox"][1]), min(y0, all_bboxes[i]['bbox'][1]),
max(x1, all_bboxes[i]["bbox"][2]), max(x1, all_bboxes[i]['bbox'][2]),
max(y1, all_bboxes[i]["bbox"][3]), max(y1, all_bboxes[i]['bbox'][3]),
] ]
ret.append(result) ret.append(result)
...@@ -432,7 +487,7 @@ class MagicModel: ...@@ -432,7 +487,7 @@ class MagicModel:
for i in subject_object_relation_map.keys(): for i in subject_object_relation_map.keys():
for j in subject_object_relation_map[i]: for j in subject_object_relation_map[i]:
total_subject_object_dis += bbox_distance( total_subject_object_dis += bbox_distance(
all_bboxes[i]["bbox"], all_bboxes[j]["bbox"] all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
) )
# 计算未匹配的 subject 和 object 的距离(非精确版) # 计算未匹配的 subject 和 object 的距离(非精确版)
...@@ -444,12 +499,12 @@ class MagicModel: ...@@ -444,12 +499,12 @@ class MagicModel:
] ]
) )
for i in range(N): for i in range(N):
if all_bboxes[i]["category_id"] != object_category_id or i in used: if all_bboxes[i]['category_id'] != object_category_id or i in used:
continue continue
candidates = [] candidates = []
for j in range(N): for j in range(N):
if ( if (
all_bboxes[j]["category_id"] != subject_category_id all_bboxes[j]['category_id'] != subject_category_id
or j in with_caption_subject or j in with_caption_subject
): ):
continue continue
...@@ -461,18 +516,28 @@ class MagicModel: ...@@ -461,18 +516,28 @@ class MagicModel:
return ret, total_subject_object_dis return ret, total_subject_object_dis
def get_imgs(self, page_no: int): def get_imgs(self, page_no: int):
figure_captions, _ = self.__tie_up_category_by_distance( with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
page_no, 3, 4 with_footnotes, _ = self.__tie_up_category_by_distance(
page_no, 3, CategoryId.ImageFootnote
) )
return [ ret = []
{ N, M = len(with_captions), len(with_footnotes)
"bbox": record["all"], assert N == M
"img_body_bbox": record["subject_body"], for i in range(N):
"img_caption_bbox": record.get("object_body", None), record = {
"score": record["score"], 'score': with_captions[i]['score'],
'img_caption_bbox': with_captions[i].get('object_body', None),
'img_body_bbox': with_captions[i]['subject_body'],
'img_footnote_bbox': with_footnotes[i].get('object_body', None),
} }
for record in figure_captions
] x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
record['bbox'] = [x0, y0, x1, y1]
ret.append(record)
return ret
def get_tables( def get_tables(
self, page_no: int self, page_no: int
...@@ -484,26 +549,26 @@ class MagicModel: ...@@ -484,26 +549,26 @@ class MagicModel:
assert N == M assert N == M
for i in range(N): for i in range(N):
record = { record = {
"score": with_captions[i]["score"], 'score': with_captions[i]['score'],
"table_caption_bbox": with_captions[i].get("object_body", None), 'table_caption_bbox': with_captions[i].get('object_body', None),
"table_body_bbox": with_captions[i]["subject_body"], 'table_body_bbox': with_captions[i]['subject_body'],
"table_footnote_bbox": with_footnotes[i].get("object_body", None), 'table_footnote_bbox': with_footnotes[i].get('object_body', None),
} }
x0 = min(with_captions[i]["all"][0], with_footnotes[i]["all"][0]) x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
y0 = min(with_captions[i]["all"][1], with_footnotes[i]["all"][1]) y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
x1 = max(with_captions[i]["all"][2], with_footnotes[i]["all"][2]) x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
y1 = max(with_captions[i]["all"][3], with_footnotes[i]["all"][3]) y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
record["bbox"] = [x0, y0, x1, y1] record['bbox'] = [x0, y0, x1, y1]
ret.append(record) ret.append(record)
return ret return ret
def get_equations(self, page_no: int) -> list: # 有坐标,也有字 def get_equations(self, page_no: int) -> list: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type( inline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.EMBEDDING.value, page_no, ["latex"] ModelBlockTypeEnum.EMBEDDING.value, page_no, ['latex']
) )
interline_equations = self.__get_blocks_by_type( interline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATED.value, page_no, ["latex"] ModelBlockTypeEnum.ISOLATED.value, page_no, ['latex']
) )
interline_equations_blocks = self.__get_blocks_by_type( interline_equations_blocks = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no
...@@ -525,17 +590,18 @@ class MagicModel: ...@@ -525,17 +590,18 @@ class MagicModel:
def get_ocr_text(self, page_no: int) -> list: # paddle 搞的,有字也有坐标 def get_ocr_text(self, page_no: int) -> list: # paddle 搞的,有字也有坐标
text_spans = [] text_spans = []
model_page_info = self.__model_list[page_no] model_page_info = self.__model_list[page_no]
layout_dets = model_page_info["layout_dets"] layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets: for layout_det in layout_dets:
if layout_det["category_id"] == "15": if layout_det['category_id'] == '15':
span = { span = {
"bbox": layout_det["bbox"], 'bbox': layout_det['bbox'],
"content": layout_det["text"], 'content': layout_det['text'],
} }
text_spans.append(span) text_spans.append(span)
return text_spans return text_spans
def get_all_spans(self, page_no: int) -> list: def get_all_spans(self, page_no: int) -> list:
def remove_duplicate_spans(spans): def remove_duplicate_spans(spans):
new_spans = [] new_spans = []
for span in spans: for span in spans:
...@@ -545,7 +611,7 @@ class MagicModel: ...@@ -545,7 +611,7 @@ class MagicModel:
all_spans = [] all_spans = []
model_page_info = self.__model_list[page_no] model_page_info = self.__model_list[page_no]
layout_dets = model_page_info["layout_dets"] layout_dets = model_page_info['layout_dets']
allow_category_id_list = [3, 5, 13, 14, 15] allow_category_id_list = [3, 5, 13, 14, 15]
"""当成span拼接的""" """当成span拼接的"""
# 3: 'image', # 图片 # 3: 'image', # 图片
...@@ -554,11 +620,11 @@ class MagicModel: ...@@ -554,11 +620,11 @@ class MagicModel:
# 14: 'interline_equation', # 行间公式 # 14: 'interline_equation', # 行间公式
# 15: 'text', # ocr识别文本 # 15: 'text', # ocr识别文本
for layout_det in layout_dets: for layout_det in layout_dets:
category_id = layout_det["category_id"] category_id = layout_det['category_id']
if category_id in allow_category_id_list: if category_id in allow_category_id_list:
span = {"bbox": layout_det["bbox"], "score": layout_det["score"]} span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == 3: if category_id == 3:
span["type"] = ContentType.Image span['type'] = ContentType.Image
elif category_id == 5: elif category_id == 5:
# 获取table模型结果 # 获取table模型结果
latex = layout_det.get("latex", None) latex = layout_det.get("latex", None)
...@@ -569,14 +635,14 @@ class MagicModel: ...@@ -569,14 +635,14 @@ class MagicModel:
span["html"] = html span["html"] = html
span["type"] = ContentType.Table span["type"] = ContentType.Table
elif category_id == 13: elif category_id == 13:
span["content"] = layout_det["latex"] span['content'] = layout_det['latex']
span["type"] = ContentType.InlineEquation span['type'] = ContentType.InlineEquation
elif category_id == 14: elif category_id == 14:
span["content"] = layout_det["latex"] span['content'] = layout_det['latex']
span["type"] = ContentType.InterlineEquation span['type'] = ContentType.InterlineEquation
elif category_id == 15: elif category_id == 15:
span["content"] = layout_det["text"] span['content'] = layout_det['text']
span["type"] = ContentType.Text span['type'] = ContentType.Text
all_spans.append(span) all_spans.append(span)
return remove_duplicate_spans(all_spans) return remove_duplicate_spans(all_spans)
...@@ -593,19 +659,19 @@ class MagicModel: ...@@ -593,19 +659,19 @@ class MagicModel:
) -> list: ) -> list:
blocks = [] blocks = []
for page_dict in self.__model_list: for page_dict in self.__model_list:
layout_dets = page_dict.get("layout_dets", []) layout_dets = page_dict.get('layout_dets', [])
page_info = page_dict.get("page_info", {}) page_info = page_dict.get('page_info', {})
page_number = page_info.get("page_no", -1) page_number = page_info.get('page_no', -1)
if page_no != page_number: if page_no != page_number:
continue continue
for item in layout_dets: for item in layout_dets:
category_id = item.get("category_id", -1) category_id = item.get('category_id', -1)
bbox = item.get("bbox", None) bbox = item.get('bbox', None)
if category_id == type: if category_id == type:
block = { block = {
"bbox": bbox, 'bbox': bbox,
"score": item.get("score"), 'score': item.get('score'),
} }
for col in extra_col: for col in extra_col:
block[col] = item.get(col, None) block[col] = item.get(col, None)
...@@ -616,28 +682,28 @@ class MagicModel: ...@@ -616,28 +682,28 @@ class MagicModel:
return self.__model_list[page_no] return self.__model_list[page_no]
if __name__ == "__main__": if __name__ == '__main__':
drw = DiskReaderWriter(r"D:/project/20231108code-clean") drw = DiskReaderWriter(r'D:/project/20231108code-clean')
if 0: if 0:
pdf_file_path = r"linshixuqiu\19983-00.pdf" pdf_file_path = r'linshixuqiu\19983-00.pdf'
model_file_path = r"linshixuqiu\19983-00_new.json" model_file_path = r'linshixuqiu\19983-00_new.json'
pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN) pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT) model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
model_list = json.loads(model_json_txt) model_list = json.loads(model_json_txt)
write_path = r"D:\project\20231108code-clean\linshixuqiu\19983-00" write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path = "imgs" img_bucket_path = 'imgs'
img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path)) img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
pdf_docs = fitz.open("pdf", pdf_bytes) pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs) magic_model = MagicModel(model_list, pdf_docs)
if 1: if 1:
model_list = json.loads( model_list = json.loads(
drw.read("/opt/data/pdf/20240418/j.chroma.2009.03.042.json") drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json')
) )
pdf_bytes = drw.read( pdf_bytes = drw.read(
"/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf", AbsReaderWriter.MODE_BIN '/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf', AbsReaderWriter.MODE_BIN
) )
pdf_docs = fitz.open("pdf", pdf_bytes) pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs) magic_model = MagicModel(model_list, pdf_docs)
for i in range(7): for i in range(7):
print(magic_model.get_imgs(i)) print(magic_model.get_imgs(i))
from loguru import logger from magic_pdf.libs.boxbase import (__is_overlaps_y_exceeds_threshold,
_is_in_or_part_overlap_with_area_ratio,
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold, get_minbox_if_overlap_by_ratio, \ calculate_overlap_area_in_bbox1_area_ratio)
calculate_overlap_area_in_bbox1_area_ratio, _is_in_or_part_overlap_with_area_ratio
from magic_pdf.libs.drop_tag import DropTag from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.ocr_content_type import ContentType, BlockType from magic_pdf.libs.ocr_content_type import BlockType, ContentType
from magic_pdf.pre_proc.ocr_span_list_modify import modify_y_axis, modify_inline_equation
from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_span
# 将每一个line中的span从左到右排序 # 将每一个line中的span从左到右排序
def line_sort_spans_by_left_to_right(lines): def line_sort_spans_by_left_to_right(lines):
line_objects = [] line_objects = []
for line in lines: for line in lines:
# 按照x0坐标排序 # 按照x0坐标排序
line.sort(key=lambda span: span['bbox'][0]) line.sort(key=lambda span: span['bbox'][0])
line_bbox = [ line_bbox = [
min(span['bbox'][0] for span in line), # x0 min(span['bbox'][0] for span in line), # x0
...@@ -21,8 +18,8 @@ def line_sort_spans_by_left_to_right(lines): ...@@ -21,8 +18,8 @@ def line_sort_spans_by_left_to_right(lines):
max(span['bbox'][3] for span in line), # y1 max(span['bbox'][3] for span in line), # y1
] ]
line_objects.append({ line_objects.append({
"bbox": line_bbox, 'bbox': line_bbox,
"spans": line, 'spans': line,
}) })
return line_objects return line_objects
...@@ -39,16 +36,21 @@ def merge_spans_to_line(spans): ...@@ -39,16 +36,21 @@ def merge_spans_to_line(spans):
for span in spans[1:]: for span in spans[1:]:
# 如果当前的span类型为"interline_equation" 或者 当前行中已经有"interline_equation" # 如果当前的span类型为"interline_equation" 或者 当前行中已经有"interline_equation"
# image和table类型,同上 # image和table类型,同上
if span['type'] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table] or any( if span['type'] in [
s['type'] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table] for s in ContentType.InterlineEquation, ContentType.Image,
current_line): ContentType.Table
] or any(s['type'] in [
ContentType.InterlineEquation, ContentType.Image,
ContentType.Table
] for s in current_line):
# 则开始新行 # 则开始新行
lines.append(current_line) lines.append(current_line)
current_line = [span] current_line = [span]
continue continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行 # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox']): if __is_overlaps_y_exceeds_threshold(span['bbox'],
current_line[-1]['bbox']):
current_line.append(span) current_line.append(span)
else: else:
# 否则,开始新行 # 否则,开始新行
...@@ -71,7 +73,8 @@ def merge_spans_to_line_by_layout(spans, layout_bboxes): ...@@ -71,7 +73,8 @@ def merge_spans_to_line_by_layout(spans, layout_bboxes):
# 遍历spans,将每个span放入对应的layout中 # 遍历spans,将每个span放入对应的layout中
layout_sapns = [] layout_sapns = []
for span in spans: for span in spans:
if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], layout_bbox) > 0.6: if calculate_overlap_area_in_bbox1_area_ratio(
span['bbox'], layout_bbox) > 0.6:
layout_sapns.append(span) layout_sapns.append(span)
# 如果layout_sapns不为空,则放入new_spans中 # 如果layout_sapns不为空,则放入new_spans中
if len(layout_sapns) > 0: if len(layout_sapns) > 0:
...@@ -99,12 +102,10 @@ def merge_lines_to_block(lines): ...@@ -99,12 +102,10 @@ def merge_lines_to_block(lines):
# 目前不做block拼接,先做个结构,每个block中只有一个line,block的bbox就是line的bbox # 目前不做block拼接,先做个结构,每个block中只有一个line,block的bbox就是line的bbox
blocks = [] blocks = []
for line in lines: for line in lines:
blocks.append( blocks.append({
{ 'bbox': line['bbox'],
"bbox": line["bbox"], 'lines': [line],
"lines": [line], })
}
)
return blocks return blocks
...@@ -121,7 +122,8 @@ def sort_blocks_by_layout(all_bboxes, layout_bboxes): ...@@ -121,7 +122,8 @@ def sort_blocks_by_layout(all_bboxes, layout_bboxes):
if block[7] == BlockType.Footnote: if block[7] == BlockType.Footnote:
continue continue
block_bbox = block[:4] block_bbox = block[:4]
if calculate_overlap_area_in_bbox1_area_ratio(block_bbox, layout_bbox) > 0.8: if calculate_overlap_area_in_bbox1_area_ratio(
block_bbox, layout_bbox) > 0.8:
layout_blocks.append(block) layout_blocks.append(block)
# 如果layout_blocks不为空,则放入new_blocks中 # 如果layout_blocks不为空,则放入new_blocks中
...@@ -134,7 +136,8 @@ def sort_blocks_by_layout(all_bboxes, layout_bboxes): ...@@ -134,7 +136,8 @@ def sort_blocks_by_layout(all_bboxes, layout_bboxes):
# 如果new_blocks不为空,则对new_blocks中每个block进行排序 # 如果new_blocks不为空,则对new_blocks中每个block进行排序
if len(new_blocks) > 0: if len(new_blocks) > 0:
for bboxes_in_layout_block in new_blocks: for bboxes_in_layout_block in new_blocks:
bboxes_in_layout_block.sort(key=lambda x: x[1]) # 一个layout内部的box,按照y0自上而下排序 bboxes_in_layout_block.sort(
key=lambda x: x[1]) # 一个layout内部的box,按照y0自上而下排序
sort_blocks.extend(bboxes_in_layout_block) sort_blocks.extend(bboxes_in_layout_block)
# sort_blocks中已经包含了当前页面所有最终留下的block,且已经排好了顺序 # sort_blocks中已经包含了当前页面所有最终留下的block,且已经排好了顺序
...@@ -142,9 +145,7 @@ def sort_blocks_by_layout(all_bboxes, layout_bboxes): ...@@ -142,9 +145,7 @@ def sort_blocks_by_layout(all_bboxes, layout_bboxes):
def fill_spans_in_blocks(blocks, spans, radio): def fill_spans_in_blocks(blocks, spans, radio):
''' """将allspans中的span按位置关系,放入blocks中."""
将allspans中的span按位置关系,放入blocks中
'''
block_with_spans = [] block_with_spans = []
for block in blocks: for block in blocks:
block_type = block[7] block_type = block[7]
...@@ -156,17 +157,15 @@ def fill_spans_in_blocks(blocks, spans, radio): ...@@ -156,17 +157,15 @@ def fill_spans_in_blocks(blocks, spans, radio):
block_spans = [] block_spans = []
for span in spans: for span in spans:
span_bbox = span['bbox'] span_bbox = span['bbox']
if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > radio: if calculate_overlap_area_in_bbox1_area_ratio(
span_bbox, block_bbox) > radio:
block_spans.append(span) block_spans.append(span)
'''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)''' '''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)'''
# displayed_list = [] # displayed_list = []
# text_inline_lines = [] # text_inline_lines = []
# modify_y_axis(block_spans, displayed_list, text_inline_lines) # modify_y_axis(block_spans, displayed_list, text_inline_lines)
'''模型识别错误的行间公式, type类型转换成行内公式''' '''模型识别错误的行间公式, type类型转换成行内公式'''
# block_spans = modify_inline_equation(block_spans, displayed_list, text_inline_lines) # block_spans = modify_inline_equation(block_spans, displayed_list, text_inline_lines)
'''bbox去除粘连''' # 去粘连会影响span的bbox,导致后续fill的时候出错 '''bbox去除粘连''' # 去粘连会影响span的bbox,导致后续fill的时候出错
# block_spans = remove_overlap_between_bbox_for_span(block_spans) # block_spans = remove_overlap_between_bbox_for_span(block_spans)
...@@ -182,12 +181,9 @@ def fill_spans_in_blocks(blocks, spans, radio): ...@@ -182,12 +181,9 @@ def fill_spans_in_blocks(blocks, spans, radio):
def fix_block_spans(block_with_spans, img_blocks, table_blocks): def fix_block_spans(block_with_spans, img_blocks, table_blocks):
''' """1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系 需要将caption和footnote的text_span放入相应img_block和table_block内的
需要将caption和footnote的text_span放入相应img_block和table_block内的 caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
caption_block和footnote_block中
2、同时需要删除block中的spans字段
'''
fix_blocks = [] fix_blocks = []
for block in block_with_spans: for block in block_with_spans:
block_type = block['type'] block_type = block['type']
...@@ -218,16 +214,13 @@ def merge_spans_to_block(spans: list, block_bbox: list, block_type: str): ...@@ -218,16 +214,13 @@ def merge_spans_to_block(spans: list, block_bbox: list, block_type: str):
block_spans = [] block_spans = []
# 如果有img_caption,则将img_block中的text_spans放入img_caption_block中 # 如果有img_caption,则将img_block中的text_spans放入img_caption_block中
for span in spans: for span in spans:
if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block_bbox) > 0.6: if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'],
block_bbox) > 0.6:
block_spans.append(span) block_spans.append(span)
block_lines = merge_spans_to_line(block_spans) block_lines = merge_spans_to_line(block_spans)
# 对line中的span进行排序 # 对line中的span进行排序
sort_block_lines = line_sort_spans_by_left_to_right(block_lines) sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
block = { block = {'bbox': block_bbox, 'type': block_type, 'lines': sort_block_lines}
'bbox': block_bbox,
'type': block_type,
'lines': sort_block_lines
}
return block, block_spans return block, block_spans
...@@ -237,11 +230,7 @@ def make_body_block(span: dict, block_bbox: list, block_type: str): ...@@ -237,11 +230,7 @@ def make_body_block(span: dict, block_bbox: list, block_type: str):
'bbox': block_bbox, 'bbox': block_bbox,
'spans': [span], 'spans': [span],
} }
body_block = { body_block = {'bbox': block_bbox, 'type': block_type, 'lines': [body_line]}
'bbox': block_bbox,
'type': block_type,
'lines': [body_line]
}
return body_block return body_block
...@@ -249,13 +238,16 @@ def fix_image_block(block, img_blocks): ...@@ -249,13 +238,16 @@ def fix_image_block(block, img_blocks):
block['blocks'] = [] block['blocks'] = []
# 遍历img_blocks,找到与当前block匹配的img_block # 遍历img_blocks,找到与当前block匹配的img_block
for img_block in img_blocks: for img_block in img_blocks:
if _is_in_or_part_overlap_with_area_ratio(block['bbox'], img_block['bbox'], 0.95): if _is_in_or_part_overlap_with_area_ratio(block['bbox'],
img_block['bbox'], 0.95):
# 创建img_body_block # 创建img_body_block
for span in block['spans']: for span in block['spans']:
if span['type'] == ContentType.Image and img_block['img_body_bbox'] == span['bbox']: if span['type'] == ContentType.Image and img_block[
'img_body_bbox'] == span['bbox']:
# 创建img_body_block # 创建img_body_block
img_body_block = make_body_block(span, img_block['img_body_bbox'], BlockType.ImageBody) img_body_block = make_body_block(
span, img_block['img_body_bbox'], BlockType.ImageBody)
block['blocks'].append(img_body_block) block['blocks'].append(img_body_block)
# 从spans中移除img_body_block中已经放入的span # 从spans中移除img_body_block中已经放入的span
...@@ -265,10 +257,15 @@ def fix_image_block(block, img_blocks): ...@@ -265,10 +257,15 @@ def fix_image_block(block, img_blocks):
# 根据list长度,判断img_block中是否有img_caption # 根据list长度,判断img_block中是否有img_caption
if img_block['img_caption_bbox'] is not None: if img_block['img_caption_bbox'] is not None:
img_caption_block, img_caption_spans = merge_spans_to_block( img_caption_block, img_caption_spans = merge_spans_to_block(
block['spans'], img_block['img_caption_bbox'], BlockType.ImageCaption block['spans'], img_block['img_caption_bbox'],
) BlockType.ImageCaption)
block['blocks'].append(img_caption_block) block['blocks'].append(img_caption_block)
if img_block['img_footnote_bbox'] is not None:
img_footnote_block, img_footnote_spans = merge_spans_to_block(
block['spans'], img_block['img_footnote_bbox'],
BlockType.ImageFootnote)
block['blocks'].append(img_footnote_block)
break break
del block['spans'] del block['spans']
return block return block
...@@ -278,13 +275,17 @@ def fix_table_block(block, table_blocks): ...@@ -278,13 +275,17 @@ def fix_table_block(block, table_blocks):
block['blocks'] = [] block['blocks'] = []
# 遍历table_blocks,找到与当前block匹配的table_block # 遍历table_blocks,找到与当前block匹配的table_block
for table_block in table_blocks: for table_block in table_blocks:
if _is_in_or_part_overlap_with_area_ratio(block['bbox'], table_block['bbox'], 0.95): if _is_in_or_part_overlap_with_area_ratio(block['bbox'],
table_block['bbox'], 0.95):
# 创建table_body_block # 创建table_body_block
for span in block['spans']: for span in block['spans']:
if span['type'] == ContentType.Table and table_block['table_body_bbox'] == span['bbox']: if span['type'] == ContentType.Table and table_block[
'table_body_bbox'] == span['bbox']:
# 创建table_body_block # 创建table_body_block
table_body_block = make_body_block(span, table_block['table_body_bbox'], BlockType.TableBody) table_body_block = make_body_block(
span, table_block['table_body_bbox'],
BlockType.TableBody)
block['blocks'].append(table_body_block) block['blocks'].append(table_body_block)
# 从spans中移除img_body_block中已经放入的span # 从spans中移除img_body_block中已经放入的span
...@@ -294,8 +295,8 @@ def fix_table_block(block, table_blocks): ...@@ -294,8 +295,8 @@ def fix_table_block(block, table_blocks):
# 根据list长度,判断table_block中是否有caption # 根据list长度,判断table_block中是否有caption
if table_block['table_caption_bbox'] is not None: if table_block['table_caption_bbox'] is not None:
table_caption_block, table_caption_spans = merge_spans_to_block( table_caption_block, table_caption_spans = merge_spans_to_block(
block['spans'], table_block['table_caption_bbox'], BlockType.TableCaption block['spans'], table_block['table_caption_bbox'],
) BlockType.TableCaption)
block['blocks'].append(table_caption_block) block['blocks'].append(table_caption_block)
# 如果table_caption_block_spans不为空 # 如果table_caption_block_spans不为空
...@@ -307,8 +308,8 @@ def fix_table_block(block, table_blocks): ...@@ -307,8 +308,8 @@ def fix_table_block(block, table_blocks):
# 根据list长度,判断table_block中是否有table_note # 根据list长度,判断table_block中是否有table_note
if table_block['table_footnote_bbox'] is not None: if table_block['table_footnote_bbox'] is not None:
table_footnote_block, table_footnote_spans = merge_spans_to_block( table_footnote_block, table_footnote_spans = merge_spans_to_block(
block['spans'], table_block['table_footnote_bbox'], BlockType.TableFootnote block['spans'], table_block['table_footnote_bbox'],
) BlockType.TableFootnote)
block['blocks'].append(table_footnote_block) block['blocks'].append(table_footnote_block)
break break
......
# 欢迎来到 MinerU 项目列表
## 项目列表
- [llama_index_rag](./llama_index_rag/README.md): 基于 llama_index 构建轻量级 RAG 系统
import pytest
import os import os
from magic_pdf.libs.boxbase import _is_in_or_part_overlap, _is_in_or_part_overlap_with_area_ratio, _is_in, \
_is_part_overlap, _left_intersect, _right_intersect, _is_vertical_full_overlap, _is_bottom_full_overlap, \ import pytest
_is_left_overlap, __is_overlaps_y_exceeds_threshold, calculate_iou, calculate_overlap_area_2_minbox_area_ratio, \
calculate_overlap_area_in_bbox1_area_ratio, get_minbox_if_overlap_by_ratio, get_bbox_in_boundry, \ from magic_pdf.libs.boxbase import (__is_overlaps_y_exceeds_threshold,
find_top_nearest_text_bbox, find_bottom_nearest_text_bbox, find_left_nearest_text_bbox, \ _is_bottom_full_overlap, _is_in,
find_right_nearest_text_bbox, bbox_relative_pos, bbox_distance _is_in_or_part_overlap,
from magic_pdf.libs.commons import mymax, join_path, get_top_percent_list _is_in_or_part_overlap_with_area_ratio,
_is_left_overlap, _is_part_overlap,
_is_vertical_full_overlap, _left_intersect,
_right_intersect, bbox_distance,
bbox_relative_pos, calculate_iou,
calculate_overlap_area_2_minbox_area_ratio,
calculate_overlap_area_in_bbox1_area_ratio,
find_bottom_nearest_text_bbox,
find_left_nearest_text_bbox,
find_right_nearest_text_bbox,
find_top_nearest_text_bbox,
get_bbox_in_boundary,
get_minbox_if_overlap_by_ratio)
from magic_pdf.libs.commons import get_top_percent_list, join_path, mymax
from magic_pdf.libs.config_reader import get_s3_config from magic_pdf.libs.config_reader import get_s3_config
from magic_pdf.libs.path_utils import parse_s3path from magic_pdf.libs.path_utils import parse_s3path
# 输入一个列表,如果列表空返回0,否则返回最大元素 # 输入一个列表,如果列表空返回0,否则返回最大元素
@pytest.mark.parametrize("list_input, target_num", @pytest.mark.parametrize('list_input, target_num',
[ [
([0, 0, 0, 0], 0), ([0, 0, 0, 0], 0),
([0], 0), ([0], 0),
...@@ -29,7 +41,7 @@ def test_list_max(list_input: list, target_num) -> None: ...@@ -29,7 +41,7 @@ def test_list_max(list_input: list, target_num) -> None:
# 连接多个参数生成路径信息,使用"/"作为连接符,生成的结果需要是一个合法路径 # 连接多个参数生成路径信息,使用"/"作为连接符,生成的结果需要是一个合法路径
@pytest.mark.parametrize("path_input, target_path", [ @pytest.mark.parametrize('path_input, target_path', [
(['https:', '', 'www.baidu.com'], 'https://www.baidu.com'), (['https:', '', 'www.baidu.com'], 'https://www.baidu.com'),
(['https:', 'www.baidu.com'], 'https:/www.baidu.com'), (['https:', 'www.baidu.com'], 'https:/www.baidu.com'),
(['D:', 'file', 'pythonProject', 'demo' + '.py'], 'D:/file/pythonProject/demo.py'), (['D:', 'file', 'pythonProject', 'demo' + '.py'], 'D:/file/pythonProject/demo.py'),
...@@ -42,7 +54,7 @@ def test_join_path(path_input: list, target_path: str) -> None: ...@@ -42,7 +54,7 @@ def test_join_path(path_input: list, target_path: str) -> None:
# 获取列表中前百分之多少的元素 # 获取列表中前百分之多少的元素
@pytest.mark.parametrize("num_list, percent, target_num_list", [ @pytest.mark.parametrize('num_list, percent, target_num_list', [
([], 0.75, []), ([], 0.75, []),
([-5, -10, 9, 3, 7, -7, 0, 23, -1, -11], 0.8, [23, 9, 7, 3, 0, -1, -5, -7]), ([-5, -10, 9, 3, 7, -7, 0, 23, -1, -11], 0.8, [23, 9, 7, 3, 0, -1, -5, -7]),
([-5, -10, 9, 3, 7, -7, 0, 23, -1, -11], 0, []), ([-5, -10, 9, 3, 7, -7, 0, 23, -1, -11], 0, []),
...@@ -57,9 +69,9 @@ def test_get_top_percent_list(num_list: list, percent: float, target_num_list: l ...@@ -57,9 +69,9 @@ def test_get_top_percent_list(num_list: list, percent: float, target_num_list: l
# 输入一个s3路径,返回bucket名字和其余部分(key) # 输入一个s3路径,返回bucket名字和其余部分(key)
@pytest.mark.parametrize("s3_path, target_data", [ @pytest.mark.parametrize('s3_path, target_data', [
("s3://bucket/path/to/my/file.txt", "bucket"), ('s3://bucket/path/to/my/file.txt', 'bucket'),
("s3a://bucket1/path/to/my/file2.txt", "bucket1"), ('s3a://bucket1/path/to/my/file2.txt', 'bucket1'),
# ("/path/to/my/file1.txt", "path"), # ("/path/to/my/file1.txt", "path"),
# ("bucket/path/to/my/file2.txt", "bucket"), # ("bucket/path/to/my/file2.txt", "bucket"),
]) ])
...@@ -76,7 +88,7 @@ def test_parse_s3path(s3_path: str, target_data: str): ...@@ -76,7 +88,7 @@ def test_parse_s3path(s3_path: str, target_data: str):
# 2个box是否处于包含或者部分重合关系。 # 2个box是否处于包含或者部分重合关系。
# 如果某边界重合算重合。 # 如果某边界重合算重合。
# 部分边界重合,其他在内部也算包含 # 部分边界重合,其他在内部也算包含
@pytest.mark.parametrize("box1, box2, target_bool", [ @pytest.mark.parametrize('box1, box2, target_bool', [
((120, 133, 223, 248), (128, 168, 269, 295), True), ((120, 133, 223, 248), (128, 168, 269, 295), True),
((137, 53, 245, 157), (134, 11, 200, 147), True), # 部分重合 ((137, 53, 245, 157), (134, 11, 200, 147), True), # 部分重合
((137, 56, 211, 116), (140, 66, 202, 199), True), # 部分重合 ((137, 56, 211, 116), (140, 66, 202, 199), True), # 部分重合
...@@ -101,7 +113,7 @@ def test_is_in_or_part_overlap(box1: tuple, box2: tuple, target_bool: bool) -> N ...@@ -101,7 +113,7 @@ def test_is_in_or_part_overlap(box1: tuple, box2: tuple, target_bool: bool) -> N
# 如果box1在box2内部,返回True # 如果box1在box2内部,返回True
# 如果是部分重合的,则重合面积占box1的比例大于阈值时候返回True # 如果是部分重合的,则重合面积占box1的比例大于阈值时候返回True
@pytest.mark.parametrize("box1, box2, target_bool", [ @pytest.mark.parametrize('box1, box2, target_bool', [
((35, 28, 108, 90), (47, 60, 83, 96), False), # 包含 box1 up box2, box2 多半,box1少半 ((35, 28, 108, 90), (47, 60, 83, 96), False), # 包含 box1 up box2, box2 多半,box1少半
((65, 151, 92, 177), (49, 99, 105, 198), True), # 包含 box1 in box2 ((65, 151, 92, 177), (49, 99, 105, 198), True), # 包含 box1 in box2
((80, 62, 112, 84), (74, 40, 144, 111), True), # 包含 box1 in box2 ((80, 62, 112, 84), (74, 40, 144, 111), True), # 包含 box1 in box2
...@@ -119,7 +131,7 @@ def test_is_in_or_part_overlap_with_area_ratio(box1: tuple, box2: tuple, target_ ...@@ -119,7 +131,7 @@ def test_is_in_or_part_overlap_with_area_ratio(box1: tuple, box2: tuple, target_
# box1在box2内部或者box2在box1内部返回True。如果部分边界重合也算作包含。 # box1在box2内部或者box2在box1内部返回True。如果部分边界重合也算作包含。
@pytest.mark.parametrize("box1, box2, target_bool", [ @pytest.mark.parametrize('box1, box2, target_bool', [
# ((), (), "Error"), # Error # ((), (), "Error"), # Error
((65, 151, 92, 177), (49, 99, 105, 198), True), # 包含 box1 in box2 ((65, 151, 92, 177), (49, 99, 105, 198), True), # 包含 box1 in box2
((80, 62, 112, 84), (74, 40, 144, 111), True), # 包含 box1 in box2 ((80, 62, 112, 84), (74, 40, 144, 111), True), # 包含 box1 in box2
...@@ -141,7 +153,7 @@ def test_is_in(box1: tuple, box2: tuple, target_bool: bool) -> None: ...@@ -141,7 +153,7 @@ def test_is_in(box1: tuple, box2: tuple, target_bool: bool) -> None:
# 仅仅是部分包含关系,返回True,如果是完全包含关系则返回False # 仅仅是部分包含关系,返回True,如果是完全包含关系则返回False
@pytest.mark.parametrize("box1, box2, target_bool", [ @pytest.mark.parametrize('box1, box2, target_bool', [
((65, 151, 92, 177), (49, 99, 105, 198), False), # 包含 box1 in box2 ((65, 151, 92, 177), (49, 99, 105, 198), False), # 包含 box1 in box2
((80, 62, 112, 84), (74, 40, 144, 111), False), # 包含 box1 in box2 ((80, 62, 112, 84), (74, 40, 144, 111), False), # 包含 box1 in box2
# ((76, 140, 154, 277), (121, 326, 192, 384), False), # 分离 Error # ((76, 140, 154, 277), (121, 326, 192, 384), False), # 分离 Error
...@@ -161,7 +173,7 @@ def test_is_part_overlap(box1: tuple, box2: tuple, target_bool: bool) -> None: ...@@ -161,7 +173,7 @@ def test_is_part_overlap(box1: tuple, box2: tuple, target_bool: bool) -> None:
# left_box右侧是否和right_box左侧有部分重叠 # left_box右侧是否和right_box左侧有部分重叠
@pytest.mark.parametrize("box1, box2, target_bool", [ @pytest.mark.parametrize('box1, box2, target_bool', [
(None, None, False), (None, None, False),
((88, 81, 222, 173), (60, 221, 123, 358), False), # 分离 ((88, 81, 222, 173), (60, 221, 123, 358), False), # 分离
((121, 149, 184, 289), (172, 130, 230, 268), True), # box1 left bottom box2 相交 ((121, 149, 184, 289), (172, 130, 230, 268), True), # box1 left bottom box2 相交
...@@ -179,7 +191,7 @@ def test_left_intersect(box1: tuple, box2: tuple, target_bool: bool) -> None: ...@@ -179,7 +191,7 @@ def test_left_intersect(box1: tuple, box2: tuple, target_bool: bool) -> None:
# left_box左侧是否和right_box右侧部分重叠 # left_box左侧是否和right_box右侧部分重叠
@pytest.mark.parametrize("box1, box2, target_bool", [ @pytest.mark.parametrize('box1, box2, target_bool', [
(None, None, False), (None, None, False),
((88, 81, 222, 173), (60, 221, 123, 358), False), # 分离 ((88, 81, 222, 173), (60, 221, 123, 358), False), # 分离
((121, 149, 184, 289), (172, 130, 230, 268), False), # box1 left bottom box2 相交 ((121, 149, 184, 289), (172, 130, 230, 268), False), # box1 left bottom box2 相交
...@@ -198,7 +210,7 @@ def test_right_intersect(box1: tuple, box2: tuple, target_bool: bool) -> None: ...@@ -198,7 +210,7 @@ def test_right_intersect(box1: tuple, box2: tuple, target_bool: bool) -> None:
# x方向上:要么box1包含box2, 要么box2包含box1。不能部分包含 # x方向上:要么box1包含box2, 要么box2包含box1。不能部分包含
# y方向上:box1和box2有重叠 # y方向上:box1和box2有重叠
@pytest.mark.parametrize("box1, box2, target_bool", [ @pytest.mark.parametrize('box1, box2, target_bool', [
# (None, None, False), # Error # (None, None, False), # Error
((35, 28, 108, 90), (47, 60, 83, 96), True), # box1 top box2, x:box2 in box1, y:有重叠 ((35, 28, 108, 90), (47, 60, 83, 96), True), # box1 top box2, x:box2 in box1, y:有重叠
((35, 28, 98, 90), (27, 60, 103, 96), True), # box1 top box2, x:box1 in box2, y:有重叠 ((35, 28, 98, 90), (27, 60, 103, 96), True), # box1 top box2, x:box1 in box2, y:有重叠
...@@ -214,7 +226,7 @@ def test_is_vertical_full_overlap(box1: tuple, box2: tuple, target_bool: bool) - ...@@ -214,7 +226,7 @@ def test_is_vertical_full_overlap(box1: tuple, box2: tuple, target_bool: bool) -
# 检查box1下方和box2的上方有轻微的重叠,轻微程度收到y_tolerance的限制 # 检查box1下方和box2的上方有轻微的重叠,轻微程度收到y_tolerance的限制
@pytest.mark.parametrize("box1, box2, target_bool", [ @pytest.mark.parametrize('box1, box2, target_bool', [
(None, None, False), (None, None, False),
((35, 28, 108, 90), (47, 89, 83, 116), True), # box1 top box2, y:有重叠 ((35, 28, 108, 90), (47, 89, 83, 116), True), # box1 top box2, y:有重叠
((35, 28, 108, 90), (47, 60, 83, 96), False), # box1 top box2, y:有重叠且过多 ((35, 28, 108, 90), (47, 60, 83, 96), False), # box1 top box2, y:有重叠且过多
...@@ -228,7 +240,7 @@ def test_is_bottom_full_overlap(box1: tuple, box2: tuple, target_bool: bool) -> ...@@ -228,7 +240,7 @@ def test_is_bottom_full_overlap(box1: tuple, box2: tuple, target_bool: bool) ->
# 检查box1的左侧是否和box2有重叠 # 检查box1的左侧是否和box2有重叠
@pytest.mark.parametrize("box1, box2, target_bool", [ @pytest.mark.parametrize('box1, box2, target_bool', [
(None, None, False), (None, None, False),
((88, 81, 222, 173), (60, 221, 123, 358), False), # 分离 ((88, 81, 222, 173), (60, 221, 123, 358), False), # 分离
# ((121, 149, 184, 289), (172, 130, 230, 268), False), # box1 left bottom box2 相交 Error # ((121, 149, 184, 289), (172, 130, 230, 268), False), # box1 left bottom box2 相交 Error
...@@ -245,7 +257,7 @@ def test_is_left_overlap(box1: tuple, box2: tuple, target_bool: bool) -> None: ...@@ -245,7 +257,7 @@ def test_is_left_overlap(box1: tuple, box2: tuple, target_bool: bool) -> None:
# 查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过阈值 # 查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过阈值
@pytest.mark.parametrize("box1, box2, target_bool", [ @pytest.mark.parametrize('box1, box2, target_bool', [
# (None, None, "Error"), # Error # (None, None, "Error"), # Error
((51, 69, 192, 147), (75, 48, 132, 187), True), # y: box1 in box2 ((51, 69, 192, 147), (75, 48, 132, 187), True), # y: box1 in box2
((51, 39, 192, 197), (75, 48, 132, 187), True), # y: box2 in box1 ((51, 39, 192, 197), (75, 48, 132, 187), True), # y: box2 in box1
...@@ -260,7 +272,7 @@ def test_is_overlaps_y_exceeds_threshold(box1: tuple, box2: tuple, target_bool: ...@@ -260,7 +272,7 @@ def test_is_overlaps_y_exceeds_threshold(box1: tuple, box2: tuple, target_bool:
# Determine the coordinates of the intersection rectangle # Determine the coordinates of the intersection rectangle
@pytest.mark.parametrize("box1, box2, target_num", [ @pytest.mark.parametrize('box1, box2, target_num', [
# (None, None, "Error"), # Error # (None, None, "Error"), # Error
((88, 81, 222, 173), (60, 221, 123, 358), 0.0), # 分离 ((88, 81, 222, 173), (60, 221, 123, 358), 0.0), # 分离
((76, 140, 154, 277), (121, 326, 192, 384), 0.0), # 分离 ((76, 140, 154, 277), (121, 326, 192, 384), 0.0), # 分离
...@@ -276,7 +288,7 @@ def test_calculate_iou(box1: tuple, box2: tuple, target_num: float) -> None: ...@@ -276,7 +288,7 @@ def test_calculate_iou(box1: tuple, box2: tuple, target_num: float) -> None:
# 计算box1和box2的重叠面积占最小面积的box的比例 # 计算box1和box2的重叠面积占最小面积的box的比例
@pytest.mark.parametrize("box1, box2, target_num", [ @pytest.mark.parametrize('box1, box2, target_num', [
# (None, None, "Error"), # Error # (None, None, "Error"), # Error
((142, 109, 238, 164), (134, 211, 224, 270), 0.0), # 分离 ((142, 109, 238, 164), (134, 211, 224, 270), 0.0), # 分离
((88, 81, 222, 173), (60, 221, 123, 358), 0.0), # 分离 ((88, 81, 222, 173), (60, 221, 123, 358), 0.0), # 分离
...@@ -295,7 +307,7 @@ def test_calculate_overlap_area_2_minbox_area_ratio(box1: tuple, box2: tuple, ta ...@@ -295,7 +307,7 @@ def test_calculate_overlap_area_2_minbox_area_ratio(box1: tuple, box2: tuple, ta
# 计算box1和box2的重叠面积占bbox1的比例 # 计算box1和box2的重叠面积占bbox1的比例
@pytest.mark.parametrize("box1, box2, target_num", [ @pytest.mark.parametrize('box1, box2, target_num', [
# (None, None, "Error"), # Error # (None, None, "Error"), # Error
((142, 109, 238, 164), (134, 211, 224, 270), 0.0), # 分离 ((142, 109, 238, 164), (134, 211, 224, 270), 0.0), # 分离
((88, 81, 222, 173), (60, 221, 123, 358), 0.0), # 分离 ((88, 81, 222, 173), (60, 221, 123, 358), 0.0), # 分离
...@@ -315,7 +327,7 @@ def test_calculate_overlap_area_in_bbox1_area_ratio(box1: tuple, box2: tuple, ta ...@@ -315,7 +327,7 @@ def test_calculate_overlap_area_in_bbox1_area_ratio(box1: tuple, box2: tuple, ta
# 计算两个bbox重叠的面积占最小面积的box的比例,如果比例大于ratio,则返回小的那个bbox,否则返回None # 计算两个bbox重叠的面积占最小面积的box的比例,如果比例大于ratio,则返回小的那个bbox,否则返回None
@pytest.mark.parametrize("box1, box2, ratio, target_box", [ @pytest.mark.parametrize('box1, box2, ratio, target_box', [
# (None, None, 0.8, "Error"), # Error # (None, None, 0.8, "Error"), # Error
((142, 109, 238, 164), (134, 211, 224, 270), 0.0, None), # 分离 ((142, 109, 238, 164), (134, 211, 224, 270), 0.0, None), # 分离
((109, 126, 204, 245), (110, 127, 232, 206), 0.5, (110, 127, 232, 206)), ((109, 126, 204, 245), (110, 127, 232, 206), 0.5, (110, 127, 232, 206)),
...@@ -331,7 +343,7 @@ def test_get_minbox_if_overlap_by_ratio(box1: tuple, box2: tuple, ratio: float, ...@@ -331,7 +343,7 @@ def test_get_minbox_if_overlap_by_ratio(box1: tuple, box2: tuple, ratio: float,
# 根据boundry获取在这个范围内的所有的box的列表,完全包含关系 # 根据boundry获取在这个范围内的所有的box的列表,完全包含关系
@pytest.mark.parametrize("boxes, boundry, target_boxs", [ @pytest.mark.parametrize('boxes, boundary, target_boxs', [
# ([], (), "Error"), # Error # ([], (), "Error"), # Error
([], (110, 340, 209, 387), []), ([], (110, 340, 209, 387), []),
([(142, 109, 238, 164)], (134, 211, 224, 270), []), # 分离 ([(142, 109, 238, 164)], (134, 211, 224, 270), []), # 分离
...@@ -347,15 +359,15 @@ def test_get_minbox_if_overlap_by_ratio(box1: tuple, box2: tuple, ratio: float, ...@@ -347,15 +359,15 @@ def test_get_minbox_if_overlap_by_ratio(box1: tuple, box2: tuple, ratio: float,
(137, 29, 287, 87)], (30, 20, 200, 320), (137, 29, 287, 87)], (30, 20, 200, 320),
[(81, 280, 123, 315), (46, 99, 133, 148), (33, 156, 97, 211)]), [(81, 280, 123, 315), (46, 99, 133, 148), (33, 156, 97, 211)]),
]) ])
def test_get_bbox_in_boundry(boxes: list, boundry: tuple, target_boxs: list) -> None: def test_get_bbox_in_boundary(boxes: list, boundary: tuple, target_boxs: list) -> None:
assert target_boxs == get_bbox_in_boundry(boxes, boundry) assert target_boxs == get_bbox_in_boundary(boxes, boundary)
# 寻找上方距离最近的box,margin 4个单位, x方向有重合,y方向最近的 # 寻找上方距离最近的box,margin 4个单位, x方向有重合,y方向最近的
@pytest.mark.parametrize("pymu_blocks, obj_box, target_boxs", [ @pytest.mark.parametrize('pymu_blocks, obj_box, target_boxs', [
([{"bbox": (81, 280, 123, 315)}, {"bbox": (282, 203, 342, 247)}, {"bbox": (183, 100, 300, 155)}, ([{'bbox': (81, 280, 123, 315)}, {'bbox': (282, 203, 342, 247)}, {'bbox': (183, 100, 300, 155)},
{"bbox": (46, 99, 133, 148)}, {"bbox": (33, 156, 97, 211)}, {'bbox': (46, 99, 133, 148)}, {'bbox': (33, 156, 97, 211)},
{"bbox": (137, 29, 287, 87)}], (81, 280, 123, 315), {"bbox": (33, 156, 97, 211)}), {'bbox': (137, 29, 287, 87)}], (81, 280, 123, 315), {'bbox': (33, 156, 97, 211)}),
# ([{"bbox": (168, 120, 263, 159)}, # ([{"bbox": (168, 120, 263, 159)},
# {"bbox": (231, 61, 279, 159)}, # {"bbox": (231, 61, 279, 159)},
# {"bbox": (35, 85, 136, 110)}, # {"bbox": (35, 85, 136, 110)},
...@@ -363,46 +375,46 @@ def test_get_bbox_in_boundry(boxes: list, boundry: tuple, target_boxs: list) -> ...@@ -363,46 +375,46 @@ def test_get_bbox_in_boundry(boxes: list, boundry: tuple, target_boxs: list) ->
# {"bbox": (144, 264, 188, 323)}, # {"bbox": (144, 264, 188, 323)},
# {"bbox": (62, 37, 126, 64)}], (228, 193, 347, 225), # {"bbox": (62, 37, 126, 64)}], (228, 193, 347, 225),
# [{"bbox": (168, 120, 263, 159)}, {"bbox": (231, 61, 279, 159)}]), # y:方向最近的有两个,x: 两个均有重合 Error # [{"bbox": (168, 120, 263, 159)}, {"bbox": (231, 61, 279, 159)}]), # y:方向最近的有两个,x: 两个均有重合 Error
([{"bbox": (35, 85, 136, 159)}, ([{'bbox': (35, 85, 136, 159)},
{"bbox": (168, 120, 263, 159)}, {'bbox': (168, 120, 263, 159)},
{"bbox": (231, 61, 279, 118)}, {'bbox': (231, 61, 279, 118)},
{"bbox": (228, 193, 347, 225)}, {'bbox': (228, 193, 347, 225)},
{"bbox": (144, 264, 188, 323)}, {'bbox': (144, 264, 188, 323)},
{"bbox": (62, 37, 126, 64)}], (228, 193, 347, 225), {'bbox': (62, 37, 126, 64)}], (228, 193, 347, 225),
{"bbox": (168, 120, 263, 159)},), # y:方向最近的有两个,x:只有一个有重合 {'bbox': (168, 120, 263, 159)},), # y:方向最近的有两个,x:只有一个有重合
([{"bbox": (239, 115, 379, 167)}, ([{'bbox': (239, 115, 379, 167)},
{"bbox": (33, 237, 104, 262)}, {'bbox': (33, 237, 104, 262)},
{"bbox": (124, 288, 168, 325)}, {'bbox': (124, 288, 168, 325)},
{"bbox": (242, 291, 379, 340)}, {'bbox': (242, 291, 379, 340)},
{"bbox": (55, 117, 121, 154)}, {'bbox': (55, 117, 121, 154)},
{"bbox": (266, 183, 384, 217)}, ], (124, 288, 168, 325), {'bbox': (55, 117, 121, 154)}), {'bbox': (266, 183, 384, 217)}, ], (124, 288, 168, 325), {'bbox': (55, 117, 121, 154)}),
([{"bbox": (239, 115, 379, 167)}, ([{'bbox': (239, 115, 379, 167)},
{"bbox": (33, 237, 104, 262)}, {'bbox': (33, 237, 104, 262)},
{"bbox": (124, 288, 168, 325)}, {'bbox': (124, 288, 168, 325)},
{"bbox": (242, 291, 379, 340)}, {'bbox': (242, 291, 379, 340)},
{"bbox": (55, 117, 119, 154)}, {'bbox': (55, 117, 119, 154)},
{"bbox": (266, 183, 384, 217)}, ], (124, 288, 168, 325), None), # x没有重合 {'bbox': (266, 183, 384, 217)}, ], (124, 288, 168, 325), None), # x没有重合
([{"bbox": (80, 90, 249, 200)}, ([{'bbox': (80, 90, 249, 200)},
{"bbox": (183, 100, 240, 155)}, ], (183, 100, 240, 155), None), # 包含 {'bbox': (183, 100, 240, 155)}, ], (183, 100, 240, 155), None), # 包含
]) ])
def test_find_top_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_boxs: dict) -> None: def test_find_top_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_boxs: dict) -> None:
assert target_boxs == find_top_nearest_text_bbox(pymu_blocks, obj_box) assert target_boxs == find_top_nearest_text_bbox(pymu_blocks, obj_box)
# 寻找下方距离自己最近的box, x方向有重合,y方向最近的 # 寻找下方距离自己最近的box, x方向有重合,y方向最近的
@pytest.mark.parametrize("pymu_blocks, obj_box, target_boxs", [ @pytest.mark.parametrize('pymu_blocks, obj_box, target_boxs', [
([{"bbox": (165, 96, 300, 114)}, ([{'bbox': (165, 96, 300, 114)},
{"bbox": (11, 157, 139, 201)}, {'bbox': (11, 157, 139, 201)},
{"bbox": (124, 208, 265, 262)}, {'bbox': (124, 208, 265, 262)},
{"bbox": (124, 283, 248, 306)}, {'bbox': (124, 283, 248, 306)},
{"bbox": (39, 267, 84, 301)}, {'bbox': (39, 267, 84, 301)},
{"bbox": (36, 89, 114, 145)}, ], (165, 96, 300, 114), {"bbox": (124, 208, 265, 262)}), {'bbox': (36, 89, 114, 145)}, ], (165, 96, 300, 114), {'bbox': (124, 208, 265, 262)}),
([{"bbox": (187, 37, 303, 49)}, ([{'bbox': (187, 37, 303, 49)},
{"bbox": (2, 227, 90, 283)}, {'bbox': (2, 227, 90, 283)},
{"bbox": (158, 174, 200, 212)}, {'bbox': (158, 174, 200, 212)},
{"bbox": (259, 174, 324, 228)}, {'bbox': (259, 174, 324, 228)},
{"bbox": (205, 61, 316, 97)}, {'bbox': (205, 61, 316, 97)},
{"bbox": (295, 248, 374, 287)}, ], (205, 61, 316, 97), {"bbox": (259, 174, 324, 228)}), # y有两个最近的, x只有一个重合 {'bbox': (295, 248, 374, 287)}, ], (205, 61, 316, 97), {'bbox': (259, 174, 324, 228)}), # y有两个最近的, x只有一个重合
# ([{"bbox": (187, 37, 303, 49)}, # ([{"bbox": (187, 37, 303, 49)},
# {"bbox": (2, 227, 90, 283)}, # {"bbox": (2, 227, 90, 283)},
# {"bbox": (259, 174, 324, 228)}, # {"bbox": (259, 174, 324, 228)},
...@@ -410,31 +422,31 @@ def test_find_top_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_bo ...@@ -410,31 +422,31 @@ def test_find_top_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_bo
# {"bbox": (295, 248, 374, 287)}, # {"bbox": (295, 248, 374, 287)},
# {"bbox": (158, 174, 209, 212)}, ], (205, 61, 316, 97), # {"bbox": (158, 174, 209, 212)}, ], (205, 61, 316, 97),
# [{"bbox": (259, 174, 324, 228)}, {"bbox": (158, 174, 209, 212)}]), # x有重合,y有两个最近的 Error # [{"bbox": (259, 174, 324, 228)}, {"bbox": (158, 174, 209, 212)}]), # x有重合,y有两个最近的 Error
([{"bbox": (287, 132, 398, 191)}, ([{'bbox': (287, 132, 398, 191)},
{"bbox": (44, 141, 163, 188)}, {'bbox': (44, 141, 163, 188)},
{"bbox": (132, 191, 240, 241)}, {'bbox': (132, 191, 240, 241)},
{"bbox": (81, 25, 142, 67)}, {'bbox': (81, 25, 142, 67)},
{"bbox": (74, 297, 116, 314)}, {'bbox': (74, 297, 116, 314)},
{"bbox": (77, 84, 224, 107)}, ], (287, 132, 398, 191), None), # x没有重合 {'bbox': (77, 84, 224, 107)}, ], (287, 132, 398, 191), None), # x没有重合
([{"bbox": (80, 90, 249, 200)}, ([{'bbox': (80, 90, 249, 200)},
{"bbox": (183, 100, 240, 155)}, ], (183, 100, 240, 155), None), # 包含 {'bbox': (183, 100, 240, 155)}, ], (183, 100, 240, 155), None), # 包含
]) ])
def test_find_bottom_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_boxs: dict) -> None: def test_find_bottom_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_boxs: dict) -> None:
assert target_boxs == find_bottom_nearest_text_bbox(pymu_blocks, obj_box) assert target_boxs == find_bottom_nearest_text_bbox(pymu_blocks, obj_box)
# 寻找左侧距离自己最近的box, y方向有重叠,x方向最近 # 寻找左侧距离自己最近的box, y方向有重叠,x方向最近
@pytest.mark.parametrize("pymu_blocks, obj_box, target_boxs", [ @pytest.mark.parametrize('pymu_blocks, obj_box, target_boxs', [
([{"bbox": (80, 90, 249, 200)}, {"bbox": (183, 100, 240, 155)}], (183, 100, 240, 155), None), # 包含 ([{'bbox': (80, 90, 249, 200)}, {'bbox': (183, 100, 240, 155)}], (183, 100, 240, 155), None), # 包含
([{"bbox": (28, 90, 77, 126)}, {"bbox": (35, 84, 84, 120)}], (35, 84, 84, 120), None), # y:重叠,x:重叠大于2 ([{'bbox': (28, 90, 77, 126)}, {'bbox': (35, 84, 84, 120)}], (35, 84, 84, 120), None), # y:重叠,x:重叠大于2
([{"bbox": (28, 90, 77, 126)}, {"bbox": (75, 84, 134, 120)}], (75, 84, 134, 120), {"bbox": (28, 90, 77, 126)}), ([{'bbox': (28, 90, 77, 126)}, {'bbox': (75, 84, 134, 120)}], (75, 84, 134, 120), {'bbox': (28, 90, 77, 126)}),
# y:重叠,x:重叠小于等于2 # y:重叠,x:重叠小于等于2
([{"bbox": (239, 115, 379, 167)}, ([{'bbox': (239, 115, 379, 167)},
{"bbox": (33, 237, 104, 262)}, {'bbox': (33, 237, 104, 262)},
{"bbox": (124, 288, 168, 325)}, {'bbox': (124, 288, 168, 325)},
{"bbox": (242, 291, 379, 340)}, {'bbox': (242, 291, 379, 340)},
{"bbox": (55, 113, 161, 154)}, {'bbox': (55, 113, 161, 154)},
{"bbox": (266, 123, 384, 217)}], (266, 123, 384, 217), {"bbox": (55, 113, 161, 154)}), # y重叠,x left {'bbox': (266, 123, 384, 217)}], (266, 123, 384, 217), {'bbox': (55, 113, 161, 154)}), # y重叠,x left
# ([{"bbox": (136, 219, 268, 240)}, # ([{"bbox": (136, 219, 268, 240)},
# {"bbox": (169, 115, 268, 181)}, # {"bbox": (169, 115, 268, 181)},
# {"bbox": (33, 237, 104, 262)}, # {"bbox": (33, 237, 104, 262)},
...@@ -448,17 +460,17 @@ def test_find_left_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_b ...@@ -448,17 +460,17 @@ def test_find_left_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_b
# 寻找右侧距离自己最近的box, y方向有重叠,x方向最近 # 寻找右侧距离自己最近的box, y方向有重叠,x方向最近
@pytest.mark.parametrize("pymu_blocks, obj_box, target_boxs", [ @pytest.mark.parametrize('pymu_blocks, obj_box, target_boxs', [
([{"bbox": (80, 90, 249, 200)}, {"bbox": (183, 100, 240, 155)}], (183, 100, 240, 155), None), # 包含 ([{'bbox': (80, 90, 249, 200)}, {'bbox': (183, 100, 240, 155)}], (183, 100, 240, 155), None), # 包含
([{"bbox": (28, 90, 77, 126)}, {"bbox": (35, 84, 84, 120)}], (28, 90, 77, 126), None), # y:重叠,x:重叠大于2 ([{'bbox': (28, 90, 77, 126)}, {'bbox': (35, 84, 84, 120)}], (28, 90, 77, 126), None), # y:重叠,x:重叠大于2
([{"bbox": (28, 90, 77, 126)}, {"bbox": (75, 84, 134, 120)}], (28, 90, 77, 126), {"bbox": (75, 84, 134, 120)}), ([{'bbox': (28, 90, 77, 126)}, {'bbox': (75, 84, 134, 120)}], (28, 90, 77, 126), {'bbox': (75, 84, 134, 120)}),
# y:重叠,x:重叠小于等于2 # y:重叠,x:重叠小于等于2
([{"bbox": (239, 115, 379, 167)}, ([{'bbox': (239, 115, 379, 167)},
{"bbox": (33, 237, 104, 262)}, {'bbox': (33, 237, 104, 262)},
{"bbox": (124, 288, 168, 325)}, {'bbox': (124, 288, 168, 325)},
{"bbox": (242, 291, 379, 340)}, {'bbox': (242, 291, 379, 340)},
{"bbox": (55, 113, 161, 154)}, {'bbox': (55, 113, 161, 154)},
{"bbox": (266, 123, 384, 217)}], (55, 113, 161, 154), {"bbox": (239, 115, 379, 167)}), # y重叠,x right {'bbox': (266, 123, 384, 217)}], (55, 113, 161, 154), {'bbox': (239, 115, 379, 167)}), # y重叠,x right
# ([{"bbox": (169, 115, 298, 181)}, # ([{"bbox": (169, 115, 298, 181)},
# {"bbox": (169, 219, 268, 240)}, # {"bbox": (169, 219, 268, 240)},
# {"bbox": (33, 177, 104, 262)}, # {"bbox": (33, 177, 104, 262)},
...@@ -472,7 +484,7 @@ def test_find_right_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_ ...@@ -472,7 +484,7 @@ def test_find_right_nearest_text_bbox(pymu_blocks: list, obj_box: tuple, target_
# 判断两个矩形框的相对位置关系 (left, right, bottom, top) # 判断两个矩形框的相对位置关系 (left, right, bottom, top)
@pytest.mark.parametrize("box1, box2, target_box", [ @pytest.mark.parametrize('box1, box2, target_box', [
# (None, None, "Error"), # Error # (None, None, "Error"), # Error
((80, 90, 249, 200), (183, 100, 240, 155), (False, False, False, False)), # 包含 ((80, 90, 249, 200), (183, 100, 240, 155), (False, False, False, False)), # 包含
# ((124, 81, 222, 173), (60, 221, 123, 358), (False, True, False, True)), # 分离,右上 Error # ((124, 81, 222, 173), (60, 221, 123, 358), (False, True, False, True)), # 分离,右上 Error
...@@ -494,7 +506,7 @@ def test_bbox_relative_pos(box1: tuple, box2: tuple, target_box: tuple) -> None: ...@@ -494,7 +506,7 @@ def test_bbox_relative_pos(box1: tuple, box2: tuple, target_box: tuple) -> None:
""" """
@pytest.mark.parametrize("box1, box2, target_num", [ @pytest.mark.parametrize('box1, box2, target_num', [
# (None, None, "Error"), # Error # (None, None, "Error"), # Error
((80, 90, 249, 200), (183, 100, 240, 155), 0.0), # 包含 ((80, 90, 249, 200), (183, 100, 240, 155), 0.0), # 包含
((142, 109, 238, 164), (134, 211, 224, 270), 47.0), # 分离,上 ((142, 109, 238, 164), (134, 211, 224, 270), 47.0), # 分离,上
...@@ -514,17 +526,17 @@ def test_bbox_relative_pos(box1: tuple, box2: tuple, target_box: tuple) -> None: ...@@ -514,17 +526,17 @@ def test_bbox_relative_pos(box1: tuple, box2: tuple, target_box: tuple) -> None:
def test_bbox_distance(box1: tuple, box2: tuple, target_num: float) -> None: def test_bbox_distance(box1: tuple, box2: tuple, target_num: float) -> None:
assert target_num - bbox_distance(box1, box2) < 1 assert target_num - bbox_distance(box1, box2) < 1
@pytest.mark.skip(reason="skip")
@pytest.mark.skip(reason='skip')
# 根据bucket_name获取s3配置ak,sk,endpoint # 根据bucket_name获取s3配置ak,sk,endpoint
def test_get_s3_config() -> None: def test_get_s3_config() -> None:
bucket_name = os.getenv('bucket_name') bucket_name = os.getenv('bucket_name')
target_data = os.getenv('target_data') target_data = os.getenv('target_data')
assert convert_string_to_list(target_data) == list(get_s3_config(bucket_name)) assert convert_string_to_list(target_data) == list(get_s3_config(bucket_name))
def convert_string_to_list(s):
def convert_string_to_list(s): cleaned_s = s.strip("'")
cleaned_s = s.strip("'") items = cleaned_s.split(',')
items = cleaned_s.split(',') cleaned_items = [item.strip() for item in items]
cleaned_items = [item.strip() for item in items] return cleaned_items
return cleaned_items
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment