Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
P
pdf-miner
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Qin Kaijie
pdf-miner
Commits
c8b06ad5
Unverified
Commit
c8b06ad5
authored
Apr 10, 2024
by
myhloli
Committed by
GitHub
Apr 10, 2024
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into master
parents
88f5b932
2783bb39
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
448 additions
and
48 deletions
+448
-48
benchmark.yml
.github/workflows/benchmark.yml
+14
-4
update_base.yml
.github/workflows/update_base.yml
+8
-6
ocr_demo.py
demo/ocr_demo.py
+1
-0
AbsReaderWriter.py
magic_pdf/io/AbsReaderWriter.py
+24
-10
DiskReaderWriter.py
magic_pdf/io/DiskReaderWriter.py
+49
-0
S3ReaderWriter.py
magic_pdf/io/S3ReaderWriter.py
+66
-12
para_split.py
magic_pdf/para/para_split.py
+22
-3
base_data.json
tools/base_data.json
+87
-0
ocr_badcase.py
tools/ocr_badcase.py
+177
-13
No files found.
.github/workflows/benchmark.yml
View file @
c8b06ad5
...
@@ -9,7 +9,7 @@ on:
...
@@ -9,7 +9,7 @@ on:
paths-ignore
:
paths-ignore
:
-
"
cmds/**"
-
"
cmds/**"
-
"
**.md"
-
"
**.md"
workflow_dispatch
:
jobs
:
jobs
:
pdf-test
:
pdf-test
:
runs-on
:
pdf
runs-on
:
pdf
...
@@ -18,14 +18,16 @@ jobs:
...
@@ -18,14 +18,16 @@ jobs:
fail-fast
:
true
fail-fast
:
true
steps
:
steps
:
-
name
:
config-net
run
:
|
export http_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
export https_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
-
name
:
PDF benchmark
-
name
:
PDF benchmark
uses
:
actions/checkout@v3
uses
:
actions/checkout@v3
with
:
with
:
fetch-depth
:
2
fetch-depth
:
2
-
name
:
check-requirements
-
name
:
check-requirements
run
:
|
run
:
|
export http_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
export https_proxy=http://bigdata_open_proxy:H89k5qwQRDYfz@10.140.90.20:10811
changed_files=$(git diff --name-only -r HEAD~1 HEAD)
changed_files=$(git diff --name-only -r HEAD~1 HEAD)
echo $changed_files
echo $changed_files
if [[ $changed_files =~ "requirements.txt" ]]; then
if [[ $changed_files =~ "requirements.txt" ]]; then
...
@@ -36,4 +38,12 @@ jobs:
...
@@ -36,4 +38,12 @@ jobs:
-
name
:
benchmark
-
name
:
benchmark
run
:
|
run
:
|
echo "start test"
echo "start test"
cd tools && python ocr_badcase.py pdf_json_label_0306.json ocr_dataset.json json_files.zip output.json
cd tools && python ocr_badcase.py pdf_json_label_0306.json ocr_dataset.json json_files.zip badcase.json overall.json base_data.json
notify_to_feishu
:
if
:
${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }}
needs
:
[
pdf-test
]
runs-on
:
[
pdf
]
steps
:
-
name
:
notify
run
:
|
curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'${{ secrets.USER_ID }}'"}]]}}}}' ${{ secrets.WEBHOOK_URL }}
.github/workflows/update_base.yml
View file @
c8b06ad5
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
name
:
PDF
name
:
update-base
on
:
on
:
release
:
push
:
types
:
[
published
]
tags
:
-
'
*released'
workflow_dispatch
:
jobs
:
jobs
:
pdf-test
:
pdf-test
:
runs-on
:
pdf
runs-on
:
pdf
...
@@ -15,6 +16,7 @@ jobs:
...
@@ -15,6 +16,7 @@ jobs:
steps
:
steps
:
-
name
:
update-base
-
name
:
update-base
uses
:
actions/checkout@v3
uses
:
actions/checkout@v3
-
name
:
start-update
run
:
|
run
:
|
python update_base.py
echo "start test"
demo/ocr_demo.py
View file @
c8b06ad5
...
@@ -116,6 +116,7 @@ if __name__ == '__main__':
...
@@ -116,6 +116,7 @@ if __name__ == '__main__':
pdf_path
=
r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.pdf"
pdf_path
=
r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.pdf"
json_file_path
=
r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.json"
json_file_path
=
r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.json"
# ocr_local_parse(pdf_path, json_file_path)
# ocr_local_parse(pdf_path, json_file_path)
book_name
=
"数学新星网/edu_00001236"
book_name
=
"数学新星网/edu_00001236"
ocr_online_parse
(
book_name
)
ocr_online_parse
(
book_name
)
...
...
magic_pdf/io/AbsReaderWriter.py
View file @
c8b06ad5
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
class
AbsReaderWriter
(
ABC
):
class
AbsReaderWriter
(
ABC
):
"""
"""
同时支持二进制和文本读写的抽象类
同时支持二进制和文本读写的抽象类
TODO
"""
"""
@
abstractmethod
MODE_TXT
=
"text"
def
read
(
self
,
path
:
str
):
MODE_BIN
=
"binary"
pass
@
abstractmethod
def
write
(
self
,
path
:
str
,
content
:
str
):
pass
def
__init__
(
self
,
parent_path
):
# 初始化代码可以在这里添加,如果需要的话
self
.
parent_path
=
parent_path
# 对于本地目录是父目录,对于s3是会写到这个apth下。
@
abstractmethod
def
read
(
self
,
path
:
str
,
mode
=
"text"
):
"""
无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
"""
raise
NotImplementedError
@
abstractmethod
def
write
(
self
,
content
:
str
,
path
:
str
,
mode
=
MODE_TXT
):
"""
无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
"""
raise
NotImplementedError
@
abstractmethod
def
read_jsonl
(
self
,
path
:
str
,
byte_start
=
0
,
byte_end
=
None
,
encoding
=
'utf-8'
):
"""
无论对于本地还是s3的路径,检查如果path是绝对路径,那么就不再 拼接parent_path, 如果是相对路径就拼接parent_path
"""
raise
NotImplementedError
magic_pdf/io/DiskReaderWriter.py
View file @
c8b06ad5
import
os
from
magic_pdf.io.AbsReaderWriter
import
AbsReaderWriter
from
loguru
import
logger
class
DiskReaderWriter
(
AbsReaderWriter
):
def
__init__
(
self
,
parent_path
,
encoding
=
'utf-8'
):
self
.
path
=
parent_path
self
.
encoding
=
encoding
def
read
(
self
,
mode
=
"text"
):
if
not
os
.
path
.
exists
(
self
.
path
):
logger
.
error
(
f
"文件 {self.path} 不存在"
)
raise
Exception
(
f
"文件 {self.path} 不存在"
)
if
mode
==
"text"
:
with
open
(
self
.
path
,
'r'
,
encoding
=
self
.
encoding
)
as
f
:
return
f
.
read
()
elif
mode
==
"binary"
:
with
open
(
self
.
path
,
'rb'
)
as
f
:
return
f
.
read
()
else
:
raise
ValueError
(
"Invalid mode. Use 'text' or 'binary'."
)
def
write
(
self
,
data
,
mode
=
"text"
):
if
mode
==
"text"
:
with
open
(
self
.
path
,
'w'
,
encoding
=
self
.
encoding
)
as
f
:
f
.
write
(
data
)
logger
.
info
(
f
"内容已成功写入 {self.path}"
)
elif
mode
==
"binary"
:
with
open
(
self
.
path
,
'wb'
)
as
f
:
f
.
write
(
data
)
logger
.
info
(
f
"内容已成功写入 {self.path}"
)
else
:
raise
ValueError
(
"Invalid mode. Use 'text' or 'binary'."
)
# 使用示例
if
__name__
==
"__main__"
:
file_path
=
"example.txt"
drw
=
DiskReaderWriter
(
file_path
)
# 写入内容到文件
drw
.
write
(
b
"Hello, World!"
,
mode
=
"binary"
)
# 从文件读取内容
content
=
drw
.
read
()
if
content
:
logger
.
info
(
f
"从 {file_path} 读取的内容: {content}"
)
magic_pdf/io/S3ReaderWriter.py
View file @
c8b06ad5
from
magic_pdf.io
import
AbsReaderWriter
from
magic_pdf.io.AbsReaderWriter
import
AbsReaderWriter
from
magic_pdf.libs.commons
import
parse_aws_param
,
parse_bucket_key
import
boto3
from
loguru
import
logger
from
boto3.s3.transfer
import
TransferConfig
from
botocore.config
import
Config
class
DiskReaderWriter
(
AbsReaderWriter
):
class
S3ReaderWriter
(
AbsReaderWriter
):
def
__init__
(
self
,
parent_path
,
encoding
=
'utf-8'
):
def
__init__
(
self
,
ak
:
str
,
sk
:
str
,
endpoint_url
:
str
,
addressing_style
:
str
):
self
.
path
=
parent_path
self
.
client
=
self
.
_get_client
(
ak
,
sk
,
endpoint_url
,
addressing_style
)
self
.
encoding
=
encoding
def
read
(
self
):
def
_get_client
(
self
,
ak
:
str
,
sk
:
str
,
endpoint_url
:
str
,
addressing_style
:
str
):
with
open
(
self
.
path
,
'rb'
)
as
f
:
s3_client
=
boto3
.
client
(
return
f
.
read
()
service_name
=
"s3"
,
aws_access_key_id
=
ak
,
aws_secret_access_key
=
sk
,
endpoint_url
=
endpoint_url
,
config
=
Config
(
s3
=
{
"addressing_style"
:
addressing_style
},
retries
=
{
'max_attempts'
:
5
,
'mode'
:
'standard'
}),
)
return
s3_client
def
read
(
self
,
s3_path
,
mode
=
"text"
,
encoding
=
"utf-8"
):
bucket_name
,
bucket_key
=
parse_bucket_key
(
s3_path
)
res
=
self
.
client
.
get_object
(
Bucket
=
bucket_name
,
Key
=
bucket_key
)
body
=
res
[
"Body"
]
.
read
()
if
mode
==
'text'
:
data
=
body
.
decode
(
encoding
)
# Decode bytes to text
elif
mode
==
'binary'
:
data
=
body
else
:
raise
ValueError
(
"Invalid mode. Use 'text' or 'binary'."
)
return
data
def
write
(
self
,
data
):
def
write
(
self
,
data
,
s3_path
,
mode
=
"text"
,
encoding
=
"utf-8"
):
with
open
(
self
.
path
,
'wb'
)
as
f
:
if
mode
==
'text'
:
f
.
write
(
data
)
body
=
data
.
encode
(
encoding
)
# Encode text data as bytes
elif
mode
==
'binary'
:
body
=
data
else
:
raise
ValueError
(
"Invalid mode. Use 'text' or 'binary'."
)
bucket_name
,
bucket_key
=
parse_bucket_key
(
s3_path
)
self
.
client
.
put_object
(
Body
=
body
,
Bucket
=
bucket_name
,
Key
=
bucket_key
)
logger
.
info
(
f
"内容已写入 {s3_path} "
)
if
__name__
==
"__main__"
:
# Config the connection info
ak
=
""
sk
=
""
endpoint_url
=
""
addressing_style
=
""
# Create an S3ReaderWriter object
s3_reader_writer
=
S3ReaderWriter
(
ak
,
sk
,
endpoint_url
,
addressing_style
)
# Write text data to S3
text_data
=
"This is some text data"
s3_reader_writer
.
write
(
data
=
text_data
,
s3_path
=
"s3://bucket_name/ebook/test/test.json"
,
mode
=
'text'
)
# Read text data from S3
text_data_read
=
s3_reader_writer
.
read
(
s3_path
=
"s3://bucket_name/ebook/test/test.json"
,
mode
=
'text'
)
logger
.
info
(
f
"Read text data from S3: {text_data_read}"
)
# Write binary data to S3
binary_data
=
b
"This is some binary data"
s3_reader_writer
.
write
(
data
=
text_data
,
s3_path
=
"s3://bucket_name/ebook/test/test2.json"
,
mode
=
'binary'
)
# Read binary data from S3
binary_data_read
=
s3_reader_writer
.
read
(
s3_path
=
"s3://bucket_name/ebook/test/test2.json"
,
mode
=
'binary'
)
logger
.
info
(
f
"Read binary data from S3: {binary_data_read}"
)
\ No newline at end of file
magic_pdf/para/para_split.py
View file @
c8b06ad5
...
@@ -183,11 +183,31 @@ def __valign_lines(blocks, layout_bboxes):
...
@@ -183,11 +183,31 @@ def __valign_lines(blocks, layout_bboxes):
return
new_layout_bboxes
return
new_layout_bboxes
def
__align_text_in_layout
(
blocks
,
layout_bboxes
):
"""
由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。
"""
for
layout
in
layout_bboxes
:
lb
=
layout
[
'layout_bbox'
]
blocks_in_layoutbox
=
[
b
for
b
in
blocks
if
is_in_layout
(
b
[
'bbox'
],
lb
)]
if
len
(
blocks_in_layoutbox
)
==
0
:
continue
for
block
in
blocks_in_layoutbox
:
for
line
in
block
[
'lines'
]:
x0
,
x1
=
line
[
'bbox'
][
0
],
line
[
'bbox'
][
2
]
if
x0
<
lb
[
0
]:
line
[
'bbox'
][
0
]
=
lb
[
0
]
if
x1
>
lb
[
2
]:
line
[
'bbox'
][
2
]
=
lb
[
2
]
def
__common_pre_proc
(
blocks
,
layout_bboxes
):
def
__common_pre_proc
(
blocks
,
layout_bboxes
):
"""
"""
不分语言的,对文本进行预处理
不分语言的,对文本进行预处理
"""
"""
#__add_line_period(blocks, layout_bboxes)
#__add_line_period(blocks, layout_bboxes)
__align_text_in_layout
(
blocks
,
layout_bboxes
)
aligned_layout_bboxes
=
__valign_lines
(
blocks
,
layout_bboxes
)
aligned_layout_bboxes
=
__valign_lines
(
blocks
,
layout_bboxes
)
return
aligned_layout_bboxes
return
aligned_layout_bboxes
...
@@ -233,7 +253,6 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_
...
@@ -233,7 +253,6 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_
layout_paras
=
[]
layout_paras
=
[]
right_tail_distance
=
1.5
*
char_avg_len
right_tail_distance
=
1.5
*
char_avg_len
for
lines
in
lines_group
:
for
lines
in
lines_group
:
paras
=
[]
paras
=
[]
total_lines
=
len
(
lines
)
total_lines
=
len
(
lines
)
...
...
tools/base_data.json
0 → 100644
View file @
c8b06ad5
{
"accuracy"
:
1.0
,
"precision"
:
1.0
,
"recall"
:
1.0
,
"f1_score"
:
1.0
,
"pdf间的平均编辑距离"
:
133.10256410256412
,
"pdf间的平均bleu"
:
0.28838311595434046
,
"分段准确率"
:
0.07220216606498195
,
"行内公式准确率"
:
{
"accuracy"
:
0.004835727492533068
,
"precision"
:
0.008790072388831437
,
"recall"
:
0.010634970284641852
,
"f1_score"
:
0.009624911535739562
},
"行内公式编辑距离"
:
1.6176470588235294
,
"行内公式bleu"
:
0.17154724654721457
,
"行间公式准确率"
:
{
"accuracy"
:
0.08490566037735849
,
"precision"
:
0.1836734693877551
,
"recall"
:
0.13636363636363635
,
"f1_score"
:
0.1565217391304348
},
"行间公式编辑距离"
:
113.22222222222223
,
"行间公式bleu"
:
0.2531053359913409
,
"丢弃文本准确率"
:
{
"accuracy"
:
0.00035398230088495576
,
"precision"
:
0.0006389776357827476
,
"recall"
:
0.0007930214115781126
,
"f1_score"
:
0.0007077140835102619
},
"丢弃文本标签准确率"
:
{
"color_background_header_txt_block"
:
{
"precision"
:
0.0
,
"recall"
:
0.0
,
"f1-score"
:
0.0
,
"support"
:
41.0
},
"header"
:
{
"precision"
:
0.0
,
"recall"
:
0.0
,
"f1-score"
:
0.0
,
"support"
:
4.0
},
"footnote"
:
{
"precision"
:
1.0
,
"recall"
:
0.009708737864077669
,
"f1-score"
:
0.019230769230769232
,
"support"
:
103.0
},
"on-table"
:
{
"precision"
:
0.0
,
"recall"
:
0.0
,
"f1-score"
:
0.0
,
"support"
:
665.0
},
"rotate"
:
{
"precision"
:
0.0
,
"recall"
:
0.0
,
"f1-score"
:
0.0
,
"support"
:
63.0
},
"on-image"
:
{
"precision"
:
0.0
,
"recall"
:
0.0
,
"f1-score"
:
0.0
,
"support"
:
380.0
},
"micro avg"
:
{
"precision"
:
1.0
,
"recall"
:
0.0007961783439490446
,
"f1-score"
:
0.0015910898965791568
,
"support"
:
1256.0
}
},
"丢弃图片准确率"
:
{
"accuracy"
:
0.0
,
"precision"
:
0.0
,
"recall"
:
0.0
,
"f1_score"
:
0.0
},
"丢弃表格准确率"
:
{
"accuracy"
:
0.0
,
"precision"
:
0.0
,
"recall"
:
0.0
,
"f1_score"
:
0.0
}
}
\ No newline at end of file
tools/ocr_badcase.py
View file @
c8b06ad5
...
@@ -413,6 +413,8 @@ def bbox_match_indicator_dropped_text_block(test_dropped_text_bboxs, standard_dr
...
@@ -413,6 +413,8 @@ def bbox_match_indicator_dropped_text_block(test_dropped_text_bboxs, standard_dr
# 计算和返回标签匹配指标
# 计算和返回标签匹配指标
text_block_tag_report
=
classification_report
(
y_true
=
standard_tag
,
y_pred
=
test_tag
,
labels
=
list
(
set
(
standard_tag
)
-
{
'None'
}),
output_dict
=
True
,
zero_division
=
0
)
text_block_tag_report
=
classification_report
(
y_true
=
standard_tag
,
y_pred
=
test_tag
,
labels
=
list
(
set
(
standard_tag
)
-
{
'None'
}),
output_dict
=
True
,
zero_division
=
0
)
del
text_block_tag_report
[
"macro avg"
]
del
text_block_tag_report
[
"weighted avg"
]
return
text_block_report
,
text_block_tag_report
return
text_block_report
,
text_block_tag_report
...
@@ -500,6 +502,142 @@ def merge_json_data(json_test_df, json_standard_df):
...
@@ -500,6 +502,142 @@ def merge_json_data(json_test_df, json_standard_df):
return
inner_merge
,
standard_exist
,
test_exist
return
inner_merge
,
standard_exist
,
test_exist
def
consolidate_data
(
test_data
,
standard_data
,
key_path
):
"""
Consolidates data from test and standard datasets based on the provided key path.
:param test_data: Dictionary containing the test dataset.
:param standard_data: Dictionary containing the standard dataset.
:param key_path: List of keys leading to the desired data within the dictionaries.
:return: List containing all items from both test and standard data at the specified key path.
"""
# Initialize an empty list to hold the consolidated data
overall_data_standard
=
[]
overall_data_test
=
[]
# Helper function to recursively navigate through the dictionaries based on the key path
def
extract_data
(
source_data
,
keys
):
for
key
in
keys
[:
-
1
]:
source_data
=
source_data
.
get
(
key
,
{})
return
source_data
.
get
(
keys
[
-
1
],
[])
for
data
in
extract_data
(
standard_data
,
key_path
):
# 假设每个 single_table_tags 已经是一个列表,直接将它的元素添加到总列表中
overall_data_standard
.
extend
(
data
)
for
data
in
extract_data
(
test_data
,
key_path
):
overall_data_test
.
extend
(
data
)
# Extract and extend the overall data list with items from both test and standard datasets
return
overall_data_standard
,
overall_data_test
def
overall_calculate_metrics
(
inner_merge
,
json_test
,
json_standard
,
standard_exist
,
test_exist
):
process_data_standard
=
process_equations_and_blocks
(
json_standard
,
is_standard
=
True
)
process_data_test
=
process_equations_and_blocks
(
json_test
,
is_standard
=
False
)
overall_report
=
{}
overall_report
[
'accuracy'
]
=
metrics
.
accuracy_score
(
standard_exist
,
test_exist
)
overall_report
[
'precision'
]
=
metrics
.
precision_score
(
standard_exist
,
test_exist
)
overall_report
[
'recall'
]
=
metrics
.
recall_score
(
standard_exist
,
test_exist
)
overall_report
[
'f1_score'
]
=
metrics
.
f1_score
(
standard_exist
,
test_exist
)
overall_report
test_para_text
=
np
.
asarray
(
process_data_test
[
'para_texts'
],
dtype
=
object
)[
inner_merge
[
'pass_label'
]
==
'yes'
]
standard_para_text
=
np
.
asarray
(
process_data_standard
[
'para_texts'
],
dtype
=
object
)[
inner_merge
[
'pass_label'
]
==
'yes'
]
ids_yes
=
inner_merge
[
'id'
][
inner_merge
[
'pass_label'
]
==
'yes'
]
.
tolist
()
pdf_dis
=
{}
pdf_bleu
=
{}
# 对pass_label为'yes'的数据计算编辑距离和BLEU得分
for
idx
,(
a
,
b
,
id
)
in
enumerate
(
zip
(
test_para_text
,
standard_para_text
,
ids_yes
)):
a1
=
''
.
join
(
a
)
b1
=
''
.
join
(
b
)
pdf_dis
[
id
]
=
Levenshtein_Distance
(
a
,
b
)
pdf_bleu
[
id
]
=
sentence_bleu
([
a1
],
b1
)
overall_report
[
'pdf间的平均编辑距离'
]
=
np
.
mean
(
list
(
pdf_dis
.
values
()))
overall_report
[
'pdf间的平均bleu'
]
=
np
.
mean
(
list
(
pdf_bleu
.
values
()))
# Consolidate equations bboxs inline
overall_equations_bboxs_inline_standard
,
overall_equations_bboxs_inline_test
=
consolidate_data
(
process_data_test
,
process_data_standard
,
[
"equations_bboxs"
,
"inline"
])
# # Consolidate equations texts inline
overall_equations_texts_inline_standard
,
overall_equations_texts_inline_test
=
consolidate_data
(
process_data_test
,
process_data_standard
,
[
"equations_texts"
,
"inline"
])
# Consolidate equations bboxs interline
overall_equations_bboxs_interline_standard
,
overall_equations_bboxs_interline_test
=
consolidate_data
(
process_data_test
,
process_data_standard
,
[
"equations_bboxs"
,
"interline"
])
# Consolidate equations texts interline
overall_equations_texts_interline_standard
,
overall_equations_texts_interline_test
=
consolidate_data
(
process_data_test
,
process_data_standard
,
[
"equations_texts"
,
"interline"
])
overall_dropped_bboxs_text_standard
,
overall_dropped_bboxs_text_test
=
consolidate_data
(
process_data_test
,
process_data_standard
,
[
"dropped_bboxs"
,
"text"
])
overall_dropped_tags_text_standard
,
overall_dropped_tags_text_test
=
consolidate_data
(
process_data_test
,
process_data_standard
,
[
"dropped_tags"
,
"text"
])
overall_dropped_bboxs_image_standard
,
overall_dropped_bboxs_image_test
=
consolidate_data
(
process_data_test
,
process_data_standard
,
[
"dropped_bboxs"
,
"image"
])
overall_dropped_bboxs_table_standard
,
overall_dropped_bboxs_table_test
=
consolidate_data
(
process_data_test
,
process_data_standard
,[
"dropped_bboxs"
,
"table"
])
para_nums_test
=
process_data_test
[
'para_nums'
]
para_nums_standard
=
process_data_standard
[
'para_nums'
]
overall_para_nums_standard
=
[
item
for
sublist
in
para_nums_standard
for
item
in
(
sublist
if
isinstance
(
sublist
,
list
)
else
[
sublist
])]
overall_para_nums_test
=
[
item
for
sublist
in
para_nums_test
for
item
in
(
sublist
if
isinstance
(
sublist
,
list
)
else
[
sublist
])]
test_para_num
=
np
.
array
(
overall_para_nums_test
)
standard_para_num
=
np
.
array
(
overall_para_nums_standard
)
acc_para
=
np
.
mean
(
test_para_num
==
standard_para_num
)
overall_report
[
'分段准确率'
]
=
acc_para
# 行内公式准确率和编辑距离、bleu
overall_report
[
'行内公式准确率'
]
=
bbox_match_indicator_general
(
overall_equations_bboxs_inline_test
,
overall_equations_bboxs_inline_standard
)
overall_report
[
'行内公式编辑距离'
],
overall_report
[
'行内公式bleu'
]
=
equations_indicator
(
overall_equations_bboxs_inline_test
,
overall_equations_bboxs_inline_standard
,
overall_equations_texts_inline_test
,
overall_equations_texts_inline_standard
)
# 行间公式准确率和编辑距离、bleu
overall_report
[
'行间公式准确率'
]
=
bbox_match_indicator_general
(
overall_equations_bboxs_interline_test
,
overall_equations_bboxs_interline_standard
)
overall_report
[
'行间公式编辑距离'
],
overall_report
[
'行间公式bleu'
]
=
equations_indicator
(
overall_equations_bboxs_interline_test
,
overall_equations_bboxs_interline_standard
,
overall_equations_texts_interline_test
,
overall_equations_texts_interline_standard
)
# 丢弃文本准确率,丢弃文本标签准确率
overall_report
[
'丢弃文本准确率'
],
overall_report
[
'丢弃文本标签准确率'
]
=
bbox_match_indicator_dropped_text_block
(
overall_dropped_bboxs_text_test
,
overall_dropped_bboxs_text_standard
,
overall_dropped_tags_text_standard
,
overall_dropped_tags_text_test
)
# 丢弃图片准确率
overall_report
[
'丢弃图片准确率'
]
=
bbox_match_indicator_general
(
overall_dropped_bboxs_image_test
,
overall_dropped_bboxs_image_standard
)
# 丢弃表格准确率
overall_report
[
'丢弃表格准确率'
]
=
bbox_match_indicator_general
(
overall_dropped_bboxs_table_test
,
overall_dropped_bboxs_table_standard
)
return
overall_report
def
calculate_metrics
(
inner_merge
,
json_test
,
json_standard
,
json_standard_origin
):
def
calculate_metrics
(
inner_merge
,
json_test
,
json_standard
,
json_standard_origin
):
...
@@ -602,21 +740,27 @@ def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origi
...
@@ -602,21 +740,27 @@ def calculate_metrics(inner_merge, json_test, json_standard, json_standard_origi
return
result_dict
return
result_dict
def
save_results
(
result_dict
,
output_path
):
def
save_results
(
result_dict
,
overall_report_dict
,
badcase_path
,
overall_path
,):
"""
"""
将结果字典保存为JSON文件至指定路径。
将结果字典保存为JSON文件至指定路径。
参数:
参数:
- result_dict: 包含计算结果的字典。
- result_dict: 包含计算结果的字典。
- o
utput
_path: 结果文件的保存路径,包括文件名。
- o
verall
_path: 结果文件的保存路径,包括文件名。
"""
"""
# 打开指定的文件以写入
# 打开指定的文件以写入
with
open
(
output
_path
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
badcase
_path
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
# 将结果字典转换为JSON格式并写入文件
# 将结果字典转换为JSON格式并写入文件
json
.
dump
(
result_dict
,
f
,
ensure_ascii
=
False
,
indent
=
4
)
json
.
dump
(
result_dict
,
f
,
ensure_ascii
=
False
,
indent
=
4
)
print
(
f
"计算结果已经保存到文件:{output_path}"
)
print
(
f
"计算结果已经保存到文件:{badcase_path}"
)
with
open
(
overall_path
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
# 将结果字典转换为JSON格式并写入文件
json
.
dump
(
overall_report_dict
,
f
,
ensure_ascii
=
False
,
indent
=
4
)
print
(
f
"计算结果已经保存到文件:{overall_path}"
)
def
upload_to_s3
(
file_path
,
bucket_name
,
s3_file_name
,
AWS_ACCESS_KEY
,
AWS_SECRET_KEY
,
END_POINT_URL
):
def
upload_to_s3
(
file_path
,
bucket_name
,
s3_file_name
,
AWS_ACCESS_KEY
,
AWS_SECRET_KEY
,
END_POINT_URL
):
"""
"""
...
@@ -634,7 +778,7 @@ def upload_to_s3(file_path, bucket_name, s3_file_name,AWS_ACCESS_KEY,AWS_SECRET_
...
@@ -634,7 +778,7 @@ def upload_to_s3(file_path, bucket_name, s3_file_name,AWS_ACCESS_KEY,AWS_SECRET_
except
ClientError
as
e
:
except
ClientError
as
e
:
print
(
f
"上传文件时发生错误:{e}"
)
print
(
f
"上传文件时发生错误:{e}"
)
def
generate_
output_filename
(
base
_path
):
def
generate_
filename
(
badcase_path
,
overall
_path
):
"""
"""
生成带有当前时间戳的输出文件名。
生成带有当前时间戳的输出文件名。
...
@@ -647,13 +791,24 @@ def generate_output_filename(base_path):
...
@@ -647,13 +791,24 @@ def generate_output_filename(base_path):
# 获取当前时间并格式化为字符串
# 获取当前时间并格式化为字符串
current_time
=
datetime
.
now
()
.
strftime
(
'
%
Y-
%
m-
%
d_
%
H-
%
M-
%
S'
)
current_time
=
datetime
.
now
()
.
strftime
(
'
%
Y-
%
m-
%
d_
%
H-
%
M-
%
S'
)
# 构建并返回完整的输出文件名
# 构建并返回完整的输出文件名
return
f
"{base_path}_{current_time}.json"
return
f
"{badcase_path}_{current_time}.json"
,
f
"{overall_path}_{current_time}.json"
def
compare_edit_distance
(
json_file
,
overall_report
):
with
open
(
json_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
json_data
=
json
.
load
(
f
)
json_edit_distance
=
json_data
[
'pdf间的平均编辑距离'
]
if
overall_report
[
'pdf间的平均编辑距离'
]
>=
json_edit_distance
:
return
0
else
:
return
1
def
main
(
standard_file
,
test_file
,
zip_file
,
base_output_path
,
s3_bucket_name
=
None
,
s3_file_name
=
None
,
AWS_ACCESS_KEY
=
None
,
AWS_SECRET_KEY
=
None
,
END_POINT_URL
=
None
):
def
main
(
standard_file
,
test_file
,
zip_file
,
badcase_path
,
overall_path
,
base_data_path
,
s3_bucket_name
=
None
,
s3_file_name
=
None
,
AWS_ACCESS_KEY
=
None
,
AWS_SECRET_KEY
=
None
,
END_POINT_URL
=
None
):
"""
"""
主函数,执行整个评估流程。
主函数,执行整个评估流程。
...
@@ -661,7 +816,8 @@ def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=No
...
@@ -661,7 +816,8 @@ def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=No
- standard_file: 标准文件的路径。
- standard_file: 标准文件的路径。
- test_file: 测试文件的路径。
- test_file: 测试文件的路径。
- zip_file: 压缩包的路径的路径。
- zip_file: 压缩包的路径的路径。
- base_output_path: 结果文件的基础路径和文件名前缀。
- badcase_path: badcase文件的基础路径和文件名前缀。
- overall_path: overall文件的基础路径和文件名前缀。
- s3_bucket_name: S3桶名称(可选)。
- s3_bucket_name: S3桶名称(可选)。
- s3_file_name: S3上的文件名(可选)。
- s3_file_name: S3上的文件名(可选)。
- AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL: AWS访问凭证和端点URL(可选)。
- AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL: AWS访问凭证和端点URL(可选)。
...
@@ -675,21 +831,29 @@ def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=No
...
@@ -675,21 +831,29 @@ def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=No
# 合并JSON数据
# 合并JSON数据
inner_merge
,
standard_exist
,
test_exist
=
merge_json_data
(
json_test_origin
,
json_standard_origin
)
inner_merge
,
standard_exist
,
test_exist
=
merge_json_data
(
json_test_origin
,
json_standard_origin
)
#计算总体指标
overall_report_dict
=
overall_calculate_metrics
(
inner_merge
,
inner_merge
[
'test_mid_json'
],
inner_merge
[
'standard_mid_json'
],
standard_exist
,
test_exist
)
# 计算指标
# 计算指标
result_dict
=
calculate_metrics
(
inner_merge
,
inner_merge
[
'test_mid_json'
],
inner_merge
[
'standard_mid_json'
],
json_standard_origin
)
result_dict
=
calculate_metrics
(
inner_merge
,
inner_merge
[
'test_mid_json'
],
inner_merge
[
'standard_mid_json'
],
json_standard_origin
)
# 生成带时间戳的输出文件名
# 生成带时间戳的输出文件名
output_file
=
generate_output_filename
(
base_output
_path
)
badcase_file
,
overall_file
=
generate_filename
(
badcase_path
,
overall
_path
)
# 保存结果到JSON文件
# 保存结果到JSON文件
save_results
(
result_dict
,
output_file
)
save_results
(
result_dict
,
overall_report_dict
,
badcase_file
,
overall_file
)
result
=
compare_edit_distance
(
base_data_path
,
overall_report_dict
)
print
(
result
)
assert
result
==
1
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"主函数,执行整个评估流程。"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"主函数,执行整个评估流程。"
)
parser
.
add_argument
(
'standard_file'
,
type
=
str
,
help
=
'标准文件的路径。'
)
parser
.
add_argument
(
'standard_file'
,
type
=
str
,
help
=
'标准文件的路径。'
)
parser
.
add_argument
(
'test_file'
,
type
=
str
,
help
=
'测试文件的路径。'
)
parser
.
add_argument
(
'test_file'
,
type
=
str
,
help
=
'测试文件的路径。'
)
parser
.
add_argument
(
'zip_file'
,
type
=
str
,
help
=
'压缩包的路径。'
)
parser
.
add_argument
(
'zip_file'
,
type
=
str
,
help
=
'压缩包的路径。'
)
parser
.
add_argument
(
'base_output_path'
,
type
=
str
,
help
=
'结果文件的基础路径和文件名前缀。'
)
parser
.
add_argument
(
'badcase_path'
,
type
=
str
,
help
=
'badcase文件的基础路径和文件名前缀。'
)
parser
.
add_argument
(
'overall_path'
,
type
=
str
,
help
=
'overall文件的基础路径和文件名前缀。'
)
parser
.
add_argument
(
'base_data_path'
,
type
=
str
,
help
=
'基准文件的基础路径和文件名前缀。'
)
parser
.
add_argument
(
'--s3_bucket_name'
,
type
=
str
,
help
=
'S3桶名称。'
,
default
=
None
)
parser
.
add_argument
(
'--s3_bucket_name'
,
type
=
str
,
help
=
'S3桶名称。'
,
default
=
None
)
parser
.
add_argument
(
'--s3_file_name'
,
type
=
str
,
help
=
'S3上的文件名。'
,
default
=
None
)
parser
.
add_argument
(
'--s3_file_name'
,
type
=
str
,
help
=
'S3上的文件名。'
,
default
=
None
)
parser
.
add_argument
(
'--AWS_ACCESS_KEY'
,
type
=
str
,
help
=
'AWS访问密钥。'
,
default
=
None
)
parser
.
add_argument
(
'--AWS_ACCESS_KEY'
,
type
=
str
,
help
=
'AWS访问密钥。'
,
default
=
None
)
...
@@ -698,5 +862,5 @@ if __name__ == "__main__":
...
@@ -698,5 +862,5 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
standard_file
,
args
.
test_file
,
args
.
zip_file
,
args
.
ba
se_output_path
,
args
.
s3_bucket_name
,
args
.
s3_file_name
,
args
.
AWS_ACCESS_KEY
,
args
.
AWS_SECRET_KEY
,
args
.
END_POINT_URL
)
main
(
args
.
standard_file
,
args
.
test_file
,
args
.
zip_file
,
args
.
ba
dcase_path
,
args
.
overall_path
,
args
.
base_data_path
,
args
.
s3_bucket_name
,
args
.
s3_file_name
,
args
.
AWS_ACCESS_KEY
,
args
.
AWS_SECRET_KEY
,
args
.
END_POINT_URL
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment