Unverified Commit 37925f36 authored by Kaiwen Liu's avatar Kaiwen Liu Committed by GitHub

feat(model inference): add table recognition and conversion to LaTeX (#284)

* # add table recognition using struct-eqtable
## Changelog
31/07/20204
- Support table recognition. Table images will be converted into html.

### how to use the new feature:
set the attribute 'table-mode' to 'true' in magic-pdf.json

### caution:
it takes 200s to 500s to convert a single table image using cpu

* # add table recognition using struct-eqtable
## Changelog
31/07/20204
- Support table recognition. Table images will be converted into LaTex.

### how to use the new feature:
set the attribute 'table-mode' to 'true' in magic-pdf.json

### caution:
it takes 200s to 500s to convert a single table image using cpu

* # feat(model inference): add table recognition and convertion to LaTeX

# What's Changed

### New Features

- Add table content recognition, we use weights of [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) to convert table image to LaTex.

### Instruction

- pip install pypandoc struct-eqtable==0.1.0
- Download [StructEqTable weights](https://huggingface.co/wanderkid/PDF-Extract-Kit/tree/main/models/TabRec) and put it under models/ directory.
- Edit 'table-mode' value to turn on table recognition function which is turned off by default.
- If you did not download any models before, refer to [how to download models](docs/how_to_download_models_zh_cn.md)。

* add table recognition and convertion to LaTeX

* add table recognition and conversion to LaTeX

* add table recognition and conversion to LaTeX

* add table recognition and conversion to LaTeX

---------
Co-authored-by: 's avatarliukaiwen <liukaiwen@pjlab.org.cn>
parent 41737adf
...@@ -92,6 +92,7 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c ...@@ -92,6 +92,7 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
- 保留原文档的结构,包括标题、段落、列表等 - 保留原文档的结构,包括标题、段落、列表等
- 提取图像、图片标题、表格、表格标题 - 提取图像、图片标题、表格、表格标题
- 自动识别文档中的公式并将公式转换成latex - 自动识别文档中的公式并将公式转换成latex
- 自动识别文档中的表格并将表格转换成latex
- 乱码PDF自动检测并启用OCR - 乱码PDF自动检测并启用OCR
- 支持CPU和GPU环境 - 支持CPU和GPU环境
- 支持windows/linux/mac平台 - 支持windows/linux/mac平台
...@@ -274,7 +275,7 @@ TODO ...@@ -274,7 +275,7 @@ TODO
- [ ] 正文中列表识别 - [ ] 正文中列表识别
- [ ] 正文中代码块识别 - [ ] 正文中代码块识别
- [ ] 目录识别 - [ ] 目录识别
- [ ] 表格识别 - [x] 表格识别
- [ ] 化学式识别 - [ ] 化学式识别
- [ ] 几何图形识别 - [ ] 几何图形识别
...@@ -311,6 +312,7 @@ The project currently leverages PyMuPDF to deliver advanced functionalities; how ...@@ -311,6 +312,7 @@ The project currently leverages PyMuPDF to deliver advanced functionalities; how
- [PyMuPDF](https://github.com/pymupdf/PyMuPDF) - [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
- [fast-langdetect](https://github.com/LlmKira/fast-langdetect) - [fast-langdetect](https://github.com/LlmKira/fast-langdetect)
- [pdfminer.six](https://github.com/pdfminer/pdfminer.six) - [pdfminer.six](https://github.com/pdfminer/pdfminer.six)
- [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
# Citation # Citation
......
...@@ -82,6 +82,16 @@ git lfs clone https://www.modelscope.cn/wanderkid/PDF-Extract-Kit.git ...@@ -82,6 +82,16 @@ git lfs clone https://www.modelscope.cn/wanderkid/PDF-Extract-Kit.git
│ ├── README.md │ ├── README.md
│ ├── tokenizer_config.json │ ├── tokenizer_config.json
│ └── tokenizer.json │ └── tokenizer.json
│── TabRec
│ └─StructEqTable
│ ├── config.json
│ ├── generation_config.json
│ ├── model.safetensors
│ ├── preprocessor_config.json
│ ├── special_tokens_map.json
│ ├── spiece.model
│ ├── tokenizer.json
│ └── tokenizer_config.json
└── README.md └── README.md
``` ```
......
...@@ -4,5 +4,9 @@ ...@@ -4,5 +4,9 @@
"bucket-name-2":["ak", "sk", "endpoint"] "bucket-name-2":["ak", "sk", "endpoint"]
}, },
"models-dir":"/tmp/models", "models-dir":"/tmp/models",
"device-mode":"cpu" "device-mode":"cpu",
"table-config": {
"is_table_recog_enable": false,
"max_time": 400
}
} }
\ No newline at end of file
...@@ -120,15 +120,20 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""): ...@@ -120,15 +120,20 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
if mode == 'nlp': if mode == 'nlp':
continue continue
elif mode == 'mm': elif mode == 'mm':
table_caption = ''
for block in para_block['blocks']: # 1st.拼table_caption for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TableCaption: if block['type'] == BlockType.TableCaption:
para_text += merge_para_with_text(block) table_caption = merge_para_with_text(block)
for block in para_block['blocks']: # 2nd.拼table_body for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TableBody: if block['type'] == BlockType.TableBody:
for line in block['lines']: for line in block['lines']:
for span in line['spans']: for span in line['spans']:
if span['type'] == ContentType.Table: if span['type'] == ContentType.Table:
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n" # if processed by table model
if span.get('latex', ''):
para_text += f"\n\n$\n {span['latex']}\n$\n\n"
else:
para_text += f"\n![{table_caption}]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 3rd.拼table_footnote for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TableFootnote: if block['type'] == BlockType.TableFootnote:
para_text += merge_para_with_text(block) para_text += merge_para_with_text(block)
...@@ -249,6 +254,8 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx): ...@@ -249,6 +254,8 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
} }
for block in para_block['blocks']: for block in para_block['blocks']:
if block['type'] == BlockType.TableBody: if block['type'] == BlockType.TableBody:
if block["lines"][0]["spans"][0].get('latex', ''):
para_content['table_body'] = f"\n\n$\n {block['lines'][0]['spans'][0]['content']}\n$\n\n"
para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path']) para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
if block['type'] == BlockType.TableCaption: if block['type'] == BlockType.TableCaption:
para_content['table_caption'] = merge_para_with_text(block) para_content['table_caption'] = merge_para_with_text(block)
......
...@@ -76,6 +76,11 @@ def get_device(): ...@@ -76,6 +76,11 @@ def get_device():
else: else:
return device return device
def get_table_recog_config():
config = read_config()
table_config = config.get("table-config")
return table_config
if __name__ == "__main__": if __name__ == "__main__":
ak, sk, endpoint = get_s3_config("llm-raw") ak, sk, endpoint = get_s3_config("llm-raw")
...@@ -4,7 +4,7 @@ import fitz ...@@ -4,7 +4,7 @@ import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.libs.config_reader import get_local_models_dir, get_device from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config import magic_pdf.model as model_config
...@@ -84,7 +84,13 @@ def custom_model_init(ocr: bool = False, show_log: bool = False): ...@@ -84,7 +84,13 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
# 从配置文件读取model-dir和device # 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir() local_models_dir = get_local_models_dir()
device = get_device() device = get_device()
custom_model = CustomPEKModel(ocr=ocr, show_log=show_log, models_dir=local_models_dir, device=device) table_config = get_table_recog_config()
model_input = {"ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config}
custom_model = CustomPEKModel(**model_input)
else: else:
logger.error("Not allow model_name!") logger.error("Not allow model_name!")
exit(1) exit(1)
......
...@@ -560,6 +560,10 @@ class MagicModel: ...@@ -560,6 +560,10 @@ class MagicModel:
if category_id == 3: if category_id == 3:
span["type"] = ContentType.Image span["type"] = ContentType.Image
elif category_id == 5: elif category_id == 5:
# 获取table模型结果
latex = layout_det.get("latex", None)
if latex:
span["latex"] = latex
span["type"] = ContentType.Table span["type"] = ContentType.Table
elif category_id == 13: elif category_id == 13:
span["content"] = layout_det["latex"] span["content"] = layout_det["latex"]
......
from loguru import logger from loguru import logger
import os import os
import time import time
from pypandoc import convert_text
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
try: try:
...@@ -10,6 +11,7 @@ try: ...@@ -10,6 +11,7 @@ try:
import numpy as np import numpy as np
import torch import torch
import torchtext import torchtext
if torchtext.__version__ >= "0.18.0": if torchtext.__version__ >= "0.18.0":
torchtext.disable_torchtext_deprecation_warning() torchtext.disable_torchtext_deprecation_warning()
from PIL import Image from PIL import Image
...@@ -30,6 +32,12 @@ except ImportError as e: ...@@ -30,6 +32,12 @@ except ImportError as e:
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
def table_model_init(model_path, max_time=400, _device_='cpu'):
table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
return table_model
def mfd_model_init(weight): def mfd_model_init(weight):
...@@ -95,6 +103,8 @@ class CustomPEKModel: ...@@ -95,6 +103,8 @@ class CustomPEKModel:
# 初始化解析配置 # 初始化解析配置
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"]) self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"]) self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
self.apply_table = self.table_config.get("is_table_recog_enable", False)
self.apply_ocr = ocr self.apply_ocr = ocr
logger.info( logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format( "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
...@@ -129,6 +139,11 @@ class CustomPEKModel: ...@@ -129,6 +139,11 @@ class CustomPEKModel:
if self.apply_ocr: if self.apply_ocr:
self.ocr_model = ModifiedPaddleOCR(show_log=show_log) self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
# init structeqtable
if self.apply_table:
max_time = self.table_config.get("max_time", 400)
self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
max_time=max_time, _device_=self.device)
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
def __call__(self, image): def __call__(self, image):
...@@ -249,4 +264,32 @@ class CustomPEKModel: ...@@ -249,4 +264,32 @@ class CustomPEKModel:
ocr_cost = round(time.time() - ocr_start, 2) ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr cost: {ocr_cost}") logger.info(f"ocr cost: {ocr_cost}")
# 表格识别 table recognition
if self.apply_table:
pil_img = Image.fromarray(image)
for layout in layout_res:
if layout.get("category_id", -1) == 5:
poly = layout["poly"]
xmin, ymin = int(poly[0]), int(poly[1])
xmax, ymax = int(poly[4]), int(poly[5])
paste_x = 50
paste_y = 50
# 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
new_width = xmax - xmin + paste_x * 2
new_height = ymax - ymin + paste_y * 2
new_image = Image.new('RGB', (new_width, new_height), 'white')
# 裁剪图像 crop image
crop_box = (xmin, ymin, xmax, ymax)
cropped_img = pil_img.crop(crop_box)
new_image.paste(cropped_img, (paste_x, paste_y))
start_time = time.time()
logger.info("------------------table recognition processing begins-----------------")
latex_code = self.table_model.image2latex(new_image)[0]
end_time = time.time()
run_time = end_time - start_time
logger.info(f"------------table recognition processing ends within {run_time}s-----")
layout["latex"] = latex_code
return layout_res return layout_res
from struct_eqtable.model import StructTable
from pypandoc import convert_text
class StructTableModel:
def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
# init
self.model_path = model_path
self.max_new_tokens = max_new_tokens # maximum output tokens length
self.max_time = max_time # timeout for processing in seconds
if device == 'cuda':
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
else:
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
def image2latex(self, image) -> str:
#
table_latex = self.model.forward(image)
return table_latex
def image2html(self, image) -> str:
table_latex = self.image2latex(image)
table_html = convert_text(table_latex, 'html', format='latex')
return table_html
...@@ -2,8 +2,12 @@ config: ...@@ -2,8 +2,12 @@ config:
device: cpu device: cpu
layout: True layout: True
formula: True formula: True
table_config:
is_table_recog_enable: False
max_time: 400
weights: weights:
layout: Layout/model_final.pth layout: Layout/model_final.pth
mfd: MFD/weights.pt mfd: MFD/weights.pt
mfr: MFR/UniMERNet mfr: MFR/UniMERNet
table: TabRec/StructEqTable
\ No newline at end of file
...@@ -14,3 +14,4 @@ tqdm ...@@ -14,3 +14,4 @@ tqdm
htmltabletomd htmltabletomd
pypandoc pypandoc
pyopenssl==24.0.0 pyopenssl==24.0.0
struct-eqtable==0.1.0
\ No newline at end of file
...@@ -8,4 +8,6 @@ fast-langdetect==0.2.0 ...@@ -8,4 +8,6 @@ fast-langdetect==0.2.0
wordninja>=2.0.0 wordninja>=2.0.0
scikit-learn>=1.0.2 scikit-learn>=1.0.2
pdfminer.six==20231228 pdfminer.six==20231228
pypandoc
struct-eqtable==0.1.0
# The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator. # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment