Commit dbe628ee authored by liukaiwen's avatar liukaiwen

add table recognition and conversion to LaTeX

parent 78238f39
...@@ -6,5 +6,8 @@ ...@@ -6,5 +6,8 @@
"temp-output-dir":"/tmp", "temp-output-dir":"/tmp",
"models-dir":"/tmp/models", "models-dir":"/tmp/models",
"device-mode":"cpu", "device-mode":"cpu",
"table-mode":"false" "table-config": {
"is_table_recog_enable": false,
"max_time": 400
}
} }
\ No newline at end of file
...@@ -120,19 +120,21 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""): ...@@ -120,19 +120,21 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
if mode == 'nlp': if mode == 'nlp':
continue continue
elif mode == 'mm': elif mode == 'mm':
table_caption = ''
for block in para_block['blocks']: # 1st.拼table_caption for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TableCaption: if block['type'] == BlockType.TableCaption:
para_text += merge_para_with_text(block) table_caption = merge_para_with_text(block)
para_text += table_caption
for block in para_block['blocks']: # 2nd.拼table_body for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TableBody: if block['type'] == BlockType.TableBody:
for line in block['lines']: for line in block['lines']:
for span in line['spans']: for span in line['spans']:
if span['type'] == ContentType.Table: if span['type'] == ContentType.Table:
# if processed by table model # if processed by table model
if span.get('content', ''): if span.get('latex', ''):
para_text += f"\n\n$\n {span['content']}\n$\n\n" para_text += f"\n\n$\n {span['latex']}\n$\n\n"
else: else:
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n" para_text += f"\n![{table_caption}]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 3rd.拼table_footnote for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TableFootnote: if block['type'] == BlockType.TableFootnote:
para_text += merge_para_with_text(block) para_text += merge_para_with_text(block)
...@@ -253,7 +255,7 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx): ...@@ -253,7 +255,7 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
} }
for block in para_block['blocks']: for block in para_block['blocks']:
if block['type'] == BlockType.TableBody: if block['type'] == BlockType.TableBody:
if block["lines"][0]["spans"][0].get('content', ''): if block["lines"][0]["spans"][0].get('latex', ''):
para_content['table_body'] = f"\n\n$\n {block['lines'][0]['spans'][0]['content']}\n$\n\n" para_content['table_body'] = f"\n\n$\n {block['lines'][0]['spans'][0]['content']}\n$\n\n"
para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path']) para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
if block['type'] == BlockType.TableCaption: if block['type'] == BlockType.TableCaption:
......
...@@ -86,22 +86,10 @@ def get_device(): ...@@ -86,22 +86,10 @@ def get_device():
else: else:
return device return device
def get_table_mode(): def get_table_recog_config():
config = read_config() config = read_config()
table_mode = config.get("table-mode") table_config = config.get("table-config")
if table_mode is None: return table_config
logger.warning(f"'table-mode' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return False
else:
table_mode = table_mode.lower()
if table_mode == "true":
boolean_value = True
elif table_mode == "False":
boolean_value = False
else:
logger.warning(f"invalid 'table-mode' value in {CONFIG_FILE_NAME}, use 'False' as default")
boolean_value = False
return boolean_value
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -4,7 +4,7 @@ import fitz ...@@ -4,7 +4,7 @@ import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_mode from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config import magic_pdf.model as model_config
...@@ -84,12 +84,12 @@ def custom_model_init(ocr: bool = False, show_log: bool = False): ...@@ -84,12 +84,12 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
# 从配置文件读取model-dir和device # 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir() local_models_dir = get_local_models_dir()
device = get_device() device = get_device()
table_mode = get_table_mode() table_config = get_table_recog_config()
model_input = {"ocr": ocr, model_input = {"ocr": ocr,
"show_log": show_log, "show_log": show_log,
"models_dir": local_models_dir, "models_dir": local_models_dir,
"device": device, "device": device,
"table_mode": table_mode} "table_config": table_config}
custom_model = CustomPEKModel(**model_input) custom_model = CustomPEKModel(**model_input)
else: else:
logger.error("Not allow model_name!") logger.error("Not allow model_name!")
......
...@@ -563,7 +563,7 @@ class MagicModel: ...@@ -563,7 +563,7 @@ class MagicModel:
# 获取table模型结果 # 获取table模型结果
latex = layout_det.get("latex", None) latex = layout_det.get("latex", None)
if latex: if latex:
span["content"] = latex span["latex"] = latex
span["type"] = ContentType.Table span["type"] = ContentType.Table
elif category_id == 13: elif category_id == 13:
span["content"] = layout_det["latex"] span["content"] = layout_det["latex"]
......
...@@ -35,8 +35,8 @@ from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR ...@@ -35,8 +35,8 @@ from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
def table_model_init(model_path, _device_ = 'cpu'): def table_model_init(model_path, max_time=400, _device_='cpu'):
table_model = StructTableModel(model_path, device = _device_) table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
return table_model return table_model
...@@ -103,7 +103,7 @@ class CustomPEKModel: ...@@ -103,7 +103,7 @@ class CustomPEKModel:
# 初始化解析配置 # 初始化解析配置
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"]) self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"]) self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
self.apply_table = kwargs.get("table_mode", self.configs["config"]["table"]) self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
self.apply_ocr = ocr self.apply_ocr = ocr
logger.info( logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format( "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
...@@ -139,8 +139,10 @@ class CustomPEKModel: ...@@ -139,8 +139,10 @@ class CustomPEKModel:
self.ocr_model = ModifiedPaddleOCR(show_log=show_log) self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
# init structeqtable # init structeqtable
if self.apply_table: if self.table_config.get("is_table_recog_enable", False):
self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])), _device_=self.device) max_time = self.table_config.get("max_time", 400)
self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
max_time=max_time, _device_=self.device)
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
def __call__(self, image): def __call__(self, image):
...@@ -282,12 +284,11 @@ class CustomPEKModel: ...@@ -282,12 +284,11 @@ class CustomPEKModel:
cropped_img = pil_img.crop(crop_box) cropped_img = pil_img.crop(crop_box)
new_image.paste(cropped_img, (paste_x, paste_y)) new_image.paste(cropped_img, (paste_x, paste_y))
start_time = time.time() start_time = time.time()
print("------------------table recognition processing begins-----------------") logger.info("------------------table recognition processing begins-----------------")
latex_code = self.table_model.image2latex(new_image)[0] latex_code = self.table_model.image2latex(new_image)[0]
end_time = time.time() end_time = time.time()
run_time = end_time - start_time run_time = end_time - start_time
print(f"------------table recognition processing ends within {run_time}s-----") logger.info(f"------------table recognition processing ends within {run_time}s-----")
layout["latex"] = latex_code layout["latex"] = latex_code
return layout_res return layout_res
...@@ -2,7 +2,9 @@ config: ...@@ -2,7 +2,9 @@ config:
device: cpu device: cpu
layout: True layout: True
formula: True formula: True
table: False table_config:
is_table_recog_enable: False
max_time: 400
weights: weights:
layout: Layout/model_final.pth layout: Layout/model_final.pth
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment