Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
P
pdf-miner
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Qin Kaijie
pdf-miner
Commits
c1ba9dcb
Unverified
Commit
c1ba9dcb
authored
Oct 23, 2024
by
Xiaomeng Zhao
Committed by
GitHub
Oct 23, 2024
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #773 from myhloli/add-doclayout-yolo
feat(model): add support for DocLayout-YOLO model
parents
efb5851f
1279f2cd
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
365 additions
and
130 deletions
+365
-130
magic-pdf.template.json
magic-pdf.template.json
+1
-1
Constants.py
magic_pdf/libs/Constants.py
+13
-6
boxbase.py
magic_pdf/libs/boxbase.py
+35
-0
config_reader.py
magic_pdf/libs/config_reader.py
+22
-1
doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+38
-14
pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+82
-43
ppTableModel.py
magic_pdf/model/ppTableModel.py
+2
-2
AbsPipe.py
magic_pdf/pipe/AbsPipe.py
+4
-1
OCRPipe.py
magic_pdf/pipe/OCRPipe.py
+8
-4
TXTPipe.py
magic_pdf/pipe/TXTPipe.py
+8
-4
UNIPipe.py
magic_pdf/pipe/UNIPipe.py
+10
-5
ocr_detect_all_bboxes.py
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
+25
-3
model_configs.yaml
magic_pdf/resources/model_config/model_configs.yaml
+5
-13
common.py
magic_pdf/tools/common.py
+9
-4
user_api.py
magic_pdf/user_api.py
+13
-5
download_models.py
old_docs/download_models.py
+21
-8
download_models_hf.py
old_docs/download_models_hf.py
+29
-9
app.py
projects/gradio_app/app.py
+40
-7
No files found.
magic-pdf.template.json
View file @
c1ba9dcb
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
"layoutreader-model-dir"
:
"/tmp/layoutreader"
,
"layoutreader-model-dir"
:
"/tmp/layoutreader"
,
"device-mode"
:
"cpu"
,
"device-mode"
:
"cpu"
,
"layout-config"
:
{
"layout-config"
:
{
"model"
:
"
doclayout_yolo
"
"model"
:
"
layoutlmv3
"
},
},
"formula-config"
:
{
"formula-config"
:
{
"mfd_model"
:
"yolo_v8_mfd"
,
"mfd_model"
:
"yolo_v8_mfd"
,
...
...
magic_pdf/libs/Constants.py
View file @
c1ba9dcb
...
@@ -10,18 +10,12 @@ block维度自定义字段
...
@@ -10,18 +10,12 @@ block维度自定义字段
# block中lines是否被删除
# block中lines是否被删除
LINES_DELETED
=
"lines_deleted"
LINES_DELETED
=
"lines_deleted"
# struct eqtable
STRUCT_EQTABLE
=
"struct_eqtable"
# table recognition max time default value
# table recognition max time default value
TABLE_MAX_TIME_VALUE
=
400
TABLE_MAX_TIME_VALUE
=
400
# pp_table_result_max_length
# pp_table_result_max_length
TABLE_MAX_LEN
=
480
TABLE_MAX_LEN
=
480
# pp table structure algorithm
TABLE_MASTER
=
"TableMaster"
# table master structure dict
# table master structure dict
TABLE_MASTER_DICT
=
"table_master_structure_dict.txt"
TABLE_MASTER_DICT
=
"table_master_structure_dict.txt"
...
@@ -38,3 +32,16 @@ REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer"
...
@@ -38,3 +32,16 @@ REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer"
REC_CHAR_DICT
=
"ppocr_keys_v1.txt"
REC_CHAR_DICT
=
"ppocr_keys_v1.txt"
class
MODEL_NAME
:
# pp table structure algorithm
TABLE_MASTER
=
"tablemaster"
# struct eqtable
STRUCT_EQTABLE
=
"struct_eqtable"
DocLayout_YOLO
=
"doclayout_yolo"
LAYOUTLMv3
=
"layoutlmv3"
YOLO_V8_MFD
=
"yolo_v8_mfd"
UniMerNet_v2_Small
=
"unimernet_small"
\ No newline at end of file
magic_pdf/libs/boxbase.py
View file @
c1ba9dcb
...
@@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):
...
@@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):
# The area of overlap area
# The area of overlap area
return
(
x_right
-
x_left
)
*
(
y_bottom
-
y_top
)
return
(
x_right
-
x_left
)
*
(
y_bottom
-
y_top
)
def
calculate_vertical_projection_overlap_ratio
(
block1
,
block2
):
"""
Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
Args:
block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).
Returns:
float: The proportion of the x-axis covered by the vertical projection of the two blocks.
"""
x0_1
,
_
,
x1_1
,
_
=
block1
x0_2
,
_
,
x1_2
,
_
=
block2
# Calculate the intersection of the x-coordinates
x_left
=
max
(
x0_1
,
x0_2
)
x_right
=
min
(
x1_1
,
x1_2
)
if
x_right
<
x_left
:
return
0.0
# Length of the intersection
intersection_length
=
x_right
-
x_left
# Length of the x-axis projection of the first block
block1_length
=
x1_1
-
x0_1
if
block1_length
==
0
:
return
0.0
# Proportion of the x-axis covered by the intersection
# logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
return
intersection_length
/
block1_length
magic_pdf/libs/config_reader.py
View file @
c1ba9dcb
...
@@ -8,6 +8,7 @@ import os
...
@@ -8,6 +8,7 @@ import os
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.libs.Constants
import
MODEL_NAME
from
magic_pdf.libs.commons
import
parse_bucket_key
from
magic_pdf.libs.commons
import
parse_bucket_key
# 定义配置文件名常量
# 定义配置文件名常量
...
@@ -94,10 +95,30 @@ def get_table_recog_config():
...
@@ -94,10 +95,30 @@ def get_table_recog_config():
table_config
=
config
.
get
(
"table-config"
)
table_config
=
config
.
get
(
"table-config"
)
if
table_config
is
None
:
if
table_config
is
None
:
logger
.
warning
(
f
"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default"
)
logger
.
warning
(
f
"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default"
)
return
json
.
loads
(
'{"is_table_recog_enable": false, "max_time": 400
}'
)
return
json
.
loads
(
f
'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}
}'
)
else
:
else
:
return
table_config
return
table_config
def
get_layout_config
():
config
=
read_config
()
layout_config
=
config
.
get
(
"layout-config"
)
if
layout_config
is
None
:
logger
.
warning
(
f
"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default"
)
return
json
.
loads
(
f
'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}'
)
else
:
return
layout_config
def
get_formula_config
():
config
=
read_config
()
formula_config
=
config
.
get
(
"formula-config"
)
if
formula_config
is
None
:
logger
.
warning
(
f
"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default"
)
return
json
.
loads
(
f
'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}'
)
else
:
return
formula_config
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
ak
,
sk
,
endpoint
=
get_s3_config
(
"llm-raw"
)
ak
,
sk
,
endpoint
=
get_s3_config
(
"llm-raw"
)
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
c1ba9dcb
...
@@ -5,7 +5,8 @@ import numpy as np
...
@@ -5,7 +5,8 @@ import numpy as np
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.config_reader
import
get_local_models_dir
,
get_device
,
get_table_recog_config
from
magic_pdf.libs.config_reader
import
get_local_models_dir
,
get_device
,
get_table_recog_config
,
get_layout_config
,
\
get_formula_config
from
magic_pdf.model.model_list
import
MODEL
from
magic_pdf.model.model_list
import
MODEL
import
magic_pdf.model
as
model_config
import
magic_pdf.model
as
model_config
...
@@ -68,14 +69,17 @@ class ModelSingleton:
...
@@ -68,14 +69,17 @@ class ModelSingleton:
cls
.
_instance
=
super
()
.
__new__
(
cls
)
cls
.
_instance
=
super
()
.
__new__
(
cls
)
return
cls
.
_instance
return
cls
.
_instance
def
get_model
(
self
,
ocr
:
bool
,
show_log
:
bool
,
lang
=
None
):
def
get_model
(
self
,
ocr
:
bool
,
show_log
:
bool
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
key
=
(
ocr
,
show_log
,
lang
)
key
=
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
if
key
not
in
self
.
_models
:
if
key
not
in
self
.
_models
:
self
.
_models
[
key
]
=
custom_model_init
(
ocr
=
ocr
,
show_log
=
show_log
,
lang
=
lang
)
self
.
_models
[
key
]
=
custom_model_init
(
ocr
=
ocr
,
show_log
=
show_log
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
return
self
.
_models
[
key
]
return
self
.
_models
[
key
]
def
custom_model_init
(
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
):
def
custom_model_init
(
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
model
=
None
model
=
None
if
model_config
.
__model_mode__
==
"lite"
:
if
model_config
.
__model_mode__
==
"lite"
:
...
@@ -95,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
...
@@ -95,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
# 从配置文件读取model-dir和device
# 从配置文件读取model-dir和device
local_models_dir
=
get_local_models_dir
()
local_models_dir
=
get_local_models_dir
()
device
=
get_device
()
device
=
get_device
()
layout_config
=
get_layout_config
()
if
layout_model
is
not
None
:
layout_config
[
"model"
]
=
layout_model
formula_config
=
get_formula_config
()
if
formula_enable
is
not
None
:
formula_config
[
"enable"
]
=
formula_enable
table_config
=
get_table_recog_config
()
table_config
=
get_table_recog_config
()
model_input
=
{
"ocr"
:
ocr
,
if
table_enable
is
not
None
:
"show_log"
:
show_log
,
table_config
[
"enable"
]
=
table_enable
"models_dir"
:
local_models_dir
,
"device"
:
device
,
model_input
=
{
"table_config"
:
table_config
,
"ocr"
:
ocr
,
"lang"
:
lang
,
"show_log"
:
show_log
,
}
"models_dir"
:
local_models_dir
,
"device"
:
device
,
"table_config"
:
table_config
,
"layout_config"
:
layout_config
,
"formula_config"
:
formula_config
,
"lang"
:
lang
,
}
custom_model
=
CustomPEKModel
(
**
model_input
)
custom_model
=
CustomPEKModel
(
**
model_input
)
else
:
else
:
logger
.
error
(
"Not allow model_name!"
)
logger
.
error
(
"Not allow model_name!"
)
...
@@ -117,10 +137,14 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
...
@@ -117,10 +137,14 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
def
doc_analyze
(
pdf_bytes
:
bytes
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
def
doc_analyze
(
pdf_bytes
:
bytes
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
if
lang
==
""
:
lang
=
None
model_manager
=
ModelSingleton
()
model_manager
=
ModelSingleton
()
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
)
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
with
fitz
.
open
(
"pdf"
,
pdf_bytes
)
as
doc
:
with
fitz
.
open
(
"pdf"
,
pdf_bytes
)
as
doc
:
pdf_page_num
=
doc
.
page_count
pdf_page_num
=
doc
.
page_count
...
...
magic_pdf/model/pdf_extract_kit.py
View file @
c1ba9dcb
...
@@ -25,6 +25,7 @@ try:
...
@@ -25,6 +25,7 @@ try:
from
unimernet.common.config
import
Config
from
unimernet.common.config
import
Config
import
unimernet.tasks
as
tasks
import
unimernet.tasks
as
tasks
from
unimernet.processors
import
load_processor
from
unimernet.processors
import
load_processor
from
doclayout_yolo
import
YOLOv10
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
exception
(
e
)
logger
.
exception
(
e
)
...
@@ -41,7 +42,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
...
@@ -41,7 +42,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
def
table_model_init
(
table_model_type
,
model_path
,
max_time
,
_device_
=
'cpu'
):
def
table_model_init
(
table_model_type
,
model_path
,
max_time
,
_device_
=
'cpu'
):
if
table_model_type
==
STRUCT_EQTABLE
:
if
table_model_type
==
MODEL_NAME
.
STRUCT_EQTABLE
:
table_model
=
StructTableModel
(
model_path
,
max_time
=
max_time
,
device
=
_device_
)
table_model
=
StructTableModel
(
model_path
,
max_time
=
max_time
,
device
=
_device_
)
else
:
else
:
config
=
{
config
=
{
...
@@ -77,6 +78,11 @@ def layout_model_init(weight, config_file, device):
...
@@ -77,6 +78,11 @@ def layout_model_init(weight, config_file, device):
return
model
return
model
def
doclayout_yolo_model_init
(
weight
):
model
=
YOLOv10
(
weight
)
return
model
def
ocr_model_init
(
show_log
:
bool
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
None
,
use_dilation
=
True
,
det_db_unclip_ratio
=
2.4
):
def
ocr_model_init
(
show_log
:
bool
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
None
,
use_dilation
=
True
,
det_db_unclip_ratio
=
2.4
):
if
lang
is
not
None
:
if
lang
is
not
None
:
model
=
ModifiedPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
,
lang
=
lang
,
use_dilation
=
use_dilation
,
det_db_unclip_ratio
=
det_db_unclip_ratio
)
model
=
ModifiedPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
,
lang
=
lang
,
use_dilation
=
use_dilation
,
det_db_unclip_ratio
=
det_db_unclip_ratio
)
...
@@ -114,19 +120,27 @@ class AtomModelSingleton:
...
@@ -114,19 +120,27 @@ class AtomModelSingleton:
return
cls
.
_instance
return
cls
.
_instance
def
get_atom_model
(
self
,
atom_model_name
:
str
,
**
kwargs
):
def
get_atom_model
(
self
,
atom_model_name
:
str
,
**
kwargs
):
if
atom_model_name
not
in
self
.
_models
:
lang
=
kwargs
.
get
(
"lang"
,
None
)
self
.
_models
[
atom_model_name
]
=
atom_model_init
(
model_name
=
atom_model_name
,
**
kwargs
)
layout_model_name
=
kwargs
.
get
(
"layout_model_name"
,
None
)
return
self
.
_models
[
atom_model_name
]
key
=
(
atom_model_name
,
layout_model_name
,
lang
)
if
key
not
in
self
.
_models
:
self
.
_models
[
key
]
=
atom_model_init
(
model_name
=
atom_model_name
,
**
kwargs
)
return
self
.
_models
[
key
]
def
atom_model_init
(
model_name
:
str
,
**
kwargs
):
def
atom_model_init
(
model_name
:
str
,
**
kwargs
):
if
model_name
==
AtomicModel
.
Layout
:
if
model_name
==
AtomicModel
.
Layout
:
atom_model
=
layout_model_init
(
if
kwargs
.
get
(
"layout_model_name"
)
==
MODEL_NAME
.
LAYOUTLMv3
:
kwargs
.
get
(
"layout_weights"
),
atom_model
=
layout_model_init
(
kwargs
.
get
(
"layout_config_file"
),
kwargs
.
get
(
"layout_weights"
),
kwargs
.
get
(
"device"
)
kwargs
.
get
(
"layout_config_file"
),
)
kwargs
.
get
(
"device"
)
)
elif
kwargs
.
get
(
"layout_model_name"
)
==
MODEL_NAME
.
DocLayout_YOLO
:
atom_model
=
doclayout_yolo_model_init
(
kwargs
.
get
(
"doclayout_yolo_weights"
),
)
elif
model_name
==
AtomicModel
.
MFD
:
elif
model_name
==
AtomicModel
.
MFD
:
atom_model
=
mfd_model_init
(
atom_model
=
mfd_model_init
(
kwargs
.
get
(
"mfd_weights"
)
kwargs
.
get
(
"mfd_weights"
)
...
@@ -145,7 +159,7 @@ def atom_model_init(model_name: str, **kwargs):
...
@@ -145,7 +159,7 @@ def atom_model_init(model_name: str, **kwargs):
)
)
elif
model_name
==
AtomicModel
.
Table
:
elif
model_name
==
AtomicModel
.
Table
:
atom_model
=
table_model_init
(
atom_model
=
table_model_init
(
kwargs
.
get
(
"table_model_
typ
e"
),
kwargs
.
get
(
"table_model_
nam
e"
),
kwargs
.
get
(
"table_model_path"
),
kwargs
.
get
(
"table_model_path"
),
kwargs
.
get
(
"table_max_time"
),
kwargs
.
get
(
"table_max_time"
),
kwargs
.
get
(
"device"
)
kwargs
.
get
(
"device"
)
...
@@ -193,23 +207,35 @@ class CustomPEKModel:
...
@@ -193,23 +207,35 @@ class CustomPEKModel:
with
open
(
config_path
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
config_path
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
self
.
configs
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
self
.
configs
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
# 初始化解析配置
# 初始化解析配置
self
.
apply_layout
=
kwargs
.
get
(
"apply_layout"
,
self
.
configs
[
"config"
][
"layout"
])
self
.
apply_formula
=
kwargs
.
get
(
"apply_formula"
,
self
.
configs
[
"config"
][
"formula"
])
# layout config
self
.
layout_config
=
kwargs
.
get
(
"layout_config"
)
self
.
layout_model_name
=
self
.
layout_config
.
get
(
"model"
,
MODEL_NAME
.
DocLayout_YOLO
)
# formula config
self
.
formula_config
=
kwargs
.
get
(
"formula_config"
)
self
.
mfd_model_name
=
self
.
formula_config
.
get
(
"mfd_model"
,
MODEL_NAME
.
YOLO_V8_MFD
)
self
.
mfr_model_name
=
self
.
formula_config
.
get
(
"mfr_model"
,
MODEL_NAME
.
UniMerNet_v2_Small
)
self
.
apply_formula
=
self
.
formula_config
.
get
(
"enable"
,
True
)
# table config
# table config
self
.
table_config
=
kwargs
.
get
(
"table_config"
,
self
.
configs
[
"config"
][
"table_config"
]
)
self
.
table_config
=
kwargs
.
get
(
"table_config"
)
self
.
apply_table
=
self
.
table_config
.
get
(
"
is_table_recog_
enable"
,
False
)
self
.
apply_table
=
self
.
table_config
.
get
(
"enable"
,
False
)
self
.
table_max_time
=
self
.
table_config
.
get
(
"max_time"
,
TABLE_MAX_TIME_VALUE
)
self
.
table_max_time
=
self
.
table_config
.
get
(
"max_time"
,
TABLE_MAX_TIME_VALUE
)
self
.
table_model_type
=
self
.
table_config
.
get
(
"model"
,
TABLE_MASTER
)
self
.
table_model_name
=
self
.
table_config
.
get
(
"model"
,
MODEL_NAME
.
TABLE_MASTER
)
# ocr config
self
.
apply_ocr
=
ocr
self
.
apply_ocr
=
ocr
self
.
lang
=
kwargs
.
get
(
"lang"
,
None
)
self
.
lang
=
kwargs
.
get
(
"lang"
,
None
)
logger
.
info
(
logger
.
info
(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}"
.
format
(
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
self
.
apply_layout
,
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
,
self
.
lang
"apply_table: {}, table_model: {}, lang: {}"
.
format
(
self
.
layout_model_name
,
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
,
self
.
table_model_name
,
self
.
lang
)
)
)
)
assert
self
.
apply_layout
,
"DocAnalysis must contain layout model."
# 初始化解析方案
# 初始化解析方案
self
.
device
=
kwargs
.
get
(
"device"
,
self
.
configs
[
"config"
][
"device"
]
)
self
.
device
=
kwargs
.
get
(
"device"
,
"cpu"
)
logger
.
info
(
"using device: {}"
.
format
(
self
.
device
))
logger
.
info
(
"using device: {}"
.
format
(
self
.
device
))
models_dir
=
kwargs
.
get
(
"models_dir"
,
os
.
path
.
join
(
root_dir
,
"resources"
,
"models"
))
models_dir
=
kwargs
.
get
(
"models_dir"
,
os
.
path
.
join
(
root_dir
,
"resources"
,
"models"
))
logger
.
info
(
"using models_dir: {}"
.
format
(
models_dir
))
logger
.
info
(
"using models_dir: {}"
.
format
(
models_dir
))
...
@@ -218,17 +244,16 @@ class CustomPEKModel:
...
@@ -218,17 +244,16 @@ class CustomPEKModel:
# 初始化公式识别
# 初始化公式识别
if
self
.
apply_formula
:
if
self
.
apply_formula
:
# 初始化公式检测模型
# 初始化公式检测模型
# self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
self
.
mfd_model
=
atom_model_manager
.
get_atom_model
(
self
.
mfd_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFD
,
atom_model_name
=
AtomicModel
.
MFD
,
mfd_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
"mfd"
]))
mfd_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
self
.
mfd_model_name
]))
)
)
# 初始化公式解析模型
# 初始化公式解析模型
mfr_weight_dir
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
"mfr"
]))
mfr_weight_dir
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
self
.
mfr_model_name
]))
mfr_cfg_path
=
str
(
os
.
path
.
join
(
model_config_dir
,
"UniMERNet"
,
"demo.yaml"
))
mfr_cfg_path
=
str
(
os
.
path
.
join
(
model_config_dir
,
"UniMERNet"
,
"demo.yaml"
))
# self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
# self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
self
.
mfr_model
,
self
.
mfr_transform
=
atom_model_manager
.
get_atom_model
(
self
.
mfr_model
,
self
.
mfr_transform
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFR
,
atom_model_name
=
AtomicModel
.
MFR
,
mfr_weight_dir
=
mfr_weight_dir
,
mfr_weight_dir
=
mfr_weight_dir
,
...
@@ -237,17 +262,20 @@ class CustomPEKModel:
...
@@ -237,17 +262,20 @@ class CustomPEKModel:
)
)
# 初始化layout模型
# 初始化layout模型
# self.layout_model = Layoutlmv3_Predictor(
if
self
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# str(os.path.join(models_dir, self.configs['weights']['layout'])),
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
# str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
atom_model_name
=
AtomicModel
.
Layout
,
# device=self.device
layout_model_name
=
MODEL_NAME
.
LAYOUTLMv3
,
# )
layout_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
])),
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
layout_config_file
=
str
(
os
.
path
.
join
(
model_config_dir
,
"layoutlmv3"
,
"layoutlmv3_base_inference.yaml"
)),
atom_model_name
=
AtomicModel
.
Layout
,
device
=
self
.
device
layout_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
'layout'
])),
)
layout_config_file
=
str
(
os
.
path
.
join
(
model_config_dir
,
"layoutlmv3"
,
"layoutlmv3_base_inference.yaml"
)),
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
device
=
self
.
device
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
)
atom_model_name
=
AtomicModel
.
Layout
,
layout_model_name
=
MODEL_NAME
.
DocLayout_YOLO
,
doclayout_yolo_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
]))
)
# 初始化ocr
# 初始化ocr
if
self
.
apply_ocr
:
if
self
.
apply_ocr
:
...
@@ -260,12 +288,10 @@ class CustomPEKModel:
...
@@ -260,12 +288,10 @@ class CustomPEKModel:
)
)
# init table model
# init table model
if
self
.
apply_table
:
if
self
.
apply_table
:
table_model_dir
=
self
.
configs
[
"weights"
][
self
.
table_model_type
]
table_model_dir
=
self
.
configs
[
"weights"
][
self
.
table_model_name
]
# self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
# max_time=self.table_max_time, _device_=self.device)
self
.
table_model
=
atom_model_manager
.
get_atom_model
(
self
.
table_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Table
,
atom_model_name
=
AtomicModel
.
Table
,
table_model_
type
=
self
.
table_model_typ
e
,
table_model_
name
=
self
.
table_model_nam
e
,
table_model_path
=
str
(
os
.
path
.
join
(
models_dir
,
table_model_dir
)),
table_model_path
=
str
(
os
.
path
.
join
(
models_dir
,
table_model_dir
)),
table_max_time
=
self
.
table_max_time
,
table_max_time
=
self
.
table_max_time
,
device
=
self
.
device
device
=
self
.
device
...
@@ -282,7 +308,21 @@ class CustomPEKModel:
...
@@ -282,7 +308,21 @@ class CustomPEKModel:
# layout检测
# layout检测
layout_start
=
time
.
time
()
layout_start
=
time
.
time
()
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
if
self
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
layout_res
=
[]
doclayout_yolo_res
=
self
.
layout_model
.
predict
(
image
,
imgsz
=
1024
,
conf
=
0.15
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
for
xyxy
,
conf
,
cla
in
zip
(
doclayout_yolo_res
.
boxes
.
xyxy
.
cpu
(),
doclayout_yolo_res
.
boxes
.
conf
.
cpu
(),
doclayout_yolo_res
.
boxes
.
cls
.
cpu
()):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
'category_id'
:
int
(
cla
.
item
()),
'poly'
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
'score'
:
round
(
float
(
conf
.
item
()),
3
),
}
layout_res
.
append
(
new_item
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
logger
.
info
(
f
"layout detection time: {layout_cost}"
)
logger
.
info
(
f
"layout detection time: {layout_cost}"
)
...
@@ -291,7 +331,7 @@ class CustomPEKModel:
...
@@ -291,7 +331,7 @@ class CustomPEKModel:
if
self
.
apply_formula
:
if
self
.
apply_formula
:
# 公式检测
# 公式检测
mfd_start
=
time
.
time
()
mfd_start
=
time
.
time
()
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
)[
0
]
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
logger
.
info
(
f
"mfd time: {round(time.time() - mfd_start, 2)}"
)
logger
.
info
(
f
"mfd time: {round(time.time() - mfd_start, 2)}"
)
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
.
cpu
(),
mfd_res
.
boxes
.
conf
.
cpu
(),
mfd_res
.
boxes
.
cls
.
cpu
()):
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
.
cpu
(),
mfd_res
.
boxes
.
conf
.
cpu
(),
mfd_res
.
boxes
.
cls
.
cpu
()):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
...
@@ -303,7 +343,6 @@ class CustomPEKModel:
...
@@ -303,7 +343,6 @@ class CustomPEKModel:
}
}
layout_res
.
append
(
new_item
)
layout_res
.
append
(
new_item
)
latex_filling_list
.
append
(
new_item
)
latex_filling_list
.
append
(
new_item
)
# bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
mf_image_list
.
append
(
bbox_img
)
mf_image_list
.
append
(
bbox_img
)
...
@@ -405,7 +444,7 @@ class CustomPEKModel:
...
@@ -405,7 +444,7 @@ class CustomPEKModel:
# logger.info("------------------table recognition processing begins-----------------")
# logger.info("------------------table recognition processing begins-----------------")
latex_code
=
None
latex_code
=
None
html_code
=
None
html_code
=
None
if
self
.
table_model_
type
==
STRUCT_EQTABLE
:
if
self
.
table_model_
name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
latex_code
=
self
.
table_model
.
image2latex
(
new_image
)[
0
]
latex_code
=
self
.
table_model
.
image2latex
(
new_image
)[
0
]
else
:
else
:
...
...
magic_pdf/model/ppTableModel.py
View file @
c1ba9dcb
...
@@ -52,11 +52,11 @@ class ppTableModel(object):
...
@@ -52,11 +52,11 @@ class ppTableModel(object):
rec_model_dir
=
os
.
path
.
join
(
model_dir
,
REC_MODEL_DIR
)
rec_model_dir
=
os
.
path
.
join
(
model_dir
,
REC_MODEL_DIR
)
rec_char_dict_path
=
os
.
path
.
join
(
model_dir
,
REC_CHAR_DICT
)
rec_char_dict_path
=
os
.
path
.
join
(
model_dir
,
REC_CHAR_DICT
)
device
=
kwargs
.
get
(
"device"
,
"cpu"
)
device
=
kwargs
.
get
(
"device"
,
"cpu"
)
use_gpu
=
True
if
device
==
"cuda"
else
False
use_gpu
=
True
if
device
.
startswith
(
"cuda"
)
else
False
config
=
{
config
=
{
"use_gpu"
:
use_gpu
,
"use_gpu"
:
use_gpu
,
"table_max_len"
:
kwargs
.
get
(
"table_max_len"
,
TABLE_MAX_LEN
),
"table_max_len"
:
kwargs
.
get
(
"table_max_len"
,
TABLE_MAX_LEN
),
"table_algorithm"
:
TABLE_MASTER
,
"table_algorithm"
:
"TableMaster"
,
"table_model_dir"
:
table_model_dir
,
"table_model_dir"
:
table_model_dir
,
"table_char_dict_path"
:
table_char_dict_path
,
"table_char_dict_path"
:
table_char_dict_path
,
"det_model_dir"
:
det_model_dir
,
"det_model_dir"
:
det_model_dir
,
...
...
magic_pdf/pipe/AbsPipe.py
View file @
c1ba9dcb
...
@@ -17,7 +17,7 @@ class AbsPipe(ABC):
...
@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT
=
"txt"
PIP_TXT
=
"txt"
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
self
.
pdf_bytes
=
pdf_bytes
self
.
pdf_bytes
=
pdf_bytes
self
.
model_list
=
model_list
self
.
model_list
=
model_list
self
.
image_writer
=
image_writer
self
.
image_writer
=
image_writer
...
@@ -26,6 +26,9 @@ class AbsPipe(ABC):
...
@@ -26,6 +26,9 @@ class AbsPipe(ABC):
self
.
start_page_id
=
start_page_id
self
.
start_page_id
=
start_page_id
self
.
end_page_id
=
end_page_id
self
.
end_page_id
=
end_page_id
self
.
lang
=
lang
self
.
lang
=
lang
self
.
layout_model
=
layout_model
self
.
formula_enable
=
formula_enable
self
.
table_enable
=
table_enable
def
get_compress_pdf_mid_data
(
self
):
def
get_compress_pdf_mid_data
(
self
):
return
JsonCompressor
.
compress_json
(
self
.
pdf_mid_data
)
return
JsonCompressor
.
compress_json
(
self
.
pdf_mid_data
)
...
...
magic_pdf/pipe/OCRPipe.py
View file @
c1ba9dcb
...
@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf
...
@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf
class
OCRPipe
(
AbsPipe
):
class
OCRPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
super
()
.
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
)
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
super
()
.
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
def
pipe_classify
(
self
):
def
pipe_classify
(
self
):
pass
pass
...
@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe):
...
@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe):
def
pipe_analyze
(
self
):
def
pipe_analyze
(
self
):
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
True
,
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
True
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
def
pipe_parse
(
self
):
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
result
=
super
()
.
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
result
=
super
()
.
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
...
...
magic_pdf/pipe/TXTPipe.py
View file @
c1ba9dcb
...
@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf
...
@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf
class
TXTPipe
(
AbsPipe
):
class
TXTPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
super
()
.
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
)
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
super
()
.
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
def
pipe_classify
(
self
):
def
pipe_classify
(
self
):
pass
pass
...
@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe):
...
@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe):
def
pipe_analyze
(
self
):
def
pipe_analyze
(
self
):
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
False
,
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
False
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
def
pipe_parse
(
self
):
self
.
pdf_mid_data
=
parse_txt_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
self
.
pdf_mid_data
=
parse_txt_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
result
=
super
()
.
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
result
=
super
()
.
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
...
...
magic_pdf/pipe/UNIPipe.py
View file @
c1ba9dcb
...
@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
...
@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class
UNIPipe
(
AbsPipe
):
class
UNIPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
jso_useful_key
:
dict
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
def
__init__
(
self
,
pdf_bytes
:
bytes
,
jso_useful_key
:
dict
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
self
.
pdf_type
=
jso_useful_key
[
"_pdf_type"
]
self
.
pdf_type
=
jso_useful_key
[
"_pdf_type"
]
super
()
.
__init__
(
pdf_bytes
,
jso_useful_key
[
"model_list"
],
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
)
super
()
.
__init__
(
pdf_bytes
,
jso_useful_key
[
"model_list"
],
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
if
len
(
self
.
model_list
)
==
0
:
if
len
(
self
.
model_list
)
==
0
:
self
.
input_model_is_empty
=
True
self
.
input_model_is_empty
=
True
else
:
else
:
...
@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe):
...
@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe):
if
self
.
pdf_type
==
self
.
PIP_TXT
:
if
self
.
pdf_type
==
self
.
PIP_TXT
:
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
False
,
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
False
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
True
,
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
True
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
def
pipe_parse
(
self
):
if
self
.
pdf_type
==
self
.
PIP_TXT
:
if
self
.
pdf_type
==
self
.
PIP_TXT
:
self
.
pdf_mid_data
=
parse_union_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
self
.
pdf_mid_data
=
parse_union_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
input_model_is_empty
=
self
.
input_model_is_empty
,
is_debug
=
self
.
is_debug
,
input_model_is_empty
=
self
.
input_model_is_empty
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
is_debug
=
self
.
is_debug
,
...
...
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
View file @
c1ba9dcb
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.libs.boxbase
import
get_minbox_if_overlap_by_ratio
,
calculate_overlap_area_in_bbox1_area_ratio
,
\
from
magic_pdf.libs.boxbase
import
get_minbox_if_overlap_by_ratio
,
calculate_overlap_area_in_bbox1_area_ratio
,
\
calculate_iou
calculate_iou
,
calculate_vertical_projection_overlap_ratio
from
magic_pdf.libs.drop_tag
import
DropTag
from
magic_pdf.libs.drop_tag
import
DropTag
from
magic_pdf.libs.ocr_content_type
import
BlockType
from
magic_pdf.libs.ocr_content_type
import
BlockType
from
magic_pdf.pre_proc.remove_bbox_overlap
import
remove_overlap_between_bbox_for_block
from
magic_pdf.pre_proc.remove_bbox_overlap
import
remove_overlap_between_bbox_for_block
...
@@ -97,12 +97,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
...
@@ -97,12 +97,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
# 通过后续大框套小框逻辑删除
# 通过后续大框套小框逻辑删除
'''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50
%
区域的(限定footnote)'''
'''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50
%
区域的(限定footnote)'''
footnote_blocks
=
[]
for
discarded
in
discarded_blocks
:
for
discarded
in
discarded_blocks
:
x0
,
y0
,
x1
,
y1
=
discarded
[
'bbox'
]
x0
,
y0
,
x1
,
y1
=
discarded
[
'bbox'
]
all_discarded_blocks
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Discarded
,
None
,
None
,
None
,
None
,
discarded
[
"score"
]])
all_discarded_blocks
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Discarded
,
None
,
None
,
None
,
None
,
discarded
[
"score"
]])
# 将footnote加入到all_bboxes中,用来计算layout
# 将footnote加入到all_bboxes中,用来计算layout
# if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
if
(
x1
-
x0
)
>
(
page_w
/
3
)
and
(
y1
-
y0
)
>
10
and
y0
>
(
page_h
/
2
):
# all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]])
footnote_blocks
.
append
([
x0
,
y0
,
x1
,
y1
])
'''移除在footnote下面的任何框'''
need_remove_blocks
=
find_blocks_under_footnote
(
all_bboxes
,
footnote_blocks
)
if
len
(
need_remove_blocks
)
>
0
:
for
block
in
need_remove_blocks
:
all_bboxes
.
remove
(
block
)
all_discarded_blocks
.
append
(
block
)
'''经过以上处理后,还存在大框套小框的情况,则删除小框'''
'''经过以上处理后,还存在大框套小框的情况,则删除小框'''
all_bboxes
=
remove_overlaps_min_blocks
(
all_bboxes
)
all_bboxes
=
remove_overlaps_min_blocks
(
all_bboxes
)
...
@@ -113,6 +121,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
...
@@ -113,6 +121,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
return
all_bboxes
,
all_discarded_blocks
return
all_bboxes
,
all_discarded_blocks
def
find_blocks_under_footnote
(
all_bboxes
,
footnote_blocks
):
need_remove_blocks
=
[]
for
block
in
all_bboxes
:
block_x0
,
block_y0
,
block_x1
,
block_y1
=
block
[:
4
]
for
footnote_bbox
in
footnote_blocks
:
footnote_x0
,
footnote_y0
,
footnote_x1
,
footnote_y1
=
footnote_bbox
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
if
block_y0
>=
footnote_y1
and
calculate_vertical_projection_overlap_ratio
((
block_x0
,
block_y0
,
block_x1
,
block_y1
),
footnote_bbox
)
>=
0.8
:
if
block
not
in
need_remove_blocks
:
need_remove_blocks
.
append
(
block
)
break
return
need_remove_blocks
def
fix_interline_equation_overlap_text_blocks_with_hi_iou
(
all_bboxes
):
def
fix_interline_equation_overlap_text_blocks_with_hi_iou
(
all_bboxes
):
# 先提取所有text和interline block
# 先提取所有text和interline block
text_blocks
=
[]
text_blocks
=
[]
...
...
magic_pdf/resources/model_config/model_configs.yaml
View file @
c1ba9dcb
config
:
device
:
cpu
layout
:
True
formula
:
True
table_config
:
model
:
TableMaster
is_table_recog_enable
:
False
max_time
:
400
weights
:
weights
:
layout
:
Layout/model_final.pth
layoutlmv3
:
Layout/LayoutLMv3/model_final.pth
mfd
:
MFD/weights.pt
doclayout_yolo
:
Layout/YOLO/doclayout_yolo_ft.pt
mfr
:
MFR/unimernet_small
yolo_v8_mfd
:
MFD/YOLO/yolo_v8_ft.pt
unimernet_small
:
MFR/unimernet_small
struct_eqtable
:
TabRec/StructEqTable
struct_eqtable
:
TabRec/StructEqTable
TableMaster
:
TabRec/TableMaster
tablemaster
:
TabRec/TableMaster
\ No newline at end of file
\ No newline at end of file
magic_pdf/tools/common.py
View file @
c1ba9dcb
...
@@ -46,10 +46,12 @@ def do_parse(
...
@@ -46,10 +46,12 @@ def do_parse(
start_page_id
=
0
,
start_page_id
=
0
,
end_page_id
=
None
,
end_page_id
=
None
,
lang
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
):
if
debug_able
:
if
debug_able
:
logger
.
warning
(
'debug mode is on'
)
logger
.
warning
(
'debug mode is on'
)
# f_dump_content_list = True
f_draw_model_bbox
=
True
f_draw_model_bbox
=
True
f_draw_line_sort_bbox
=
True
f_draw_line_sort_bbox
=
True
...
@@ -64,13 +66,16 @@ def do_parse(
...
@@ -64,13 +66,16 @@ def do_parse(
if
parse_method
==
'auto'
:
if
parse_method
==
'auto'
:
jso_useful_key
=
{
'_pdf_type'
:
''
,
'model_list'
:
model_list
}
jso_useful_key
=
{
'_pdf_type'
:
''
,
'model_list'
:
model_list
}
pipe
=
UNIPipe
(
pdf_bytes
,
jso_useful_key
,
image_writer
,
is_debug
=
True
,
pipe
=
UNIPipe
(
pdf_bytes
,
jso_useful_key
,
image_writer
,
is_debug
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
)
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
elif
parse_method
==
'txt'
:
elif
parse_method
==
'txt'
:
pipe
=
TXTPipe
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
=
True
,
pipe
=
TXTPipe
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
)
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
elif
parse_method
==
'ocr'
:
elif
parse_method
==
'ocr'
:
pipe
=
OCRPipe
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
=
True
,
pipe
=
OCRPipe
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
)
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
else
:
else
:
logger
.
error
(
'unknown parse method'
)
logger
.
error
(
'unknown parse method'
)
exit
(
1
)
exit
(
1
)
...
...
magic_pdf/user_api.py
View file @
c1ba9dcb
...
@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
...
@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
if
pdf_info_dict
is
None
or
pdf_info_dict
.
get
(
"_need_drop"
,
False
):
if
pdf_info_dict
is
None
or
pdf_info_dict
.
get
(
"_need_drop"
,
False
):
logger
.
warning
(
f
"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr"
)
logger
.
warning
(
f
"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr"
)
if
input_model_is_empty
:
if
input_model_is_empty
:
pdf_models
=
doc_analyze
(
pdf_bytes
,
layout_model
=
kwargs
.
get
(
"layout_model"
,
None
)
ocr
=
True
,
formula_enable
=
kwargs
.
get
(
"formula_enable"
,
None
)
start_page_id
=
start_page_id
,
table_enable
=
kwargs
.
get
(
"table_enable"
,
None
)
end_page_id
=
end_page_id
,
pdf_models
=
doc_analyze
(
lang
=
lang
)
pdf_bytes
,
ocr
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
,
)
pdf_info_dict
=
parse_pdf
(
parse_pdf_by_ocr
)
pdf_info_dict
=
parse_pdf
(
parse_pdf_by_ocr
)
if
pdf_info_dict
is
None
:
if
pdf_info_dict
is
None
:
raise
Exception
(
"Both parse_pdf_by_txt and parse_pdf_by_ocr failed."
)
raise
Exception
(
"Both parse_pdf_by_txt and parse_pdf_by_ocr failed."
)
...
...
old_docs/download_models.py
View file @
c1ba9dcb
...
@@ -5,16 +5,21 @@ import requests
...
@@ -5,16 +5,21 @@ import requests
from
modelscope
import
snapshot_download
from
modelscope
import
snapshot_download
def
download_json
(
url
):
# 下载JSON文件
response
=
requests
.
get
(
url
)
response
.
raise_for_status
()
# 检查请求是否成功
return
response
.
json
()
def
download_and_modify_json
(
url
,
local_filename
,
modifications
):
def
download_and_modify_json
(
url
,
local_filename
,
modifications
):
if
os
.
path
.
exists
(
local_filename
):
if
os
.
path
.
exists
(
local_filename
):
data
=
json
.
load
(
open
(
local_filename
))
data
=
json
.
load
(
open
(
local_filename
))
config_version
=
data
.
get
(
'config_version'
,
'0.0.0'
)
if
config_version
<
'1.0.0'
:
data
=
download_json
(
url
)
else
:
else
:
# 下载JSON文件
data
=
download_json
(
url
)
response
=
requests
.
get
(
url
)
response
.
raise_for_status
()
# 检查请求是否成功
# 解析JSON内容
data
=
response
.
json
()
# 修改内容
# 修改内容
for
key
,
value
in
modifications
.
items
():
for
key
,
value
in
modifications
.
items
():
...
@@ -26,13 +31,21 @@ def download_and_modify_json(url, local_filename, modifications):
...
@@ -26,13 +31,21 @@ def download_and_modify_json(url, local_filename, modifications):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
model_dir
=
snapshot_download
(
'opendatalab/PDF-Extract-Kit'
)
mineru_patterns
=
[
"models/Layout/LayoutLMv3/*"
,
"models/Layout/YOLO/*"
,
"models/MFD/YOLO/*"
,
"models/MFR/unimernet_small/*"
,
"models/TabRec/TableMaster/*"
,
"models/TabRec/StructEqTable/*"
,
]
model_dir
=
snapshot_download
(
'opendatalab/PDF-Extract-Kit-1.0'
,
allow_patterns
=
mineru_patterns
)
layoutreader_model_dir
=
snapshot_download
(
'ppaanngggg/layoutreader'
)
layoutreader_model_dir
=
snapshot_download
(
'ppaanngggg/layoutreader'
)
model_dir
=
model_dir
+
'/models'
model_dir
=
model_dir
+
'/models'
print
(
f
'model_dir is: {model_dir}'
)
print
(
f
'model_dir is: {model_dir}'
)
print
(
f
'layoutreader_model_dir is: {layoutreader_model_dir}'
)
print
(
f
'layoutreader_model_dir is: {layoutreader_model_dir}'
)
json_url
=
'https://gitee.com/myhloli/MinerU/raw/
master
/magic-pdf.template.json'
json_url
=
'https://gitee.com/myhloli/MinerU/raw/
dev
/magic-pdf.template.json'
config_file_name
=
'magic-pdf.json'
config_file_name
=
'magic-pdf.json'
home_dir
=
os
.
path
.
expanduser
(
'~'
)
home_dir
=
os
.
path
.
expanduser
(
'~'
)
config_file
=
os
.
path
.
join
(
home_dir
,
config_file_name
)
config_file
=
os
.
path
.
join
(
home_dir
,
config_file_name
)
...
...
old_docs/download_models_hf.py
View file @
c1ba9dcb
...
@@ -5,16 +5,21 @@ import requests
...
@@ -5,16 +5,21 @@ import requests
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
def
download_json
(
url
):
# 下载JSON文件
response
=
requests
.
get
(
url
)
response
.
raise_for_status
()
# 检查请求是否成功
return
response
.
json
()
def
download_and_modify_json
(
url
,
local_filename
,
modifications
):
def
download_and_modify_json
(
url
,
local_filename
,
modifications
):
if
os
.
path
.
exists
(
local_filename
):
if
os
.
path
.
exists
(
local_filename
):
data
=
json
.
load
(
open
(
local_filename
))
data
=
json
.
load
(
open
(
local_filename
))
config_version
=
data
.
get
(
'config_version'
,
'0.0.0'
)
if
config_version
<
'1.0.0'
:
data
=
download_json
(
url
)
else
:
else
:
# 下载JSON文件
data
=
download_json
(
url
)
response
=
requests
.
get
(
url
)
response
.
raise_for_status
()
# 检查请求是否成功
# 解析JSON内容
data
=
response
.
json
()
# 修改内容
# 修改内容
for
key
,
value
in
modifications
.
items
():
for
key
,
value
in
modifications
.
items
():
...
@@ -26,13 +31,28 @@ def download_and_modify_json(url, local_filename, modifications):
...
@@ -26,13 +31,28 @@ def download_and_modify_json(url, local_filename, modifications):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
model_dir
=
snapshot_download
(
'opendatalab/PDF-Extract-Kit'
)
layoutreader_model_dir
=
snapshot_download
(
'hantian/layoutreader'
)
mineru_patterns
=
[
"models/Layout/LayoutLMv3/*"
,
"models/Layout/YOLO/*"
,
"models/MFD/YOLO/*"
,
"models/MFR/unimernet_small/*"
,
"models/TabRec/TableMaster/*"
,
"models/TabRec/StructEqTable/*"
,
]
model_dir
=
snapshot_download
(
'opendatalab/PDF-Extract-Kit-1.0'
,
allow_patterns
=
mineru_patterns
)
layoutreader_pattern
=
[
"*.json"
,
"*.safetensors"
,
]
layoutreader_model_dir
=
snapshot_download
(
'hantian/layoutreader'
,
allow_patterns
=
layoutreader_pattern
)
model_dir
=
model_dir
+
'/models'
model_dir
=
model_dir
+
'/models'
print
(
f
'model_dir is: {model_dir}'
)
print
(
f
'model_dir is: {model_dir}'
)
print
(
f
'layoutreader_model_dir is: {layoutreader_model_dir}'
)
print
(
f
'layoutreader_model_dir is: {layoutreader_model_dir}'
)
json_url
=
'https://github.com/opendatalab/MinerU/raw/
master
/magic-pdf.template.json'
json_url
=
'https://github.com/opendatalab/MinerU/raw/
dev
/magic-pdf.template.json'
config_file_name
=
'magic-pdf.json'
config_file_name
=
'magic-pdf.json'
home_dir
=
os
.
path
.
expanduser
(
'~'
)
home_dir
=
os
.
path
.
expanduser
(
'~'
)
config_file
=
os
.
path
.
join
(
home_dir
,
config_file_name
)
config_file
=
os
.
path
.
join
(
home_dir
,
config_file_name
)
...
...
projects/gradio_app/app.py
View file @
c1ba9dcb
...
@@ -23,7 +23,7 @@ def read_fn(path):
...
@@ -23,7 +23,7 @@ def read_fn(path):
return
disk_rw
.
read
(
os
.
path
.
basename
(
path
),
AbsReaderWriter
.
MODE_BIN
)
return
disk_rw
.
read
(
os
.
path
.
basename
(
path
),
AbsReaderWriter
.
MODE_BIN
)
def
parse_pdf
(
doc_path
,
output_dir
,
end_page_id
,
is_ocr
):
def
parse_pdf
(
doc_path
,
output_dir
,
end_page_id
,
is_ocr
,
layout_mode
,
formula_enable
,
table_enable
,
language
):
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
try
:
try
:
...
@@ -42,6 +42,10 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
...
@@ -42,6 +42,10 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
parse_method
,
parse_method
,
False
,
False
,
end_page_id
=
end_page_id
,
end_page_id
=
end_page_id
,
layout_model
=
layout_mode
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
,
lang
=
language
,
)
)
return
local_md_dir
,
file_name
return
local_md_dir
,
file_name
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -93,9 +97,10 @@ def replace_image_with_base64(markdown_text, image_dir_path):
...
@@ -93,9 +97,10 @@ def replace_image_with_base64(markdown_text, image_dir_path):
return
re
.
sub
(
pattern
,
replace
,
markdown_text
)
return
re
.
sub
(
pattern
,
replace
,
markdown_text
)
def
to_markdown
(
file_path
,
end_pages
,
is_ocr
):
def
to_markdown
(
file_path
,
end_pages
,
is_ocr
,
layout_mode
,
formula_enable
,
table_enable
,
language
):
# 获取识别的md文件以及压缩包文件路径
# 获取识别的md文件以及压缩包文件路径
local_md_dir
,
file_name
=
parse_pdf
(
file_path
,
'./output'
,
end_pages
-
1
,
is_ocr
)
local_md_dir
,
file_name
=
parse_pdf
(
file_path
,
'./output'
,
end_pages
-
1
,
is_ocr
,
layout_mode
,
formula_enable
,
table_enable
,
language
)
archive_zip_path
=
os
.
path
.
join
(
"./output"
,
compute_sha256
(
local_md_dir
)
+
".zip"
)
archive_zip_path
=
os
.
path
.
join
(
"./output"
,
compute_sha256
(
local_md_dir
)
+
".zip"
)
zip_archive_success
=
compress_directory_to_zip
(
local_md_dir
,
archive_zip_path
)
zip_archive_success
=
compress_directory_to_zip
(
local_md_dir
,
archive_zip_path
)
if
zip_archive_success
==
0
:
if
zip_archive_success
==
0
:
...
@@ -138,6 +143,27 @@ with open("header.html", "r") as file:
...
@@ -138,6 +143,27 @@ with open("header.html", "r") as file:
header
=
file
.
read
()
header
=
file
.
read
()
latin_lang
=
[
'af'
,
'az'
,
'bs'
,
'cs'
,
'cy'
,
'da'
,
'de'
,
'es'
,
'et'
,
'fr'
,
'ga'
,
'hr'
,
'hu'
,
'id'
,
'is'
,
'it'
,
'ku'
,
'la'
,
'lt'
,
'lv'
,
'mi'
,
'ms'
,
'mt'
,
'nl'
,
'no'
,
'oc'
,
'pi'
,
'pl'
,
'pt'
,
'ro'
,
'rs_latin'
,
'sk'
,
'sl'
,
'sq'
,
'sv'
,
'sw'
,
'tl'
,
'tr'
,
'uz'
,
'vi'
,
'french'
,
'german'
]
arabic_lang
=
[
'ar'
,
'fa'
,
'ug'
,
'ur'
]
cyrillic_lang
=
[
'ru'
,
'rs_cyrillic'
,
'be'
,
'bg'
,
'uk'
,
'mn'
,
'abq'
,
'ady'
,
'kbd'
,
'ava'
,
'dar'
,
'inh'
,
'che'
,
'lbe'
,
'lez'
,
'tab'
]
devanagari_lang
=
[
'hi'
,
'mr'
,
'ne'
,
'bh'
,
'mai'
,
'ang'
,
'bho'
,
'mah'
,
'sck'
,
'new'
,
'gom'
,
'sa'
,
'bgc'
]
other_lang
=
[
'ch'
,
'en'
,
'korean'
,
'japan'
,
'chinese_cht'
,
'ta'
,
'te'
,
'ka'
]
all_lang
=
[
""
]
all_lang
.
extend
([
*
other_lang
,
*
latin_lang
,
*
arabic_lang
,
*
cyrillic_lang
,
*
devanagari_lang
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
with
gr
.
Blocks
()
as
demo
:
with
gr
.
Blocks
()
as
demo
:
gr
.
HTML
(
header
)
gr
.
HTML
(
header
)
...
@@ -145,8 +171,14 @@ if __name__ == "__main__":
...
@@ -145,8 +171,14 @@ if __name__ == "__main__":
with
gr
.
Column
(
variant
=
'panel'
,
scale
=
5
):
with
gr
.
Column
(
variant
=
'panel'
,
scale
=
5
):
pdf_show
=
gr
.
Markdown
()
pdf_show
=
gr
.
Markdown
()
max_pages
=
gr
.
Slider
(
1
,
10
,
5
,
step
=
1
,
label
=
"Max convert pages"
)
max_pages
=
gr
.
Slider
(
1
,
10
,
5
,
step
=
1
,
label
=
"Max convert pages"
)
with
gr
.
Row
()
as
bu_flow
:
with
gr
.
Row
():
is_ocr
=
gr
.
Checkbox
(
label
=
"Force enable OCR"
)
layout_mode
=
gr
.
Dropdown
([
"layoutlmv3"
,
"doclayout_yolo"
],
label
=
"Layout model"
,
value
=
"layoutlmv3"
)
language
=
gr
.
Dropdown
(
all_lang
,
label
=
"Language"
,
value
=
""
)
with
gr
.
Row
():
formula_enable
=
gr
.
Checkbox
(
label
=
"Enable formula recognition"
,
value
=
True
)
is_ocr
=
gr
.
Checkbox
(
label
=
"Force enable OCR"
,
value
=
False
)
table_enable
=
gr
.
Checkbox
(
label
=
"Enable table recognition(test)"
,
value
=
False
)
with
gr
.
Row
():
change_bu
=
gr
.
Button
(
"Convert"
)
change_bu
=
gr
.
Button
(
"Convert"
)
clear_bu
=
gr
.
ClearButton
([
pdf_show
],
value
=
"Clear"
)
clear_bu
=
gr
.
ClearButton
([
pdf_show
],
value
=
"Clear"
)
pdf_show
=
PDF
(
label
=
"Please upload pdf"
,
interactive
=
True
,
height
=
800
)
pdf_show
=
PDF
(
label
=
"Please upload pdf"
,
interactive
=
True
,
height
=
800
)
...
@@ -166,7 +198,8 @@ if __name__ == "__main__":
...
@@ -166,7 +198,8 @@ if __name__ == "__main__":
latex_delimiters
=
latex_delimiters
,
line_breaks
=
True
)
latex_delimiters
=
latex_delimiters
,
line_breaks
=
True
)
with
gr
.
Tab
(
"Markdown text"
):
with
gr
.
Tab
(
"Markdown text"
):
md_text
=
gr
.
TextArea
(
lines
=
45
,
show_copy_button
=
True
)
md_text
=
gr
.
TextArea
(
lines
=
45
,
show_copy_button
=
True
)
change_bu
.
click
(
fn
=
to_markdown
,
inputs
=
[
pdf_show
,
max_pages
,
is_ocr
],
outputs
=
[
md
,
md_text
,
output_file
,
pdf_show
])
change_bu
.
click
(
fn
=
to_markdown
,
inputs
=
[
pdf_show
,
max_pages
,
is_ocr
,
layout_mode
,
formula_enable
,
table_enable
,
language
],
outputs
=
[
md
,
md_text
,
output_file
,
pdf_show
])
clear_bu
.
add
([
md
,
pdf_show
,
md_text
,
output_file
,
is_ocr
])
clear_bu
.
add
([
md
,
pdf_show
,
md_text
,
output_file
,
is_ocr
])
demo
.
launch
()
demo
.
launch
(
server_name
=
"0.0.0.0"
)
\ No newline at end of file
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment