Commit a3358878 authored by liukaiwen's avatar liukaiwen

feat: merge formula update

parent 763688c0
This diff is collapsed.
...@@ -5,6 +5,7 @@ import time ...@@ -5,6 +5,7 @@ import time
from magic_pdf.libs.Constants import * from magic_pdf.libs.Constants import *
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from .mfr_cudagraph import GraphRunner
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
...@@ -67,6 +68,11 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'): ...@@ -67,6 +68,11 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
model = task.build_model(cfg) model = task.build_model(cfg)
model.to(_device_) model.to(_device_)
model.eval() model.eval()
model = model.to(_device_)
if 'cuda' in _device_:
decoder_runner = GraphRunner(model.model.model.decoder.model.decoder, max_batchs=128, max_kvlens=256,
device=_device_)
model.model.model.decoder.model.decoder = decoder_runner
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
mfr_transform = transforms.Compose([vis_processor, ]) mfr_transform = transforms.Compose([vis_processor, ])
return [model, mfr_transform] return [model, mfr_transform]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment