Commit 2145a8b6 authored by myhloli's avatar myhloli

fix(pdf_parse): handle blocks without lines and enable bf16 on compatible devices

Blocks without lines are now correctly indexed even when they contain textual content rendered
as images. The sorting logic has been updated to accommodate this scenario. Additionally, the
LayoutLMv3 model initialization has been enhanced to utilize bfloat16 precision on devices that
support it, offering potential performance benefits on supported hardware.
parent 177ab08e
......@@ -94,16 +94,27 @@ def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
def model_init(model_name: str):
def model_init(model_name: str, local_path=None):
from transformers import LayoutLMv3ForTokenClassification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
if torch.cuda.is_bf16_supported():
supports_bfloat16 = True
else:
supports_bfloat16 = False
else:
device = torch.device("cpu")
supports_bfloat16 = False
if model_name == "layoutreader":
model = (
LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
# .bfloat16()
.to(device)
.eval()
)
if local_path:
model = LayoutLMv3ForTokenClassification.from_pretrained(local_path)
else:
model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
# 检查设备是否支持 bfloat16
if supports_bfloat16:
model.bfloat16()
model.to(device).eval()
else:
logger.error("model name not allow")
exit(1)
......@@ -119,9 +130,12 @@ class ModelSingleton:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, model_name: str):
def get_model(self, model_name: str, local_path=None):
if model_name not in self._models:
self._models[model_name] = model_init(model_name=model_name)
if local_path:
self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
else:
self._models[model_name] = model_init(model_name=model_name)
return self._models[model_name]
......@@ -134,13 +148,11 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:
def cal_block_index(fix_blocks, sorted_bboxes):
block_without_lines = []
for block in fix_blocks:
if block['type'] in ['text', 'title', 'interline_equation']:
line_index_list = []
if len(block['lines']) == 0:
block_without_lines.append(block)
continue
block['index'] = sorted_bboxes.index(block['bbox'])
else:
for line in block['lines']:
line['index'] = sorted_bboxes.index(line['bbox'])
......@@ -151,10 +163,6 @@ def cal_block_index(fix_blocks, sorted_bboxes):
elif block['type'] in ['table', 'image']:
block['index'] = sorted_bboxes.index(block['bbox'])
'''移除没有line的block'''
for block in block_without_lines:
fix_blocks.remove(block)
return fix_blocks
......@@ -162,9 +170,13 @@ def sort_lines_by_model(fix_blocks, page_w, page_h):
page_line_list = []
for block in fix_blocks:
if block['type'] in ['text', 'title', 'interline_equation']:
for line in block['lines']:
bbox = line['bbox']
if len(block['lines']) == 0: # 没有line的block(一般是图片形式的文本块),就直接用block的bbox来排序
bbox = block['bbox']
page_line_list.append(bbox)
else:
for line in block['lines']:
bbox = line['bbox']
page_line_list.append(bbox)
elif block['type'] in ['table', 'image']: # 简单的把表和图都当成一个line处理
bbox = block['bbox']
page_line_list.append(bbox)
......@@ -175,6 +187,23 @@ def sort_lines_by_model(fix_blocks, page_w, page_h):
boxes = []
# logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
for left, top, right, bottom in page_line_list:
if left < 0:
logger.warning(
f"left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
left = 0
if right > page_w:
logger.warning(
f"right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
right = page_w
if top < 0:
logger.warning(
f"top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
top = 0
if bottom > page_h:
logger.warning(
f"bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
bottom = page_h
left = round(left * x_scale)
top = round(top * y_scale)
right = round(right * x_scale)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment