Commit 695b3579 authored by myhloli's avatar myhloli

feat(config-reader): add models-dir and device-mode configurations

Add new configuration options for custom model directories and device modeselection. This allows users to specify the directory where models are stored
and choose between CPU and GPU modes for model inference. The configurations
are read from a JSON file and can be easily extended to support additional
options in the future.
parent 45e7fbd2
...@@ -3,5 +3,7 @@ ...@@ -3,5 +3,7 @@
"bucket-name-1":["ak", "sk", "endpoint"], "bucket-name-1":["ak", "sk", "endpoint"],
"bucket-name-2":["ak", "sk", "endpoint"] "bucket-name-2":["ak", "sk", "endpoint"]
}, },
"temp-output-dir":"/tmp" "temp-output-dir":"/tmp",
"models-dir":"/tmp/models",
"device-mode":"cpu"
} }
\ No newline at end of file
...@@ -33,13 +33,15 @@ from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_span_bbox ...@@ -33,13 +33,15 @@ from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_span_bbox
from magic_pdf.pipe.UNIPipe import UNIPipe from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.pipe.OCRPipe import OCRPipe from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe from magic_pdf.pipe.TXTPipe import TXTPipe
from magic_pdf.libs.config_reader import get_s3_config
from magic_pdf.libs.path_utils import ( from magic_pdf.libs.path_utils import (
parse_s3path, parse_s3path,
parse_s3_range_params, parse_s3_range_params,
remove_non_official_s3_args, remove_non_official_s3_args,
) )
from magic_pdf.libs.config_reader import get_local_dir from magic_pdf.libs.config_reader import (
get_local_dir,
get_s3_config,
)
from magic_pdf.rw.S3ReaderWriter import S3ReaderWriter from magic_pdf.rw.S3ReaderWriter import S3ReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
......
...@@ -59,5 +59,15 @@ def get_local_dir(): ...@@ -59,5 +59,15 @@ def get_local_dir():
return config.get("temp-output-dir", "/tmp") return config.get("temp-output-dir", "/tmp")
def get_local_models_dir():
config = read_config()
return config.get("models-dir", "/tmp/models")
def get_device():
config = read_config()
return config.get("device-mode", "cpu")
if __name__ == "__main__": if __name__ == "__main__":
ak, sk, endpoint = get_s3_config("llm-raw") ak, sk, endpoint = get_s3_config("llm-raw")
__use_inside_model__ = False __use_inside_model__ = True
__model_mode__ = "lite" __model_mode__ = "full"
...@@ -3,6 +3,8 @@ import time ...@@ -3,6 +3,8 @@ import time
import fitz import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config import magic_pdf.model as model_config
...@@ -61,7 +63,10 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False): ...@@ -61,7 +63,10 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log) custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
elif model == MODEL.PEK: elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
custom_model = CustomPEKModel(ocr=ocr, show_log=show_log) # 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir()
device = get_device()
custom_model = CustomPEKModel(ocr=ocr, show_log=show_log, models_dir=local_models_dir, device=device)
else: else:
logger.error("Not allow model_name!") logger.error("Not allow model_name!")
exit(1) exit(1)
......
...@@ -7,6 +7,7 @@ import yaml ...@@ -7,6 +7,7 @@ import yaml
from PIL import Image from PIL import Image
from ultralytics import YOLO from ultralytics import YOLO
from loguru import logger from loguru import logger
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from unimernet.common.config import Config from unimernet.common.config import Config
import unimernet.tasks as tasks import unimernet.tasks as tasks
...@@ -84,23 +85,26 @@ class CustomPEKModel: ...@@ -84,23 +85,26 @@ class CustomPEKModel:
) )
assert self.apply_layout, "DocAnalysis must contain layout model." assert self.apply_layout, "DocAnalysis must contain layout model."
# 初始化解析方案 # 初始化解析方案
self.device = self.configs["config"]["device"] self.device = kwargs.get("device", self.configs["config"]["device"])
logger.info("using device: {}".format(self.device)) logger.info("using device: {}".format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
# 初始化layout模型 # 初始化layout模型
self.layout_model = layout_model_init( self.layout_model = layout_model_init(
os.path.join(root_dir, self.configs['weights']['layout']), os.path.join(models_dir, self.configs['weights']['layout']),
os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml"), os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml"),
device=self.device device=self.device
) )
# 初始化公式识别 # 初始化公式识别
if self.apply_formula: if self.apply_formula:
# 初始化公式检测模型 # 初始化公式检测模型
self.mfd_model = YOLO(model=str(os.path.join(root_dir, self.configs["weights"]["mfd"]))) self.mfd_model = YOLO(model=str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
# 初始化公式解析模型 # 初始化公式解析模型
mfr_config_path = os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml') mfr_config_path = os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml')
self.mfr_model, mfr_vis_processors = mfr_model_init( self.mfr_model, mfr_vis_processors = mfr_model_init(
os.path.join(root_dir, self.configs["weights"]["mfr"]), mfr_config_path, os.path.join(models_dir, self.configs["weights"]["mfr"]),
device=self.device) mfr_config_path,
device=self.device
)
self.mfr_transform = transforms.Compose([mfr_vis_processors, ]) self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
# 初始化ocr # 初始化ocr
if self.apply_ocr: if self.apply_ocr:
......
...@@ -4,6 +4,6 @@ config: ...@@ -4,6 +4,6 @@ config:
formula: True formula: True
weights: weights:
layout: resources/models/Layout/model_final.pth layout: Layout/model_final.pth
mfd: resources/models/MFD/weights.pt mfd: MFD/weights.pt
mfr: resources/models/MFR/UniMERNet mfr: MFR/UniMERNet
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment