Commit 1279f2cd authored by myhloli's avatar myhloli

feat(model): add support for DocLayout-YOLO model

- Add new layout model option: DocLayout-YOLO
- Implement model initialization and prediction for DocLayout-YOLO
- Update configuration options to include new model- Modify existing code to support both LayoutLMv3 and DocLayout-YOLO models
- Update Gradio app to support more Custom Switch
parent 790691d6
......@@ -7,7 +7,7 @@
"layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu",
"layout-config": {
"model": "doclayout_yolo"
"model": "layoutlmv3"
},
"formula-config": {
"mfd_model": "yolo_v8_mfd",
......
......@@ -10,18 +10,12 @@ block维度自定义字段
# block中lines是否被删除
LINES_DELETED = "lines_deleted"
# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"
# table recognition max time default value
TABLE_MAX_TIME_VALUE = 400
# pp_table_result_max_length
TABLE_MAX_LEN = 480
# pp table structure algorithm
TABLE_MASTER = "TableMaster"
# table master structure dict
TABLE_MASTER_DICT = "table_master_structure_dict.txt"
......@@ -38,3 +32,16 @@ REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer"
REC_CHAR_DICT = "ppocr_keys_v1.txt"
class MODEL_NAME:
# pp table structure algorithm
TABLE_MASTER = "tablemaster"
# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"
DocLayout_YOLO = "doclayout_yolo"
LAYOUTLMv3 = "layoutlmv3"
YOLO_V8_MFD = "yolo_v8_mfd"
UniMerNet_v2_Small = "unimernet_small"
\ No newline at end of file
......@@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):
# The area of overlap area
return (x_right - x_left) * (y_bottom - y_top)
def calculate_vertical_projection_overlap_ratio(block1, block2):
"""
Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
Args:
block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).
Returns:
float: The proportion of the x-axis covered by the vertical projection of the two blocks.
"""
x0_1, _, x1_1, _ = block1
x0_2, _, x1_2, _ = block2
# Calculate the intersection of the x-coordinates
x_left = max(x0_1, x0_2)
x_right = min(x1_1, x1_2)
if x_right < x_left:
return 0.0
# Length of the intersection
intersection_length = x_right - x_left
# Length of the x-axis projection of the first block
block1_length = x1_1 - x0_1
if block1_length == 0:
return 0.0
# Proportion of the x-axis covered by the intersection
# logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
return intersection_length / block1_length
......@@ -8,6 +8,7 @@ import os
from loguru import logger
from magic_pdf.libs.Constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key
# 定义配置文件名常量
......@@ -94,10 +95,30 @@ def get_table_recog_config():
table_config = config.get("table-config")
if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads('{"is_table_recog_enable": false, "max_time": 400}')
return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
else:
return table_config
def get_layout_config():
config = read_config()
layout_config = config.get("layout-config")
if layout_config is None:
logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
else:
return layout_config
def get_formula_config():
config = read_config()
formula_config = config.get("formula-config")
if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
else:
return formula_config
if __name__ == "__main__":
ak, sk, endpoint = get_s3_config("llm-raw")
......@@ -5,7 +5,8 @@ import numpy as np
from loguru import logger
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
get_formula_config
from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config
......@@ -68,14 +69,17 @@ class ModelSingleton:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, ocr: bool, show_log: bool, lang=None):
key = (ocr, show_log, lang)
def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang)
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
formula_enable=formula_enable, table_enable=table_enable)
return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
model = None
if model_config.__model_mode__ == "lite":
......@@ -95,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
# 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir()
device = get_device()
layout_config = get_layout_config()
if layout_model is not None:
layout_config["model"] = layout_model
formula_config = get_formula_config()
if formula_enable is not None:
formula_config["enable"] = formula_enable
table_config = get_table_recog_config()
model_input = {"ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config,
"lang": lang,
}
if table_enable is not None:
table_config["enable"] = table_enable
model_input = {
"ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config,
"layout_config": layout_config,
"formula_config": formula_config,
"lang": lang,
}
custom_model = CustomPEKModel(**model_input)
else:
logger.error("Not allow model_name!")
......@@ -117,10 +137,14 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None, lang=None):
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
if lang == "":
lang = None
model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang)
custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
with fitz.open("pdf", pdf_bytes) as doc:
pdf_page_num = doc.page_count
......
This diff is collapsed.
......@@ -52,11 +52,11 @@ class ppTableModel(object):
rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
device = kwargs.get("device", "cpu")
use_gpu = True if device == "cuda" else False
use_gpu = True if device.startswith("cuda") else False
config = {
"use_gpu": use_gpu,
"table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
"table_algorithm": TABLE_MASTER,
"table_algorithm": "TableMaster",
"table_model_dir": table_model_dir,
"table_char_dict_path": table_char_dict_path,
"det_model_dir": det_model_dir,
......
......@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT = "txt"
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None):
start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
self.pdf_bytes = pdf_bytes
self.model_list = model_list
self.image_writer = image_writer
......@@ -26,6 +26,9 @@ class AbsPipe(ABC):
self.start_page_id = start_page_id
self.end_page_id = end_page_id
self.lang = lang
self.layout_model = layout_model
self.formula_enable = formula_enable
self.table_enable = table_enable
def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data)
......
......@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self):
pass
......@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe):
def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
......@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self):
pass
......@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe):
def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
......@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class UNIPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None):
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
self.pdf_type = jso_useful_key["_pdf_type"]
super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, lang)
super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id,
lang, layout_model, formula_enable, table_enable)
if len(self.model_list) == 0:
self.input_model_is_empty = True
else:
......@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe):
if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self):
if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug,
......
from loguru import logger
from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \
calculate_iou
calculate_iou, calculate_vertical_projection_overlap_ratio
from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.ocr_content_type import BlockType
from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block
......@@ -97,12 +97,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
# 通过后续大框套小框逻辑删除
'''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
footnote_blocks = []
for discarded in discarded_blocks:
x0, y0, x1, y1 = discarded['bbox']
all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]])
# 将footnote加入到all_bboxes中,用来计算layout
# if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
# all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]])
if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
footnote_blocks.append([x0, y0, x1, y1])
'''移除在footnote下面的任何框'''
need_remove_blocks = find_blocks_under_footnote(all_bboxes, footnote_blocks)
if len(need_remove_blocks) > 0:
for block in need_remove_blocks:
all_bboxes.remove(block)
all_discarded_blocks.append(block)
'''经过以上处理后,还存在大框套小框的情况,则删除小框'''
all_bboxes = remove_overlaps_min_blocks(all_bboxes)
......@@ -113,6 +121,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
return all_bboxes, all_discarded_blocks
def find_blocks_under_footnote(all_bboxes, footnote_blocks):
need_remove_blocks = []
for block in all_bboxes:
block_x0, block_y0, block_x1, block_y1 = block[:4]
for footnote_bbox in footnote_blocks:
footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
if block_y0 >= footnote_y1 and calculate_vertical_projection_overlap_ratio((block_x0, block_y0, block_x1, block_y1), footnote_bbox) >= 0.8:
if block not in need_remove_blocks:
need_remove_blocks.append(block)
break
return need_remove_blocks
def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
# 先提取所有text和interline block
text_blocks = []
......
config:
device: cpu
layout: True
formula: True
table_config:
model: TableMaster
is_table_recog_enable: False
max_time: 400
weights:
layout: Layout/model_final.pth
mfd: MFD/weights.pt
mfr: MFR/unimernet_small
layoutlmv3: Layout/LayoutLMv3/model_final.pth
doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
unimernet_small: MFR/unimernet_small
struct_eqtable: TabRec/StructEqTable
TableMaster: TabRec/TableMaster
\ No newline at end of file
tablemaster: TabRec/TableMaster
\ No newline at end of file
......@@ -46,10 +46,12 @@ def do_parse(
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
if debug_able:
logger.warning('debug mode is on')
# f_dump_content_list = True
f_draw_model_bbox = True
f_draw_line_sort_bbox = True
......@@ -64,13 +66,16 @@ def do_parse(
if parse_method == 'auto':
jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'txt':
pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'ocr':
pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
else:
logger.error('unknown parse method')
exit(1)
......
......@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
if input_model_is_empty:
pdf_models = doc_analyze(pdf_bytes,
ocr=True,
start_page_id=start_page_id,
end_page_id=end_page_id,
lang=lang)
layout_model = kwargs.get("layout_model", None)
formula_enable = kwargs.get("formula_enable", None)
table_enable = kwargs.get("table_enable", None)
pdf_models = doc_analyze(
pdf_bytes,
ocr=True,
start_page_id=start_page_id,
end_page_id=end_page_id,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
if pdf_info_dict is None:
raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
......
......@@ -5,16 +5,21 @@ import requests
from modelscope import snapshot_download
def download_json(url):
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
return response.json()
def download_and_modify_json(url, local_filename, modifications):
if os.path.exists(local_filename):
data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.0.0':
data = download_json(url)
else:
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
# 解析JSON内容
data = response.json()
data = download_json(url)
# 修改内容
for key, value in modifications.items():
......@@ -26,13 +31,21 @@ def download_and_modify_json(url, local_filename, modifications):
if __name__ == '__main__':
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
mineru_patterns = [
"models/Layout/LayoutLMv3/*",
"models/Layout/YOLO/*",
"models/MFD/YOLO/*",
"models/MFR/unimernet_small/*",
"models/TabRec/TableMaster/*",
"models/TabRec/StructEqTable/*",
]
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
model_dir = model_dir + '/models'
print(f'model_dir is: {model_dir}')
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
json_url = 'https://gitee.com/myhloli/MinerU/raw/master/magic-pdf.template.json'
json_url = 'https://gitee.com/myhloli/MinerU/raw/dev/magic-pdf.template.json'
config_file_name = 'magic-pdf.json'
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, config_file_name)
......
......@@ -5,16 +5,21 @@ import requests
from huggingface_hub import snapshot_download
def download_json(url):
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
return response.json()
def download_and_modify_json(url, local_filename, modifications):
if os.path.exists(local_filename):
data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.0.0':
data = download_json(url)
else:
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
# 解析JSON内容
data = response.json()
data = download_json(url)
# 修改内容
for key, value in modifications.items():
......@@ -26,13 +31,28 @@ def download_and_modify_json(url, local_filename, modifications):
if __name__ == '__main__':
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
layoutreader_model_dir = snapshot_download('hantian/layoutreader')
mineru_patterns = [
"models/Layout/LayoutLMv3/*",
"models/Layout/YOLO/*",
"models/MFD/YOLO/*",
"models/MFR/unimernet_small/*",
"models/TabRec/TableMaster/*",
"models/TabRec/StructEqTable/*",
]
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
layoutreader_pattern = [
"*.json",
"*.safetensors",
]
layoutreader_model_dir = snapshot_download('hantian/layoutreader', allow_patterns=layoutreader_pattern)
model_dir = model_dir + '/models'
print(f'model_dir is: {model_dir}')
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
json_url = 'https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json'
json_url = 'https://github.com/opendatalab/MinerU/raw/dev/magic-pdf.template.json'
config_file_name = 'magic-pdf.json'
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, config_file_name)
......
......@@ -23,7 +23,7 @@ def read_fn(path):
return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_enable, table_enable, language):
os.makedirs(output_dir, exist_ok=True)
try:
......@@ -42,6 +42,10 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
parse_method,
False,
end_page_id=end_page_id,
layout_model=layout_mode,
formula_enable=formula_enable,
table_enable=table_enable,
lang=language,
)
return local_md_dir, file_name
except Exception as e:
......@@ -93,9 +97,10 @@ def replace_image_with_base64(markdown_text, image_dir_path):
return re.sub(pattern, replace, markdown_text)
def to_markdown(file_path, end_pages, is_ocr):
def to_markdown(file_path, end_pages, is_ocr, layout_mode, formula_enable, table_enable, language):
# 获取识别的md文件以及压缩包文件路径
local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr)
local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr,
layout_mode, formula_enable, table_enable, language)
archive_zip_path = os.path.join("./output", compute_sha256(local_md_dir) + ".zip")
zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
if zip_archive_success == 0:
......@@ -138,6 +143,27 @@ with open("header.html", "r") as file:
header = file.read()
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
'sa', 'bgc'
]
other_lang = ['ch', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
all_lang = [""]
all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.HTML(header)
......@@ -145,8 +171,14 @@ if __name__ == "__main__":
with gr.Column(variant='panel', scale=5):
pdf_show = gr.Markdown()
max_pages = gr.Slider(1, 10, 5, step=1, label="Max convert pages")
with gr.Row() as bu_flow:
is_ocr = gr.Checkbox(label="Force enable OCR")
with gr.Row():
layout_mode = gr.Dropdown(["layoutlmv3", "doclayout_yolo"], label="Layout model", value="layoutlmv3")
language = gr.Dropdown(all_lang, label="Language", value="")
with gr.Row():
formula_enable = gr.Checkbox(label="Enable formula recognition", value=True)
is_ocr = gr.Checkbox(label="Force enable OCR", value=False)
table_enable = gr.Checkbox(label="Enable table recognition(test)", value=False)
with gr.Row():
change_bu = gr.Button("Convert")
clear_bu = gr.ClearButton([pdf_show], value="Clear")
pdf_show = PDF(label="Please upload pdf", interactive=True, height=800)
......@@ -166,7 +198,8 @@ if __name__ == "__main__":
latex_delimiters=latex_delimiters, line_breaks=True)
with gr.Tab("Markdown text"):
md_text = gr.TextArea(lines=45, show_copy_button=True)
change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr], outputs=[md, md_text, output_file, pdf_show])
change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
outputs=[md, md_text, output_file, pdf_show])
clear_bu.add([md, pdf_show, md_text, output_file, is_ocr])
demo.launch()
\ No newline at end of file
demo.launch(server_name="0.0.0.0")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment