Unverified Commit 8786d208 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub

Merge pull request #698 from myhloli/dev

feat(layoutreader): support local model directory and improve model loading
parents 3fb0494b 0b2b0cef
......@@ -395,6 +395,7 @@ This project currently uses PyMuPDF to achieve advanced functionality. However,
- [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
- [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
- [layoutreader](https://github.com/ppaanngggg/layoutreader)
- [fast-langdetect](https://github.com/LlmKira/fast-langdetect)
- [pdfminer.six](https://github.com/pdfminer/pdfminer.six)
......
......@@ -400,6 +400,7 @@ TODO
- [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
- [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
- [layoutreader](https://github.com/ppaanngggg/layoutreader)
- [fast-langdetect](https://github.com/LlmKira/fast-langdetect)
- [pdfminer.six](https://github.com/pdfminer/pdfminer.six)
......
# use modelscope sdk download models
from modelscope import snapshot_download
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
print(f"model dir is: {model_dir}/models")
from huggingface_hub import snapshot_download
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
layoutreader_model_dir = snapshot_download('hantian/layoutreader')
print(f"model dir is: {model_dir}/models")
......@@ -38,6 +38,13 @@ python脚本执行完毕后,会输出模型下载目录
如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
> 0.9.x及以后版本由于新增layout排序模型,且该模型和此前的模型不在同一仓库,不能通过`git pull`命令更新,需要单独下载。
>
>```
>from modelscope import snapshot_download
>snapshot_download('ppaanngggg/layoutreader')
>```
## 2. 通过 Hugging Face 或 Model Scope 下载过模型
如此前通过 HuggingFace 或 Model Scope 下载过模型,可以重复执行此前的模型下载python脚本,将会自动将模型目录更新到最新版本。
......@@ -4,6 +4,7 @@
"bucket-name-2":["ak", "sk", "endpoint"]
},
"models-dir":"/tmp/models",
"layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu",
"table-config": {
"model": "TableMaster",
......
......@@ -67,6 +67,18 @@ def get_local_models_dir():
return models_dir
def get_local_layoutreader_model_dir():
config = read_config()
layoutreader_model_dir = config.get("layoutreader-model-dir")
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser("~")
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, ".cache/modelscope/hub/ppaanngggg/layoutreader")
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir
def get_device():
config = read_config()
device = config.get("device-mode")
......
import os
import statistics
import time
......@@ -9,6 +10,7 @@ import torch
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5
......@@ -95,7 +97,7 @@ def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
def model_init(model_name: str, local_path=None):
def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification
if torch.cuda.is_available():
device = torch.device("cuda")
......@@ -108,9 +110,13 @@ def model_init(model_name: str, local_path=None):
supports_bfloat16 = False
if model_name == "layoutreader":
if local_path:
model = LayoutLMv3ForTokenClassification.from_pretrained(local_path)
# 检测modelscope的缓存目录是否存在
layoutreader_model_dir = get_local_layoutreader_model_dir()
if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(layoutreader_model_dir)
else:
logger.warning(
f"local layoutreader model not exists, use online model from huggingface")
model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
# 检查设备是否支持 bfloat16
if supports_bfloat16:
......@@ -131,12 +137,9 @@ class ModelSingleton:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, model_name: str, local_path=None):
def get_model(self, model_name: str):
if model_name not in self._models:
if local_path:
self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
else:
self._models[model_name] = model_init(model_name=model_name)
self._models[model_name] = model_init(model_name=model_name)
return self._models[model_name]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment