Commit b9dfdea3 authored by myhloli's avatar myhloli

refactor(pdf_parse_union_core_v2): implement model initialization within...

refactor(pdf_parse_union_core_v2): implement model initialization within classRefactored model initialization to be handled by a singleton class to ensure that model
instances are reused across calls, avoiding redundant initializations. Removed logger
information that was commented out and ensured consistency in logging behavior.
parent b2790f6f
...@@ -94,11 +94,39 @@ def replace_text_span(pymu_spans, ocr_spans): ...@@ -94,11 +94,39 @@ def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
def do_predict(boxes: List[List[int]]) -> List[int]: def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification from transformers import LayoutLMv3ForTokenClassification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_name == "layoutreader":
model = (
LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
# .bfloat16()
.to(device)
.eval()
)
else:
logger.error("model name not allow")
exit(1)
return model
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, model_name: str):
if model_name not in self._models:
self._models[model_name] = model_init(model_name=model_name)
return self._models[model_name]
def do_predict(boxes: List[List[int]], model) -> List[int]:
from magic_pdf.v3.helpers import prepare_inputs, boxes2inputs, parse_logits from magic_pdf.v3.helpers import prepare_inputs, boxes2inputs, parse_logits
model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
# model.to("cuda")
inputs = boxes2inputs(boxes) inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, model) inputs = prepare_inputs(inputs, model)
logits = model(**inputs).logits.cpu().squeeze(0) logits = model(**inputs).logits.cpu().squeeze(0)
...@@ -184,7 +212,7 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, ...@@ -184,7 +212,7 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
x_scale = 1000.0 / page_w x_scale = 1000.0 / page_w
y_scale = 1000.0 / page_h y_scale = 1000.0 / page_h
boxes = [] boxes = []
logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}") # logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
for left, top, right, bottom in page_line_list: for left, top, right, bottom in page_line_list:
left = round(left * x_scale) left = round(left * x_scale)
top = round(top * y_scale) top = round(top * y_scale)
...@@ -194,9 +222,12 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, ...@@ -194,9 +222,12 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0 1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}" ), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}"
boxes.append([left, top, right, bottom]) boxes.append([left, top, right, bottom])
model_manager = ModelSingleton()
model = model_manager.get_model("layoutreader")
layoutreader_start = time.time() layoutreader_start = time.time()
orders = do_predict(boxes) with torch.no_grad():
logger.info(f"layoutreader cost time{time.time() - layoutreader_start}") orders = do_predict(boxes, model)
# logger.info(f"layoutreader cost time{time.time() - layoutreader_start}")
sorted_bboxes = [page_line_list[i] for i in orders] sorted_bboxes = [page_line_list[i] for i in orders]
'''根据line的中位数算block的序列关系''' '''根据line的中位数算block的序列关系'''
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment