Unverified Commit 6d571e2e authored by Kaiwen Liu's avatar Kaiwen Liu Committed by GitHub

Merge pull request #7 from opendatalab/dev

Dev
parents a3358878 37c335ae
......@@ -43,3 +43,8 @@ projects/web/node_modules
projects/web/dist
projects/web_demo/web_demo/static/
cli_debug/
debug_utils/
# sphinx docs
_build/
......@@ -3,7 +3,7 @@ repos:
rev: 5.0.4
hooks:
- id: flake8
args: ["--max-line-length=120", "--ignore=E131,E125,W503,W504,E203"]
args: ["--max-line-length=150", "--ignore=E131,E125,W503,W504,E203"]
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
......@@ -12,11 +12,12 @@ repos:
rev: v0.32.0
hooks:
- id: yapf
args: ["--style={based_on_style: google, column_limit: 120, indent_width: 4}"]
args: ["--style={based_on_style: google, column_limit: 150, indent_width: 4}"]
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
hooks:
- id: codespell
args: ['--skip', '*.json']
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
......
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.10"
formats:
- epub
python:
install:
- requirements: docs/zh_cn/requirements.txt
sphinx:
configuration: docs/zh_cn/conf.py
......@@ -41,6 +41,17 @@
</div>
# Changelog
- 2024/10/31 0.9.0 released. This is a major new version with extensive code refactoring, addressing numerous issues, improving performance, reducing hardware requirements, and enhancing usability:
- Refactored the sorting module code to use [layoutreader](https://github.com/ppaanngggg/layoutreader) for reading order sorting, ensuring high accuracy in various layouts.
- Refactored the paragraph concatenation module to achieve good results in cross-column, cross-page, cross-figure, and cross-table scenarios.
- Refactored the list and table of contents recognition functions, significantly improving the accuracy of list blocks and table of contents blocks, as well as the parsing of corresponding text paragraphs.
- Refactored the matching logic for figures, tables, and descriptive text, greatly enhancing the accuracy of matching captions and footnotes to figures and tables, and reducing the loss rate of descriptive text to zero.
- Added multi-language support for OCR, supporting detection and recognition of 84 languages.For the list of supported languages, see [OCR Language Support List](https://paddlepaddle.github.io/PaddleOCR/latest/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations).
- Added memory recycling logic and other memory optimization measures, significantly reducing memory usage. The memory requirement for enabling all acceleration features except table acceleration (layout/formula/OCR) has been reduced from 16GB to 8GB, and the memory requirement for enabling all acceleration features has been reduced from 24GB to 10GB.
- Optimized configuration file feature switches, adding an independent formula detection switch to significantly improve speed and parsing results when formula detection is not needed.
- Integrated [PDF-Extract-Kit 1.0](https://github.com/opendatalab/PDF-Extract-Kit):
- Added the self-developed `doclayout_yolo` model, which speeds up processing by more than 10 times compared to the original solution while maintaining similar parsing effects, and can be freely switched with `layoutlmv3` via the configuration file.
- Upgraded formula parsing to `unimernet 0.2.1`, improving formula parsing accuracy while significantly reducing memory usage.
- 2024/09/27 Version 0.8.1 released, Fixed some bugs, and providing a [localized deployment version](projects/web_demo/README.md) of the [online demo](https://opendatalab.com/OpenSourceTools/Extractor/PDF/) and the [front-end interface](projects/web/README.md).
- 2024/09/09: Version 0.8.0 released, supporting fast deployment with Dockerfile, and launching demos on Huggingface and Modelscope.
- 2024/08/30: Version 0.7.1 released, add paddle tablemaster table recognition option
......@@ -69,6 +80,7 @@
<ul>
<li><a href="#command-line">Command Line</a></li>
<li><a href="#api">API</a></li>
<li><a href="#deploy-derived-projects">Deploy Derived Projects</a></li>
<li><a href="#development-guide">Development Guide</a></li>
</ul>
</li>
......@@ -100,15 +112,18 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
## Key Features
- Removes elements such as headers, footers, footnotes, and page numbers while maintaining semantic continuity
- Outputs text in a human-readable order from multi-column documents
- Retains the original structure of the document, including titles, paragraphs, and lists
- Extracts images, image captions, tables, and table captions
- Automatically recognizes formulas in the document and converts them to LaTeX
- Automatically recognizes tables in the document and converts them to LaTeX
- Automatically detects and enables OCR for corrupted PDFs
- Supports both CPU and GPU environments
- Supports Windows, Linux, and Mac platforms
- Remove headers, footers, footnotes, page numbers, etc., to ensure semantic coherence.
- Output text in human-readable order, suitable for single-column, multi-column, and complex layouts.
- Preserve the structure of the original document, including headings, paragraphs, lists, etc.
- Extract images, image descriptions, tables, table titles, and footnotes.
- Automatically recognize and convert formulas in the document to LaTeX format.
- Automatically recognize and convert tables in the document to LaTeX or HTML format.
- Automatically detect scanned PDFs and garbled PDFs and enable OCR functionality.
- OCR supports detection and recognition of 84 languages.
- Supports multiple output formats, such as multimodal and NLP Markdown, JSON sorted by reading order, and rich intermediate formats.
- Supports various visualization results, including layout visualization and span visualization, for efficient confirmation of output quality.
- Supports both CPU and GPU environments.
- Compatible with Windows, Linux, and Mac platforms.
## Quick Start
......@@ -139,8 +154,8 @@ In non-mainline environments, due to the diversity of hardware and software conf
</tr>
<tr>
<td colspan="3">CPU</td>
<td>x86_64</td>
<td>x86_64</td>
<td>x86_64(unsupported ARM Linux)</td>
<td>x86_64(unsupported ARM Windows)</td>
<td>x86_64 / arm64</td>
</tr>
<tr>
......@@ -149,7 +164,7 @@ In non-mainline environments, due to the diversity of hardware and software conf
</tr>
<tr>
<td colspan="3">Python Version</td>
<td colspan="3">3.10</td>
<td colspan="3">3.10(Please make sure to create a Python 3.10 virtual environment using conda)</td>
</tr>
<tr>
<td colspan="3">Nvidia Driver Version</td>
......@@ -166,22 +181,24 @@ In non-mainline environments, due to the diversity of hardware and software conf
<tr>
<td rowspan="2">GPU Hardware Support List</td>
<td colspan="2">Minimum Requirement 8G+ VRAM</td>
<td colspan="2">3060ti/3070/3080/3080ti/4060/4070/4070ti<br>
8G VRAM only enables layout and formula recognition acceleration</td>
<td colspan="2">3060ti/3070/4060<br>
8G VRAM enables layout, formula recognition acceleration and OCR acceleration</td>
<td rowspan="2">None</td>
</tr>
<tr>
<td colspan="2">Recommended Configuration 16G+ VRAM</td>
<td colspan="2">3090/3090ti/4070ti super/4080/4090<br>
16G or more can enable layout, formula recognition, and OCR acceleration simultaneously<br>
24G or more can enable layout, formula recognition, OCR acceleration and table recognition simultaneously
<td colspan="2">Recommended Configuration 10G+ VRAM</td>
<td colspan="2">3080/3080ti/3090/3090ti/4070/4070ti/4070tisuper/4080/4090<br>
10G VRAM or more can enable layout, formula recognition, OCR acceleration and table recognition acceleration simultaneously
</td>
</tr>
</table>
### Online Demo
Stable Version (Stable version verified by QA):
[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://opendatalab.com/OpenSourceTools/Extractor/PDF)
Test Version (Synced with dev branch updates, testing new features):
[![HuggingFace](https://img.shields.io/badge/Demo_on_HuggingFace-yellow.svg?logo=&labelColor=white)](https://huggingface.co/spaces/opendatalab/MinerU)
[![ModelScope](https://img.shields.io/badge/Demo_on_ModelScope-purple?logo=&labelColor=white)](https://www.modelscope.cn/studios/OpenDataLab/MinerU)
......@@ -199,37 +216,31 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
Refer to [How to Download Model Files](docs/how_to_download_models_en.md) for detailed instructions.
> ❗️After downloading the models, please make sure to verify the completeness of the model files.
>
> Check if the model file sizes match the description on the webpage. If possible, use sha256 to verify the integrity of the files.
#### 3. Copy and configure the template file
You can find the `magic-pdf.template.json` template configuration file in the root directory of the repository.
#### 3. Modify the Configuration File for Additional Configuration
> ❗️Make sure to execute the following command to copy the configuration file to your **user directory**; otherwise, the program will not run.
>
> The user directory for Windows is `C:\Users\YourUsername`, for Linux it is `/home/YourUsername`, and for macOS it is `/Users/YourUsername`.
After completing the [2. Download model weight files](#2-download-model-weight-files) step, the script will automatically generate a `magic-pdf.json` file in the user directory and configure the default model path.
You can find the `magic-pdf.json` file in your 【user directory】.
```bash
cp magic-pdf.template.json ~/magic-pdf.json
```
> The user directory for Windows is "C:\\Users\\username", for Linux it is "/home/username", and for macOS it is "/Users/username".
Find the `magic-pdf.json` file in your user directory and configure the "models-dir" path to point to the directory where the model weight files were downloaded in [Step 2](#2-download-model-weight-files).
You can modify certain configurations in this file to enable or disable features, such as table recognition:
> ❗️Make sure to correctly configure the **absolute path** to the model weight files directory, otherwise the program will not run because it can't find the model files.
>
> On Windows, this path should include the drive letter and all backslashes (`\`) in the path should be replaced with forward slashes (`/`) to avoid syntax errors in the JSON file due to escape sequences.
>
> For example: If the models are stored in the "models" directory at the root of the D drive, the "model-dir" value should be `D:/models`.
> If the following items are not present in the JSON, please manually add the required items and remove the comment content (standard JSON does not support comments).
```json
{
// other config
"models-dir": "D:/models",
"layout-config": {
"model": "layoutlmv3" // Please change to "doclayout_yolo" when using doclayout_yolo.
},
"formula-config": {
"mfd_model": "yolo_v8_mfd",
"mfr_model": "unimernet_small",
"enable": true // The formula recognition feature is enabled by default. If you need to disable it, please change the value here to "false".
},
"table-config": {
"model": "TableMaster", // Another option of this value is 'struct_eqtable'
"is_table_recog_enable": false, // Table recognition is disabled by default, modify this value to enable it
"model": "tablemaster", // When using structEqTable, please change to "struct_eqtable".
"enable": false, // The table recognition feature is disabled by default. If you need to enable it, please change the value here to "true".
"max_time": 400
}
}
......@@ -278,8 +289,8 @@ Options:
-l, --lang TEXT Input the languages in the pdf (if known) to
improve OCR accuracy. Optional. You should
input "Abbreviation" with language form url: ht
tps://paddlepaddle.github.io/PaddleOCR/en/ppocr
/blog/multi_languages.html#5-support-languages-
tps://paddlepaddle.github.io/PaddleOCR/latest/en
/ppocr/blog/multi_languages.html#5-support-languages-
and-abbreviations
-d, --debug BOOLEAN Enables detailed debugging information during
the execution of the CLI commands.
......@@ -303,11 +314,12 @@ The results will be saved in the `{some_output_dir}` directory. The output file
```text
├── some_pdf.md # markdown file
├── images # directory for storing images
├── some_pdf_layout.pdf # layout diagram
├── some_pdf_layout.pdf # layout diagram (Include layout reading order)
├── some_pdf_middle.json # MinerU intermediate processing result
├── some_pdf_model.json # model inference result
├── some_pdf_origin.pdf # original PDF file
└── some_pdf_spans.pdf # smallest granularity bbox position information diagram
├── some_pdf_spans.pdf # smallest granularity bbox position information diagram
└── some_pdf_content_list.json # Rich text JSON arranged in reading order
```
For more information about the output files, please refer to the [Output File Description](docs/output_file_en_us.md).
......@@ -347,29 +359,38 @@ For detailed implementation, refer to:
- [demo.py Simplest Processing Method](demo/demo.py)
- [magic_pdf_parse_main.py More Detailed Processing Workflow](demo/magic_pdf_parse_main.py)
### Deploy Derived Projects
Derived projects include secondary development projects based on MinerU by project developers and community developers,
such as application interfaces based on Gradio, RAG based on llama, web demos similar to the official website, lightweight multi-GPU load balancing client/server ends, etc.
These projects may offer more features and a better user experience.
For specific deployment methods, please refer to the [Derived Project README](projects/README.md)
### Development Guide
TODO
# TODO
- [ ] Semantic-based reading order
- [ ] List recognition within the text
- [ ] Code block recognition within the text
- [ ] Table of contents recognition
- [x] Table recognition
- [ ] [Chemical formula recognition](docs/chemical_knowledge_introduction/introduction.pdf)
- [ ] Geometric shape recognition
- 🗹 Reading order based on the model
- 🗹 Recognition of `index` and `list` in the main text
- 🗹 Table recognition
- ☐ Code block recognition in the main text
- ☐ [Chemical formula recognition](docs/chemical_knowledge_introduction/introduction.pdf)
- ☐ Geometric shape recognition
# Known Issues
- Reading order is segmented based on rules, which can cause disordered sequences in some cases
- Vertical text is not supported
- Lists, code blocks, and table of contents are not yet supported in the layout model
- Comic books, art books, elementary school textbooks, and exercise books are not well-parsed yet
- Enabling OCR may produce better results in PDFs with a high density of formulas
- If you are processing PDFs with a large number of formulas, it is strongly recommended to enable the OCR function. When using PyMuPDF to extract text, overlapping text lines can occur, leading to inaccurate formula insertion positions.
- Reading order is determined by the model based on the spatial distribution of readable content, and may be out of order in some areas under extremely complex layouts.
- Vertical text is not supported.
- Tables of contents and lists are recognized through rules, and some uncommon list formats may not be recognized.
- Only one level of headings is supported; hierarchical headings are not currently supported.
- Code blocks are not yet supported in the layout model.
- Comic books, art albums, primary school textbooks, and exercises cannot be parsed well.
- Table recognition may result in row/column recognition errors in complex tables.
- OCR recognition may produce inaccurate characters in PDFs of lesser-known languages (e.g., diacritical marks in Latin script, easily confused characters in Arabic script).
- Some formulas may not render correctly in Markdown.
# FAQ
......@@ -395,6 +416,7 @@ This project currently uses PyMuPDF to achieve advanced functionality. However,
- [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
- [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
- [layoutreader](https://github.com/ppaanngggg/layoutreader)
- [fast-langdetect](https://github.com/LlmKira/fast-langdetect)
- [pdfminer.six](https://github.com/pdfminer/pdfminer.six)
......
......@@ -41,6 +41,18 @@
</div>
# 更新记录
- 2024/10/31 0.9.0发布,这是我们进行了大量代码重构的全新版本,解决了众多问题,提升了性能,降低了硬件需求,并提供了更丰富的易用性:
- 重构排序模块代码,使用 [layoutreader](https://github.com/ppaanngggg/layoutreader) 进行阅读顺序排序,确保在各种排版下都能实现极高准确率
- 重构段落拼接模块,在跨栏、跨页、跨图、跨表情况下均能实现良好的段落拼接效果
- 重构列表和目录识别功能,极大提升列表块和目录块识别的准确率及对应文本段落的解析效果
- 重构图、表与描述性文本的匹配逻辑,大幅提升 caption 和 footnote 与图表的匹配准确率,并将描述性文本的丢失率降至零
- 增加 OCR 的多语言支持,支持 84 种语言的检测与识别,语言支持列表详见 [OCR 语言支持列表](https://paddlepaddle.github.io/PaddleOCR/latest/ppocr/blog/multi_languages.html#5)
- 增加显存回收逻辑及其他显存优化措施,大幅降低显存使用需求。开启除表格加速外的全部加速功能(layout/公式/OCR)的显存需求从16GB降至8GB,开启全部加速功能的显存需求从24GB降至10GB
- 优化配置文件的功能开关,增加独立的公式检测开关,无需公式检测时可大幅提升速度和解析效果
- 集成 [PDF-Extract-Kit 1.0](https://github.com/opendatalab/PDF-Extract-Kit)
- 加入自研的 `doclayout_yolo` 模型,在相近解析效果情况下比原方案提速10倍以上,可通过配置文件与 `layoutlmv3` 自由切换
- 公式解析升级至 `unimernet 0.2.1`,在提升公式解析准确率的同时,大幅降低显存需求
- 2024/09/27 0.8.1发布,修复了一些bug,同时提供了[在线demo](https://opendatalab.com/OpenSourceTools/Extractor/PDF/)[本地化部署版本](projects/web_demo/README_zh-CN.md)[前端界面](projects/web/README_zh-CN.md)
- 2024/09/09 0.8.0发布,支持Dockerfile快速部署,同时上线了huggingface、modelscope demo
- 2024/08/30 0.7.1发布,集成了paddle tablemaster表格识别功能
......@@ -69,6 +81,7 @@
<ul>
<li><a href="#命令行">命令行</a></li>
<li><a href="#api">API</a></li>
<li><a href="#部署衍生项目">部署衍生项目</a></li>
<li><a href="#二次开发">二次开发</a></li>
</ul>
</li>
......@@ -100,15 +113,18 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
## 主要功能
- 删除页眉、页脚、脚注、页码等元素,保持语义连贯
- 对多栏输出符合人类阅读顺序的文本
- 删除页眉、页脚、脚注、页码等元素,确保语义连贯
- 输出符合人类阅读顺序的文本,适用于单栏、多栏及复杂排版
- 保留原文档的结构,包括标题、段落、列表等
- 提取图像、图片标题、表格、表格标题
- 自动识别文档中的公式并将公式转换成latex
- 自动识别文档中的表格并将表格转换成latex
- 乱码PDF自动检测并启用OCR
- 提取图像、图片描述、表格、表格标题及脚注
- 自动识别并转换文档中的公式为LaTeX格式
- 自动识别并转换文档中的表格为LaTeX或HTML格式
- 自动检测扫描版PDF和乱码PDF,并启用OCR功能
- OCR支持84种语言的检测与识别
- 支持多种输出格式,如多模态与NLP的Markdown、按阅读顺序排序的JSON、含有丰富信息的中间格式等
- 支持多种可视化结果,包括layout可视化、span可视化等,便于高效确认输出效果与质检
- 支持CPU和GPU环境
- 支持windows/linux/mac平台
- 兼容Windows、Linux和Mac平台
## 快速开始
......@@ -139,8 +155,8 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
</tr>
<tr>
<td colspan="3">CPU</td>
<td>x86_64</td>
<td>x86_64</td>
<td>x86_64(暂不支持ARM Linux)</td>
<td>x86_64(暂不支持ARM Windows)</td>
<td>x86_64 / arm64</td>
</tr>
<tr>
......@@ -149,7 +165,7 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
</tr>
<tr>
<td colspan="3">python版本</td>
<td colspan="3">3.10</td>
<td colspan="3">3.10 (请务必通过conda创建3.10虚拟环境)</td>
</tr>
<tr>
<td colspan="3">Nvidia Driver 版本</td>
......@@ -166,24 +182,27 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
<tr>
<td rowspan="2">GPU硬件支持列表</td>
<td colspan="2">最低要求 8G+显存</td>
<td colspan="2">3060ti/3070/3080/3080ti/4060/4070/4070ti<br>
8G显存仅可开启lavout和公式识别加速</td>
<td colspan="2">3060ti/3070/4060<br>
8G显存可开启layout、公式识别和ocr加速</td>
<td rowspan="2">None</td>
</tr>
<tr>
<td colspan="2">推荐配置 16G+显存</td>
<td colspan="2">3090/3090ti/4070tisuper/4080/4090<br>
16G及以上可以同时开启layout,公式识别和ocr加速<br>
24G及以上可以同时开启layout,公式识别,ocr加速和表格识别
<td colspan="2">推荐配置 10G+显存</td>
<td colspan="2">3080/3080ti/3090/3090ti/4070/4070ti/4070tisuper/4080/4090<br>
10G显存及以上可以同时开启layout、公式识别和ocr加速和表格识别加速<br>
</td>
</tr>
</table>
### 在线体验
稳定版(经过QA验证的稳定版本):
[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://opendatalab.com/OpenSourceTools/Extractor/PDF)
[![ModelScope](https://img.shields.io/badge/Demo_on_ModelScope-purple?logo=&labelColor=white)](https://www.modelscope.cn/studios/OpenDataLab/MinerU)
测试版(同步dev分支更新,测试新特性):
[![HuggingFace](https://img.shields.io/badge/Demo_on_HuggingFace-yellow.svg?logo=&labelColor=white)](https://huggingface.co/spaces/opendatalab/MinerU)
[![ModelScope](https://img.shields.io/badge/Demo_on_ModelScope-purple?logo=&labelColor=white)](https://www.modelscope.cn/studios/OpenDataLab/MinerU)
### 使用CPU快速体验
......@@ -201,38 +220,30 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i h
详细参考 [如何下载模型文件](docs/how_to_download_models_zh_cn.md)
> ❗️模型下载后请务必检查模型文件是否下载完整
>
> 请检查目录下的模型文件大小与网页上描述是否一致,如果可以的话,最好通过sha256校验模型是否下载完整
#### 3. 拷贝配置文件并进行配置
#### 3. 修改配置文件以进行额外配置
在仓库根目录可以获得 [magic-pdf.template.json](magic-pdf.template.json) 配置模版文件
> ❗️务必执行以下命令将配置文件拷贝到【用户目录】下,否则程序将无法运行
>
完成[2. 下载模型权重文件](#2-下载模型权重文件)步骤后,脚本会自动生成用户目录下的magic-pdf.json文件,并自动配置默认模型路径。
您可在【用户目录】下找到magic-pdf.json文件。
> windows的用户目录为 "C:\\Users\\用户名", linux用户目录为 "/home/用户名", macOS用户目录为 "/Users/用户名"
```bash
cp magic-pdf.template.json ~/magic-pdf.json
```
在用户目录中找到magic-pdf.json文件并配置"models-dir"为[2. 下载模型权重文件](#2-下载模型权重文件)中下载的模型权重文件所在目录
您可修改该文件中的部分配置实现功能的开关,如表格识别功能:
> ❗️务必正确配置模型权重文件所在目录的【绝对路径】,否则会因为找不到模型文件而导致程序无法运行
>
> windows系统中此路径应包含盘符,且需把路径中所有的`"\"`替换为`"/"`,否则会因为转义原因导致json文件语法错误。
>
> 例如:模型放在D盘根目录的models目录,则model-dir的值应为"D:/models"
>如json内没有如下项目,请手动添加需要的项目,并删除注释内容(标准json不支持注释)
```json
{
// other config
"models-dir": "D:/models",
"layout-config": {
"model": "layoutlmv3" // 使用doclayout_yolo请修改为“doclayout_yolo"
},
"formula-config": {
"mfd_model": "yolo_v8_mfd",
"mfr_model": "unimernet_small",
"enable": true // 公式识别功能默认是开启的,如果需要关闭请修改此处的值为"false"
},
"table-config": {
"model": "TableMaster", // 使用structEqTable请修改为'struct_eqtable'
"is_table_recog_enable": false, // 表格识别功能默认是关闭的,如果需要修改此处的值
"model": "tablemaster", // 使用structEqTable请修改为"struct_eqtable"
"enable": false, // 表格识别功能默认是关闭的,如果需要开启请修改此处的值为"true"
"max_time": 400
}
}
......@@ -282,8 +293,8 @@ Options:
-l, --lang TEXT Input the languages in the pdf (if known) to
improve OCR accuracy. Optional. You should
input "Abbreviation" with language form url: ht
tps://paddlepaddle.github.io/PaddleOCR/en/ppocr
/blog/multi_languages.html#5-support-languages-
tps://paddlepaddle.github.io/PaddleOCR/latest/en
/ppocr/blog/multi_languages.html#5-support-languages-
and-abbreviations
-d, --debug BOOLEAN Enables detailed debugging information during
the execution of the CLI commands.
......@@ -307,11 +318,12 @@ magic-pdf -p {some_pdf} -o {some_output_dir} -m auto
```text
├── some_pdf.md # markdown 文件
├── images # 存放图片目录
├── some_pdf_layout.pdf # layout 绘图
├── some_pdf_layout.pdf # layout 绘图 (包含layout阅读顺序)
├── some_pdf_middle.json # minerU 中间处理结果
├── some_pdf_model.json # 模型推理结果
├── some_pdf_origin.pdf # 原 pdf 文件
└── some_pdf_spans.pdf # 最小粒度的bbox位置信息绘图
├── some_pdf_spans.pdf # 最小粒度的bbox位置信息绘图
└── some_pdf_content_list.json # 按阅读顺序排列的富文本json
```
更多有关输出文件的信息,请参考[输出文件说明](docs/output_file_zh_cn.md)
......@@ -351,29 +363,38 @@ md_content = pipe.pipe_mk_markdown(image_dir, drop_mode="none")
- [demo.py 最简单的处理方式](demo/demo.py)
- [magic_pdf_parse_main.py 能够更清晰看到处理流程](demo/magic_pdf_parse_main.py)
### 部署衍生项目
衍生项目包含项目开发者和社群开发者们基于MinerU的二次开发项目,
例如基于Gradio的应用界面、基于llama的RAG、官网同款web demo、轻量级的多卡负载均衡c/s端等,
这些项目可能会提供更多的功能和更好的用户体验。
具体部署方式请参考 [衍生项目readme](projects/README_zh-CN.md)
### 二次开发
TODO
# TODO
- [ ] 基于语义的阅读顺序
- [ ] 正文中列表识别
- [ ] 正文中代码块识别
- [ ] 目录识别
- [x] 表格识别
- [ ] [化学式识别](docs/chemical_knowledge_introduction/introduction.pdf)
- [ ] 几何图形识别
- 🗹 基于模型的阅读顺序
- 🗹 正文中目录、列表识别
- 🗹 表格识别
- ☐ 正文中代码块识别
- ☐ [化学式识别](docs/chemical_knowledge_introduction/introduction.pdf)
- ☐ 几何图形识别
# Known Issues
- 阅读顺序基于规则的分割,在一些情况下会乱序
- 阅读顺序基于模型对可阅读内容在空间中的分布进行排序,在极端复杂的排版下可能会部分区域乱序
- 不支持竖排文字
- 列表、代码块、目录在layout模型里还没有支持
- 目录和列表通过规则进行识别,少部分不常见的列表形式可能无法识别
- 标题只有一级,目前不支持标题分级
- 代码块在layout模型里还没有支持
- 漫画书、艺术图册、小学教材、习题尚不能很好解析
- 在一些公式密集的PDF上强制启用OCR效果会更好
- 如果您要处理包含大量公式的pdf,强烈建议开启OCR功能。使用pymuPDF提取文字的时候会出现文本行互相重叠的情况导致公式插入位置不准确。
- 表格识别在复杂表格上可能会出现行/列识别错误
- 在小语种PDF上,OCR识别可能会出现字符不准确的情况(如拉丁文的重音符号、阿拉伯文易混淆字符等)
- 部分公式可能会无法在markdown中渲染
# FAQ
......@@ -400,6 +421,7 @@ TODO
- [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
- [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
- [layoutreader](https://github.com/ppaanngggg/layoutreader)
- [fast-langdetect](https://github.com/LlmKira/fast-langdetect)
- [pdfminer.six](https://github.com/pdfminer/pdfminer.six)
......
import os
import json
from loguru import logger
from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
import magic_pdf.model as model_config
model_config.__use_inside_model__ = True
try:
current_script_dir = os.path.dirname(os.path.abspath(__file__))
demo_name = "demo1"
pdf_path = os.path.join(current_script_dir, f"{demo_name}.pdf")
model_path = os.path.join(current_script_dir, f"{demo_name}.json")
pdf_bytes = open(pdf_path, "rb").read()
# model_json = json.loads(open(model_path, "r", encoding="utf-8").read())
model_json = [] # model_json传空list使用内置模型解析
jso_useful_key = {"_pdf_type": "", "model_list": model_json}
jso_useful_key = {"_pdf_type": "", "model_list": []}
local_image_dir = os.path.join(current_script_dir, 'images')
image_dir = str(os.path.basename(local_image_dir))
image_writer = DiskReaderWriter(local_image_dir)
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
pipe.pipe_classify()
"""如果没有传入有效的模型数据,则使用内置model解析"""
if len(model_json) == 0:
if model_config.__use_inside_model__:
pipe.pipe_analyze()
else:
logger.error("need model list input")
exit(1)
pipe.pipe_parse()
md_content = pipe.pipe_mk_markdown(image_dir, drop_mode="none")
with open(f"{demo_name}.md", "w", encoding="utf-8") as f:
......
......@@ -4,13 +4,12 @@ import copy
from loguru import logger
from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_span_bbox
from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
import magic_pdf.model as model_config
model_config.__use_inside_model__ = True
# todo: 设备类型选择 (?)
......@@ -47,11 +46,20 @@ def json_md_dump(
)
# 可视化
def draw_visualization_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name):
# 画布局框,附带排序结果
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
# 画 span 框
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
def pdf_parse_main(
pdf_path: str,
parse_method: str = 'auto',
model_json_path: str = None,
is_json_md_dump: bool = True,
is_draw_visualization_bbox: bool = True,
output_dir: str = None
):
"""
......@@ -108,11 +116,7 @@ def pdf_parse_main(
# 如果没有传入模型数据,则使用内置模型解析
if not model_json:
if model_config.__use_inside_model__:
pipe.pipe_analyze() # 解析
else:
logger.error("need model list input")
exit(1)
# 执行解析
pipe.pipe_parse()
......@@ -121,10 +125,11 @@ def pdf_parse_main(
content_list = pipe.pipe_mk_uni_format(image_path_parent, drop_mode="none")
md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none")
if is_json_md_dump:
json_md_dump(pipe, md_writer, pdf_name, content_list, md_content)
if is_draw_visualization_bbox:
draw_visualization_bbox(pipe.pdf_mid_data['pdf_info'], pdf_bytes, output_path, pdf_name)
except Exception as e:
logger.exception(e)
......@@ -132,5 +137,5 @@ def pdf_parse_main(
# 测试
if __name__ == '__main__':
pdf_path = r"C:\Users\XYTK2\Desktop\2024-2016-gb-cd-300.pdf"
pdf_path = r"D:\project\20240617magicpdf\Magic-PDF\demo\demo1.pdf"
pdf_parse_main(pdf_path)
......@@ -38,17 +38,22 @@ sudo apt-get install libgl1-mesa-glx
Reference: https://github.com/opendatalab/MinerU/issues/388
### 5. Encountered error `ModuleNotFoundError: No module named 'fairscale'`
You need to uninstall the module and reinstall it:
```bash
pip uninstall fairscale
pip install fairscale
```
Reference: https://github.com/opendatalab/MinerU/issues/411
### 6. On some newer devices like the H100, the text parsed during OCR using CUDA acceleration is garbled.
The compatibility of cuda11 with new graphics cards is poor, and the CUDA version used by Paddle needs to be upgraded.
```bash
pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu123/
```
Reference: https://github.com/opendatalab/MinerU/issues/558
# 常见问题解答
### 1.在较新版本的mac上使用命令安装pip install magic-pdf[full] zsh: no matches found: magic-pdf[full]
### 1.在较新版本的mac上使用命令安装pip install magic-pdf\[full\] zsh: no matches found: magic-pdf\[full\]
在 macOS 上,默认的 shell 从 Bash 切换到了 Z shell,而 Z shell 对于某些类型的字符串匹配有特殊的处理逻辑,这可能导致no matches found错误。
可以通过在命令行禁用globbing特性,再尝试运行安装命令
```bash
setopt no_nomatch
pip install magic-pdf[full]
......@@ -17,11 +18,13 @@ pip install magic-pdf[full]
### 3.模型文件应该下载到哪里/models-dir的配置应该怎么填
模型文件的路径输入是在"magic-pdf.json"中通过
```json
{
"models-dir": "/tmp/models"
}
```
进行配置的。
这个路径是绝对路径而不是相对路径,绝对路径的获取可在models目录中通过命令 "pwd" 获取。
参考:https://github.com/opendatalab/MinerU/issues/155#issuecomment-2230216874
......@@ -29,23 +32,30 @@ pip install magic-pdf[full]
### 4.在WSL2的Ubuntu22.04中遇到报错`ImportError: libGL.so.1: cannot open shared object file: No such file or directory`
WSL2的Ubuntu22.04中缺少`libgl`库,可通过以下命令安装`libgl`库解决:
```bash
sudo apt-get install libgl1-mesa-glx
```
参考:https://github.com/opendatalab/MinerU/issues/388
### 5.遇到报错 `ModuleNotFoundError : Nomodulenamed 'fairscale'`
需要卸载该模块并重新安装
```bash
pip uninstall fairscale
pip install fairscale
```
参考:https://github.com/opendatalab/MinerU/issues/411
### 6.在部分较新的设备如H100上,使用CUDA加速OCR时解析出的文字乱码。
cuda11对新显卡的兼容性不好,需要升级paddle使用的cuda版本
```bash
pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu123/
```
参考:https://github.com/opendatalab/MinerU/issues/558
# Ubuntu 22.04 LTS
### 1. Check if NVIDIA Drivers Are Installed
```sh
nvidia-smi
```
If you see information similar to the following, it means that the NVIDIA drivers are already installed, and you can skip Step 2.
```plaintext
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 537.34 Driver Version: 537.34 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name TCC/WDDM | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 3060 Ti WDDM | 00000000:01:00.0 On | N/A |
| 0% 51C P8 12W / 200W | 1489MiB / 8192MiB | 5% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
```
```sh
nvidia-smi
```
If you see information similar to the following, it means that the NVIDIA drivers are already installed, and you can skip Step 2.
Notice:`CUDA Version` should be >= 12.1, If the displayed version number is less than 12.1, please upgrade the driver.
```plaintext
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 537.34 Driver Version: 537.34 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name TCC/WDDM | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 3060 Ti WDDM | 00000000:01:00.0 On | N/A |
| 0% 51C P8 12W / 200W | 1489MiB / 8192MiB | 5% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
```
### 2. Install the Driver
If no driver is installed, use the following command:
```sh
sudo apt-get update
sudo apt-get install nvidia-driver-545
```
Install the proprietary driver and restart your computer after installation.
```sh
reboot
```
If no driver is installed, use the following command:
```sh
sudo apt-get update
sudo apt-get install nvidia-driver-545
```
Install the proprietary driver and restart your computer after installation.
```sh
reboot
```
### 3. Install Anaconda
If Anaconda is already installed, skip this step.
```sh
wget https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
bash Anaconda3-2024.06-1-Linux-x86_64.sh
```
In the final step, enter `yes`, close the terminal, and reopen it.
If Anaconda is already installed, skip this step.
```sh
wget https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
bash Anaconda3-2024.06-1-Linux-x86_64.sh
```
In the final step, enter `yes`, close the terminal, and reopen it.
### 4. Create an Environment Using Conda
Specify Python version 3.10.
```sh
conda create -n MinerU python=3.10
conda activate MinerU
```
Specify Python version 3.10.
```sh
conda create -n MinerU python=3.10
conda activate MinerU
```
### 5. Install Applications
```sh
pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
```
```sh
pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
```
❗ After installation, make sure to check the version of `magic-pdf` using the following command:
```sh
magic-pdf --version
```
If the version number is less than 0.7.0, please report the issue.
```sh
magic-pdf --version
```
If the version number is less than 0.7.0, please report the issue.
### 6. Download Models
Refer to detailed instructions on [how to download model files](how_to_download_models_en.md).
After downloading, move the `models` directory to an SSD with more space.
❗ After downloading the models, ensure they are complete:
- Check that the file sizes match the description on the website.
- If possible, verify the integrity using SHA256.
Refer to detailed instructions on [how to download model files](how_to_download_models_en.md).
### 7. Configuration Before First Run
Obtain the configuration template file `magic-pdf.template.json` from the root directory of the repository.
## 7. Understand the Location of the Configuration File
❗ Execute the following command to copy the configuration file to your home directory, otherwise the program will not run:
```sh
wget https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json
cp magic-pdf.template.json ~/magic-pdf.json
```
Find the `magic-pdf.json` file in your home directory and configure `"models-dir"` to be the directory where the model weights from Step 6 were downloaded.
After completing the [6. Download Models](#6-download-models) step, the script will automatically generate a `magic-pdf.json` file in the user directory and configure the default model path.
You can find the `magic-pdf.json` file in your user directory.
❗ Correctly specify the absolute path of the directory containing the model weights; otherwise, the program will fail due to missing model files.
```json
{
"models-dir": "/tmp/models"
}
```
> The user directory for Linux is "/home/username".
### 8. First Run
Download a sample file from the repository and test it.
```sh
wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf
magic-pdf -p small_ocr.pdf
```
Download a sample file from the repository and test it.
```sh
wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf
magic-pdf -p small_ocr.pdf
```
### 9. Test CUDA Acceleration
If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA acceleration:
If your graphics card has at least **8GB** of VRAM, follow these steps to test CUDA acceleration:
1. Modify the value of `"device-mode"` in the `magic-pdf.json` configuration file located in your home directory.
```json
......@@ -105,8 +110,6 @@ If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA
### 10. Enable CUDA Acceleration for OCR
❗ The following operations require a graphics card with at least 16GB of VRAM; otherwise, the program may crash or experience reduced performance.
1. Download `paddlepaddle-gpu`. Installation will automatically enable OCR acceleration.
```sh
python -m pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
......
# Ubuntu 22.04 LTS
## 1. 检测是否已安装nvidia驱动
```bash
nvidia-smi
```
如果看到类似如下的信息,说明已经安装了nvidia驱动,可以跳过步骤2
注意:`CUDA Version` 显示的版本号应 >= 12.1,如显示的版本号小于12.1,请升级驱动
```plaintext
```
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 537.34 Driver Version: 537.34 CUDA Version: 12.2 |
......@@ -18,96 +24,108 @@ nvidia-smi
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
```
## 2. 安装驱动
如没有驱动,则通过如下命令
```bash
sudo apt-get update
sudo apt-get install nvidia-driver-545
```
安装专有驱动,安装完成后,重启电脑
```bash
reboot
```
## 3. 安装anacoda
如果已安装conda,可以跳过本步骤
```bash
wget -U NoSuchBrowser/1.0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
bash Anaconda3-2024.06-1-Linux-x86_64.sh
```
最后一步输入yes,关闭终端重新打开
## 4. 使用conda 创建环境
需指定python版本为3.10
```bash
conda create -n MinerU python=3.10
conda activate MinerU
```
## 5. 安装应用
```bash
pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i https://pypi.tuna.tsinghua.edu.cn/simple
```
> ❗️下载完成后,务必通过以下命令确认magic-pdf的版本是否正确
>
> ```bash
> magic-pdf --version
>```
> ```
>
> 如果版本号小于0.7.0,请到issue中向我们反馈
## 6. 下载模型
详细参考 [如何下载模型文件](how_to_download_models_zh_cn.md)
下载后请将models目录移动到空间较大的ssd磁盘目录
> ❗️模型下载后请务必检查模型文件是否下载完整
>
> 请检查目录下的模型文件大小与网页上描述是否一致,如果可以的话,最好通过sha256校验模型是否下载完整
>
## 7. 第一次运行前的配置
在仓库根目录可以获得 [magic-pdf.template.json](../magic-pdf.template.json) 配置模版文件
> ❗️务必执行以下命令将配置文件拷贝到【用户目录】下,否则程序将无法运行
>
> linux用户目录为 "/home/用户名"
```bash
wget https://gitee.com/myhloli/MinerU/raw/master/magic-pdf.template.json
cp magic-pdf.template.json ~/magic-pdf.json
```
在用户目录中找到magic-pdf.json文件并配置"models-dir"为[6. 下载模型](#6-下载模型)中下载的模型权重文件所在目录
> ❗️务必正确配置模型权重文件所在目录的【绝对路径】,否则会因为找不到模型文件而导致程序无法运行
>
```json
{
"models-dir": "/tmp/models"
}
```
## 7. 了解配置文件存放的位置
完成[6.下载模型](#6-下载模型)步骤后,脚本会自动生成用户目录下的magic-pdf.json文件,并自动配置默认模型路径。
您可在【用户目录】下找到magic-pdf.json文件。
> linux用户目录为 "/home/用户名"
## 8. 第一次运行
从仓库中下载样本文件,并测试
```bash
wget https://gitee.com/myhloli/MinerU/raw/master/demo/small_ocr.pdf
magic-pdf -p small_ocr.pdf
```
## 9. 测试CUDA加速
如果您的显卡显存大于等于8G,可以进行以下流程,测试CUDA解析加速效果
如果您的显卡显存大于等于 **8GB** ,可以进行以下流程,测试CUDA解析加速效果
**1.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
```json
{
"device-mode":"cuda"
}
```
**2.运行以下命令测试cuda加速效果**
```bash
magic-pdf -p small_ocr.pdf
```
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`layout detection cost` 和 `mfr time` 应提速10倍以上。
## 10. 为ocr开启cuda加速
> ❗️以下操作需显卡显存大于等于16G才可进行,否则会因为显存不足导致程序崩溃或运行速度下降
**1.下载paddlepaddle-gpu, 安装完成后会自动开启ocr加速**
```bash
python -m pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
```
**2.运行以下命令测试ocr加速效果**
```bash
magic-pdf -p small_ocr.pdf
```
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr cost`应提速10倍以上。
# Windows 10/11
### 1. Install CUDA and cuDNN
Required versions: CUDA 11.8 + cuDNN 8.7.0
- CUDA 11.8: https://developer.nvidia.com/cuda-11-8-0-download-archive
- cuDNN v8.7.0 (November 28th, 2022), for CUDA 11.x: https://developer.nvidia.com/rdp/cudnn-archive
- CUDA 11.8: https://developer.nvidia.com/cuda-11-8-0-download-archive
- cuDNN v8.7.0 (November 28th, 2022), for CUDA 11.x: https://developer.nvidia.com/rdp/cudnn-archive
### 2. Install Anaconda
If Anaconda is already installed, you can skip this step.
If Anaconda is already installed, you can skip this step.
Download link: https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Windows-x86_64.exe
### 3. Create an Environment Using Conda
Python version must be 3.10.
```
conda create -n MinerU python=3.10
conda activate MinerU
```
### 4. Install Applications
```
pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
```
>❗️After installation, verify the version of `magic-pdf`:
> ```bash
> magic-pdf --version
> ```
> If the version number is less than 0.7.0, please report it in the issues section.
Python version must be 3.10.
### 5. Download Models
Refer to detailed instructions on [how to download model files](how_to_download_models_en.md).
After downloading, move the `models` directory to an SSD with more space.
```
conda create -n MinerU python=3.10
conda activate MinerU
```
>❗ After downloading the models, ensure they are complete:
>- Check that the file sizes match the description on the website.
>- If possible, verify the integrity using SHA256.
### 4. Install Applications
### 6. Configuration Before the First Run
Obtain the configuration template file `magic-pdf.template.json` from the repository root directory.
```
pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
```
>❗️Execute the following command to copy the configuration file to your user directory, or the program will not run.
>
> In Windows, user directory is "C:\Users\username"
> ❗️After installation, verify the version of `magic-pdf`:
>
> ```bash
> magic-pdf --version
> ```
>
> If the version number is less than 0.7.0, please report it in the issues section.
```powershell
(New-Object System.Net.WebClient).DownloadFile('https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json', 'magic-pdf.template.json')
cp magic-pdf.template.json ~/magic-pdf.json
```
### 5. Download Models
Find the `magic-pdf.json` file in your user directory and configure `"models-dir"` to point to the directory where the model weights from step 5 were downloaded.
Refer to detailed instructions on [how to download model files](how_to_download_models_en.md).
> ❗️Ensure the absolute path of the model weights directory is correctly configured, or the program will fail to run due to not finding the model files.
>
> In Windows, this path should include the drive letter and replace all `"\"` to `"/"`.
>
> Example: If the models are placed in the root directory of drive D, the value for `model-dir` should be `"D:/models"`.
### 6. Understand the Location of the Configuration File
```json
{
"models-dir": "/tmp/models"
}
```
After completing the [5. Download Models](#5-download-models) step, the script will automatically generate a `magic-pdf.json` file in the user directory and configure the default model path.
You can find the `magic-pdf.json` file in your 【user directory】 .
> The user directory for Windows is "C:/Users/username".
### 7. First Run
Download a sample file from the repository and test it.
```powershell
(New-Object System.Net.WebClient).DownloadFile('https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf', 'small_ocr.pdf')
Download a sample file from the repository and test it.
```powershell
wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf -O small_ocr.pdf
magic-pdf -p small_ocr.pdf
```
```
### 8. Test CUDA Acceleration
If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA-accelerated parsing performance.
1. **Overwrite the installation of torch and torchvision** supporting CUDA.
If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA-accelerated parsing performance.
1. **Overwrite the installation of torch and torchvision** supporting CUDA.
```
pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
```
>❗️Ensure the following versions are specified in the command:
>```
> ❗️Ensure the following versions are specified in the command:
>
> ```
> torch==2.3.1 torchvision==0.18.1
>```
>These are the highest versions we support. Installing higher versions without specifying them will cause the program to fail.
2. **Modify the value of `"device-mode"`** in the `magic-pdf.json` configuration file located in your user directory.
> ```
>
> These are the highest versions we support. Installing higher versions without specifying them will cause the program to fail.
2. **Modify the value of `"device-mode"`** in the `magic-pdf.json` configuration file located in your user directory.
```json
{
"device-mode": "cuda"
}
```
3. **Run the following command to test CUDA acceleration**:
3. **Run the following command to test CUDA acceleration**:
```
magic-pdf -p small_ocr.pdf
```
### 9. Enable CUDA Acceleration for OCR
>❗️This operation requires at least 16GB of VRAM on your graphics card, otherwise it will cause the program to crash or slow down.
1. **Download paddlepaddle-gpu**, which will automatically enable OCR acceleration upon installation.
1. **Download paddlepaddle-gpu**, which will automatically enable OCR acceleration upon installation.
```
pip install paddlepaddle-gpu==2.6.1
```
2. **Run the following command to test OCR acceleration**:
2. **Run the following command to test OCR acceleration**:
```
magic-pdf -p small_ocr.pdf
```
......@@ -3,103 +3,106 @@
## 1. 安装cuda和cuDNN
需要安装的版本 CUDA 11.8 + cuDNN 8.7.0
- CUDA 11.8 https://developer.nvidia.com/cuda-11-8-0-download-archive
- cuDNN v8.7.0 (November 28th, 2022), for CUDA 11.x https://developer.nvidia.com/rdp/cudnn-archive
## 2. 安装anaconda
如果已安装conda,可以跳过本步骤
下载链接:
https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2024.06-1-Windows-x86_64.exe
## 3. 使用conda 创建环境
需指定python版本为3.10
```bash
conda create -n MinerU python=3.10
conda activate MinerU
```
## 4. 安装应用
```bash
pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i https://pypi.tuna.tsinghua.edu.cn/simple
```
> ❗️下载完成后,务必通过以下命令确认magic-pdf的版本是否正确
>
> ```bash
> magic-pdf --version
>```
> ```
>
> 如果版本号小于0.7.0,请到issue中向我们反馈
## 5. 下载模型
详细参考 [如何下载模型文件](how_to_download_models_zh_cn.md)
下载后请将models目录移动到空间较大的ssd磁盘目录
> ❗️模型下载后请务必检查模型文件是否下载完整
>
> 请检查目录下的模型文件大小与网页上描述是否一致,如果可以的话,最好通过sha256校验模型是否下载完整
## 6. 第一次运行前的配置
在仓库根目录可以获得 [magic-pdf.template.json](../magic-pdf.template.json) 配置模版文件
> ❗️务必执行以下命令将配置文件拷贝到【用户目录】下,否则程序将无法运行
>
> windows用户目录为 "C:\Users\用户名"
```powershell
(New-Object System.Net.WebClient).DownloadFile('https://gitee.com/myhloli/MinerU/raw/master/magic-pdf.template.json', 'magic-pdf.template.json')
cp magic-pdf.template.json ~/magic-pdf.json
```
## 6. 了解配置文件存放的位置
在用户目录中找到magic-pdf.json文件并配置"models-dir"为[5. 下载模型](#5-下载模型)中下载的模型权重文件所在目录
> ❗️务必正确配置模型权重文件所在目录的【绝对路径】,否则会因为找不到模型文件而导致程序无法运行
>
> windows系统中此路径应包含盘符,且需把路径中所有的`"\"`替换为`"/"`,否则会因为转义原因导致json文件语法错误。
>
> 例如:模型放在D盘根目录的models目录,则model-dir的值应为"D:/models"
```json
{
"models-dir": "/tmp/models"
}
```
完成[5.下载模型](#5-下载模型)步骤后,脚本会自动生成用户目录下的magic-pdf.json文件,并自动配置默认模型路径。
您可在【用户目录】下找到magic-pdf.json文件。
> windows用户目录为 "C:/Users/用户名"
## 7. 第一次运行
从仓库中下载样本文件,并测试
```powershell
(New-Object System.Net.WebClient).DownloadFile('https://gitee.com/myhloli/MinerU/raw/master/demo/small_ocr.pdf', 'small_ocr.pdf')
magic-pdf -p small_ocr.pdf
wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf -O small_ocr.pdf
magic-pdf -p small_ocr.pdf
```
## 8. 测试CUDA加速
如果您的显卡显存大于等于8G,可以进行以下流程,测试CUDA解析加速效果
如果您的显卡显存大于等于 **8GB** ,可以进行以下流程,测试CUDA解析加速效果
**1.覆盖安装支持cuda的torch和torchvision**
```bash
pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
```
> ❗️务必在命令中指定以下版本
>
> ```bash
> torch==2.3.1 torchvision==0.18.1
> ```
>
> 这是我们支持的最高版本,如果不指定版本会自动安装更高版本导致程序无法运行
**2.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
```json
{
"device-mode":"cuda"
}
```
**3.运行以下命令测试cuda加速效果**
```bash
magic-pdf -p small_ocr.pdf
```
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`layout detection cost` 和 `mfr time` 应提速10倍以上。
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段的耗时来简单判断,通常情况下,`layout detection time` 和 `mfr time` 应提速10倍以上。
## 9. 为ocr开启cuda加速
> ❗️以下操作需显卡显存大于等于16G才可进行,否则会因为显存不足导致程序崩溃或运行速度下降
**1.下载paddlepaddle-gpu, 安装完成后会自动开启ocr加速**
```bash
pip install paddlepaddle-gpu==2.6.1
```
**2.运行以下命令测试ocr加速效果**
```bash
magic-pdf -p small_ocr.pdf
```
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr cost`应提速10倍以上。
> 提示:CUDA加速是否生效可以根据log中输出的各个阶段cost耗时来简单判断,通常情况下,`ocr time`应提速10倍以上。
# use modelscope sdk download models
import json
import os
import requests
from modelscope import snapshot_download
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
print(f"model dir is: {model_dir}/models")
def download_json(url):
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
return response.json()
def download_and_modify_json(url, local_filename, modifications):
if os.path.exists(local_filename):
data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.0.0':
data = download_json(url)
else:
data = download_json(url)
# 修改内容
for key, value in modifications.items():
data[key] = value
# 保存修改后的内容
with open(local_filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
if __name__ == '__main__':
mineru_patterns = [
"models/Layout/LayoutLMv3/*",
"models/Layout/YOLO/*",
"models/MFD/YOLO/*",
"models/MFR/unimernet_small/*",
"models/TabRec/TableMaster/*",
"models/TabRec/StructEqTable/*",
]
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
model_dir = model_dir + '/models'
print(f'model_dir is: {model_dir}')
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
json_url = 'https://gitee.com/myhloli/MinerU/raw/dev/magic-pdf.template.json'
config_file_name = 'magic-pdf.json'
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, config_file_name)
json_mods = {
'models-dir': model_dir,
'layoutreader-model-dir': layoutreader_model_dir,
}
download_and_modify_json(json_url, config_file, json_mods)
print(f'The configuration file has been configured successfully, the path is: {config_file}')
import json
import os
import requests
from huggingface_hub import snapshot_download
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
print(f"model dir is: {model_dir}/models")
def download_json(url):
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
return response.json()
def download_and_modify_json(url, local_filename, modifications):
if os.path.exists(local_filename):
data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.0.0':
data = download_json(url)
else:
data = download_json(url)
# 修改内容
for key, value in modifications.items():
data[key] = value
# 保存修改后的内容
with open(local_filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
if __name__ == '__main__':
mineru_patterns = [
"models/Layout/LayoutLMv3/*",
"models/Layout/YOLO/*",
"models/MFD/YOLO/*",
"models/MFR/unimernet_small/*",
"models/TabRec/TableMaster/*",
"models/TabRec/StructEqTable/*",
]
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
layoutreader_pattern = [
"*.json",
"*.safetensors",
]
layoutreader_model_dir = snapshot_download('hantian/layoutreader', allow_patterns=layoutreader_pattern)
model_dir = model_dir + '/models'
print(f'model_dir is: {model_dir}')
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
json_url = 'https://github.com/opendatalab/MinerU/raw/dev/magic-pdf.template.json'
config_file_name = 'magic-pdf.json'
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, config_file_name)
json_mods = {
'models-dir': model_dir,
'layoutreader-model-dir': layoutreader_model_dir,
}
download_and_modify_json(json_url, config_file, json_mods)
print(f'The configuration file has been configured successfully, the path is: {config_file}')
Model downloads are divided into initial downloads and updates to the model directory. Please refer to the corresponding documentation for instructions on how to proceed.
# Initial download of model files
### 1. Download the Model from Hugging Face
Use a Python Script to Download Model Files from Hugging Face
```bash
pip install huggingface_hub
wget https://github.com/opendatalab/MinerU/raw/master/docs/download_models_hf.py
wget https://github.com/opendatalab/MinerU/raw/master/docs/download_models_hf.py -O download_models_hf.py
python download_models_hf.py
```
After the Python script finishes executing, it will output the directory where the models are downloaded.
### 2. To modify the model path address in the configuration file
Additionally, in `~/magic-pdf.json`, update the model directory path to the absolute path of the `models` directory output by the previous Python script. Otherwise, you will encounter an error indicating that the model cannot be loaded.
The Python script will automatically download the model files and configure the model directory in the configuration file.
The configuration file can be found in the user directory, with the filename `magic-pdf.json`.
# How to update models previously downloaded
## 1. Models downloaded via Git LFS
>Due to feedback from some users that downloading model files using git lfs was incomplete or resulted in corrupted model files, this method is no longer recommended.
> Due to feedback from some users that downloading model files using git lfs was incomplete or resulted in corrupted model files, this method is no longer recommended.
When magic-pdf <= 0.8.1, if you have previously downloaded the model files via git lfs, you can navigate to the previous download directory and update the models using the `git pull` command.
If you previously downloaded model files via git lfs, you can navigate to the previous download directory and use the `git pull` command to update the model.
> For versions 0.9.x and later, due to the repository change and the addition of the layout sorting model in PDF-Extract-Kit 1.0, the models cannot be updated using the `git pull` command. Instead, a Python script must be used for one-click updates.
## 2. Models downloaded via Hugging Face or Model Scope
......
......@@ -8,9 +8,8 @@
<summary>方法一:从 Hugging Face 下载模型</summary>
<p>使用python脚本 从Hugging Face下载模型文件</p>
<pre><code>pip install huggingface_hub
wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models_hf.py
wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models_hf.py -O download_models_hf.py
python download_models_hf.py</code></pre>
<p>python脚本执行完毕后,会输出模型下载目录</p>
</details>
## 方法二:从 ModelScope 下载模型
......@@ -19,24 +18,26 @@ python download_models_hf.py</code></pre>
```bash
pip install modelscope
wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models.py
wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models.py -O download_models.py
python download_models.py
```
python脚本执行完毕后,会输出模型下载目录
python脚本会自动下载模型文件并配置好配置文件中的模型目录
## 下载完成后的操作:修改magic-pdf.json中的模型路径
`~/magic-pdf.json`里修改模型的目录指向上一步脚本输出的models目录的绝对路径,否则会报模型无法加载的错误。
配置文件可以在用户目录中找到,文件名为`magic-pdf.json`
> windows的用户目录为 "C:\\Users\\用户名", linux用户目录为 "/home/用户名", macOS用户目录为 "/Users/用户名"
# 此前下载过模型,如何更新
## 1. 通过git lfs下载过模型
>由于部分用户反馈通过git lfs下载模型文件遇到下载不全和模型文件损坏情况,现已不推荐使用该方式下载。
> 由于部分用户反馈通过git lfs下载模型文件遇到下载不全和模型文件损坏情况,现已不推荐使用该方式下载。
当magic-pdf <= 0.8.1时,如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
> 0.9.x及以后版本由于PDF-Extract-Kit 1.0更换仓库和新增layout排序模型,不能通过`git pull`命令更新,需要使用python脚本一键更新。
如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。
## 2. 通过 Hugging Face 或 Model Scope 下载过模型
......
......@@ -4,10 +4,20 @@
"bucket-name-2":["ak", "sk", "endpoint"]
},
"models-dir":"/tmp/models",
"layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu",
"layout-config": {
"model": "layoutlmv3"
},
"formula-config": {
"mfd_model": "yolo_v8_mfd",
"mfr_model": "unimernet_small",
"enable": true
},
"table-config": {
"model": "TableMaster",
"is_table_recog_enable": false,
"model": "tablemaster",
"enable": false,
"max_time": 400
}
},
"config_version": "1.0.0"
}
\ No newline at end of file
import enum
class SupportedPdfParseMethod(enum.Enum):
OCR = 'ocr'
TXT = 'txt'
class FileNotExisted(Exception):
def __init__(self, path):
self.path = path
def __str__(self):
return f'File {self.path} does not exist.'
class InvalidConfig(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Invalid config: {self.msg}'
class InvalidParams(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Invalid params: {self.msg}'
class EmptyData(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return f'Empty data: {self.msg}'
from magic_pdf.data.data_reader_writer.filebase import \
FileBasedDataReader # noqa: F401
from magic_pdf.data.data_reader_writer.filebase import \
FileBasedDataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
MultiBucketS3DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import \
MultiBucketS3DataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.s3 import S3DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.s3 import S3DataWriter # noqa: F401
from magic_pdf.data.data_reader_writer.base import DataReader # noqa: F401
from magic_pdf.data.data_reader_writer.base import DataWriter # noqa: F401
\ No newline at end of file
from abc import ABC, abstractmethod
class DataReader(ABC):
def read(self, path: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return self.read_at(path)
@abstractmethod
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read the file at offset and limit.
Args:
path (str): the file path
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of the file
"""
pass
class DataWriter(ABC):
@abstractmethod
def write(self, path: str, data: bytes) -> None:
"""Write the data to the file.
Args:
path (str): the target file where to write
data (bytes): the data want to write
"""
pass
def write_string(self, path: str, data: str) -> None:
"""Write the data to file, the data will be encoded to bytes.
Args:
path (str): the target file where to write
data (str): the data want to write
"""
self.write(path, data.encode())
import os
from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
class FileBasedDataReader(DataReader):
def __init__(self, parent_dir: str = ''):
"""Initialized with parent_dir.
Args:
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
"""
self._parent_dir = parent_dir
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
fn_path = path
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
with open(fn_path, 'rb') as f:
f.seek(offset)
if limit == -1:
return f.read()
else:
return f.read(limit)
class FileBasedDataWriter(DataWriter):
def __init__(self, parent_dir: str = '') -> None:
"""Initialized with parent_dir.
Args:
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
"""
self._parent_dir = parent_dir
def write(self, path: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
fn_path = path
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
fn_path = os.path.join(self._parent_dir, path)
with open(fn_path, 'wb') as f:
f.write(data)
from magic_pdf.config.exceptions import InvalidConfig, InvalidParams
from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
from magic_pdf.data.io.s3 import S3Reader, S3Writer
from magic_pdf.data.schemas import S3Config
from magic_pdf.libs.path_utils import (parse_s3_range_params, parse_s3path,
remove_non_official_s3_args)
class MultiS3Mixin:
def __init__(self, default_bucket: str, s3_configs: list[S3Config]):
"""Initialized with multiple s3 configs.
Args:
default_bucket (str): the default bucket name of the relative path
s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
Raises:
InvalidConfig: default bucket config not in s3_configs
InvalidConfig: bucket name not unique in s3_configs
InvalidConfig: default bucket must be provided
"""
if len(default_bucket) == 0:
raise InvalidConfig('default_bucket must be provided')
found_default_bucket_config = False
for conf in s3_configs:
if conf.bucket_name == default_bucket:
found_default_bucket_config = True
break
if not found_default_bucket_config:
raise InvalidConfig(
f'default_bucket: {default_bucket} config must be provided in s3_configs: {s3_configs}'
)
uniq_bucket = set([conf.bucket_name for conf in s3_configs])
if len(uniq_bucket) != len(s3_configs):
raise InvalidConfig(
f'the bucket_name in s3_configs: {s3_configs} must be unique'
)
self.default_bucket = default_bucket
self.s3_configs = s3_configs
self._s3_clients_h: dict = {}
class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
def read(self, path: str) -> bytes:
"""Read the path from s3, select diffect bucket client for each request
based on the path, also support range read.
Args:
path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit
for example: s3://bucket_name/path?0,100
Returns:
bytes: the content of s3 file
"""
may_range_params = parse_s3_range_params(path)
if may_range_params is None or 2 != len(may_range_params):
byte_start, byte_len = 0, -1
else:
byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
path = remove_non_official_s3_args(path)
return self.read_at(path, byte_start, byte_len)
def __get_s3_client(self, bucket_name: str):
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
raise InvalidParams(
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
)
if bucket_name not in self._s3_clients_h:
conf = next(
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
)
self._s3_clients_h[bucket_name] = S3Reader(
bucket_name,
conf.access_key,
conf.secret_key,
conf.endpoint_url,
conf.addressing_style,
)
return self._s3_clients_h[bucket_name]
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read the file with offset and limit, select diffect bucket client
for each request based on the path.
Args:
path (str): the file path
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
Returns:
bytes: the file content
"""
if path.startswith('s3://'):
bucket_name, path = parse_s3path(path)
s3_reader = self.__get_s3_client(bucket_name)
else:
s3_reader = self.__get_s3_client(self.default_bucket)
return s3_reader.read_at(path, offset, limit)
class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
def __get_s3_client(self, bucket_name: str):
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
raise InvalidParams(
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
)
if bucket_name not in self._s3_clients_h:
conf = next(
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
)
self._s3_clients_h[bucket_name] = S3Writer(
bucket_name,
conf.access_key,
conf.secret_key,
conf.endpoint_url,
conf.addressing_style,
)
return self._s3_clients_h[bucket_name]
def write(self, path: str, data: bytes) -> None:
"""Write file with data, also select diffect bucket client for each
request based on the path.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
if path.startswith('s3://'):
bucket_name, path = parse_s3path(path)
s3_writer = self.__get_s3_client(bucket_name)
else:
s3_writer = self.__get_s3_client(self.default_bucket)
return s3_writer.write(path, data)
from magic_pdf.data.data_reader_writer.multi_bucket_s3 import (
MultiBucketS3DataReader, MultiBucketS3DataWriter)
from magic_pdf.data.schemas import S3Config
class S3DataReader(MultiBucketS3DataReader):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
super().__init__(
bucket,
[
S3Config(
bucket_name=bucket,
access_key=ak,
secret_key=sk,
endpoint_url=endpoint_url,
addressing_style=addressing_style,
)
],
)
class S3DataWriter(MultiBucketS3DataWriter):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 writer client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
super().__init__(
bucket,
[
S3Config(
bucket_name=bucket,
access_key=ak,
secret_key=sk,
endpoint_url=endpoint_url,
addressing_style=addressing_style,
)
],
)
from abc import ABC, abstractmethod
from typing import Iterator
import fitz
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.schemas import PageInfo
from magic_pdf.data.utils import fitz_doc_to_image
class PageableData(ABC):
@abstractmethod
def get_image(self) -> dict:
"""Transform data to image."""
pass
@abstractmethod
def get_doc(self) -> fitz.Page:
"""Get the pymudoc page."""
pass
@abstractmethod
def get_page_info(self) -> PageInfo:
"""Get the page info of the page.
Returns:
PageInfo: the page info of this page
"""
pass
class Dataset(ABC):
@abstractmethod
def __len__(self) -> int:
"""The length of the dataset."""
pass
@abstractmethod
def __iter__(self) -> Iterator[PageableData]:
"""Yield the page data."""
pass
@abstractmethod
def supported_methods(self) -> list[SupportedPdfParseMethod]:
"""The methods that this dataset support.
Returns:
list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
"""
pass
@abstractmethod
def data_bits(self) -> bytes:
"""The bits used to create this dataset."""
pass
@abstractmethod
def get_page(self, page_id: int) -> PageableData:
"""Get the page indexed by page_id.
Args:
page_id (int): the index of the page
Returns:
PageableData: the page doc object
"""
pass
class PymuDocDataset(Dataset):
def __init__(self, bits: bytes):
"""Initialize the dataset, which wraps the pymudoc documents.
Args:
bits (bytes): the bytes of the pdf
"""
self._records = [Doc(v) for v in fitz.open('pdf', bits)]
self._data_bits = bits
self._raw_data = bits
def __len__(self) -> int:
"""The page number of the pdf."""
return len(self._records)
def __iter__(self) -> Iterator[PageableData]:
"""Yield the page doc object."""
return iter(self._records)
def supported_methods(self) -> list[SupportedPdfParseMethod]:
"""The method supported by this dataset.
Returns:
list[SupportedPdfParseMethod]: the supported methods
"""
return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]
def data_bits(self) -> bytes:
"""The pdf bits used to create this dataset."""
return self._data_bits
def get_page(self, page_id: int) -> PageableData:
"""The page doc object.
Args:
page_id (int): the page doc index
Returns:
PageableData: the page doc object
"""
return self._records[page_id]
class ImageDataset(Dataset):
def __init__(self, bits: bytes):
"""Initialize the dataset, which wraps the pymudoc documents.
Args:
bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
"""
pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
self._records = [Doc(v) for v in fitz.open('pdf', pdf_bytes)]
self._raw_data = bits
self._data_bits = pdf_bytes
def __len__(self) -> int:
"""The length of the dataset."""
return len(self._records)
def __iter__(self) -> Iterator[PageableData]:
"""Yield the page object."""
return iter(self._records)
def supported_methods(self):
"""The method supported by this dataset.
Returns:
list[SupportedPdfParseMethod]: the supported methods
"""
return [SupportedPdfParseMethod.OCR]
def data_bits(self) -> bytes:
"""The pdf bits used to create this dataset."""
return self._data_bits
def get_page(self, page_id: int) -> PageableData:
"""The page doc object.
Args:
page_id (int): the page doc index
Returns:
PageableData: the page doc object
"""
return self._records[page_id]
class Doc(PageableData):
"""Initialized with pymudoc object."""
def __init__(self, doc: fitz.Page):
self._doc = doc
def get_image(self):
"""Return the imge info.
Returns:
dict: {
img: np.ndarray,
width: int,
height: int
}
"""
return fitz_doc_to_image(self._doc)
def get_doc(self) -> fitz.Page:
"""Get the pymudoc object.
Returns:
fitz.Page: the pymudoc object
"""
return self._doc
def get_page_info(self) -> PageInfo:
"""Get the page info of the page.
Returns:
PageInfo: the page info of this page
"""
page_w = self._doc.rect.width
page_h = self._doc.rect.height
return PageInfo(w=page_w, h=page_h)
def __getattr__(self, name):
if hasattr(self._doc, name):
return getattr(self._doc, name)
from abc import ABC, abstractmethod
class IOReader(ABC):
@abstractmethod
def read(self, path: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
pass
@abstractmethod
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
pass
class IOWriter:
@abstractmethod
def write(self, path: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
pass
import io
import requests
from magic_pdf.data.io.base import IOReader, IOWriter
class HttpReader(IOReader):
def read(self, url: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return requests.get(url).content
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
"""Not Implemented."""
raise NotImplementedError
class HttpWriter(IOWriter):
def write(self, url: str, data: bytes) -> None:
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
files = {'file': io.BytesIO(data)}
response = requests.post(url, files=files)
assert 300 > response.status_code and response.status_code > 199
import boto3
from botocore.config import Config
from magic_pdf.data.io.base import IOReader, IOWriter
class S3Reader(IOReader):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
self._bucket = bucket
self._ak = ak
self._sk = sk
self._s3_client = boto3.client(
service_name='s3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={'addressing_style': addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'},
),
)
def read(self, key: str) -> bytes:
"""Read the file.
Args:
path (str): file path to read
Returns:
bytes: the content of the file
"""
return self.read_at(key)
def read_at(self, key: str, offset: int = 0, limit: int = -1) -> bytes:
"""Read at offset and limit.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
offset (int, optional): the number of bytes skipped. Defaults to 0.
limit (int, optional): the length of bytes want to read. Defaults to -1.
Returns:
bytes: the content of file
"""
if limit > -1:
range_header = f'bytes={offset}-{offset+limit-1}'
res = self._s3_client.get_object(
Bucket=self._bucket, Key=key, Range=range_header
)
else:
res = self._s3_client.get_object(
Bucket=self._bucket, Key=key, Range=f'bytes={offset}-'
)
return res['Body'].read()
class S3Writer(IOWriter):
def __init__(
self,
bucket: str,
ak: str,
sk: str,
endpoint_url: str,
addressing_style: str = 'auto',
):
"""s3 reader client.
Args:
bucket (str): bucket name
ak (str): access key
sk (str): secret key
endpoint_url (str): endpoint url of s3
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
"""
self._bucket = bucket
self._ak = ak
self._sk = sk
self._s3_client = boto3.client(
service_name='s3',
aws_access_key_id=ak,
aws_secret_access_key=sk,
endpoint_url=endpoint_url,
config=Config(
s3={'addressing_style': addressing_style},
retries={'max_attempts': 5, 'mode': 'standard'},
),
)
def write(self, key: str, data: bytes):
"""Write file with data.
Args:
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
data (bytes): the data want to write
"""
self._s3_client.put_object(Bucket=self._bucket, Key=key, Body=data)
import json
import os
from pathlib import Path
from magic_pdf.config.exceptions import EmptyData, InvalidParams
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
MultiBucketS3DataReader)
from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
def read_jsonl(
s3_path_or_local: str, s3_client: MultiBucketS3DataReader | None = None
) -> list[PymuDocDataset]:
"""Read the jsonl file and return the list of PymuDocDataset.
Args:
s3_path_or_local (str): local file or s3 path
s3_client (MultiBucketS3DataReader | None, optional): s3 client that support multiple bucket. Defaults to None.
Raises:
InvalidParams: if s3_path_or_local is s3 path but s3_client is not provided.
EmptyData: if no pdf file location is provided in some line of jsonl file.
InvalidParams: if the file location is s3 path but s3_client is not provided
Returns:
list[PymuDocDataset]: each line in the jsonl file will be converted to a PymuDocDataset
"""
bits_arr = []
if s3_path_or_local.startswith('s3://'):
if s3_client is None:
raise InvalidParams('s3_client is required when s3_path is provided')
jsonl_bits = s3_client.read(s3_path_or_local)
else:
jsonl_bits = FileBasedDataReader('').read(s3_path_or_local)
jsonl_d = [
json.loads(line) for line in jsonl_bits.decode().split('\n') if line.strip()
]
for d in jsonl_d[:5]:
pdf_path = d.get('file_location', '') or d.get('path', '')
if len(pdf_path) == 0:
raise EmptyData('pdf file location is empty')
if pdf_path.startswith('s3://'):
if s3_client is None:
raise InvalidParams('s3_client is required when s3_path is provided')
bits_arr.append(s3_client.read(pdf_path))
else:
bits_arr.append(FileBasedDataReader('').read(pdf_path))
return [PymuDocDataset(bits) for bits in bits_arr]
def read_local_pdfs(path: str) -> list[PymuDocDataset]:
"""Read pdf from path or directory.
Args:
path (str): pdf file path or directory that contains pdf files
Returns:
list[PymuDocDataset]: each pdf file will converted to a PymuDocDataset
"""
if os.path.isdir(path):
reader = FileBasedDataReader(path)
return [
PymuDocDataset(reader.read(doc_path.name))
for doc_path in Path(path).glob('*.pdf')
]
else:
reader = FileBasedDataReader()
bits = reader.read(path)
return [PymuDocDataset(bits)]
def read_local_images(path: str, suffixes: list[str]) -> list[ImageDataset]:
"""Read images from path or directory.
Args:
path (str): image file path or directory that contains image files
suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['jpg', 'png']
Returns:
list[ImageDataset]: each image file will converted to a ImageDataset
"""
if os.path.isdir(path):
imgs_bits = []
s_suffixes = set(suffixes)
reader = FileBasedDataReader(path)
for root, _, files in os.walk(path):
for file in files:
suffix = file.split('.')
if suffix[-1] in s_suffixes:
imgs_bits.append(reader.read(file))
return [ImageDataset(bits) for bits in imgs_bits]
else:
reader = FileBasedDataReader()
bits = reader.read(path)
return [ImageDataset(bits)]
from pydantic import BaseModel, Field
class S3Config(BaseModel):
bucket_name: str = Field(description='s3 bucket name', min_length=1)
access_key: str = Field(description='s3 access key', min_length=1)
secret_key: str = Field(description='s3 secret key', min_length=1)
endpoint_url: str = Field(description='s3 endpoint url', min_length=1)
addressing_style: str = Field(description='s3 addressing style', default='auto', min_length=1)
class PageInfo(BaseModel):
w: float = Field(description='the width of page')
h: float = Field(description='the height of page')
import fitz
import numpy as np
from magic_pdf.utils.annotations import ImportPIL
@ImportPIL
def fitz_doc_to_image(doc, dpi=200) -> dict:
"""Convert fitz.Document to image, Then convert the image to numpy array.
Args:
doc (_type_): pymudoc page
dpi (int, optional): reset the dpi of dpi. Defaults to 200.
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
"""
from PIL import Image
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = doc.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 9000 after scaling, do not scale further.
if pm.width > 9000 or pm.height > 9000:
pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
return img_dict
import re
import wordninja
from loguru import logger
from magic_pdf.libs.commons import join_path
......@@ -8,6 +7,7 @@ from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
from magic_pdf.para.para_split_v3 import ListLineTag
def __is_hyphen_at_line_end(line):
......@@ -24,37 +24,6 @@ def __is_hyphen_at_line_end(line):
return bool(re.search(r'[A-Za-z]+-\s*$', line))
def split_long_words(text):
segments = text.split(' ')
for i in range(len(segments)):
words = re.findall(r'\w+|[^\w]', segments[i], re.UNICODE)
for j in range(len(words)):
if len(words[j]) > 10:
words[j] = ' '.join(wordninja.split(words[j]))
segments[i] = ''.join(words)
return ' '.join(segments)
def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
markdown = []
for page_info in pdf_info_list:
paras_of_layout = page_info.get('para_blocks')
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
markdown.extend(page_markdown)
return '\n\n'.join(markdown)
def ocr_mk_nlp_markdown_with_para(pdf_info_dict: list):
markdown = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'nlp')
markdown.extend(page_markdown)
return '\n\n'.join(markdown)
def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
img_buket_path):
markdown_with_para_and_pagination = []
......@@ -75,61 +44,20 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
return markdown_with_para_and_pagination
def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=''):
page_markdown = []
for paras in paras_of_layout:
for para in paras:
para_text = ''
for line in para:
for span in line['spans']:
span_type = span.get('type')
content = ''
language = ''
if span_type == ContentType.Text:
content = span['content']
language = detect_lang(content)
if (language == 'en'): # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
content = f"${span['content']}$"
elif span_type == ContentType.InterlineEquation:
content = f"\n$$\n{span['content']}\n$$\n"
elif span_type in [ContentType.Image, ContentType.Table]:
if mode == 'mm':
content = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
elif mode == 'nlp':
pass
if content != '':
if language == 'en': # 英文语境下 content间需要空格分隔
para_text += content + ' '
else: # 中文语境下,content间不需要空格分隔
para_text += content
if para_text.strip() == '':
continue
else:
page_markdown.append(para_text.strip() + ' ')
return page_markdown
def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
mode,
img_buket_path='',
parse_type="auto",
lang=None
):
page_markdown = []
for para_block in paras_of_layout:
para_text = ''
para_type = para_block['type']
if para_type == BlockType.Text:
para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Title:
para_text = f'# {merge_para_with_text(para_block, parse_type=parse_type, lang=lang)}'
para_text = f'# {merge_para_with_text(para_block)}'
elif para_type == BlockType.InterlineEquation:
para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Image:
if mode == 'nlp':
continue
......@@ -142,17 +70,17 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageCaption:
para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
for block in para_block['blocks']: # 2nd.拼image_caption
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 3rd.拼image_footnote
if block['type'] == BlockType.ImageFootnote:
para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
para_text += merge_para_with_text(block) + ' \n'
elif para_type == BlockType.Table:
if mode == 'nlp':
continue
elif mode == 'mm':
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TableCaption:
para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TableBody:
for line in block['lines']:
......@@ -167,7 +95,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TableFootnote:
para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
para_text += merge_para_with_text(block) + ' \n'
if para_text.strip() == '':
continue
......@@ -177,9 +105,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
return page_markdown
def merge_para_with_text(para_block, parse_type="auto", lang=None):
def detect_language(text):
def detect_language(text):
en_pattern = r'[a-zA-Z]+'
en_matches = re.findall(en_pattern, text)
en_length = sum(len(match) for match in en_matches)
......@@ -191,8 +117,14 @@ def merge_para_with_text(para_block, parse_type="auto", lang=None):
else:
return 'empty'
def merge_para_with_text(para_block):
para_text = ''
for line in para_block['lines']:
for i, line in enumerate(para_block['lines']):
if i >= 1 and line.get(ListLineTag.IS_LIST_START_LINE, False):
para_text += ' \n'
line_text = ''
line_lang = ''
for span in line['spans']:
......@@ -202,21 +134,11 @@ def merge_para_with_text(para_block, parse_type="auto", lang=None):
if line_text != '':
line_lang = detect_lang(line_text)
for span in line['spans']:
span_type = span['type']
content = ''
if span_type == ContentType.Text:
content = span['content']
# language = detect_lang(content)
language = detect_language(content)
# 判断是否小语种
if lang is not None and lang != 'en':
content = ocr_escape_special_markdown_char(content)
else: # 非小语种逻辑
if language == 'en' and parse_type == 'ocr': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
content = ocr_escape_special_markdown_char(span['content'])
elif span_type == ContentType.InlineEquation:
content = f" ${span['content']}$ "
elif span_type == ContentType.InterlineEquation:
......@@ -237,74 +159,39 @@ def merge_para_with_text(para_block, parse_type="auto", lang=None):
return para_text
def para_to_standard_format(para, img_buket_path):
para_content = {}
if len(para) == 1:
para_content = line_to_standard_format(para[0], img_buket_path)
elif len(para) > 1:
para_text = ''
inline_equation_num = 0
for line in para:
for span in line['spans']:
language = ''
span_type = span.get('type')
content = ''
if span_type == ContentType.Text:
content = span['content']
language = detect_lang(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
content = f"${span['content']}$"
inline_equation_num += 1
if language == 'en': # 英文语境下 content间需要空格分隔
para_text += content + ' '
else: # 中文语境下,content间不需要空格分隔
para_text += content
para_content = {
'type': 'text',
'text': para_text,
'inline_equation_num': inline_equation_num,
}
return para_content
def para_to_standard_format_v2(para_block, img_buket_path, page_idx, parse_type="auto", lang=None, drop_reason=None):
def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason=None):
para_type = para_block['type']
para_content = {}
if para_type == BlockType.Text:
if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
'text': merge_para_with_text(para_block),
}
elif para_type == BlockType.Title:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
'text': merge_para_with_text(para_block),
'text_level': 1,
}
elif para_type == BlockType.InterlineEquation:
para_content = {
'type': 'equation',
'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
'text': merge_para_with_text(para_block),
'text_format': 'latex',
}
elif para_type == BlockType.Image:
para_content = {'type': 'image'}
para_content = {'type': 'image', 'img_caption': [], 'img_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.ImageBody:
para_content['img_path'] = join_path(
img_buket_path,
block['lines'][0]['spans'][0]['image_path'])
if block['type'] == BlockType.ImageCaption:
para_content['img_caption'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
para_content['img_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.ImageFootnote:
para_content['img_footnote'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
para_content['img_footnote'].append(merge_para_with_text(block))
elif para_type == BlockType.Table:
para_content = {'type': 'table'}
para_content = {'type': 'table', 'table_caption': [], 'table_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.TableBody:
if block["lines"][0]["spans"][0].get('latex', ''):
......@@ -313,9 +200,9 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, parse_type=
para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n"
para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
if block['type'] == BlockType.TableCaption:
para_content['table_caption'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
para_content['table_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.TableFootnote:
para_content['table_footnote'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
para_content['table_footnote'].append(merge_para_with_text(block))
para_content['page_idx'] = page_idx
......@@ -325,88 +212,11 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, parse_type=
return para_content
def make_standard_format_with_para(pdf_info_dict: list, img_buket_path: str):
content_list = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
if not paras_of_layout:
continue
for para_block in paras_of_layout:
para_content = para_to_standard_format_v2(para_block,
img_buket_path)
content_list.append(para_content)
return content_list
def line_to_standard_format(line, img_buket_path):
line_text = ''
inline_equation_num = 0
for span in line['spans']:
if not span.get('content'):
if not span.get('image_path'):
continue
else:
if span['type'] == ContentType.Image:
content = {
'type': 'image',
'img_path': join_path(img_buket_path,
span['image_path']),
}
return content
elif span['type'] == ContentType.Table:
content = {
'type': 'table',
'img_path': join_path(img_buket_path,
span['image_path']),
}
return content
else:
if span['type'] == ContentType.InterlineEquation:
interline_equation = span['content']
content = {
'type': 'equation',
'latex': f'$$\n{interline_equation}\n$$'
}
return content
elif span['type'] == ContentType.InlineEquation:
inline_equation = span['content']
line_text += f'${inline_equation}$'
inline_equation_num += 1
elif span['type'] == ContentType.Text:
text_content = ocr_escape_special_markdown_char(
span['content']) # 转义特殊符号
line_text += text_content
content = {
'type': 'text',
'text': line_text,
'inline_equation_num': inline_equation_num,
}
return content
def ocr_mk_mm_standard_format(pdf_info_dict: list):
"""content_list type string
image/text/table/equation(行间的单独拿出来,行内的和text合并) latex string
latex文本字段。 text string 纯文本格式的文本数据。 md string
markdown格式的文本数据。 img_path string s3://full/path/to/img.jpg."""
content_list = []
for page_info in pdf_info_dict:
blocks = page_info.get('preproc_blocks')
if not blocks:
continue
for block in blocks:
for line in block['lines']:
content = line_to_standard_format(line)
content_list.append(content)
return content_list
def union_make(pdf_info_dict: list,
make_mode: str,
drop_mode: str,
img_buket_path: str = '',
parse_type: str = "auto",
lang=None):
):
output_content = []
for page_info in pdf_info_dict:
drop_reason_flag = False
......@@ -433,20 +243,20 @@ def union_make(pdf_info_dict: list,
continue
if make_mode == MakeMode.MM_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path, parse_type=parse_type, lang=lang)
paras_of_layout, 'mm', img_buket_path)
output_content.extend(page_markdown)
elif make_mode == MakeMode.NLP_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'nlp', parse_type=parse_type, lang=lang)
paras_of_layout, 'nlp')
output_content.extend(page_markdown)
elif make_mode == MakeMode.STANDARD_FORMAT:
for para_block in paras_of_layout:
if drop_reason_flag:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx, parse_type=parse_type, lang=lang, drop_reason=drop_reason)
para_block, img_buket_path, page_idx)
else:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx, parse_type=parse_type, lang=lang)
para_block, img_buket_path, page_idx)
output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content)
......
......@@ -10,18 +10,12 @@ block维度自定义字段
# block中lines是否被删除
LINES_DELETED = "lines_deleted"
# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"
# table recognition max time default value
TABLE_MAX_TIME_VALUE = 400
# pp_table_result_max_length
TABLE_MAX_LEN = 480
# pp table structure algorithm
TABLE_MASTER = "TableMaster"
# table master structure dict
TABLE_MASTER_DICT = "table_master_structure_dict.txt"
......@@ -38,3 +32,16 @@ REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer"
REC_CHAR_DICT = "ppocr_keys_v1.txt"
class MODEL_NAME:
# pp table structure algorithm
TABLE_MASTER = "tablemaster"
# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"
DocLayout_YOLO = "doclayout_yolo"
LAYOUTLMv3 = "layoutlmv3"
YOLO_V8_MFD = "yolo_v8_mfd"
UniMerNet_v2_Small = "unimernet_small"
\ No newline at end of file
......@@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):
# The area of overlap area
return (x_right - x_left) * (y_bottom - y_top)
def calculate_vertical_projection_overlap_ratio(block1, block2):
"""
Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
Args:
block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).
Returns:
float: The proportion of the x-axis covered by the vertical projection of the two blocks.
"""
x0_1, _, x1_1, _ = block1
x0_2, _, x1_2, _ = block2
# Calculate the intersection of the x-coordinates
x_left = max(x0_1, x0_2)
x_right = min(x1_1, x1_2)
if x_right < x_left:
return 0.0
# Length of the intersection
intersection_length = x_right - x_left
# Length of the x-axis projection of the first block
block1_length = x1_1 - x0_1
if block1_length == 0:
return 0.0
# Proportion of the x-axis covered by the intersection
# logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
return intersection_length / block1_length
"""
根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
"""
"""根据bucket的名字返回对应的s3 AK, SK,endpoint三元组."""
import json
import os
from loguru import logger
from magic_pdf.libs.Constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key
# 定义配置文件名常量
CONFIG_FILE_NAME = "magic-pdf.json"
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
def read_config():
home_dir = os.path.expanduser("~")
if os.path.isabs(CONFIG_FILE_NAME):
config_file = CONFIG_FILE_NAME
else:
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
if not os.path.exists(config_file):
raise FileNotFoundError(f"{config_file} not found")
raise FileNotFoundError(f'{config_file} not found')
with open(config_file, "r", encoding="utf-8") as f:
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
def get_s3_config(bucket_name: str):
"""
~/magic-pdf.json 读出来
"""
"""~/magic-pdf.json 读出来."""
config = read_config()
bucket_info = config.get("bucket_info")
bucket_info = config.get('bucket_info')
if bucket_name not in bucket_info:
access_key, secret_key, storage_endpoint = bucket_info["[default]"]
access_key, secret_key, storage_endpoint = bucket_info['[default]']
else:
access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
if access_key is None or secret_key is None or storage_endpoint is None:
raise Exception(f"ak, sk or endpoint not found in {CONFIG_FILE_NAME}")
raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
# logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
......@@ -49,7 +47,7 @@ def get_s3_config(bucket_name: str):
def get_s3_config_dict(path: str):
access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
return {"ak": access_key, "sk": secret_key, "endpoint": storage_endpoint}
return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
def get_bucket_name(path):
......@@ -59,33 +57,65 @@ def get_bucket_name(path):
def get_local_models_dir():
config = read_config()
models_dir = config.get("models-dir")
models_dir = config.get('models-dir')
if models_dir is None:
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
return "/tmp/models"
return '/tmp/models'
else:
return models_dir
def get_local_layoutreader_model_dir():
config = read_config()
layoutreader_model_dir = config.get('layoutreader-model-dir')
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser('~')
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir
def get_device():
config = read_config()
device = config.get("device-mode")
device = config.get('device-mode')
if device is None:
logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
return "cpu"
return 'cpu'
else:
return device
def get_table_recog_config():
config = read_config()
table_config = config.get("table-config")
table_config = config.get('table-config')
if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads('{"is_table_recog_enable": false, "max_time": 400}')
return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
else:
return table_config
def get_layout_config():
config = read_config()
layout_config = config.get("layout-config")
if layout_config is None:
logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
else:
return layout_config
def get_formula_config():
config = read_config()
formula_config = config.get("formula-config")
if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
else:
return formula_config
if __name__ == "__main__":
ak, sk, endpoint = get_s3_config("llm-raw")
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.libs.commons import fitz # PyMuPDF
from magic_pdf.libs.Constants import CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
......@@ -62,7 +63,7 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox
overlay=True,
) # Draw the rectangle
page.insert_text(
(x1+2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
(x1 + 2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
) # Insert the index in the top left corner of the rectangle
......@@ -75,6 +76,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
titles_list = []
texts_list = []
interequations_list = []
lists_list = []
indexs_list = []
for page in pdf_info:
page_dropped_list = []
......@@ -83,6 +86,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
titles = []
texts = []
interequations = []
lists = []
indices = []
for dropped_bbox in page['discarded_blocks']:
page_dropped_list.append(dropped_bbox['bbox'])
......@@ -115,6 +120,11 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
texts.append(bbox)
elif block['type'] == BlockType.InterlineEquation:
interequations.append(bbox)
elif block['type'] == BlockType.List:
lists.append(bbox)
elif block['type'] == BlockType.Index:
indices.append(bbox)
tables_list.append(tables)
tables_body_list.append(tables_body)
tables_caption_list.append(tables_caption)
......@@ -126,42 +136,62 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
titles_list.append(titles)
texts_list.append(texts)
interequations_list.append(interequations)
lists_list.append(lists)
indexs_list.append(indices)
layout_bbox_list = []
table_type_order = {
'table_caption': 1,
'table_body': 2,
'table_footnote': 3
}
for page in pdf_info:
page_block_list = []
for block in page['para_blocks']:
if block['type'] in [
BlockType.Text,
BlockType.Title,
BlockType.InterlineEquation,
BlockType.List,
BlockType.Index,
]:
bbox = block['bbox']
page_block_list.append(bbox)
elif block['type'] in [BlockType.Image]:
for sub_block in block['blocks']:
bbox = sub_block['bbox']
page_block_list.append(bbox)
elif block['type'] in [BlockType.Table]:
sorted_blocks = sorted(block['blocks'], key=lambda x: table_type_order[x['type']])
for sub_block in sorted_blocks:
bbox = sub_block['bbox']
page_block_list.append(bbox)
layout_bbox_list.append(page_block_list)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158],
True)
draw_bbox_without_number(i, tables_list, page, [153, 153, 0],
True) # color !
draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0],
True)
draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102],
True)
draw_bbox_without_number(i, tables_footnote_list, page,
[229, 255, 204], True)
draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158], True)
# draw_bbox_without_number(i, tables_list, page, [153, 153, 0], True) # color !
draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0], True)
draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102], True)
draw_bbox_without_number(i, tables_footnote_list, page, [229, 255, 204], True)
# draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255],
True)
draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102],
True),
draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], True)
draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102], True),
draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_without_number(i, interequations_list, page, [0, 255, 0],
True)
draw_bbox_without_number(i, interequations_list, page, [0, 255, 0], True)
draw_bbox_without_number(i, lists_list, page, [40, 169, 92], True)
draw_bbox_without_number(i, indexs_list, page, [40, 169, 92], True)
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False)
draw_bbox_with_number(
i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False
)
# Save the PDF
pdf_docs.save(f'{out_path}/{filename}_layout.pdf')
......@@ -224,6 +254,8 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
BlockType.Text,
BlockType.Title,
BlockType.InterlineEquation,
BlockType.List,
BlockType.Index,
]:
for line in block['lines']:
for span in line['spans']:
......@@ -260,7 +292,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
texts_list = []
interequations_list = []
pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
for i in range(len(model_list)):
page_dropped_list = []
tables_body, tables_caption, tables_footnote = [], [], []
......@@ -286,8 +318,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
imgs_body.append(bbox)
elif layout_det['category_id'] == CategoryId.ImageCaption:
imgs_caption.append(bbox)
elif layout_det[
'category_id'] == CategoryId.InterlineEquation_YOLO:
elif layout_det['category_id'] == CategoryId.InterlineEquation_YOLO:
interequations.append(bbox)
elif layout_det['category_id'] == CategoryId.Abandon:
page_dropped_list.append(bbox)
......@@ -306,18 +337,15 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
imgs_footnote_list.append(imgs_footnote)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, dropped_bbox_list, page, [158, 158, 158],
True) # color !
draw_bbox_with_number(
i, dropped_bbox_list, page, [158, 158, 158], True
) # color !
draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True)
draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102],
True)
draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204],
True)
draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], True)
draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204], True)
draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255],
True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
True)
draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102], True)
draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
......@@ -332,19 +360,23 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
for page in pdf_info:
page_line_list = []
for block in page['preproc_blocks']:
if block['type'] in ['text', 'title', 'interline_equation']:
if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
if block['type'] in ['table', 'image']:
bbox = block['bbox']
index = block['index']
if block['type'] in [BlockType.Image, BlockType.Table]:
for sub_block in block['blocks']:
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
for line in sub_block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
for line in sub_block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
# for line in block['lines']:
# bbox = line['bbox']
# index = line['index']
# page_line_list.append({'index': index, 'bbox': bbox})
sorted_bboxes = sorted(page_line_list, key=lambda x: x['index'])
layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes)
pdf_docs = fitz.open('pdf', pdf_bytes)
......
......@@ -20,6 +20,8 @@ class BlockType:
InterlineEquation = 'interline_equation'
Footnote = 'footnote'
Discarded = 'discarded'
List = 'list'
Index = 'index'
class CategoryId:
......
......@@ -4,7 +4,9 @@ import fitz
import numpy as np
from loguru import logger
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
get_formula_config
from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config
......@@ -23,7 +25,7 @@ def remove_duplicates_dicts(lst):
return unique_dicts
def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
try:
from PIL import Image
except ImportError:
......@@ -32,7 +34,14 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
images = []
with fitz.open("pdf", pdf_bytes) as doc:
pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count):
if start_page_id <= index <= end_page_id:
page = doc[index]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False)
......@@ -44,6 +53,9 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {"img": img, "width": pm.width, "height": pm.height}
else:
img_dict = {"img": [], "width": 0, "height": 0}
images.append(img_dict)
return images
......@@ -57,14 +69,17 @@ class ModelSingleton:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, ocr: bool, show_log: bool, lang=None):
key = (ocr, show_log, lang)
def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang)
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
formula_enable=formula_enable, table_enable=table_enable)
return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
model = None
if model_config.__model_mode__ == "lite":
......@@ -84,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
# 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir()
device = get_device()
layout_config = get_layout_config()
if layout_model is not None:
layout_config["model"] = layout_model
formula_config = get_formula_config()
if formula_enable is not None:
formula_config["enable"] = formula_enable
table_config = get_table_recog_config()
model_input = {"ocr": ocr,
if table_enable is not None:
table_config["enable"] = table_enable
model_input = {
"ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config,
"layout_config": layout_config,
"formula_config": formula_config,
"lang": lang,
}
custom_model = CustomPEKModel(**model_input)
else:
logger.error("Not allow model_name!")
......@@ -106,19 +137,23 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None, lang=None):
model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang)
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
images = load_images_from_pdf(pdf_bytes)
if lang == "":
lang = None
# end_page_id = end_page_id if end_page_id else len(images) - 1
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(images) - 1
model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
if end_page_id > len(images) - 1:
with fitz.open("pdf", pdf_bytes) as doc:
pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = len(images) - 1
end_page_id = pdf_page_num - 1
images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
model_json = []
doc_analyze_start = time.time()
......@@ -135,6 +170,11 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict)
gc_start = time.time()
clean_memory()
gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}")
doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)},"
......
import json
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, box_area, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio,
......@@ -9,6 +10,7 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
......@@ -24,7 +26,7 @@ class MagicModel:
need_remove_list = []
page_no = model_page_info['page_info']['page_no']
horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
model_page_info, self.__docs[page_no]
model_page_info, self.__docs.get_page(page_no)
)
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
......@@ -99,7 +101,7 @@ class MagicModel:
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __init__(self, model_list: list, docs: fitz.Document):
def __init__(self, model_list: list, docs: Dataset):
self.__model_list = model_list
self.__docs = docs
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
......@@ -119,15 +121,13 @@ class MagicModel:
if left or right:
l1 = bbox1[3] - bbox1[1]
l2 = bbox2[3] - bbox2[1]
minL, maxL = min(l1, l2), max(l1, l2)
if (maxL - minL) / minL > 0.5:
return float('inf')
if bottom or top:
else:
l1 = bbox1[2] - bbox1[0]
l2 = bbox2[2] - bbox2[0]
minL, maxL = min(l1, l2), max(l1, l2)
if (maxL - minL) / minL > 0.5:
if l2 > l1 and (l2 - l1) / l1 > 0.3:
return float('inf')
return bbox_distance(bbox1, bbox2)
def __fix_footnote(self):
......@@ -215,9 +215,8 @@ class MagicModel:
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
def search_overlap_between_boxes(
subject_idx, object_idx
):
def search_overlap_between_boxes(subject_idx, object_idx):
idxes = [subject_idx, object_idx]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
......@@ -245,9 +244,9 @@ class MagicModel:
for other_object in other_objects:
ratio = max(
ratio,
get_overlap_area(
merged_bbox, other_object['bbox']
) * 1.0 / box_area(all_bboxes[object_idx]['bbox'])
get_overlap_area(merged_bbox, other_object['bbox'])
* 1.0
/ box_area(all_bboxes[object_idx]['bbox']),
)
if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
break
......@@ -365,12 +364,17 @@ class MagicModel:
if all_bboxes[j]['category_id'] == subject_category_id:
subject_idx, object_idx = j, i
if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO:
if (
search_overlap_between_boxes(subject_idx, object_idx)
>= MERGE_BOX_OVERLAP_AREA_RATIO
):
dis[i][j] = float('inf')
dis[j][i] = dis[i][j]
continue
dis[i][j] = self._bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
dis[i][j] = self._bbox_distance(
all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
)
dis[j][i] = dis[i][j]
used = set()
......@@ -461,7 +465,7 @@ class MagicModel:
if is_nearest:
nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
n_dis = self._bbox_distance(
n_dis = bbox_distance(
all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
)
if float_gt(dis[i][j], n_dis):
......@@ -557,7 +561,7 @@ class MagicModel:
# 计算已经配对的 distance 距离
for i in subject_object_relation_map.keys():
for j in subject_object_relation_map[i]:
total_subject_object_dis += self._bbox_distance(
total_subject_object_dis += bbox_distance(
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
)
......@@ -586,6 +590,245 @@ class MagicModel:
with_caption_subject.add(j)
return ret, total_subject_object_dis
def __tie_up_category_by_distance_v2(
self, page_no, subject_category_id, object_category_id
):
AXIS_MULPLICITY = 0.5
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
M = len(objects)
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
sub_obj_map_h = {i: [] for i in range(len(subjects))}
dis_by_directions = {
'top': [[-1, float('inf')]] * M,
'bottom': [[-1, float('inf')]] * M,
'left': [[-1, float('inf')]] * M,
'right': [[-1, float('inf')]] * M,
}
for i, obj in enumerate(objects):
l_x_axis, l_y_axis = (
obj['bbox'][2] - obj['bbox'][0],
obj['bbox'][3] - obj['bbox'][1],
)
axis_unit = min(l_x_axis, l_y_axis)
for j, sub in enumerate(subjects):
bbox1, bbox2, _ = _remove_overlap_between_bbox(
objects[i]['bbox'], subjects[j]['bbox']
)
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
if sum([1 if v else 0 for v in flags]) > 1:
continue
if left:
if dis_by_directions['left'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['left'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if right:
if dis_by_directions['right'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['right'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if bottom:
if dis_by_directions['bottom'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['bottom'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if top:
if dis_by_directions['top'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['top'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
'right'
][i][1] != float('inf'):
if dis_by_directions['left'][i][1] != float(
'inf'
) and dis_by_directions['right'][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['left'][i][1]
- dis_by_directions['right'][i][1]
):
left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
'bbox'
]
right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
'bbox'
]
left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]
if (
abs(left_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['left'][i][0]
> abs(right_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['right'][i][0]
):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] > dis_by_directions['right'][i][1]:
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] == float('inf'):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = [-1, float('inf')]
if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
'bottom'
][i][1] != float('inf'):
if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
'bottom'
][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['top'][i][1]
- dis_by_directions['bottom'][i][1]
):
top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']
top_bottom_x_axis = top_bottom[2] - top_bottom[0]
bottom_top_x_axis = bottom_top[2] - bottom_top[0]
if abs(top_bottom_x_axis - l_x_axis) + dis_by_directions['bottom'][i][1] > abs(
bottom_top_x_axis - l_x_axis
) + dis_by_directions['top'][i][1]:
top_or_bottom = dis_by_directions['top'][i]
else:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] == float('inf'):
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = [-1, float('inf')]
if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
'inf'
):
if AXIS_MULPLICITY * axis_unit >= abs(
left_or_right[1] - top_or_bottom[1]
):
y_axis_bbox = subjects[left_or_right[0]]['bbox']
x_axis_bbox = subjects[top_or_bottom[0]]['bbox']
if (
abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
> abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
/ l_y_axis
):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
if left_or_right[1] > top_or_bottom[1]:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
sub_obj_map_h[left_or_right[0]].append(i)
else:
if left_or_right[1] != float('inf'):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
ret = []
for i in sub_obj_map_h.keys():
ret.append(
{
'sub_bbox': {
'bbox': subjects[i]['bbox'],
'score': subjects[i]['score'],
},
'obj_bboxes': [
{'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
for j in sub_obj_map_h[i]
],
'sub_idx': i,
}
)
return ret
def get_imgs_v2(self, page_no: int):
with_captions = self.__tie_up_category_by_distance_v2(page_no, 3, 4)
with_footnotes = self.__tie_up_category_by_distance_v2(
page_no, 3, CategoryId.ImageFootnote
)
ret = []
for v in with_captions:
record = {
'image_body': v['sub_bbox'],
'image_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['image_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_tables_v2(self, page_no: int) -> list:
with_captions = self.__tie_up_category_by_distance_v2(page_no, 5, 6)
with_footnotes = self.__tie_up_category_by_distance_v2(page_no, 5, 7)
ret = []
for v in with_captions:
record = {
'table_body': v['sub_bbox'],
'table_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['table_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_imgs(self, page_no: int):
with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
with_footnotes, _ = self.__tie_up_category_by_distance(
......@@ -719,10 +962,10 @@ class MagicModel:
def get_page_size(self, page_no: int): # 获取页面宽高
# 获取当前页的page对象
page = self.__docs[page_no]
page = self.__docs.get_page(page_no).get_page_info()
# 获取当前页的宽高
page_w = page.rect.width
page_h = page.rect.height
page_w = page.w
page_h = page.h
return page_w, page_h
def __get_blocks_by_type(
......
......@@ -26,6 +26,7 @@ try:
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
from doclayout_yolo import YOLOv10
except ImportError as e:
logger.exception(e)
......@@ -42,7 +43,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
if table_model_type == STRUCT_EQTABLE:
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
else:
config = {
......@@ -83,11 +84,16 @@ def layout_model_init(weight, config_file, device):
return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None):
def doclayout_yolo_model_init(weight):
model = YOLOv10(weight)
return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
if lang is not None:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang)
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
else:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
return model
......@@ -120,19 +126,27 @@ class AtomModelSingleton:
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
if atom_model_name not in self._models:
self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[atom_model_name]
lang = kwargs.get("lang", None)
layout_model_name = kwargs.get("layout_model_name", None)
key = (atom_model_name, layout_model_name, lang)
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[key]
def atom_model_init(model_name: str, **kwargs):
if model_name == AtomicModel.Layout:
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
atom_model = layout_model_init(
kwargs.get("layout_weights"),
kwargs.get("layout_config_file"),
kwargs.get("device")
)
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
atom_model = doclayout_yolo_model_init(
kwargs.get("doclayout_yolo_weights"),
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get("mfd_weights")
......@@ -151,7 +165,7 @@ def atom_model_init(model_name: str, **kwargs):
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
kwargs.get("table_model_type"),
kwargs.get("table_model_name"),
kwargs.get("table_model_path"),
kwargs.get("table_max_time"),
kwargs.get("device")
......@@ -199,23 +213,35 @@ class CustomPEKModel:
with open(config_path, "r", encoding='utf-8') as f:
self.configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化解析配置
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
# layout config
self.layout_config = kwargs.get("layout_config")
self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
# formula config
self.formula_config = kwargs.get("formula_config")
self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
self.apply_formula = self.formula_config.get("enable", True)
# table config
self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
self.apply_table = self.table_config.get("is_table_recog_enable", False)
self.table_config = kwargs.get("table_config")
self.apply_table = self.table_config.get("enable", False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_type = self.table_config.get("model", TABLE_MASTER)
self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
# ocr config
self.apply_ocr = ocr
self.lang = kwargs.get("lang", None)
logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format(
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
"apply_table: {}, table_model: {}, lang: {}".format(
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
)
)
assert self.apply_layout, "DocAnalysis must contain layout model."
# 初始化解析方案
self.device = kwargs.get("device", self.configs["config"]["device"])
self.device = kwargs.get("device", "cpu")
logger.info("using device: {}".format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
logger.info("using models_dir: {}".format(models_dir))
......@@ -224,17 +250,16 @@ class CustomPEKModel:
# 初始化公式识别
if self.apply_formula:
# 初始化公式检测模型
# self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
)
# 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
# self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
# self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
......@@ -243,17 +268,20 @@ class CustomPEKModel:
)
# 初始化layout模型
# self.layout_model = Layoutlmv3_Predictor(
# str(os.path.join(models_dir, self.configs['weights']['layout'])),
# str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
# device=self.device
# )
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
layout_model_name=MODEL_NAME.LAYOUTLMv3,
layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
device=self.device
)
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
)
# 初始化ocr
if self.apply_ocr:
......@@ -266,12 +294,10 @@ class CustomPEKModel:
)
# init table model
if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_type]
# self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
# max_time=self.table_max_time, _device_=self.device)
table_model_dir = self.configs["weights"][self.table_model_name]
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_type=self.table_model_type,
table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
......@@ -288,7 +314,21 @@ class CustomPEKModel:
# layout检测
layout_start = time.time()
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3
layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
layout_res = []
doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 3),
}
layout_res.append(new_item)
layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection time: {layout_cost}")
......@@ -297,7 +337,7 @@ class CustomPEKModel:
if self.apply_formula:
# 公式检测
mfd_start = time.time()
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
......@@ -309,7 +349,6 @@ class CustomPEKModel:
}
layout_res.append(new_item)
latex_filling_list.append(new_item)
# bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
mf_image_list.append(bbox_img)
......@@ -346,7 +385,7 @@ class CustomPEKModel:
if torch.cuda.is_available():
properties = torch.cuda.get_device_properties(self.device)
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
if total_memory <= 8:
if total_memory <= 10:
gc_start = time.time()
clean_memory()
gc_time = round(time.time() - gc_start, 2)
......@@ -411,7 +450,7 @@ class CustomPEKModel:
# logger.info("------------------table recognition processing begins-----------------")
latex_code = None
html_code = None
if self.table_model_type == STRUCT_EQTABLE:
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad():
latex_code = self.table_model.image2latex(new_image)[0]
else:
......
......@@ -52,11 +52,11 @@ class ppTableModel(object):
rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
device = kwargs.get("device", "cpu")
use_gpu = True if device == "cuda" else False
use_gpu = True if device.startswith("cuda") else False
config = {
"use_gpu": use_gpu,
"table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
"table_algorithm": TABLE_MASTER,
"table_algorithm": "TableMaster",
"table_model_dir": table_model_dir,
"table_char_dict_path": table_char_dict_path,
"det_model_dir": det_model_dir,
......
import copy
from loguru import logger
from magic_pdf.libs.Constants import LINES_DELETED, CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
LIST_END_FLAG = ('.', '。', ';', ';')
class ListLineTag:
IS_LIST_START_LINE = "is_list_start_line"
IS_LIST_END_LINE = "is_list_end_line"
def __process_blocks(blocks):
# 对所有block预处理
# 1.通过title和interline_equation将block分组
# 2.bbox边界根据line信息重置
result = []
current_group = []
for i in range(len(blocks)):
current_block = blocks[i]
# 如果当前块是 text 类型
if current_block['type'] == 'text':
current_block["bbox_fs"] = copy.deepcopy(current_block["bbox"])
if 'lines' in current_block and len(current_block["lines"]) > 0:
current_block['bbox_fs'] = [min([line['bbox'][0] for line in current_block['lines']]),
min([line['bbox'][1] for line in current_block['lines']]),
max([line['bbox'][2] for line in current_block['lines']]),
max([line['bbox'][3] for line in current_block['lines']])]
current_group.append(current_block)
# 检查下一个块是否存在
if i + 1 < len(blocks):
next_block = blocks[i + 1]
# 如果下一个块不是 text 类型且是 title 或 interline_equation 类型
if next_block['type'] in ['title', 'interline_equation']:
result.append(current_group)
current_group = []
# 处理最后一个 group
if current_group:
result.append(current_group)
return result
def __is_list_or_index_block(block):
# 一个block如果是list block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
# index block 是一种特殊的list block
# 一个block如果是index block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
if len(block['lines']) >= 2:
first_line = block['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]
left_close_num = 0
left_not_close_num = 0
right_not_close_num = 0
right_close_num = 0
lines_text_list = []
multiple_para_flag = False
last_line = block['lines'][-1]
# 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
if (first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2 and
# block['bbox_fs'][2] - first_line['bbox'][2] < line_height and
abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2 and
block['bbox_fs'][2] - last_line['bbox'][2] > line_height
):
multiple_para_flag = True
for line in block['lines']:
line_text = ""
for span in line['spans']:
span_type = span['type']
if span_type == ContentType.Text:
line_text += span['content'].strip()
lines_text_list.append(line_text)
# 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
left_close_num += 1
elif line['bbox'][0] - block['bbox_fs'][0] > line_height:
# logger.info(f"{line_text}, {block['bbox_fs']}, {line['bbox']}")
left_not_close_num += 1
# 计算右侧是否顶格
if abs(block['bbox_fs'][2] - line['bbox'][2]) < line_height:
right_close_num += 1
else:
# 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
closed_area = 0.3 * block_weight
# closed_area = 5 * line_height
if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
right_not_close_num += 1
# 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
line_end_flag = False
# 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
line_num_flag = False
num_start_count = 0
num_end_count = 0
flag_end_count = 0
if len(lines_text_list) > 0:
for line_text in lines_text_list:
if len(line_text) > 0:
if line_text[-1] in LIST_END_FLAG:
flag_end_count += 1
if line_text[0].isdigit():
num_start_count += 1
if line_text[-1].isdigit():
num_end_count += 1
if flag_end_count / len(lines_text_list) >= 0.8:
line_end_flag = True
if num_start_count / len(lines_text_list) >= 0.8 or num_end_count / len(lines_text_list) >= 0.8:
line_num_flag = True
# 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
if ((left_close_num/len(block['lines']) >= 0.8 or right_close_num/len(block['lines']) >= 0.8)
and line_num_flag
):
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.Index
elif left_close_num >= 2 and (
right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2) and not multiple_para_flag:
# 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
if left_close_num / len(block['lines']) > 0.9:
# 这种是每个item只有一行,且左边都贴边的短item list
if flag_end_count == 0 and right_close_num / len(block['lines']) < 0.5:
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
# 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
elif line_end_flag:
for i, line in enumerate(block['lines']):
if lines_text_list[i][-1] in LIST_END_FLAG:
line[ListLineTag.IS_LIST_END_LINE] = True
if i + 1 < len(block['lines']):
block['lines'][i+1][ListLineTag.IS_LIST_START_LINE] = True
# line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
else:
line_start_flag = False
for i, line in enumerate(block['lines']):
if line_start_flag:
line[ListLineTag.IS_LIST_START_LINE] = True
line_start_flag = False
elif abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
line[ListLineTag.IS_LIST_END_LINE] = True
line_start_flag = True
# 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_LINE 结尾且数量和start line 一致
elif num_start_count >= 2 and num_start_count == flag_end_count: # 简单一点先不考虑左侧不贴边的情况
for i, line in enumerate(block['lines']):
if lines_text_list[i][0].isdigit():
line[ListLineTag.IS_LIST_START_LINE] = True
if lines_text_list[i][-1] in LIST_END_FLAG:
line[ListLineTag.IS_LIST_END_LINE] = True
else:
# 正常有缩进的list处理
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
if abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
line[ListLineTag.IS_LIST_END_LINE] = True
return BlockType.List
else:
return BlockType.Text
else:
return BlockType.Text
def __merge_2_text_blocks(block1, block2):
if len(block1['lines']) > 0:
first_line = block1['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block1_weight = block1['bbox'][2] - block1['bbox'][0]
block2_weight = block2['bbox'][2] - block2['bbox'][0]
min_block_weight = min(block1_weight, block2_weight)
if abs(block1['bbox_fs'][0] - first_line['bbox'][0]) < line_height / 2:
last_line = block2['lines'][-1]
if len(last_line['spans']) > 0:
last_span = last_line['spans'][-1]
line_height = last_line['bbox'][3] - last_line['bbox'][1]
if (abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height and
not last_span['content'].endswith(LINE_STOP_FLAG) and
# 两个block宽度差距超过2倍也不合并
abs(block1_weight - block2_weight) < min_block_weight
):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[LINES_DELETED] = True
return block1, block2
def __merge_2_list_blocks(block1, block2):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[LINES_DELETED] = True
return block1, block2
def __is_list_group(text_blocks_group):
# list group的特征是一个group内的所有block都满足以下条件
# 1.每个block都不超过3行 2. 每个block 的左边界都比较接近(逻辑简单点先不加这个规则)
for block in text_blocks_group:
if len(block['lines']) > 3:
return False
return True
def __para_merge_page(blocks):
page_text_blocks_groups = __process_blocks(blocks)
for text_blocks_group in page_text_blocks_groups:
if len(text_blocks_group) > 0:
# 需要先在合并前对所有block判断是否为list or index block
for block in text_blocks_group:
block_type = __is_list_or_index_block(block)
block['type'] = block_type
# logger.info(f"{block['type']}:{block}")
if len(text_blocks_group) > 1:
# 在合并前判断这个group 是否是一个 list group
is_list_group = __is_list_group(text_blocks_group)
# 倒序遍历
for i in range(len(text_blocks_group) - 1, -1, -1):
current_block = text_blocks_group[i]
# 检查是否有前一个块
if i - 1 >= 0:
prev_block = text_blocks_group[i - 1]
if current_block['type'] == 'text' and prev_block['type'] == 'text' and not is_list_group:
__merge_2_text_blocks(current_block, prev_block)
elif (
(current_block['type'] == BlockType.List and prev_block['type'] == BlockType.List) or
(current_block['type'] == BlockType.Index and prev_block['type'] == BlockType.Index)
):
__merge_2_list_blocks(current_block, prev_block)
else:
continue
def para_split(pdf_info_dict, debug_mode=False):
all_blocks = []
for page_num, page in pdf_info_dict.items():
blocks = copy.deepcopy(page['preproc_blocks'])
for block in blocks:
block['page_num'] = page_num
all_blocks.extend(blocks)
__para_merge_page(all_blocks)
for page_num, page in pdf_info_dict.items():
page['para_blocks'] = []
for block in all_blocks:
if block['page_num'] == page_num:
page['para_blocks'].append(block)
if __name__ == '__main__':
input_blocks = []
# 调用函数
groups = __process_blocks(input_blocks)
for group_index, group in enumerate(groups):
print(f"Group {group_index}: {group}")
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
......@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes,
end_page_id=None,
debug_mode=False,
):
return pdf_parse_union(pdf_bytes,
dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
model_list,
imageWriter,
"ocr",
SupportedPdfParseMethod.OCR,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
......
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
......@@ -9,10 +11,11 @@ def parse_pdf_by_txt(
end_page_id=None,
debug_mode=False,
):
return pdf_parse_union(pdf_bytes,
dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
model_list,
imageWriter,
"txt",
SupportedPdfParseMethod.TXT,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
......
import copy
import os
import statistics
import time
from loguru import logger
from typing import List
import torch
from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.local_math import float_equal
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.libs.ocr_content_type import ContentType, BlockType
from magic_pdf.model.magic_model import MagicModel
from magic_pdf.para.para_split_v3 import para_split
from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
from magic_pdf.pre_proc.construct_page_dict import \
ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, replace_equations_in_textblock, \
combine_chars_to_pymudict
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2
from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans, fix_discarded_block
from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \
remove_overlaps_low_confidence_spans
from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap
from magic_pdf.pre_proc.equations_replace import (
combine_chars_to_pymudict, remove_chars_in_text_blocks,
replace_equations_in_textblock)
from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
ocr_prepare_bboxes_for_layout_split_v2
from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
fix_block_spans,
fix_discarded_block, fix_block_spans_v2)
from magic_pdf.pre_proc.ocr_span_list_modify import (
get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
remove_overlaps_min_spans)
from magic_pdf.pre_proc.resolve_bbox_conflict import \
check_useful_block_horizontal_overlap
def remove_horizontal_overlap_block_which_smaller(all_bboxes):
useful_blocks = []
for bbox in all_bboxes:
useful_blocks.append({
"bbox": bbox[:4]
})
is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = check_useful_block_horizontal_overlap(useful_blocks)
useful_blocks.append({'bbox': bbox[:4]})
is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = (
check_useful_block_horizontal_overlap(useful_blocks)
)
if is_useful_block_horz_overlap:
logger.warning(
f"skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}")
f'skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}'
) # noqa: E501
for bbox in all_bboxes.copy():
if smaller_bbox == bbox[:4]:
all_bboxes.remove(bbox)
......@@ -44,27 +56,27 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes):
return is_useful_block_horz_overlap, all_bboxes
def __replace_STX_ETX(text_str:str):
""" Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
def __replace_STX_ETX(text_str: str):
"""Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
Args:
text_str (str): raw text
Returns:
_type_: replaced text
"""
""" # noqa: E501
if text_str:
s = text_str.replace('\u0002', "'")
s = s.replace("\u0003", "'")
s = s.replace('\u0003', "'")
return s
return text_str
def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]
char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[
"blocks"
text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
'blocks'
]
text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
text_blocks = replace_equations_in_textblock(
......@@ -74,50 +86,63 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_blocks = remove_chars_in_text_blocks(text_blocks)
spans = []
for v in text_blocks:
for line in v["lines"]:
for span in line["spans"]:
bbox = span["bbox"]
for line in v['lines']:
for span in line['spans']:
bbox = span['bbox']
if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
continue
if span.get('type') not in (ContentType.InlineEquation, ContentType.InterlineEquation):
if span.get('type') not in (
ContentType.InlineEquation,
ContentType.InterlineEquation,
):
spans.append(
{
"bbox": list(span["bbox"]),
"content": __replace_STX_ETX(span["text"]),
"type": ContentType.Text,
"score": 1.0,
'bbox': list(span['bbox']),
'content': __replace_STX_ETX(span['text']),
'type': ContentType.Text,
'score': 1.0,
}
)
return spans
def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
def model_init(model_name: str, local_path=None):
def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification
if torch.cuda.is_available():
device = torch.device("cuda")
device = torch.device('cuda')
if torch.cuda.is_bf16_supported():
supports_bfloat16 = True
else:
supports_bfloat16 = False
else:
device = torch.device("cpu")
device = torch.device('cpu')
supports_bfloat16 = False
if model_name == "layoutreader":
if local_path:
model = LayoutLMv3ForTokenClassification.from_pretrained(local_path)
if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在
layoutreader_model_dir = get_local_layoutreader_model_dir()
if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(
layoutreader_model_dir
)
else:
model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
logger.warning(
'local layoutreader model not exists, use online model from huggingface'
)
model = LayoutLMv3ForTokenClassification.from_pretrained(
'hantian/layoutreader'
)
# 检查设备是否支持 bfloat16
if supports_bfloat16:
model.bfloat16()
model.to(device).eval()
else:
logger.error("model name not allow")
logger.error('model name not allow')
exit(1)
return model
......@@ -131,17 +156,16 @@ class ModelSingleton:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, model_name: str, local_path=None):
def get_model(self, model_name: str):
if model_name not in self._models:
if local_path:
self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
else:
self._models[model_name] = model_init(model_name=model_name)
return self._models[model_name]
def do_predict(boxes: List[List[int]], model) -> List[int]:
from magic_pdf.model.v3.helpers import prepare_inputs, boxes2inputs, parse_logits
from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits,
prepare_inputs)
inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, model)
logits = model(**inputs).logits.cpu().squeeze(0)
......@@ -150,19 +174,6 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:
def cal_block_index(fix_blocks, sorted_bboxes):
for block in fix_blocks:
# if block['type'] in ['text', 'title', 'interline_equation']:
# line_index_list = []
# if len(block['lines']) == 0:
# block['index'] = sorted_bboxes.index(block['bbox'])
# else:
# for line in block['lines']:
# line['index'] = sorted_bboxes.index(line['bbox'])
# line_index_list.append(line['index'])
# median_value = statistics.median(line_index_list)
# block['index'] = median_value
#
# elif block['type'] in ['table', 'image']:
# block['index'] = sorted_bboxes.index(block['bbox'])
line_index_list = []
if len(block['lines']) == 0:
......@@ -174,9 +185,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
median_value = statistics.median(line_index_list)
block['index'] = median_value
# 删除图表block中的虚拟line信息
if block['type'] in ['table', 'image']:
del block['lines']
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
return fix_blocks
......@@ -189,19 +202,20 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
block_weight = x1 - x0
# 如果block高度小于n行正文,则直接返回block的bbox
if line_height*3 < block_height:
if block_height > page_h*0.25 and page_w*0.5 > block_weight > page_w*0.25: # 可能是双列结构,可以切细点
lines = int(block_height/line_height)+1
if line_height * 3 < block_height:
if (
block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
): # 可能是双列结构,可以切细点
lines = int(block_height / line_height) + 1
else:
# 如果block的宽度超过0.4页面宽度,则将block分成3行
if block_weight > page_w*0.4:
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
if block_weight > page_w * 0.4:
line_height = (y1 - y0) / 3
lines = 3
elif block_weight > page_w*0.25: # 否则将block分成两行
line_height = (y1 - y0) / 2
lines = 2
elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
lines = int(block_height / line_height) + 1
else: # 判断长宽比
if block_height/block_weight > 1.2: # 细长的不分
if block_height / block_weight > 1.2: # 细长的不分
return [[x0, y0, x1, y1]]
else: # 不细长的还是分成两行
line_height = (y1 - y0) / 2
......@@ -225,7 +239,11 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
page_line_list = []
for block in fix_blocks:
if block['type'] in ['text', 'title', 'interline_equation']:
if block['type'] in [
BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
if len(block['lines']) == 0:
bbox = block['bbox']
lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
......@@ -236,8 +254,9 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
for line in block['lines']:
bbox = line['bbox']
page_line_list.append(bbox)
elif block['type'] in ['table', 'image']:
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
bbox = block['bbox']
block["real_lines"] = copy.deepcopy(block['lines'])
lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
block['lines'] = []
for line in lines:
......@@ -252,19 +271,23 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
for left, top, right, bottom in page_line_list:
if left < 0:
logger.warning(
f"left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
f'left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
left = 0
if right > page_w:
logger.warning(
f"right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
f'right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
right = page_w
if top < 0:
logger.warning(
f"top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
f'top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
top = 0
if bottom > page_h:
logger.warning(
f"bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
f'bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
bottom = page_h
left = round(left * x_scale)
......@@ -273,10 +296,10 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
bottom = round(bottom * y_scale)
assert (
1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}"
), f'Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}' # noqa: E126, E121
boxes.append([left, top, right, bottom])
model_manager = ModelSingleton()
model = model_manager.get_model("layoutreader")
model = model_manager.get_model('layoutreader')
with torch.no_grad():
orders = do_predict(boxes, model)
sorted_bboxes = [page_line_list[i] for i in orders]
......@@ -287,159 +310,282 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
def get_line_height(blocks):
page_line_height_list = []
for block in blocks:
if block['type'] in ['text', 'title', 'interline_equation']:
if block['type'] in [
BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
for line in block['lines']:
bbox = line['bbox']
page_line_height_list.append(int(bbox[3]-bbox[1]))
page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0:
return statistics.median(page_line_height_list)
else:
return 10
def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode):
def process_groups(groups, body_key, caption_key, footnote_key):
body_blocks = []
caption_blocks = []
footnote_blocks = []
for i, group in enumerate(groups):
group[body_key]['group_id'] = i
body_blocks.append(group[body_key])
for caption_block in group[caption_key]:
caption_block['group_id'] = i
caption_blocks.append(caption_block)
for footnote_block in group[footnote_key]:
footnote_block['group_id'] = i
footnote_blocks.append(footnote_block)
return body_blocks, caption_blocks, footnote_blocks
def process_block_list(blocks, body_type, block_type):
indices = [block['index'] for block in blocks]
median_index = statistics.median(indices)
body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
return {
'type': block_type,
'bbox': body_bbox,
'blocks': blocks,
'index': median_index,
}
def revert_group_blocks(blocks):
image_groups = {}
table_groups = {}
new_blocks = []
for block in blocks:
if block['type'] in [BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote]:
group_id = block['group_id']
if group_id not in image_groups:
image_groups[group_id] = []
image_groups[group_id].append(block)
elif block['type'] in [BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote]:
group_id = block['group_id']
if group_id not in table_groups:
table_groups[group_id] = []
table_groups[group_id].append(block)
else:
new_blocks.append(block)
for group_id, blocks in image_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.ImageBody, BlockType.Image))
for group_id, blocks in table_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.TableBody, BlockType.Table))
return new_blocks
def parse_page_core(
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
):
need_drop = False
drop_reason = []
'''从magic_model对象中获取后面会用到的区块信息'''
img_blocks = magic_model.get_imgs(page_id)
table_blocks = magic_model.get_tables(page_id)
"""从magic_model对象中获取后面会用到的区块信息"""
# img_blocks = magic_model.get_imgs(page_id)
# table_blocks = magic_model.get_tables(page_id)
img_groups = magic_model.get_imgs_v2(page_id)
table_groups = magic_model.get_tables_v2(page_id)
img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
)
table_body_blocks, table_caption_blocks, table_footnote_blocks = process_groups(
table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
)
discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
inline_equations, interline_equations, interline_equation_blocks = (
magic_model.get_equations(page_id)
)
page_w, page_h = magic_model.get_page_size(page_id)
spans = magic_model.get_all_spans(page_id)
'''根据parse_mode,构造spans'''
if parse_mode == "txt":
"""根据parse_mode,构造spans"""
if parse_mode == SupportedPdfParseMethod.TXT:
"""ocr 中文本类的 span 用 pymu spans 替换!"""
pymu_spans = txt_spans_extract(
pdf_docs[page_id], inline_equations, interline_equations
)
pymu_spans = txt_spans_extract(page_doc, inline_equations, interline_equations)
spans = replace_text_span(pymu_spans, spans)
elif parse_mode == "ocr":
elif parse_mode == SupportedPdfParseMethod.OCR:
pass
else:
raise Exception("parse_mode must be txt or ocr")
raise Exception('parse_mode must be txt or ocr')
'''删除重叠spans中置信度较低的那些'''
"""删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
'''删除重叠spans中较小的那些'''
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
'''对image和table截图'''
spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter)
"""对image和table截图"""
spans = ocr_cut_image_and_table(
spans, page_doc, page_id, pdf_bytes_md5, imageWriter
)
'''将所有区块的bbox整理到一起'''
"""将所有区块的bbox整理到一起"""
# interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks = []
if len(interline_equation_blocks) > 0:
all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
interline_equation_blocks, page_w, page_h)
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equation_blocks,
page_w,
page_h,
)
else:
all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
interline_equations, page_w, page_h)
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equations,
page_w,
page_h,
)
'''先处理不需要排版的discarded_blocks'''
discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4)
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
'''如果当前页面没有bbox则跳过'''
"""如果当前页面没有bbox则跳过"""
if len(all_bboxes) == 0:
logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}")
return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
[], [], interline_equations, fix_discarded_blocks,
need_drop, drop_reason)
logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
return ocr_construct_page_component_v2(
[],
[],
page_id,
page_w,
page_h,
[],
[],
[],
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
'''将span填入blocks中'''
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.3)
"""将span填入blocks中"""
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
'''对block进行fix操作'''
fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
"""对block进行fix操作"""
fix_blocks = fix_block_spans_v2(block_with_spans)
'''获取所有line并计算正文line的高度'''
"""获取所有line并计算正文line的高度"""
line_height = get_line_height(fix_blocks)
'''获取所有line并对line排序'''
"""获取所有line并对line排序"""
sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height)
'''根据line的中位数算block的序列关系'''
"""根据line的中位数算block的序列关系"""
fix_blocks = cal_block_index(fix_blocks, sorted_bboxes)
'''重排block'''
"""将image和table的block还原回group形式参与后续流程"""
fix_blocks = revert_group_blocks(fix_blocks)
"""重排block"""
sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
'''获取QA需要外置的list'''
"""获取QA需要外置的list"""
images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
'''构造pdf_info_dict'''
page_info = ocr_construct_page_component_v2(sorted_blocks, [], page_id, page_w, page_h, [],
images, tables, interline_equations, fix_discarded_blocks,
need_drop, drop_reason)
"""构造pdf_info_dict"""
page_info = ocr_construct_page_component_v2(
sorted_blocks,
[],
page_id,
page_w,
page_h,
[],
images,
tables,
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
return page_info
def pdf_parse_union(pdf_bytes,
def pdf_parse_union(
dataset: Dataset,
model_list,
imageWriter,
parse_mode,
start_page_id=0,
end_page_id=None,
debug_mode=False,
):
pdf_bytes_md5 = compute_md5(pdf_bytes)
pdf_docs = fitz.open("pdf", pdf_bytes)
):
pdf_bytes_md5 = compute_md5(dataset.data_bits())
'''初始化空的pdf_info_dict'''
"""初始化空的pdf_info_dict"""
pdf_info_dict = {}
'''用model_list和docs对象初始化magic_model'''
magic_model = MagicModel(model_list, pdf_docs)
"""用model_list和docs对象初始化magic_model"""
magic_model = MagicModel(model_list, dataset)
'''根据输入的起始范围解析pdf'''
"""根据输入的起始范围解析pdf"""
# end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1
)
if end_page_id > len(pdf_docs) - 1:
logger.warning("end_page_id is out of range, use pdf_docs length")
end_page_id = len(pdf_docs) - 1
if end_page_id > len(dataset) - 1:
logger.warning('end_page_id is out of range, use pdf_docs length')
end_page_id = len(dataset) - 1
'''初始化启动时间'''
"""初始化启动时间"""
start_time = time.time()
for page_id, page in enumerate(pdf_docs):
'''debug时输出每页解析的耗时'''
for page_id, page in enumerate(dataset):
"""debug时输出每页解析的耗时."""
if debug_mode:
time_now = time.time()
logger.info(
f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}"
f'page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}'
)
start_time = time_now
'''解析pdf中的每一页'''
"""解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
page_info = parse_page_core(
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
)
else:
page_w = page.rect.width
page_h = page.rect.height
page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
[], [], [], [],
True, "skip page")
pdf_info_dict[f"page_{page_id}"] = page_info
page_info = page.get_page_info()
page_w = page_info.w
page_h = page_info.h
page_info = ocr_construct_page_component_v2(
[], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
)
pdf_info_dict[f'page_{page_id}'] = page_info
"""分段"""
# para_split(pdf_info_dict, debug_mode=debug_mode)
for page_num, page in pdf_info_dict.items():
page['para_blocks'] = page['preproc_blocks']
para_split(pdf_info_dict, debug_mode=debug_mode)
"""dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = {
"pdf_info": pdf_info_list,
'pdf_info': pdf_info_list,
}
clean_memory()
......
......@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT = "txt"
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None):
start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
self.pdf_bytes = pdf_bytes
self.model_list = model_list
self.image_writer = image_writer
......@@ -26,6 +26,9 @@ class AbsPipe(ABC):
self.start_page_id = start_page_id
self.end_page_id = end_page_id
self.lang = lang
self.layout_model = layout_model
self.formula_enable = formula_enable
self.table_enable = table_enable
def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data)
......@@ -95,9 +98,7 @@ class AbsPipe(ABC):
"""
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
pdf_info_list = pdf_mid_data["pdf_info"]
parse_type = pdf_mid_data["_parse_type"]
lang = pdf_mid_data.get("_lang", None)
content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path, parse_type, lang)
content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path)
return content_list
@staticmethod
......@@ -107,9 +108,7 @@ class AbsPipe(ABC):
"""
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
pdf_info_list = pdf_mid_data["pdf_info"]
parse_type = pdf_mid_data["_parse_type"]
lang = pdf_mid_data.get("_lang", None)
md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path, parse_type, lang)
md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path)
return md_content
......@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self):
pass
......@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe):
def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
......@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self):
pass
......@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe):
def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
......@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class UNIPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None):
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
self.pdf_type = jso_useful_key["_pdf_type"]
super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, lang)
super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id,
lang, layout_model, formula_enable, table_enable)
if len(self.model_list) == 0:
self.input_model_is_empty = True
else:
......@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe):
if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self):
if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug,
......
from loguru import logger
from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \
calculate_iou
calculate_iou, calculate_vertical_projection_overlap_ratio
from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.ocr_content_type import BlockType
from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block
......@@ -60,29 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
return all_bboxes, all_discarded_blocks, drop_reasons
def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_blocks, text_blocks,
title_blocks, interline_equation_blocks, page_w, page_h):
all_bboxes = []
all_discarded_blocks = []
for image in img_blocks:
x0, y0, x1, y1 = image['bbox']
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Image, None, None, None, None, image["score"]])
def add_bboxes(blocks, block_type, bboxes):
for block in blocks:
x0, y0, x1, y1 = block['bbox']
if block_type in [
BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
]:
bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"], block["group_id"]])
else:
bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"]])
for table in table_blocks:
x0, y0, x1, y1 = table['bbox']
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Table, None, None, None, None, table["score"]])
for text in text_blocks:
x0, y0, x1, y1 = text['bbox']
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Text, None, None, None, None, text["score"]])
def ocr_prepare_bboxes_for_layout_split_v2(
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks, text_blocks, title_blocks, interline_equation_blocks, page_w, page_h
):
all_bboxes = []
for title in title_blocks:
x0, y0, x1, y1 = title['bbox']
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Title, None, None, None, None, title["score"]])
for interline_equation in interline_equation_blocks:
x0, y0, x1, y1 = interline_equation['bbox']
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None, interline_equation["score"]])
add_bboxes(img_body_blocks, BlockType.ImageBody, all_bboxes)
add_bboxes(img_caption_blocks, BlockType.ImageCaption, all_bboxes)
add_bboxes(img_footnote_blocks, BlockType.ImageFootnote, all_bboxes)
add_bboxes(table_body_blocks, BlockType.TableBody, all_bboxes)
add_bboxes(table_caption_blocks, BlockType.TableCaption, all_bboxes)
add_bboxes(table_footnote_blocks, BlockType.TableFootnote, all_bboxes)
add_bboxes(text_blocks, BlockType.Text, all_bboxes)
add_bboxes(title_blocks, BlockType.Title, all_bboxes)
add_bboxes(interline_equation_blocks, BlockType.InterlineEquation, all_bboxes)
'''block嵌套问题解决'''
'''文本框与标题框重叠,优先信任文本框'''
......@@ -96,23 +101,47 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
'''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
# 通过后续大框套小框逻辑删除
'''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
'''discarded_blocks'''
all_discarded_blocks = []
add_bboxes(discarded_blocks, BlockType.Discarded, all_discarded_blocks)
'''footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的'''
footnote_blocks = []
for discarded in discarded_blocks:
x0, y0, x1, y1 = discarded['bbox']
all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]])
# 将footnote加入到all_bboxes中,用来计算layout
# if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
# all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]])
if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
footnote_blocks.append([x0, y0, x1, y1])
'''移除在footnote下面的任何框'''
need_remove_blocks = find_blocks_under_footnote(all_bboxes, footnote_blocks)
if len(need_remove_blocks) > 0:
for block in need_remove_blocks:
all_bboxes.remove(block)
all_discarded_blocks.append(block)
'''经过以上处理后,还存在大框套小框的情况,则删除小框'''
all_bboxes = remove_overlaps_min_blocks(all_bboxes)
all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
'''将剩余的bbox做分离处理,防止后面分layout时出错'''
# all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
return all_bboxes, all_discarded_blocks
def find_blocks_under_footnote(all_bboxes, footnote_blocks):
need_remove_blocks = []
for block in all_bboxes:
block_x0, block_y0, block_x1, block_y1 = block[:4]
for footnote_bbox in footnote_blocks:
footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
if block_y0 >= footnote_y1 and calculate_vertical_projection_overlap_ratio((block_x0, block_y0, block_x1, block_y1), footnote_bbox) >= 0.8:
if block not in need_remove_blocks:
need_remove_blocks.append(block)
break
return need_remove_blocks
def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
# 先提取所有text和interline block
text_blocks = []
......
......@@ -49,8 +49,7 @@ def merge_spans_to_line(spans):
continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'],
current_line[-1]['bbox']):
if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.5):
current_line.append(span)
else:
# 否则,开始新行
......@@ -154,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio):
'type': block_type,
'bbox': block_bbox,
}
if block_type in [
BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
]:
block_dict["group_id"] = block[-1]
block_spans = []
for span in spans:
span_bbox = span['bbox']
......@@ -202,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks):
return fix_blocks
def fix_block_spans_v2(block_with_spans):
"""1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
需要将caption和footnote的text_span放入相应img_block和table_block内的
caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
fix_blocks = []
for block in block_with_spans:
block_type = block['type']
if block_type in [BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
block = fix_text_block(block)
elif block_type in [BlockType.InterlineEquation, BlockType.ImageBody, BlockType.TableBody]:
block = fix_interline_block(block)
else:
continue
fix_blocks.append(block)
return fix_blocks
def fix_discarded_block(discarded_block_with_spans):
fix_discarded_blocks = []
for block in discarded_block_with_spans:
......
config:
device: cpu
layout: True
formula: True
table_config:
model: TableMaster
is_table_recog_enable: False
max_time: 400
weights:
layout: Layout/model_final.pth
mfd: MFD/weights.pt
mfr: MFR/unimernet_small
layoutlmv3: Layout/LayoutLMv3/model_final.pth
doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
unimernet_small: MFR/unimernet_small
struct_eqtable: TabRec/StructEqTable
TableMaster: TabRec/TableMaster
\ No newline at end of file
tablemaster: TabRec/TableMaster
\ No newline at end of file
......@@ -52,7 +52,7 @@ without method specified, auto will be used by default.""",
help="""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
You should input "Abbreviation" with language form url:
https://paddlepaddle.github.io/PaddleOCR/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
https://paddlepaddle.github.io/PaddleOCR/latest/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
""",
default=None,
)
......
......@@ -6,8 +6,8 @@ import click
from loguru import logger
import magic_pdf.model as model_config
from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox,
draw_model_bbox, draw_line_sort_bbox)
from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
draw_model_bbox, draw_span_bbox)
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe
......@@ -46,10 +46,12 @@ def do_parse(
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
if debug_able:
logger.warning('debug mode is on')
# f_dump_content_list = True
f_draw_model_bbox = True
f_draw_line_sort_bbox = True
......@@ -64,13 +66,16 @@ def do_parse(
if parse_method == 'auto':
jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'txt':
pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'ocr':
pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
else:
logger.error('unknown parse method')
exit(1)
......
......@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
if input_model_is_empty:
pdf_models = doc_analyze(pdf_bytes,
layout_model = kwargs.get("layout_model", None)
formula_enable = kwargs.get("formula_enable", None)
table_enable = kwargs.get("table_enable", None)
pdf_models = doc_analyze(
pdf_bytes,
ocr=True,
start_page_id=start_page_id,
end_page_id=end_page_id,
lang=lang)
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
if pdf_info_dict is None:
raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
......
from loguru import logger
def ImportPIL(f):
try:
import PIL # noqa: F401
except ImportError:
logger.error('Pillow not installed, please install by pip.')
exit(1)
return f
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.10"
formats:
- epub
python:
install:
- requirements: docs/requirements.txt
sphinx:
configuration: docs/en/conf.py
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
Data Api
------------------
.. toctree::
:maxdepth: 2
api/dataset.rst
api/data_reader_writer.rst
api/read_api.rst
Data Reader Writer
--------------------
.. autoclass:: magic_pdf.data.data_reader_writer.DataReader
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.DataWriter
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.S3DataReader
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.S3DataWriter
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.FileBasedDataReader
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.FileBasedDataWriter
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.S3DataReader
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.S3DataWriter
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.MultiBucketS3DataReader
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.MultiBucketS3DataWriter
:members:
:inherited-members:
Dataset Api
------------------
.. autoclass:: magic_pdf.data.dataset.PageableData
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.dataset.Dataset
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.dataset.ImageDataset
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.dataset.PymuDocDataset
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.dataset.Doc
:members:
:inherited-members:
read_api Api
------------------
.. automodule:: magic_pdf.data.read_api
:members:
:inherited-members:
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
import os
import subprocess
import sys
from sphinx.ext import autodoc
def install(package):
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
requirements_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
if os.path.exists(requirements_path):
with open(requirements_path) as f:
packages = f.readlines()
for package in packages:
install(package.strip())
sys.path.insert(0, os.path.abspath('../..'))
# -- Project information -----------------------------------------------------
project = 'MinerU'
copyright = '2024, MinerU Contributors'
author = 'OpenDataLab'
# The full version, including alpha/beta/rc tags
version_file = '../../magic_pdf/libs/version.py'
with open(version_file) as f:
exec(compile(f.read(), version_file, 'exec'))
__version__ = locals()['__version__']
# The short X.Y version
version = __version__
# The full version, including alpha/beta/rc tags
release = __version__
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx_copybutton',
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'myst_parser',
'sphinxarg.ext',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# Exclude the prompt "$" when copying code
copybutton_prompt_text = r'\$ '
copybutton_prompt_is_regexp = True
language = 'en'
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_book_theme'
html_logo = '_static/image/logo.png'
html_theme_options = {
'path_to_docs': 'docs/en',
'repository_url': 'https://github.com/opendatalab/MinerU',
'use_repository_button': True,
}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
# html_static_path = ['_static']
# Mock out external dependencies here.
autodoc_mock_imports = [
'cpuinfo',
'torch',
'transformers',
'psutil',
'prometheus_client',
'sentencepiece',
'vllm.cuda_utils',
'vllm._C',
'numpy',
'tqdm',
]
class MockedClassDocumenter(autodoc.ClassDocumenter):
"""Remove note about base class when a class is derived from object."""
def add_line(self, line: str, source: str, *lineno: int) -> None:
if line == ' Bases: :py:class:`object`':
return
super().add_line(line, source, *lineno)
autodoc.ClassDocumenter = MockedClassDocumenter
navigation_with_keys = False
.. xtuner documentation master file, created by
sphinx-quickstart on Tue Jan 9 16:33:06 2024.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to the MinerU Documentation
==============================================
.. figure:: ./_static/image/logo.png
:align: center
:alt: mineru
:class: no-scaled-link
.. raw:: html
<p style="text-align:center">
<strong>A one-stop, open-source, high-quality data extraction tool
</strong>
</p>
<p style="text-align:center">
<script async defer src="https://buttons.github.io/buttons.js"></script>
<a class="github-button" href="https://github.com/opendatalab/MinerU" data-show-count="true" data-size="large" aria-label="Star">Star</a>
<a class="github-button" href="https://github.com/opendatalab/MinerU/subscription" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a>
<a class="github-button" href="https://github.com/opendatalab/MinerU/fork" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a>
</p>
API Reference
-------------
If you are looking for information on a specific function, class or
method, this part of the documentation is for you.
.. toctree::
:maxdepth: 2
api
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
boto3>=1.28.43
loguru>=0.6.0
myst-parser
Pillow==8.4.0
pydantic>=2.7.2,<2.8.0
PyMuPDF>=1.24.9
sphinx
sphinx-argparse
sphinx-book-theme
sphinx-copybutton
sphinx_rtd_theme
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.10"
formats:
- epub
python:
install:
- requirements: docs/requirements.txt
sphinx:
configuration: docs/zh_cn/conf.py
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
import os
import subprocess
import sys
from sphinx.ext import autodoc
def install(package):
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
requirements_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt'))
if os.path.exists(requirements_path):
with open(requirements_path) as f:
packages = f.readlines()
for package in packages:
install(package.strip())
sys.path.insert(0, os.path.abspath('../..'))
# -- Project information -----------------------------------------------------
project = 'MinerU'
copyright = '2024, OpenDataLab'
author = 'MinerU Contributors'
# The full version, including alpha/beta/rc tags
version_file = '../../magic_pdf/libs/version.py'
with open(version_file) as f:
exec(compile(f.read(), version_file, 'exec'))
__version__ = locals()['__version__']
# The short X.Y version
version = __version__
# The full version, including alpha/beta/rc tags
release = __version__
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx_copybutton',
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'myst_parser',
'sphinxarg.ext',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# Exclude the prompt "$" when copying code
copybutton_prompt_text = r'\$ '
copybutton_prompt_is_regexp = True
language = 'zh_CN'
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_book_theme'
html_logo = '_static/image/logo.png'
html_theme_options = {
'path_to_docs': 'docs/zh_cn',
'repository_url': 'https://github.com/opendatalab/MinerU',
'use_repository_button': True,
}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
# html_static_path = ['_static']
# Mock out external dependencies here.
autodoc_mock_imports = [
'cpuinfo',
'torch',
'transformers',
'psutil',
'prometheus_client',
'sentencepiece',
'vllm.cuda_utils',
'vllm._C',
'numpy',
'tqdm',
]
class MockedClassDocumenter(autodoc.ClassDocumenter):
"""Remove note about base class when a class is derived from object."""
def add_line(self, line: str, source: str, *lineno: int) -> None:
if line == ' Bases: :py:class:`object`':
return
super().add_line(line, source, *lineno)
autodoc.ClassDocumenter = MockedClassDocumenter
navigation_with_keys = False
.. xtuner documentation master file, created by
sphinx-quickstart on Tue Jan 9 16:33:06 2024.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
欢迎来到 MinerU 的中文文档
==============================================
.. figure:: ./_static/image/logo.png
:align: center
:alt: mineru
:class: no-scaled-link
.. raw:: html
<p style="text-align:center">
<strong> 一站式开源高质量数据提取工具
</strong>
</p>
<p style="text-align:center">
<script async defer src="https://buttons.github.io/buttons.js"></script>
<a class="github-button" href="https://github.com/opendatalab/MinerU" data-show-count="true" data-size="large" aria-label="Star">Star</a>
<a class="github-button" href="https://github.com/opendatalab/MinerU/subscription" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a>
<a class="github-button" href="https://github.com/opendatalab/MinerU/fork" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a>
</p>
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
......@@ -6,5 +6,4 @@
- [gradio_app](./gradio_app/README.md): Build a web app based on gradio
- [web_demo](./web_demo/README.md): MinerU online [demo](https://opendatalab.com/OpenSourceTools/Extractor/PDF/) localized deployment version
- [web_api](./web_api/README.md): Web API Based on FastAPI
- [multi_gpu](./multi_gpu/README.md): Multi-GPU parallel processing based on LitServe
......@@ -6,4 +6,4 @@
- [gradio_app](./gradio_app/README_zh-CN.md): 基于 Gradio 的 Web 应用
- [web_demo](./web_demo/README_zh-CN.md): MinerU在线[demo](https://opendatalab.com/OpenSourceTools/Extractor/PDF/)本地化部署版本
- [web_api](./web_api/README.md): 基于 FastAPI 的 Web API
- [multi_gpu](./multi_gpu/README.md): 基于 LitServe 的多 GPU 并行处理
......@@ -3,10 +3,12 @@
import base64
import os
import time
import uuid
import zipfile
from pathlib import Path
import re
import pymupdf
from loguru import logger
from magic_pdf.libs.hash_utils import compute_sha256
......@@ -23,7 +25,7 @@ def read_fn(path):
return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_enable, table_enable, language):
os.makedirs(output_dir, exist_ok=True)
try:
......@@ -42,6 +44,10 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
parse_method,
False,
end_page_id=end_page_id,
layout_model=layout_mode,
formula_enable=formula_enable,
table_enable=table_enable,
lang=language,
)
return local_md_dir, file_name
except Exception as e:
......@@ -93,9 +99,10 @@ def replace_image_with_base64(markdown_text, image_dir_path):
return re.sub(pattern, replace, markdown_text)
def to_markdown(file_path, end_pages, is_ocr):
def to_markdown(file_path, end_pages, is_ocr, layout_mode, formula_enable, table_enable, language):
# 获取识别的md文件以及压缩包文件路径
local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr)
local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr,
layout_mode, formula_enable, table_enable, language)
archive_zip_path = os.path.join("./output", compute_sha256(local_md_dir) + ".zip")
zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
if zip_archive_success == 0:
......@@ -138,24 +145,71 @@ with open("header.html", "r") as file:
header = file.read()
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
'sa', 'bgc'
]
other_lang = ['ch', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
all_lang = [""]
all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
def to_pdf(file_path):
with pymupdf.open(file_path) as f:
if f.is_pdf:
return file_path
else:
pdf_bytes = f.convert_to_pdf()
# 将pdfbytes 写入到uuid.pdf中
# 生成唯一的文件名
unique_filename = f"{uuid.uuid4()}.pdf"
# 构建完整的文件路径
tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename)
# 将字节数据写入文件
with open(tmp_file_path, 'wb') as tmp_pdf_file:
tmp_pdf_file.write(pdf_bytes)
return tmp_file_path
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
with gr.Column(variant='panel', scale=5):
pdf_show = gr.Markdown()
file = gr.File(label="Please upload a PDF or image", file_types=[".pdf", ".png", ".jpeg", "jpg"])
max_pages = gr.Slider(1, 10, 5, step=1, label="Max convert pages")
with gr.Row() as bu_flow:
is_ocr = gr.Checkbox(label="Force enable OCR")
with gr.Row():
layout_mode = gr.Dropdown(["layoutlmv3", "doclayout_yolo"], label="Layout model", value="layoutlmv3")
language = gr.Dropdown(all_lang, label="Language", value="")
with gr.Row():
formula_enable = gr.Checkbox(label="Enable formula recognition", value=True)
is_ocr = gr.Checkbox(label="Force enable OCR", value=False)
table_enable = gr.Checkbox(label="Enable table recognition(test)", value=False)
with gr.Row():
change_bu = gr.Button("Convert")
clear_bu = gr.ClearButton([pdf_show], value="Clear")
pdf_show = PDF(label="Please upload pdf", interactive=True, height=800)
clear_bu = gr.ClearButton(value="Clear")
pdf_show = PDF(label="PDF preview", interactive=True, height=800)
with gr.Accordion("Examples:"):
example_root = os.path.join(os.path.dirname(__file__), "examples")
gr.Examples(
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
_.endswith("pdf")],
inputs=pdf_show,
inputs=pdf_show
)
with gr.Column(variant='panel', scale=5):
......@@ -166,7 +220,9 @@ if __name__ == "__main__":
latex_delimiters=latex_delimiters, line_breaks=True)
with gr.Tab("Markdown text"):
md_text = gr.TextArea(lines=45, show_copy_button=True)
change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr], outputs=[md, md_text, output_file, pdf_show])
clear_bu.add([md, pdf_show, md_text, output_file, is_ocr])
file.upload(fn=to_pdf, inputs=file, outputs=pdf_show)
change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
outputs=[md, md_text, output_file, pdf_show])
clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language])
demo.launch()
\ No newline at end of file
demo.launch(server_name="0.0.0.0")
\ No newline at end of file
No preview for this file type
## 项目简介
本项目提供基于 LitServe 的多 GPU 并行处理方案。LitServe 是一个简便且灵活的 AI 模型服务引擎,基于 FastAPI 构建。它为 FastAPI 增强了批处理、流式传输和 GPU 自动扩展等功能,无需为每个模型单独重建 FastAPI 服务器。
## 环境配置
请使用以下命令配置所需的环境:
```bash
pip install -U litserve python-multipart filetype
pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com
pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118
```
## 快速使用
### 1. 启动服务端
以下示例展示了如何启动服务端,支持自定义设置:
```python
server = ls.LitServer(
MinerUAPI(output_dir='/tmp'), # 可自定义输出文件夹
accelerator='cuda', # 启用 GPU 加速
devices='auto', # "auto" 使用所有 GPU
workers_per_device=1, # 每个 GPU 启动一个服务实例
timeout=False # 设置为 False 以禁用超时
)
server.run(port=8000) # 设定服务端口为 8000
```
启动服务端命令:
```bash
python server.py
```
### 2. 启动客户端
以下代码展示了客户端的使用方式,可根据需求修改配置:
```python
files = ['demo/small_ocr.pdf'] # 替换为文件路径,支持 jpg/jpeg、png、pdf 文件
n_jobs = np.clip(len(files), 1, 8) # 设置并发线程数,此处最大为 8,可根据自身修改
results = Parallel(n_jobs, prefer='threads', verbose=10)(
delayed(do_parse)(p) for p in files
)
print(results)
```
启动客户端命令:
```bash
python client.py
```
好了,你的文件会自动在多个 GPU 上并行处理!🍻🍻🍻
import base64
import requests
import numpy as np
from loguru import logger
from joblib import Parallel, delayed
def to_b64(file_path):
try:
with open(file_path, 'rb') as f:
return base64.b64encode(f.read()).decode('utf-8')
except Exception as e:
raise Exception(f'File: {file_path} - Info: {e}')
def do_parse(file_path, url='http://127.0.0.1:8000/predict', **kwargs):
try:
response = requests.post(url, json={
'file': to_b64(file_path),
'kwargs': kwargs
})
if response.status_code == 200:
output = response.json()
output['file_path'] = file_path
return output
else:
raise Exception(response.text)
except Exception as e:
logger.error(f'File: {file_path} - Info: {e}')
if __name__ == '__main__':
files = ['small_ocr.pdf']
n_jobs = np.clip(len(files), 1, 8)
results = Parallel(n_jobs, prefer='threads', verbose=10)(
delayed(do_parse)(p) for p in files
)
print(results)
import os
import fitz
import torch
import base64
import litserve as ls
from uuid import uuid4
from fastapi import HTTPException
from filetype import guess_extension
from magic_pdf.tools.common import do_parse
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
class MinerUAPI(ls.LitAPI):
def __init__(self, output_dir='/tmp'):
self.output_dir = output_dir
def setup(self, device):
if device.startswith('cuda'):
os.environ['CUDA_VISIBLE_DEVICES'] = device.split(':')[-1]
if torch.cuda.device_count() > 1:
raise RuntimeError("Remove any CUDA actions before setting 'CUDA_VISIBLE_DEVICES'.")
model_manager = ModelSingleton()
model_manager.get_model(True, False)
model_manager.get_model(False, False)
print(f'Model initialization complete on {device}!')
def decode_request(self, request):
file = request['file']
file = self.to_pdf(file)
opts = request.get('kwargs', {})
opts.setdefault('debug_able', False)
opts.setdefault('parse_method', 'auto')
return file, opts
def predict(self, inputs):
try:
do_parse(self.output_dir, pdf_name := str(uuid4()), inputs[0], [], **inputs[1])
return pdf_name
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
self.clean_memory()
def encode_response(self, response):
return {'output_dir': response}
def clean_memory(self):
import gc
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
def to_pdf(self, file_base64):
try:
file_bytes = base64.b64decode(file_base64)
file_ext = guess_extension(file_bytes)
with fitz.open(stream=file_bytes, filetype=file_ext) as f:
if f.is_pdf: return f.tobytes()
return f.convert_to_pdf()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == '__main__':
server = ls.LitServer(
MinerUAPI(output_dir='/tmp'),
accelerator='cuda',
devices='auto',
workers_per_device=1,
timeout=False
)
server.run(port=8000)
......@@ -5,7 +5,6 @@ PyMuPDF>=1.24.9
loguru>=0.6.0
numpy>=1.21.6,<2.0.0
fast-langdetect==0.2.0
wordninja>=2.0.0
scikit-learn>=1.0.2
pdfminer.six==20231228
unimernet==0.2.1
......@@ -15,4 +14,5 @@ paddleocr==2.7.3
paddlepaddle==3.0.0b1
pypandoc
struct-eqtable==0.1.0
doclayout-yolo==0.0.2
detectron2
......@@ -8,7 +8,6 @@ pdfminer.six==20231228
pydantic>=2.7.2,<2.8.0
PyMuPDF>=1.24.9
scikit-learn>=1.0.2
wordninja>=2.0.0
torch>=2.2.2,<=2.3.1
transformers
# The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.
......@@ -45,6 +45,7 @@ if __name__ == '__main__':
"paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'", # windows版本3.0.0b1效率下降,需锁定2.6.1
"pypandoc", # 表格解析latex转html
"struct-eqtable==0.1.0", # 表格解析
"doclayout_yolo==0.0.2", # doclayout_yolo
"detectron2"
],
},
......
{"track_id":"e8824f5a-9fcb-4ee5-b2d4-6bf2c67019dc","path":"s3://sci-hub/enbook-scimag/78800000/libgen.scimag78872000-78872999/10.1017/cbo9780511770425.012.pdf","file_type":"pdf","content_type":"application/pdf","content_length":80078,"title":"German Idealism and the Concept of Punishment || Conclusion","remark":{"file_id":"scihub_78800000/libgen.scimag78872000-78872999.zip_10.1017/cbo9780511770425.012","file_source_type":"paper","original_file_id":"10.1017/cbo9780511770425.012","file_name":"10.1017/cbo9780511770425.012.pdf","author":"Merle, Jean-Christophe"}}
{"track_id":"e8824f5a-9fcb-4ee5-b2d4-6bf2c67019dc","path":"tests/test_data/assets/pdfs/test_02.pdf","file_type":"pdf","content_type":"application/pdf","content_length":80078,"title":"German Idealism and the Concept of Punishment || Conclusion","remark":{"file_id":"scihub_78800000/libgen.scimag78872000-78872999.zip_10.1017/cbo9780511770425.012","file_source_type":"paper","original_file_id":"10.1017/cbo9780511770425.012","file_name":"10.1017/cbo9780511770425.012.pdf","author":"Merle, Jean-Christophe"}}
import os
import shutil
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
FileBasedDataWriter)
def test_filebased_reader_writer():
unitest_dir = '/tmp/magic_pdf/unittest/data/filebased_reader_writer'
sub_dir = os.path.join(unitest_dir, 'sub')
abs_fn = os.path.join(unitest_dir, 'abspath.txt')
os.makedirs(sub_dir, exist_ok=True)
writer = FileBasedDataWriter(sub_dir)
reader = FileBasedDataReader(sub_dir)
writer.write('test.txt', b'hello world')
assert reader.read('test.txt') == b'hello world'
writer.write(abs_fn, b'hello world')
assert reader.read(abs_fn) == b'hello world'
shutil.rmtree(unitest_dir)
import json
import os
import fitz
import pytest
from magic_pdf.data.data_reader_writer import (MultiBucketS3DataReader,
MultiBucketS3DataWriter)
from magic_pdf.data.schemas import S3Config
@pytest.mark.skipif(
os.getenv('S3_ACCESS_KEY_2', None) is None, reason='need s3 config!'
)
def test_multi_bucket_s3_reader_writer():
"""test multi bucket s3 reader writer must config s3 config in the
environment export S3_BUCKET=xxx export S3_ACCESS_KEY=xxx export
S3_SECRET_KEY=xxx export S3_ENDPOINT=xxx.
export S3_BUCKET_2=xxx export S3_ACCESS_KEY_2=xxx export S3_SECRET_KEY_2=xxx export S3_ENDPOINT_2=xxx
"""
bucket = os.getenv('S3_BUCKET', '')
ak = os.getenv('S3_ACCESS_KEY', '')
sk = os.getenv('S3_SECRET_KEY', '')
endpoint_url = os.getenv('S3_ENDPOINT', '')
bucket_2 = os.getenv('S3_BUCKET_2', '')
ak_2 = os.getenv('S3_ACCESS_KEY_2', '')
sk_2 = os.getenv('S3_SECRET_KEY_2', '')
endpoint_url_2 = os.getenv('S3_ENDPOINT_2', '')
s3configs = [
S3Config(
bucket_name=bucket, access_key=ak, secret_key=sk, endpoint_url=endpoint_url
),
S3Config(
bucket_name=bucket_2,
access_key=ak_2,
secret_key=sk_2,
endpoint_url=endpoint_url_2,
),
]
reader = MultiBucketS3DataReader(default_bucket=bucket, s3_configs=s3configs)
writer = MultiBucketS3DataWriter(default_bucket=bucket, s3_configs=s3configs)
bits = reader.read('meta-index/scihub/v001/scihub/part-66210c190659-000026.jsonl')
assert bits == reader.read(
f's3://{bucket}/meta-index/scihub/v001/scihub/part-66210c190659-000026.jsonl'
)
bits = reader.read(
f's3://{bucket_2}/enbook-scimag/78800000/libgen.scimag78872000-78872999/10.1017/cbo9780511770425.012.pdf'
)
docs = fitz.open('pdf', bits)
assert len(docs) == 10
bits = reader.read(
'meta-index/scihub/v001/scihub/part-66210c190659-000026.jsonl?bytes=566,713'
)
assert bits == reader.read_at(
'meta-index/scihub/v001/scihub/part-66210c190659-000026.jsonl', 566, 713
)
assert len(json.loads(bits)) > 0
writer.write_string(
'unittest/data/data_reader_writer/multi_bucket_s3_data/test01.txt', 'abc'
)
assert 'abc'.encode() == reader.read(
'unittest/data/data_reader_writer/multi_bucket_s3_data/test01.txt'
)
writer.write(
'unittest/data/data_reader_writer/multi_bucket_s3_data/test02.txt',
'123'.encode(),
)
assert '123'.encode() == reader.read(
'unittest/data/data_reader_writer/multi_bucket_s3_data/test02.txt'
)
import json
import os
import pytest
from magic_pdf.data.data_reader_writer import S3DataReader, S3DataWriter
@pytest.mark.skipif(
os.getenv('S3_ACCESS_KEY', None) is None, reason='need s3 config!'
)
def test_multi_bucket_s3_reader_writer():
"""test multi bucket s3 reader writer must config s3 config in the
environment export S3_BUCKET=xxx export S3_ACCESS_KEY=xxx export
S3_SECRET_KEY=xxx export S3_ENDPOINT=xxx."""
bucket = os.getenv('S3_BUCKET', '')
ak = os.getenv('S3_ACCESS_KEY', '')
sk = os.getenv('S3_SECRET_KEY', '')
endpoint_url = os.getenv('S3_ENDPOINT', '')
reader = S3DataReader(bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint_url)
writer = S3DataWriter(bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint_url)
bits = reader.read('meta-index/scihub/v001/scihub/part-66210c190659-000026.jsonl')
assert bits == reader.read(
f's3://{bucket}/meta-index/scihub/v001/scihub/part-66210c190659-000026.jsonl'
)
bits = reader.read(
'meta-index/scihub/v001/scihub/part-66210c190659-000026.jsonl?bytes=566,713'
)
assert bits == reader.read_at(
'meta-index/scihub/v001/scihub/part-66210c190659-000026.jsonl', 566, 713
)
assert len(json.loads(bits)) > 0
writer.write_string(
'unittest/data/data_reader_writer/multi_bucket_s3_data/test01.txt', 'abc'
)
assert 'abc'.encode() == reader.read(
'unittest/data/data_reader_writer/multi_bucket_s3_data/test01.txt'
)
writer.write(
f'{bucket}/unittest/data/data_reader_writer/multi_bucket_s3_data/test02.txt',
'123'.encode(),
)
assert '123'.encode() == reader.read(
'unittest/data/data_reader_writer/multi_bucket_s3_data/test02.txt'
)
import json
import os
import pytest
from magic_pdf.data.io.s3 import S3Reader, S3Writer
@pytest.mark.skipif(
os.getenv('S3_ACCESS_KEY', None) is None, reason='s3 config not found'
)
def test_s3_reader():
"""test s3 reader.
must config s3 config in the environment export S3_BUCKET=xxx export S3_ACCESS_KEY=xxx export S3_SECRET_KEY=xxx
export S3_ENDPOINT=xxx
"""
bucket = os.getenv('S3_BUCKET', '')
ak = os.getenv('S3_ACCESS_KEY', '')
sk = os.getenv('S3_SECRET_KEY', '')
endpoint_url = os.getenv('S3_ENDPOINT', '')
reader = S3Reader(bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint_url)
bits = reader.read(
'meta-index/scihub/v001/scihub/part-66210c190659-000026.jsonl'
)
assert len(bits) > 0
bits = reader.read_at(
'meta-index/scihub/v001/scihub/part-66210c190659-000026.jsonl',
566,
713,
)
assert len(json.loads(bits)) > 0
@pytest.mark.skipif(
os.getenv('S3_ACCESS_KEY', None) is None, reason='s3 config not found'
)
def test_s3_writer():
"""test s3 reader.
must config s3 config in the environment export S3_BUCKET=xxx export S3_ACCESS_KEY=xxx export S3_SECRET_KEY=xxx
export S3_ENDPOINT=xxx
"""
bucket = os.getenv('S3_BUCKET', '')
ak = os.getenv('S3_ACCESS_KEY', '')
sk = os.getenv('S3_SECRET_KEY', '')
endpoint_url = os.getenv('S3_ENDPOINT', '')
writer = S3Writer(bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint_url)
test_fn = 'unittest/io/test.jsonl'
writer.write(test_fn, '123'.encode())
reader = S3Reader(bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint_url)
bits = reader.read(test_fn)
assert bits.decode() == '123'
from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
def test_pymudataset():
with open('tests/test_data/assets/pdfs/test_01.pdf', 'rb') as f:
bits = f.read()
datasets = PymuDocDataset(bits)
assert len(datasets) > 0
assert datasets.get_page(0).get_page_info().h > 100
def test_imagedataset():
with open('tests/test_data/assets/pngs/test_01.png', 'rb') as f:
bits = f.read()
datasets = ImageDataset(bits)
assert len(datasets) == 1
assert datasets.get_page(0).get_page_info().w > 100
import os
import pytest
from magic_pdf.data.data_reader_writer import MultiBucketS3DataReader
from magic_pdf.data.read_api import (read_jsonl, read_local_images,
read_local_pdfs)
from magic_pdf.data.schemas import S3Config
def test_read_local_pdfs():
datasets = read_local_pdfs('tests/test_data/assets/pdfs')
assert len(datasets) == 2
assert len(datasets[0]) > 0
assert len(datasets[1]) > 0
assert datasets[0].get_page(0).get_page_info().w > 0
assert datasets[0].get_page(0).get_page_info().h > 0
def test_read_local_images():
datasets = read_local_images('tests/test_data/assets/pngs', suffixes=['png'])
assert len(datasets) == 2
assert len(datasets[0]) == 1
assert len(datasets[1]) == 1
assert datasets[0].get_page(0).get_page_info().w > 0
assert datasets[0].get_page(0).get_page_info().h > 0
@pytest.mark.skipif(
os.getenv('S3_ACCESS_KEY_2', None) is None, reason='need s3 config!'
)
def test_read_json():
"""test multi bucket s3 reader writer must config s3 config in the
environment export S3_BUCKET=xxx export S3_ACCESS_KEY=xxx export
S3_SECRET_KEY=xxx export S3_ENDPOINT=xxx.
export S3_BUCKET_2=xxx export S3_ACCESS_KEY_2=xxx export S3_SECRET_KEY_2=xxx export S3_ENDPOINT_2=xxx
"""
bucket = os.getenv('S3_BUCKET', '')
ak = os.getenv('S3_ACCESS_KEY', '')
sk = os.getenv('S3_SECRET_KEY', '')
endpoint_url = os.getenv('S3_ENDPOINT', '')
bucket_2 = os.getenv('S3_BUCKET_2', '')
ak_2 = os.getenv('S3_ACCESS_KEY_2', '')
sk_2 = os.getenv('S3_SECRET_KEY_2', '')
endpoint_url_2 = os.getenv('S3_ENDPOINT_2', '')
s3configs = [
S3Config(
bucket_name=bucket, access_key=ak, secret_key=sk, endpoint_url=endpoint_url
),
S3Config(
bucket_name=bucket_2,
access_key=ak_2,
secret_key=sk_2,
endpoint_url=endpoint_url_2,
),
]
reader = MultiBucketS3DataReader(bucket, s3configs)
datasets = read_jsonl(
f's3://{bucket}/meta-index/scihub/v001/scihub/part-66210c190659-000026.jsonl',
reader,
)
assert len(datasets) > 0
assert len(datasets[0]) == 10
datasets = read_jsonl('tests/test_data/assets/jsonl/test_01.jsonl', reader)
assert len(datasets) == 1
assert len(datasets[0]) == 10
datasets = read_jsonl('tests/test_data/assets/jsonl/test_02.jsonl')
assert len(datasets) == 1
assert len(datasets[0]) == 1
[
{
"layout_dets": [
{
"category_id": 3,
"poly": [
776.7277221679688,
688.448974609375,
1242.224365234375,
688.448974609375,
1242.224365234375,
1182.0628662109375,
776.7277221679688,
1182.0628662109375
],
"score": 0.999997079372406
},
{
"category_id": 3,
"poly": [
775.9269409179688,
1389.754638671875,
1243.672119140625,
1389.754638671875,
1243.672119140625,
1859.716064453125,
775.9269409179688,
1859.716064453125
],
"score": 0.9999949932098389
},
{
"category_id": 1,
"poly": [
752.11572265625,
1939.3634033203125,
1430.1146240234375,
1939.3634033203125,
1430.1146240234375,
2041.1771240234375,
752.11572265625,
2041.1771240234375
],
"score": 0.999975323677063
},
{
"category_id": 3,
"poly": [
46.55152893066406,
686.12939453125,
638.8861083984375,
686.12939453125,
638.8861083984375,
1803.419189453125,
46.55152893066406,
1803.419189453125
],
"score": 0.999961256980896
},
{
"category_id": 3,
"poly": [
33.684722900390625,
150.77980041503906,
1238.0679931640625,
150.77980041503906,
1238.0679931640625,
524.98291015625,
33.684722900390625,
524.98291015625
],
"score": 0.9999504089355469
},
{
"category_id": 1,
"poly": [
24.685693740844727,
1875.9998779296875,
703.5064697265625,
1875.9998779296875,
703.5064697265625,
2050.7431640625,
24.685693740844727,
2050.7431640625
],
"score": 0.9999105334281921
},
{
"category_id": 1,
"poly": [
750.97705078125,
1252.206787109375,
1430.0809326171875,
1252.206787109375,
1430.0809326171875,
1357.2947998046875,
750.97705078125,
1357.2947998046875
],
"score": 0.999853789806366
},
{
"category_id": 4,
"poly": [
904.842041015625,
1213.027099609375,
1273.5655517578125,
1213.027099609375,
1273.5655517578125,
1242.717529296875,
904.842041015625,
1242.717529296875
],
"score": 0.9995817542076111
},
{
"category_id": 4,
"poly": [
905.3208618164062,
1898.5325927734375,
1273.1282958984375,
1898.5325927734375,
1273.1282958984375,
1928.9906005859375,
905.3208618164062,
1928.9906005859375
],
"score": 0.9986443519592285
},
{
"category_id": 4,
"poly": [
372.0135498046875,
556.02685546875,
1084.9647216796875,
556.02685546875,
1084.9647216796875,
586.6792602539062,
372.0135498046875,
586.6792602539062
],
"score": 0.9985352754592896
},
{
"category_id": 2,
"poly": [
1350.63671875,
79.77919006347656,
1379.6220703125,
79.77919006347656,
1379.6220703125,
99.83788299560547,
1350.63671875,
99.83788299560547
],
"score": 0.9973036646842957
},
{
"category_id": 4,
"poly": [
203.2659912109375,
597.2034912109375,
1251.0240478515625,
597.2034912109375,
1251.0240478515625,
657.985595703125,
203.2659912109375,
657.985595703125
],
"score": 0.9622809886932373
},
{
"category_id": 0,
"poly": [
70.87332916259766,
1834.5714111328125,
657.8504638671875,
1834.5714111328125,
657.8504638671875,
1865.07373046875,
70.87332916259766,
1865.07373046875
],
"score": 0.8580453395843506
},
{
"category_id": 1,
"poly": [
189.0360870361328,
597.2406616210938,
1252.3204345703125,
597.2406616210938,
1252.3204345703125,
658.4781494140625,
189.0360870361328,
658.4781494140625
],
"score": 0.3083903193473816
},
{
"category_id": 13,
"poly": [
1190,
1980,
1206,
1980,
1206,
1997,
1190,
1997
],
"score": 0.51,
"latex": ":"
},
{
"category_id": 13,
"poly": [
1219,
1331,
1235,
1331,
1235,
1348,
1219,
1348
],
"score": 0.49,
"latex": ":"
},
{
"category_id": 13,
"poly": [
798,
2016,
813,
2016,
813,
2033,
798,
2033
],
"score": 0.41,
"latex": ":"
},
{
"category_id": 13,
"poly": [
135,
1991,
148,
1991,
148,
2006,
135,
2006
],
"score": 0.39,
"latex": ":"
},
{
"category_id": 13,
"poly": [
400,
1916,
416,
1916,
416,
1933,
400,
1933
],
"score": 0.38,
"latex": ":"
},
{
"category_id": 13,
"poly": [
1148,
1944,
1162,
1944,
1162,
1961,
1148,
1961
],
"score": 0.31,
"latex": ":"
},
{
"category_id": 15,
"poly": [
798.0,
1943.0,
1147.0,
1943.0,
1147.0,
1968.0,
798.0,
1968.0
],
"score": 0.95,
"text": "Fig 4 SSCP analysis of FHIT exon 4. T"
},
{
"category_id": 15,
"poly": [
1163.0,
1943.0,
1425.0,
1943.0,
1425.0,
1968.0,
1163.0,
1968.0
],
"score": 0.96,
"text": "Tumor tissue ; N :Corresponding"
},
{
"category_id": 15,
"poly": [
755.0,
1979.0,
1189.0,
1979.0,
1189.0,
2004.0,
755.0,
2004.0
],
"score": 0.92,
"text": "normal tissue ; M : PBR322/Hae II Marker ; ssDNA"
},
{
"category_id": 15,
"poly": [
1207.0,
1979.0,
1422.0,
1979.0,
1422.0,
2004.0,
1207.0,
2004.0
],
"score": 0.97,
"text": "Single-stranded DNA ; ds-"
},
{
"category_id": 15,
"poly": [
755.0,
2015.0,
797.0,
2015.0,
797.0,
2038.0,
755.0,
2038.0
],
"score": 1.0,
"text": "DNA"
},
{
"category_id": 15,
"poly": [
814.0,
2015.0,
996.0,
2015.0,
996.0,
2038.0,
814.0,
2038.0
],
"score": 0.98,
"text": "Double-stranded DNA"
},
{
"category_id": 15,
"poly": [
71.0,
1880.0,
698.0,
1880.0,
698.0,
1902.0,
71.0,
1902.0
],
"score": 0.96,
"text": "Fig 2Alterations of PCR amplified products of FHIT exon 3,4,5 and"
},
{
"category_id": 15,
"poly": [
28.0,
1916.0,
399.0,
1916.0,
399.0,
1937.0,
28.0,
1937.0
],
"score": 0.98,
"text": "microsatellite marker D3S1300、D3S1312.A"
},
{
"category_id": 15,
"poly": [
417.0,
1916.0,
701.0,
1916.0,
701.0,
1937.0,
417.0,
1937.0
],
"score": 0.9,
"text": "Deletion of exon5(arrows);B :"
},
{
"category_id": 15,
"poly": [
29.0,
1953.0,
700.0,
1953.0,
700.0,
1974.0,
29.0,
1974.0
],
"score": 0.95,
"text": "Deletion of exon 3 A( arrows);C : Deletion of microsatellite marker D3S1300,"
},
{
"category_id": 15,
"poly": [
28.0,
1989.0,
134.0,
1989.0,
134.0,
2014.0,
28.0,
2014.0
],
"score": 1.0,
"text": "D3S1312.T"
},
{
"category_id": 15,
"poly": [
149.0,
1989.0,
696.0,
1989.0,
696.0,
2014.0,
149.0,
2014.0
],
"score": 0.96,
"text": "Tumor ; N : Corresponding normal tissue ; L : Corresponding lymph"
},
{
"category_id": 15,
"poly": [
30.0,
2027.0,
634.0,
2027.0,
634.0,
2047.0,
30.0,
2047.0
],
"score": 0.94,
"text": "node tissue;M :DL2000 DNA marker;L1:Lewis ;A :A549;S SPAC-1"
},
{
"category_id": 15,
"poly": [
801.0,
1259.0,
1427.0,
1259.0,
1427.0,
1280.0,
801.0,
1280.0
],
"score": 0.94,
"text": "Fig 3SSCP analysis of FHIT exon 3.The arrow indicateda deletion of"
},
{
"category_id": 15,
"poly": [
757.0,
1294.0,
1424.0,
1294.0,
1424.0,
1318.0,
757.0,
1318.0
],
"score": 0.96,
"text": "exon 3 of 41T. T : Tumor tissue ; N : Corresponding normal tissue ; M PBR322/"
},
{
"category_id": 15,
"poly": [
755.0,
1329.0,
1218.0,
1329.0,
1218.0,
1355.0,
755.0,
1355.0
],
"score": 0.95,
"text": "Hae Il Marker / ssDNA : Single-stranded DNA ; dsDNA"
},
{
"category_id": 15,
"poly": [
1236.0,
1329.0,
1418.0,
1329.0,
1418.0,
1355.0,
1236.0,
1355.0
],
"score": 1.0,
"text": "Double-strandedDNA"
},
{
"category_id": 15,
"poly": [
910.0,
1217.0,
1269.0,
1217.0,
1269.0,
1241.0,
910.0,
1241.0
],
"score": 1.0,
"text": "图3FHIT基因外显子3的SSCP分析"
},
{
"category_id": 15,
"poly": [
909.0,
1904.0,
1269.0,
1904.0,
1269.0,
1927.0,
909.0,
1927.0
],
"score": 1.0,
"text": "图4FHIT基因外显子4的SSCP分析"
},
{
"category_id": 15,
"poly": [
374.0,
563.0,
1077.0,
563.0,
1077.0,
583.0,
374.0,
583.0
],
"score": 0.99,
"text": "图1FHIT基因外显子3、4、5、8和微卫星灶的PCR扩增产物琼脂糖电泳图"
},
{
"category_id": 15,
"poly": [
1351.0,
81.0,
1376.0,
81.0,
1376.0,
102.0,
1351.0,
102.0
],
"score": 1.0,
"text": "13"
},
{
"category_id": 15,
"poly": [
207.0,
600.0,
1245.0,
600.0,
1245.0,
624.0,
207.0,
624.0
],
"score": 0.96,
"text": "Fig 1 Agarose electrophoresis of PCR products of exor( A)3 ,4 ,5 ,8 and three microsatellite markers( B)of FHIT gene"
},
{
"category_id": 15,
"poly": [
309.0,
634.0,
1142.0,
634.0,
1142.0,
662.0,
309.0,
662.0
],
"score": 0.97,
"text": "M1 :DL2000 DNA marker ; M2 PBR322/Hae Il marker ; T :Tumor ; N :Corresponding normal tissue"
},
{
"category_id": 15,
"poly": [
73.0,
1840.0,
651.0,
1840.0,
651.0,
1864.0,
73.0,
1864.0
],
"score": 1.0,
"text": "图2FHIT基因外显子和微卫星灶PCR扩增产物缺失电泳图"
},
{
"category_id": 15,
"poly": [
207.0,
600.0,
1245.0,
600.0,
1245.0,
625.0,
207.0,
625.0
],
"score": 0.96,
"text": "Fig 1 Agarose electrophoresis of PCR products of exor A)3 ,4 ,5 ,8 and three microsatellite markers( B)of FHIT gene"
},
{
"category_id": 15,
"poly": [
309.0,
635.0,
1142.0,
635.0,
1142.0,
661.0,
309.0,
661.0
],
"score": 0.97,
"text": "M1 :DL2000 DNA marker ; M2 PBR322/Hae Il marker ; T Tumor ; N :Corresponding normal tissue"
}
],
"page_info": {
"page_no": 0,
"height": 2080,
"width": 1472
}
}
]
This source diff could not be displayed because it is too large. You can view the blob instead.
import json
from magic_pdf.data.read_api import read_local_pdfs
from magic_pdf.model.magic_model import MagicModel
def test_magic_model_image_v2():
datasets = read_local_pdfs('tests/test_model/assets/test_01.pdf')
with open('tests/test_model/assets/test_01.model.json') as f:
model_json = json.load(f)
magic_model = MagicModel(model_json, datasets[0])
imgs = magic_model.get_imgs_v2(0)
print(imgs)
tables = magic_model.get_tables_v2(0)
print(tables)
def test_magic_model_table_v2():
datasets = read_local_pdfs('tests/test_model/assets/test_02.pdf')
with open('tests/test_model/assets/test_02.model.json') as f:
model_json = json.load(f)
magic_model = MagicModel(model_json, datasets[0])
tables = magic_model.get_tables_v2(5)
print(tables)
tables = magic_model.get_tables_v2(8)
print(tables)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment