Unverified Commit aac91094 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub

refactor(pdf_extract_kit): implement singleton pattern for atomic models (#533)

Refactor the pdf_extract_kit module to utilize a singleton pattern when initializing
atomic models. This change ensures that atomic models are instantiated at most once,
optimizing memory usage and reducing redundant initialization steps. The AtomModelSingleton
class now manages the instantiation and retrieval of atomic models, improving the
overall structure and efficiency of the codebase.
parent 03469909
class MODEL: class MODEL:
Paddle = "pp_structure_v2" Paddle = "pp_structure_v2"
PEK = "pdf_extract_kit" PEK = "pdf_extract_kit"
class AtomicModel:
Layout = "layout"
MFD = "mfd"
MFR = "mfr"
OCR = "ocr"
Table = "table"
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import time import time
from magic_pdf.libs.Constants import * from magic_pdf.libs.Constants import *
from magic_pdf.model.model_list import AtomicModel
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
try: try:
...@@ -64,7 +65,8 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'): ...@@ -64,7 +65,8 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
model = task.build_model(cfg) model = task.build_model(cfg)
model = model.to(_device_) model = model.to(_device_)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
return model, vis_processor mfr_transform = transforms.Compose([vis_processor, ])
return [model, mfr_transform]
def layout_model_init(weight, config_file, device): def layout_model_init(weight, config_file, device):
...@@ -72,6 +74,11 @@ def layout_model_init(weight, config_file, device): ...@@ -72,6 +74,11 @@ def layout_model_init(weight, config_file, device):
return model return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
return model
class MathDataset(Dataset): class MathDataset(Dataset):
def __init__(self, image_paths, transform=None): def __init__(self, image_paths, transform=None):
self.image_paths = image_paths self.image_paths = image_paths
...@@ -91,6 +98,58 @@ class MathDataset(Dataset): ...@@ -91,6 +98,58 @@ class MathDataset(Dataset):
return image return image
class AtomModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
if atom_model_name not in self._models:
self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[atom_model_name]
def atom_model_init(model_name: str, **kwargs):
if model_name == AtomicModel.Layout:
atom_model = layout_model_init(
kwargs.get("layout_weights"),
kwargs.get("layout_config_file"),
kwargs.get("device")
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get("mfd_weights")
)
elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init(
kwargs.get("mfr_weight_dir"),
kwargs.get("mfr_cfg_path"),
kwargs.get("device")
)
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get("ocr_show_log"),
kwargs.get("det_db_box_thresh")
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
kwargs.get("table_model_type"),
kwargs.get("table_model_path"),
kwargs.get("table_max_time"),
kwargs.get("device")
)
else:
logger.error("model name not allow")
exit(1)
return atom_model
class CustomPEKModel: class CustomPEKModel:
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
...@@ -130,32 +189,60 @@ class CustomPEKModel: ...@@ -130,32 +189,60 @@ class CustomPEKModel:
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models")) models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
logger.info("using models_dir: {}".format(models_dir)) logger.info("using models_dir: {}".format(models_dir))
atom_model_manager = AtomModelSingleton()
# 初始化公式识别 # 初始化公式识别
if self.apply_formula: if self.apply_formula:
# 初始化公式检测模型 # 初始化公式检测模型
self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"]))) # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
)
# 初始化公式解析模型 # 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"])) mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml")) mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device) # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
self.mfr_transform = transforms.Compose([mfr_vis_processors, ]) # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path,
device=self.device
)
# 初始化layout模型 # 初始化layout模型
self.layout_model = Layoutlmv3_Predictor( # self.layout_model = Layoutlmv3_Predictor(
str(os.path.join(models_dir, self.configs['weights']['layout'])), # str(os.path.join(models_dir, self.configs['weights']['layout'])),
str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")), # str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
# device=self.device
# )
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
device=self.device device=self.device
) )
# 初始化ocr # 初始化ocr
if self.apply_ocr: if self.apply_ocr:
self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3) # self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3
)
# init table model # init table model
if self.apply_table: if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_type] table_model_dir = self.configs["weights"][self.table_model_type]
self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)), # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
max_time=self.table_max_time, _device_=self.device) # max_time=self.table_max_time, _device_=self.device)
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_type=self.table_model_type,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
)
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
def __call__(self, image): def __call__(self, image):
...@@ -291,11 +378,11 @@ class CustomPEKModel: ...@@ -291,11 +378,11 @@ class CustomPEKModel:
logger.info("------------------table recognition processing begins-----------------") logger.info("------------------table recognition processing begins-----------------")
latex_code = None latex_code = None
html_code = None html_code = None
with torch.no_grad(): if self.table_model_type == STRUCT_EQTABLE:
if self.table_model_type == STRUCT_EQTABLE: with torch.no_grad():
latex_code = self.table_model.image2latex(new_image)[0] latex_code = self.table_model.image2latex(new_image)[0]
else: else:
html_code = self.table_model.img2html(new_image) html_code = self.table_model.img2html(new_image)
run_time = time.time() - single_table_start_time run_time = time.time() - single_table_start_time
logger.info(f"------------table recognition processing ends within {run_time}s-----") logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time: if run_time > self.table_max_time:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment