Commit ded2818a authored by myhloli's avatar myhloli

feat(layoutreader): support local model directory and improve model loading

- Add function to get local LayoutReader model directory- Check and use local model directory if available
- Fall back to online model if local directory not found
- Update model initialization to support local path
- Refactor model loading in singleton class
parent 3fb0494b
# use modelscope sdk download models # use modelscope sdk download models
from modelscope import snapshot_download from modelscope import snapshot_download
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit') model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
print(f"model dir is: {model_dir}/models") print(f"model dir is: {model_dir}/models")
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit') model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
layoutreader_model_dir = snapshot_download('hantian/layoutreader')
print(f"model dir is: {model_dir}/models") print(f"model dir is: {model_dir}/models")
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"bucket-name-2":["ak", "sk", "endpoint"] "bucket-name-2":["ak", "sk", "endpoint"]
}, },
"models-dir":"/tmp/models", "models-dir":"/tmp/models",
"layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu", "device-mode":"cpu",
"table-config": { "table-config": {
"model": "TableMaster", "model": "TableMaster",
......
...@@ -67,6 +67,18 @@ def get_local_models_dir(): ...@@ -67,6 +67,18 @@ def get_local_models_dir():
return models_dir return models_dir
def get_local_layoutreader_model_dir():
config = read_config()
layoutreader_model_dir = config.get("layoutreader-model-dir")
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser("~")
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, ".cache/modelscope/hub/ppaanngggg/layoutreader")
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir
def get_device(): def get_device():
config = read_config() config = read_config()
device = config.get("device-mode") device = config.get("device-mode")
......
import os
import statistics import statistics
import time import time
...@@ -9,6 +10,7 @@ import torch ...@@ -9,6 +10,7 @@ import torch
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.commons import fitz, get_delta_time from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
from magic_pdf.libs.convert_utils import dict_to_list from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.drop_reason import DropReason from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5 from magic_pdf.libs.hash_utils import compute_md5
...@@ -95,7 +97,7 @@ def replace_text_span(pymu_spans, ocr_spans): ...@@ -95,7 +97,7 @@ def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
def model_init(model_name: str, local_path=None): def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification from transformers import LayoutLMv3ForTokenClassification
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
...@@ -108,9 +110,13 @@ def model_init(model_name: str, local_path=None): ...@@ -108,9 +110,13 @@ def model_init(model_name: str, local_path=None):
supports_bfloat16 = False supports_bfloat16 = False
if model_name == "layoutreader": if model_name == "layoutreader":
if local_path: # 检测modelscope的缓存目录是否存在
model = LayoutLMv3ForTokenClassification.from_pretrained(local_path) layoutreader_model_dir = get_local_layoutreader_model_dir()
if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(layoutreader_model_dir)
else: else:
logger.warning(
f"local layoutreader model not exists, use online model from huggingface")
model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader") model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
# 检查设备是否支持 bfloat16 # 检查设备是否支持 bfloat16
if supports_bfloat16: if supports_bfloat16:
...@@ -131,11 +137,8 @@ class ModelSingleton: ...@@ -131,11 +137,8 @@ class ModelSingleton:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def get_model(self, model_name: str, local_path=None): def get_model(self, model_name: str):
if model_name not in self._models: if model_name not in self._models:
if local_path:
self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
else:
self._models[model_name] = model_init(model_name=model_name) self._models[model_name] = model_init(model_name=model_name)
return self._models[model_name] return self._models[model_name]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment