Commit 4b372f3f authored by myhloli's avatar myhloli

feat(ocr): pass language parameter for custom model init

Pass the `lang` parameter to `custom_model_init` in `doc_analyze` to support language-specific OCR configurations. This enhancement allows the use of language information to improve OCR accuracy when processing PDFs.
parent f07c2673
...@@ -57,14 +57,14 @@ class ModelSingleton: ...@@ -57,14 +57,14 @@ class ModelSingleton:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def get_model(self, ocr: bool, show_log: bool): def get_model(self, ocr: bool, show_log: bool, lang):
key = (ocr, show_log) key = (ocr, show_log, lang)
if key not in self._models: if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log) self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang)
return self._models[key] return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False): def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
model = None model = None
if model_config.__model_mode__ == "lite": if model_config.__model_mode__ == "lite":
...@@ -78,7 +78,7 @@ def custom_model_init(ocr: bool = False, show_log: bool = False): ...@@ -78,7 +78,7 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
model_init_start = time.time() model_init_start = time.time()
if model == MODEL.Paddle: if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log) custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
elif model == MODEL.PEK: elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
# 从配置文件读取model-dir和device # 从配置文件读取model-dir和device
...@@ -89,7 +89,9 @@ def custom_model_init(ocr: bool = False, show_log: bool = False): ...@@ -89,7 +89,9 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
"show_log": show_log, "show_log": show_log,
"models_dir": local_models_dir, "models_dir": local_models_dir,
"device": device, "device": device,
"table_config": table_config} "table_config": table_config,
"lang": lang,
}
custom_model = CustomPEKModel(**model_input) custom_model = CustomPEKModel(**model_input)
else: else:
logger.error("Not allow model_name!") logger.error("Not allow model_name!")
...@@ -104,10 +106,10 @@ def custom_model_init(ocr: bool = False, show_log: bool = False): ...@@ -104,10 +106,10 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None):
model_manager = ModelSingleton() model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log) custom_model = model_manager.get_model(ocr, show_log, lang)
images = load_images_from_pdf(pdf_bytes) images = load_images_from_pdf(pdf_bytes)
......
...@@ -74,8 +74,11 @@ def layout_model_init(weight, config_file, device): ...@@ -74,8 +74,11 @@ def layout_model_init(weight, config_file, device):
return model return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3): def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None):
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh) if lang is not None:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang)
else:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
return model return model
...@@ -134,7 +137,8 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -134,7 +137,8 @@ def atom_model_init(model_name: str, **kwargs):
elif model_name == AtomicModel.OCR: elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init( atom_model = ocr_model_init(
kwargs.get("ocr_show_log"), kwargs.get("ocr_show_log"),
kwargs.get("det_db_box_thresh") kwargs.get("det_db_box_thresh"),
kwargs.get("lang")
) )
elif model_name == AtomicModel.Table: elif model_name == AtomicModel.Table:
atom_model = table_model_init( atom_model = table_model_init(
...@@ -177,9 +181,10 @@ class CustomPEKModel: ...@@ -177,9 +181,10 @@ class CustomPEKModel:
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE) self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_type = self.table_config.get("model", TABLE_MASTER) self.table_model_type = self.table_config.get("model", TABLE_MASTER)
self.apply_ocr = ocr self.apply_ocr = ocr
self.lang = kwargs.get("lang", None)
logger.info( logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format( "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format(
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang
) )
) )
assert self.apply_layout, "DocAnalysis must contain layout model." assert self.apply_layout, "DocAnalysis must contain layout model."
...@@ -229,7 +234,8 @@ class CustomPEKModel: ...@@ -229,7 +234,8 @@ class CustomPEKModel:
self.ocr_model = atom_model_manager.get_atom_model( self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR, atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log, ocr_show_log=show_log,
det_db_box_thresh=0.3 det_db_box_thresh=0.3,
lang=self.lang
) )
# init table model # init table model
if self.apply_table: if self.apply_table:
......
...@@ -18,8 +18,11 @@ def region_to_bbox(region): ...@@ -18,8 +18,11 @@ def region_to_bbox(region):
class CustomPaddleModel: class CustomPaddleModel:
def __init__(self, ocr: bool = False, show_log: bool = False): def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log) if lang is not None:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
else:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
def __call__(self, img): def __call__(self, img):
try: try:
......
...@@ -17,7 +17,7 @@ class AbsPipe(ABC): ...@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT = "txt" PIP_TXT = "txt"
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None):
self.pdf_bytes = pdf_bytes self.pdf_bytes = pdf_bytes
self.model_list = model_list self.model_list = model_list
self.image_writer = image_writer self.image_writer = image_writer
...@@ -25,6 +25,7 @@ class AbsPipe(ABC): ...@@ -25,6 +25,7 @@ class AbsPipe(ABC):
self.is_debug = is_debug self.is_debug = is_debug
self.start_page_id = start_page_id self.start_page_id = start_page_id
self.end_page_id = end_page_id self.end_page_id = end_page_id
self.lang = lang
def get_compress_pdf_mid_data(self): def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data) return JsonCompressor.compress_json(self.pdf_mid_data)
......
...@@ -10,15 +10,16 @@ from magic_pdf.user_api import parse_ocr_pdf ...@@ -10,15 +10,16 @@ from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe): class OCRPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id) super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
def pipe_classify(self): def pipe_classify(self):
pass pass
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
......
...@@ -11,15 +11,16 @@ from magic_pdf.user_api import parse_txt_pdf ...@@ -11,15 +11,16 @@ from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe): class TXTPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id) super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
def pipe_classify(self): def pipe_classify(self):
pass pass
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
......
...@@ -14,9 +14,9 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf ...@@ -14,9 +14,9 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class UNIPipe(AbsPipe): class UNIPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None):
self.pdf_type = jso_useful_key["_pdf_type"] self.pdf_type = jso_useful_key["_pdf_type"]
super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id) super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, lang)
if len(self.model_list) == 0: if len(self.model_list) == 0:
self.input_model_is_empty = True self.input_model_is_empty = True
else: else:
...@@ -28,10 +28,12 @@ class UNIPipe(AbsPipe): ...@@ -28,10 +28,12 @@ class UNIPipe(AbsPipe):
def pipe_analyze(self): def pipe_analyze(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
def pipe_parse(self): def pipe_parse(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
......
...@@ -44,6 +44,18 @@ auto: automatically choose the best method for parsing pdf from ocr and txt. ...@@ -44,6 +44,18 @@ auto: automatically choose the best method for parsing pdf from ocr and txt.
without method specified, auto will be used by default.""", without method specified, auto will be used by default.""",
default='auto', default='auto',
) )
@click.option(
'-l',
'--lang',
'lang',
type=str,
help="""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
You should input "Abbreviation" with language form url:
https://paddlepaddle.github.io/PaddleOCR/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
""",
default=None,
)
@click.option( @click.option(
'-d', '-d',
'--debug', '--debug',
...@@ -68,7 +80,7 @@ without method specified, auto will be used by default.""", ...@@ -68,7 +80,7 @@ without method specified, auto will be used by default.""",
help='The ending page for PDF parsing, beginning from 0.', help='The ending page for PDF parsing, beginning from 0.',
default=None, default=None,
) )
def cli(path, output_dir, method, debug_able, start_page_id, end_page_id): def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
model_config.__use_inside_model__ = True model_config.__use_inside_model__ = True
model_config.__model_mode__ = 'full' model_config.__model_mode__ = 'full'
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
...@@ -90,6 +102,7 @@ def cli(path, output_dir, method, debug_able, start_page_id, end_page_id): ...@@ -90,6 +102,7 @@ def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
debug_able, debug_able,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
lang=lang
) )
except Exception as e: except Exception as e:
......
...@@ -44,6 +44,7 @@ def do_parse( ...@@ -44,6 +44,7 @@ def do_parse(
f_draw_model_bbox=False, f_draw_model_bbox=False,
start_page_id=0, start_page_id=0,
end_page_id=None, end_page_id=None,
lang=None,
): ):
if debug_able: if debug_able:
logger.warning("debug mode is on") logger.warning("debug mode is on")
...@@ -61,13 +62,13 @@ def do_parse( ...@@ -61,13 +62,13 @@ def do_parse(
if parse_method == 'auto': if parse_method == 'auto':
jso_useful_key = {'_pdf_type': '', 'model_list': model_list} jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True, pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id) start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
elif parse_method == 'txt': elif parse_method == 'txt':
pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True, pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id) start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
elif parse_method == 'ocr': elif parse_method == 'ocr':
pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True, pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id) start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
else: else:
logger.error('unknown parse method') logger.error('unknown parse method')
exit(1) exit(1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment