Commit 4b372f3f authored by myhloli's avatar myhloli

feat(ocr): pass language parameter for custom model init

Pass the `lang` parameter to `custom_model_init` in `doc_analyze` to support language-specific OCR configurations. This enhancement allows the use of language information to improve OCR accuracy when processing PDFs.
parent f07c2673
......@@ -57,14 +57,14 @@ class ModelSingleton:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, ocr: bool, show_log: bool):
key = (ocr, show_log)
def get_model(self, ocr: bool, show_log: bool, lang):
key = (ocr, show_log, lang)
if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log)
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang)
return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False):
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
model = None
if model_config.__model_mode__ == "lite":
......@@ -78,7 +78,7 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
model_init_start = time.time()
if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
# 从配置文件读取model-dir和device
......@@ -89,7 +89,9 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config}
"table_config": table_config,
"lang": lang,
}
custom_model = CustomPEKModel(**model_input)
else:
logger.error("Not allow model_name!")
......@@ -104,10 +106,10 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None):
start_page_id=0, end_page_id=None, lang=None):
model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log)
custom_model = model_manager.get_model(ocr, show_log, lang)
images = load_images_from_pdf(pdf_bytes)
......
......@@ -74,8 +74,11 @@ def layout_model_init(weight, config_file, device):
return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None):
if lang is not None:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang)
else:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
return model
......@@ -134,7 +137,8 @@ def atom_model_init(model_name: str, **kwargs):
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get("ocr_show_log"),
kwargs.get("det_db_box_thresh")
kwargs.get("det_db_box_thresh"),
kwargs.get("lang")
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
......@@ -177,9 +181,10 @@ class CustomPEKModel:
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_type = self.table_config.get("model", TABLE_MASTER)
self.apply_ocr = ocr
self.lang = kwargs.get("lang", None)
logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format(
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang
)
)
assert self.apply_layout, "DocAnalysis must contain layout model."
......@@ -229,7 +234,8 @@ class CustomPEKModel:
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3
det_db_box_thresh=0.3,
lang=self.lang
)
# init table model
if self.apply_table:
......
......@@ -18,8 +18,11 @@ def region_to_bbox(region):
class CustomPaddleModel:
def __init__(self, ocr: bool = False, show_log: bool = False):
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
if lang is not None:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
else:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
def __call__(self, img):
try:
......
......@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT = "txt"
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None):
start_page_id=0, end_page_id=None, lang=None):
self.pdf_bytes = pdf_bytes
self.model_list = model_list
self.image_writer = image_writer
......@@ -25,6 +25,7 @@ class AbsPipe(ABC):
self.is_debug = is_debug
self.start_page_id = start_page_id
self.end_page_id = end_page_id
self.lang = lang
def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data)
......
......@@ -10,15 +10,16 @@ from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
start_page_id=0, end_page_id=None, lang=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
def pipe_classify(self):
pass
def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id)
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
......
......@@ -11,15 +11,16 @@ from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
start_page_id=0, end_page_id=None, lang=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
def pipe_classify(self):
pass
def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id)
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
......
......@@ -14,9 +14,9 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class UNIPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None):
start_page_id=0, end_page_id=None, lang=None):
self.pdf_type = jso_useful_key["_pdf_type"]
super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id)
super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, lang)
if len(self.model_list) == 0:
self.input_model_is_empty = True
else:
......@@ -28,10 +28,12 @@ class UNIPipe(AbsPipe):
def pipe_analyze(self):
if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id)
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id)
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
def pipe_parse(self):
if self.pdf_type == self.PIP_TXT:
......
......@@ -44,6 +44,18 @@ auto: automatically choose the best method for parsing pdf from ocr and txt.
without method specified, auto will be used by default.""",
default='auto',
)
@click.option(
'-l',
'--lang',
'lang',
type=str,
help="""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
You should input "Abbreviation" with language form url:
https://paddlepaddle.github.io/PaddleOCR/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
""",
default=None,
)
@click.option(
'-d',
'--debug',
......@@ -68,7 +80,7 @@ without method specified, auto will be used by default.""",
help='The ending page for PDF parsing, beginning from 0.',
default=None,
)
def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
model_config.__use_inside_model__ = True
model_config.__model_mode__ = 'full'
os.makedirs(output_dir, exist_ok=True)
......@@ -90,6 +102,7 @@ def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
debug_able,
start_page_id=start_page_id,
end_page_id=end_page_id,
lang=lang
)
except Exception as e:
......
......@@ -44,6 +44,7 @@ def do_parse(
f_draw_model_bbox=False,
start_page_id=0,
end_page_id=None,
lang=None,
):
if debug_able:
logger.warning("debug mode is on")
......@@ -61,13 +62,13 @@ def do_parse(
if parse_method == 'auto':
jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id)
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
elif parse_method == 'txt':
pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id)
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
elif parse_method == 'ocr':
pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id)
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
else:
logger.error('unknown parse method')
exit(1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment