Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
P
pdf-miner
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Qin Kaijie
pdf-miner
Commits
e7ce3051
Commit
e7ce3051
authored
Jul 22, 2024
by
myhloli
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix(magic_pdf): optimize formula area selection for OCR
parent
5f992de4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
207 additions
and
66 deletions
+207
-66
pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+63
-22
self_modify.py
magic_pdf/model/pek_sub_modules/self_modify.py
+144
-44
No files found.
magic_pdf/model/pdf_extract_kit.py
View file @
e7ce3051
...
...
@@ -168,33 +168,74 @@ class CustomPEKModel:
if
self
.
apply_ocr
:
ocr_start
=
time
.
time
()
pil_img
=
Image
.
fromarray
(
image
)
# 筛选出需要OCR的区域和公式区域
ocr_res_list
=
[]
single_page_mfdetrec_res
=
[]
for
res
in
layout_res
:
if
int
(
res
[
'category_id'
])
in
[
13
,
14
]:
xmin
,
ymin
=
int
(
res
[
'poly'
][
0
]),
int
(
res
[
'poly'
][
1
])
xmax
,
ymax
=
int
(
res
[
'poly'
][
4
]),
int
(
res
[
'poly'
][
5
])
single_page_mfdetrec_res
.
append
({
"bbox"
:
[
xmin
,
ymin
,
xmax
,
ymax
],
"bbox"
:
[
int
(
res
[
'poly'
][
0
]),
int
(
res
[
'poly'
][
1
]),
int
(
res
[
'poly'
][
4
]),
int
(
res
[
'poly'
][
5
])],
})
for
res
in
layout_res
:
if
int
(
res
[
'category_id'
])
in
[
0
,
1
,
2
,
4
,
6
,
7
]:
# 需要进行ocr的类别
xmin
,
ymin
=
int
(
res
[
'poly'
][
0
]),
int
(
res
[
'poly'
][
1
])
xmax
,
ymax
=
int
(
res
[
'poly'
][
4
]),
int
(
res
[
'poly'
][
5
])
crop_box
=
(
xmin
,
ymin
,
xmax
,
ymax
)
cropped_img
=
Image
.
new
(
'RGB'
,
pil_img
.
size
,
'white'
)
cropped_img
.
paste
(
pil_img
.
crop
(
crop_box
),
crop_box
)
cropped_img
=
cv2
.
cvtColor
(
np
.
asarray
(
cropped_img
),
cv2
.
COLOR_RGB2BGR
)
ocr_res
=
self
.
ocr_model
.
ocr
(
cropped_img
,
mfd_res
=
single_page_mfdetrec_res
)[
0
]
if
ocr_res
:
for
box_ocr_res
in
ocr_res
:
p1
,
p2
,
p3
,
p4
=
box_ocr_res
[
0
]
text
,
score
=
box_ocr_res
[
1
]
layout_res
.
append
({
'category_id'
:
15
,
'poly'
:
p1
+
p2
+
p3
+
p4
,
'score'
:
round
(
score
,
2
),
'text'
:
text
,
})
elif
int
(
res
[
'category_id'
])
in
[
0
,
1
,
2
,
4
,
6
,
7
]:
ocr_res_list
.
append
(
res
)
# 对每一个需OCR处理的区域进行处理
for
res
in
ocr_res_list
:
xmin
,
ymin
=
int
(
res
[
'poly'
][
0
]),
int
(
res
[
'poly'
][
1
])
xmax
,
ymax
=
int
(
res
[
'poly'
][
4
]),
int
(
res
[
'poly'
][
5
])
paste_x
=
50
paste_y
=
50
# 创建一个宽高各多50的白色背景
new_width
=
xmax
-
xmin
+
paste_x
*
2
new_height
=
ymax
-
ymin
+
paste_y
*
2
new_image
=
Image
.
new
(
'RGB'
,
(
new_width
,
new_height
),
'white'
)
# 裁剪图像
crop_box
=
(
xmin
,
ymin
,
xmax
,
ymax
)
cropped_img
=
pil_img
.
crop
(
crop_box
)
new_image
.
paste
(
cropped_img
,
(
paste_x
,
paste_y
))
# 调整公式区域坐标
adjusted_mfdetrec_res
=
[]
for
mf_res
in
single_page_mfdetrec_res
:
mf_xmin
,
mf_ymin
,
mf_xmax
,
mf_ymax
=
mf_res
[
"bbox"
]
# 将公式区域坐标调整为相对于裁剪区域的坐标
x0
=
mf_xmin
-
xmin
+
paste_x
y0
=
mf_ymin
-
ymin
+
paste_y
x1
=
mf_xmax
-
xmin
+
paste_x
y1
=
mf_ymax
-
ymin
+
paste_y
if
any
([
x0
<
0
,
y0
<
0
,
x1
<
0
,
y1
<
0
])
or
any
([
x0
>
new_width
,
y0
>
new_height
,
x1
>
new_width
,
y1
>
new_height
]):
continue
else
:
adjusted_mfdetrec_res
.
append
({
"bbox"
:
[
x0
,
y0
,
x1
,
y1
],
})
# OCR识别
ocr_res
=
self
.
ocr_model
.
ocr
(
np
.
array
(
new_image
),
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
# 整合结果
if
ocr_res
:
for
box_ocr_res
in
ocr_res
:
p1
,
p2
,
p3
,
p4
=
box_ocr_res
[
0
]
text
,
score
=
box_ocr_res
[
1
]
# 将坐标转换回原图坐标系
p1
=
[
p1
[
0
]
-
paste_x
+
xmin
,
p1
[
1
]
-
paste_y
+
ymin
]
p2
=
[
p2
[
0
]
-
paste_x
+
xmin
,
p2
[
1
]
-
paste_y
+
ymin
]
p3
=
[
p3
[
0
]
-
paste_x
+
xmin
,
p3
[
1
]
-
paste_y
+
ymin
]
p4
=
[
p4
[
0
]
-
paste_x
+
xmin
,
p4
[
1
]
-
paste_y
+
ymin
]
layout_res
.
append
({
'category_id'
:
15
,
'poly'
:
p1
+
p2
+
p3
+
p4
,
'score'
:
round
(
score
,
2
),
'text'
:
text
,
})
ocr_cost
=
round
(
time
.
time
()
-
ocr_start
,
2
)
logger
.
info
(
f
"ocr cost: {ocr_cost}"
)
...
...
magic_pdf/model/pek_sub_modules/self_modify.py
View file @
e7ce3051
...
...
@@ -10,12 +10,17 @@ from paddleocr import PaddleOCR
from
paddleocr.ppocr.utils.logging
import
get_logger
from
paddleocr.ppocr.utils.utility
import
check_and_read
,
alpha_to_color
,
binarize_img
from
paddleocr.tools.infer.utility
import
draw_ocr_box_txt
,
get_rotate_crop_image
,
get_minarea_rect_crop
from
magic_pdf.libs.boxbase
import
__is_overlaps_y_exceeds_threshold
logger
=
get_logger
()
def
img_decode
(
content
:
bytes
):
np_arr
=
np
.
frombuffer
(
content
,
dtype
=
np
.
uint8
)
return
cv2
.
imdecode
(
np_arr
,
cv2
.
IMREAD_UNCHANGED
)
def
check_img
(
img
):
if
isinstance
(
img
,
bytes
):
img
=
img_decode
(
img
)
...
...
@@ -51,6 +56,7 @@ def check_img(img):
return
img
def
sorted_boxes
(
dt_boxes
):
"""
Sort text boxes in order from top to bottom, left to right
...
...
@@ -75,49 +81,87 @@ def sorted_boxes(dt_boxes):
return
_boxes
def
formula_in_text
(
mf_bbox
,
text_bbox
):
x1
,
y1
,
x2
,
y2
=
mf_bbox
x3
,
y3
=
text_bbox
[
0
]
x4
,
y4
=
text_bbox
[
2
]
left_box
,
right_box
=
None
,
None
same_line
=
abs
((
y1
+
y2
)
/
2
-
(
y3
+
y4
)
/
2
)
/
abs
(
y4
-
y3
)
<
0.2
if
not
same_line
:
return
False
,
left_box
,
right_box
else
:
drop_origin
=
False
left_x
=
x1
-
1
right_x
=
x2
+
1
if
x3
<
x1
and
x2
<
x4
:
drop_origin
=
True
left_box
=
np
.
array
([
text_bbox
[
0
],
[
left_x
,
text_bbox
[
1
][
1
]],
[
left_x
,
text_bbox
[
2
][
1
]],
text_bbox
[
3
]])
.
astype
(
'float32'
)
right_box
=
np
.
array
([[
right_x
,
text_bbox
[
0
][
1
]],
text_bbox
[
1
],
text_bbox
[
2
],
[
right_x
,
text_bbox
[
3
][
1
]]])
.
astype
(
'float32'
)
if
x3
<
x1
and
x1
<=
x4
<=
x2
:
drop_origin
=
True
left_box
=
np
.
array
([
text_bbox
[
0
],
[
left_x
,
text_bbox
[
1
][
1
]],
[
left_x
,
text_bbox
[
2
][
1
]],
text_bbox
[
3
]])
.
astype
(
'float32'
)
if
x1
<=
x3
<=
x2
and
x2
<
x4
:
drop_origin
=
True
right_box
=
np
.
array
([[
right_x
,
text_bbox
[
0
][
1
]],
text_bbox
[
1
],
text_bbox
[
2
],
[
right_x
,
text_bbox
[
3
][
1
]]])
.
astype
(
'float32'
)
if
x1
<=
x3
<
x4
<=
x2
:
drop_origin
=
True
return
drop_origin
,
left_box
,
right_box
def
update_det_boxes
(
dt_boxes
,
mfdetrec_res
):
new_dt_boxes
=
dt_boxes
for
mf_box
in
mfdetrec_res
:
flag
,
left_box
,
right_box
=
False
,
None
,
None
for
idx
,
text_box
in
enumerate
(
new_dt_boxes
):
ret
,
left_box
,
right_box
=
formula_in_text
(
mf_box
[
'bbox'
],
text_box
)
if
ret
:
new_dt_boxes
.
pop
(
idx
)
if
left_box
is
not
None
:
new_dt_boxes
.
append
(
left_box
)
if
right_box
is
not
None
:
new_dt_boxes
.
append
(
right_box
)
break
def
bbox_to_points
(
bbox
):
""" 将bbox格式转换为四个顶点的数组 """
x0
,
y0
,
x1
,
y1
=
bbox
return
np
.
array
([[
x0
,
y0
],
[
x1
,
y0
],
[
x1
,
y1
],
[
x0
,
y1
]])
.
astype
(
'float32'
)
def
points_to_bbox
(
points
):
""" 将四个顶点的数组转换为bbox格式 """
x0
,
y0
=
points
[
0
]
x1
,
_
=
points
[
1
]
_
,
y1
=
points
[
2
]
return
[
x0
,
y0
,
x1
,
y1
]
def
merge_intervals
(
intervals
):
# Sort the intervals based on the start value
intervals
.
sort
(
key
=
lambda
x
:
x
[
0
])
merged
=
[]
for
interval
in
intervals
:
# If the list of merged intervals is empty or if the current
# interval does not overlap with the previous, simply append it.
if
not
merged
or
merged
[
-
1
][
1
]
<
interval
[
0
]:
merged
.
append
(
interval
)
else
:
# Otherwise, there is overlap, so we merge the current and previous intervals.
merged
[
-
1
][
1
]
=
max
(
merged
[
-
1
][
1
],
interval
[
1
])
return
merged
def
remove_intervals
(
original
,
masks
):
# Merge all mask intervals
merged_masks
=
merge_intervals
(
masks
)
result
=
[]
original_start
,
original_end
=
original
for
mask
in
merged_masks
:
mask_start
,
mask_end
=
mask
# If the mask starts after the original range, ignore it
if
mask_start
>
original_end
:
continue
# If the mask ends before the original range starts, ignore it
if
mask_end
<
original_start
:
continue
# Remove the masked part from the original range
if
original_start
<
mask_start
:
result
.
append
([
original_start
,
mask_start
-
1
])
original_start
=
max
(
mask_end
+
1
,
original_start
)
# Add the remaining part of the original range, if any
if
original_start
<=
original_end
:
result
.
append
([
original_start
,
original_end
])
return
result
def
update_det_boxes
(
dt_boxes
,
mfd_res
):
new_dt_boxes
=
[]
for
text_box
in
dt_boxes
:
text_bbox
=
points_to_bbox
(
text_box
)
masks_list
=
[]
for
mf_box
in
mfd_res
:
mf_bbox
=
mf_box
[
'bbox'
]
if
__is_overlaps_y_exceeds_threshold
(
text_bbox
,
mf_bbox
):
masks_list
.
append
([
mf_bbox
[
0
],
mf_bbox
[
2
]])
text_x_range
=
[
text_bbox
[
0
],
text_bbox
[
2
]]
text_remove_mask_range
=
remove_intervals
(
text_x_range
,
masks_list
)
temp_dt_box
=
[]
for
text_remove_mask
in
text_remove_mask_range
:
temp_dt_box
.
append
(
bbox_to_points
([
text_remove_mask
[
0
],
text_bbox
[
1
],
text_remove_mask
[
1
],
text_bbox
[
3
]]))
if
len
(
temp_dt_box
)
>
0
:
new_dt_boxes
.
extend
(
temp_dt_box
)
return
new_dt_boxes
class
ModifiedPaddleOCR
(
PaddleOCR
):
def
ocr
(
self
,
img
,
det
=
True
,
rec
=
True
,
cls
=
True
,
bin
=
False
,
inv
=
False
,
mfd_res
=
None
,
alpha_color
=
(
255
,
255
,
255
)):
"""
...
...
@@ -197,7 +241,7 @@ class ModifiedPaddleOCR(PaddleOCR):
if
not
rec
:
return
cls_res
return
ocr_res
def
__call__
(
self
,
img
,
cls
=
True
,
mfd_res
=
None
):
time_dict
=
{
'det'
:
0
,
'rec'
:
0
,
'cls'
:
0
,
'all'
:
0
}
...
...
@@ -226,7 +270,7 @@ class ModifiedPaddleOCR(PaddleOCR):
dt_boxes
=
update_det_boxes
(
dt_boxes
,
mfd_res
)
aft
=
time
.
time
()
logger
.
debug
(
"split text box by formula, new dt_boxes num : {}, elapsed : {}"
.
format
(
len
(
dt_boxes
),
aft
-
bef
))
len
(
dt_boxes
),
aft
-
bef
))
for
bno
in
range
(
len
(
dt_boxes
)):
tmp_box
=
copy
.
deepcopy
(
dt_boxes
[
bno
])
...
...
@@ -257,4 +301,60 @@ class ModifiedPaddleOCR(PaddleOCR):
filter_rec_res
.
append
(
rec_result
)
end
=
time
.
time
()
time_dict
[
'all'
]
=
end
-
start
return
filter_boxes
,
filter_rec_res
,
time_dict
\ No newline at end of file
return
filter_boxes
,
filter_rec_res
,
time_dict
if
__name__
==
'__main__'
:
def
merge_intervals
(
intervals
):
# Sort the intervals based on the start value
intervals
.
sort
(
key
=
lambda
x
:
x
[
0
])
merged
=
[]
for
interval
in
intervals
:
# If the list of merged intervals is empty or if the current
# interval does not overlap with the previous, simply append it.
if
not
merged
or
merged
[
-
1
][
1
]
<
interval
[
0
]:
merged
.
append
(
interval
)
else
:
# Otherwise, there is overlap, so we merge the current and previous intervals.
merged
[
-
1
][
1
]
=
max
(
merged
[
-
1
][
1
],
interval
[
1
])
return
merged
def
remove_intervals
(
original
,
masks
):
# Merge all mask intervals
merged_masks
=
merge_intervals
(
masks
)
result
=
[]
original_start
,
original_end
=
original
for
mask
in
merged_masks
:
mask_start
,
mask_end
=
mask
# If the mask starts after the original range, ignore it
if
mask_start
>
original_end
:
continue
# If the mask ends before the original range starts, ignore it
if
mask_end
<
original_start
:
continue
# Remove the masked part from the original range
if
original_start
<
mask_start
:
result
.
append
([
original_start
,
mask_start
-
1
])
original_start
=
max
(
mask_end
+
1
,
original_start
)
# Add the remaining part of the original range, if any
if
original_start
<=
original_end
:
result
.
append
([
original_start
,
original_end
])
return
result
# Test the function
original_range
=
[
1
,
100
]
masks
=
[[
0
,
15
],
[
25
,
40
],
[
55
,
80
]]
result
=
remove_intervals
(
original_range
,
masks
)
print
(
result
)
# Expected output: [[1, 4], [21, 59], [81, 100]]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment