Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
P
pdf-miner
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Qin Kaijie
pdf-miner
Commits
831db2e0
Commit
831db2e0
authored
Jul 09, 2024
by
myhloli
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update:Complete the parsing logic of PEK
parent
1fac6aa7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
104 additions
and
2 deletions
+104
-2
pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+104
-2
No files found.
magic_pdf/model/pdf_extract_kit.py
View file @
831db2e0
import
os
import
time
import
cv2
import
numpy
as
np
import
yaml
from
PIL
import
Image
from
ultralytics
import
YOLO
from
loguru
import
logger
from
magic_pdf.model.pek_sub_modules.layoutlmv3.model_init
import
Layoutlmv3_Predictor
...
...
@@ -9,7 +13,9 @@ import unimernet.tasks as tasks
from
unimernet.processors
import
load_processor
import
argparse
from
torchvision
import
transforms
from
torch.utils.data
import
Dataset
,
DataLoader
from
magic_pdf.model.pek_sub_modules.post_process
import
get_croped_image
,
latex_rm_whitespace
from
magic_pdf.model.pek_sub_modules.self_modify
import
ModifiedPaddleOCR
...
...
@@ -31,6 +37,25 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
return
model
,
vis_processor
class
MathDataset
(
Dataset
):
def
__init__
(
self
,
image_paths
,
transform
=
None
):
self
.
image_paths
=
image_paths
self
.
transform
=
transform
def
__len__
(
self
):
return
len
(
self
.
image_paths
)
def
__getitem__
(
self
,
idx
):
# if not pil image, then convert to pil image
if
isinstance
(
self
.
image_paths
[
idx
],
str
):
raw_image
=
Image
.
open
(
self
.
image_paths
[
idx
])
else
:
raw_image
=
self
.
image_paths
[
idx
]
if
self
.
transform
:
image
=
self
.
transform
(
raw_image
)
return
image
class
CustomPEKModel
:
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
**
kwargs
):
"""
...
...
@@ -82,6 +107,83 @@ class CustomPEKModel:
logger
.
info
(
'DocAnalysis init done!'
)
def
__call__
(
self
,
images
):
# layout检测 + 公式检测
doc_layout_result
=
[]
latex_filling_list
=
[]
mf_image_list
=
[]
for
idx
,
img_dict
in
enumerate
(
images
):
image
=
img_dict
[
"img"
]
img_height
,
img_width
=
img_dict
[
"height"
],
img_dict
[
"width"
]
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
# 公式检测
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
)[
0
]
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
.
cpu
(),
mfd_res
.
boxes
.
conf
.
cpu
(),
mfd_res
.
boxes
.
cls
.
cpu
()):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
'category_id'
:
13
+
int
(
cla
.
item
()),
'poly'
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
'score'
:
round
(
float
(
conf
.
item
()),
2
),
'latex'
:
''
,
}
layout_res
[
'layout_dets'
]
.
append
(
new_item
)
latex_filling_list
.
append
(
new_item
)
bbox_img
=
get_croped_image
(
Image
.
fromarray
(
image
),
[
xmin
,
ymin
,
xmax
,
ymax
])
mf_image_list
.
append
(
bbox_img
)
layout_res
[
'page_info'
]
=
dict
(
page_no
=
idx
,
height
=
img_height
,
width
=
img_width
)
doc_layout_result
.
append
(
layout_res
)
# 公式识别,因为识别速度较慢,为了提速,把单个pdf的所有公式裁剪完,一起批量做识别。
a
=
time
.
time
()
dataset
=
MathDataset
(
mf_image_list
,
transform
=
self
.
mfr_transform
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
128
,
num_workers
=
0
)
mfr_res
=
[]
for
imgs
in
dataloader
:
imgs
=
imgs
.
to
(
self
.
device
)
output
=
self
.
mfr_model
.
generate
({
'image'
:
imgs
})
mfr_res
.
extend
(
output
[
'pred_str'
])
for
res
,
latex
in
zip
(
latex_filling_list
,
mfr_res
):
res
[
'latex'
]
=
latex_rm_whitespace
(
latex
)
b
=
time
.
time
()
logger
.
info
(
f
"formula nums: {len(mf_image_list)}, mfr time: {round(b - a, 2)}"
)
if
self
.
apply_ocr
:
# ocr识别
for
idx
,
img_dict
in
enumerate
(
images
):
image
=
img_dict
[
"img"
]
pil_img
=
Image
.
fromarray
(
image
)
single_page_res
=
doc_layout_result
[
idx
][
'layout_dets'
]
single_page_mfdetrec_res
=
[]
for
res
in
single_page_res
:
if
int
(
res
[
'category_id'
])
in
[
13
,
14
]:
xmin
,
ymin
=
int
(
res
[
'poly'
][
0
]),
int
(
res
[
'poly'
][
1
])
xmax
,
ymax
=
int
(
res
[
'poly'
][
4
]),
int
(
res
[
'poly'
][
5
])
single_page_mfdetrec_res
.
append
({
"bbox"
:
[
xmin
,
ymin
,
xmax
,
ymax
],
})
for
res
in
single_page_res
:
if
int
(
res
[
'category_id'
])
in
[
0
,
1
,
2
,
4
,
6
,
7
]:
# 需要进行ocr的类别
xmin
,
ymin
=
int
(
res
[
'poly'
][
0
]),
int
(
res
[
'poly'
][
1
])
xmax
,
ymax
=
int
(
res
[
'poly'
][
4
]),
int
(
res
[
'poly'
][
5
])
crop_box
=
[
xmin
,
ymin
,
xmax
,
ymax
]
cropped_img
=
Image
.
new
(
'RGB'
,
pil_img
.
size
,
'white'
)
cropped_img
.
paste
(
pil_img
.
crop
(
crop_box
),
crop_box
)
cropped_img
=
cv2
.
cvtColor
(
np
.
asarray
(
cropped_img
),
cv2
.
COLOR_RGB2BGR
)
ocr_res
=
self
.
ocr_model
.
ocr
(
cropped_img
,
mfd_res
=
single_page_mfdetrec_res
)[
0
]
if
ocr_res
:
for
box_ocr_res
in
ocr_res
:
p1
,
p2
,
p3
,
p4
=
box_ocr_res
[
0
]
text
,
score
=
box_ocr_res
[
1
]
doc_layout_result
[
idx
][
'layout_dets'
]
.
append
({
'category_id'
:
15
,
'poly'
:
p1
+
p2
+
p3
+
p4
,
'score'
:
round
(
score
,
2
),
'text'
:
text
,
})
def
__call__
(
self
,
image
):
pass
return
doc_layout_result
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment