Commit b66dda38 authored by Shuimo's avatar Shuimo

Improved script to read json in compressed packages

parent 015e2bdd
......@@ -9,7 +9,8 @@ from sklearn import metrics
from datetime import datetime
import boto3
from botocore.exceptions import NoCredentialsError, ClientError
from io import TextIOWrapper
import zipfile
......@@ -429,28 +430,46 @@ def handle_multi_deletion(test_page, test_page_tag, test_page_bbox, standard_pag
def check_files_exist(standard_file, test_file):
def check_json_files_in_zip_exist(zip_file_path, standard_json_path_in_zip, test_json_path_in_zip):
"""
检查文件是否存在
检查ZIP文件中是否存在指定的JSON文件
"""
if not os.path.isfile(standard_file) or not os.path.isfile(test_file):
raise FileNotFoundError("One or both of the required JSON files are missing.")
with zipfile.ZipFile(zip_file_path, 'r') as z:
# 获取ZIP文件中所有文件的列表
all_files_in_zip = z.namelist()
# 检查标准文件和测试文件是否都在ZIP文件中
if standard_json_path_in_zip not in all_files_in_zip or test_json_path_in_zip not in all_files_in_zip:
raise FileNotFoundError("One or both of the required JSON files are missing from the ZIP archive.")
def read_json_files(standard_file, test_file):
def read_json_files_from_streams(standard_file_stream, test_file_stream):
"""
读取JSON文件内容
从文件流中读取JSON文件内容
"""
with open(standard_file, 'r', encoding='utf-8') as sf:
pdf_json_standard = [json.loads(line) for line in sf]
with open(test_file, 'r', encoding='utf-8') as tf:
pdf_json_test = [json.loads(line) for line in tf]
pdf_json_standard = [json.loads(line) for line in standard_file_stream]
pdf_json_test = [json.loads(line) for line in test_file_stream]
json_standard_origin = pd.DataFrame(pdf_json_standard)
json_test = pd.DataFrame(pdf_json_test)
json_test_origin = pd.DataFrame(pdf_json_test)
return json_standard_origin, json_test_origin
def read_json_files_from_zip(zip_file_path, standard_json_path_in_zip, test_json_path_in_zip):
"""
从ZIP文件中读取两个JSON文件并返回它们的DataFrame
"""
with zipfile.ZipFile(zip_file_path, 'r') as z:
with z.open(standard_json_path_in_zip) as standard_file_stream, \
z.open(test_json_path_in_zip) as test_file_stream:
standard_file_text_stream = TextIOWrapper(standard_file_stream, encoding='utf-8')
test_file_text_stream = TextIOWrapper(test_file_stream, encoding='utf-8')
json_standard_origin, json_test_origin = read_json_files_from_streams(
standard_file_text_stream, test_file_text_stream
)
return json_standard_origin, json_test
return json_standard_origin, json_test_origin
def merge_json_data(json_test_df, json_standard_df):
......@@ -634,23 +653,24 @@ def generate_output_filename(base_path):
def main(standard_file, test_file, base_output_path, s3_bucket_name=None, s3_file_name=None, AWS_ACCESS_KEY=None, AWS_SECRET_KEY=None, END_POINT_URL=None):
def main(standard_file, test_file, zip_file, base_output_path, s3_bucket_name=None, s3_file_name=None, AWS_ACCESS_KEY=None, AWS_SECRET_KEY=None, END_POINT_URL=None):
"""
主函数,执行整个评估流程。
参数:
- standard_file: 标准文件的路径。
- test_file: 测试文件的路径。
- zip_file: 压缩包的路径的路径。
- base_output_path: 结果文件的基础路径和文件名前缀。
- s3_bucket_name: S3桶名称(可选)。
- s3_file_name: S3上的文件名(可选)。
- AWS_ACCESS_KEY, AWS_SECRET_KEY, END_POINT_URL: AWS访问凭证和端点URL(可选)。
"""
# 检查文件是否存在
check_files_exist(standard_file, test_file)
check_json_files_in_zip_exist(zip_file, standard_file, test_file)
# 读取JSON文件内容
json_standard_origin, json_test_origin = read_json_files(standard_file, test_file)
json_standard_origin, json_test_origin = read_json_files_from_zip(zip_file, standard_file, test_file)
# 合并JSON数据
inner_merge, standard_exist, test_exist = merge_json_data(json_test_origin, json_standard_origin)
......@@ -668,6 +688,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="主函数,执行整个评估流程。")
parser.add_argument('standard_file', type=str, help='标准文件的路径。')
parser.add_argument('test_file', type=str, help='测试文件的路径。')
parser.add_argument('zip_file', type=str, help='压缩包的路径。')
parser.add_argument('base_output_path', type=str, help='结果文件的基础路径和文件名前缀。')
parser.add_argument('--s3_bucket_name', type=str, help='S3桶名称。', default=None)
parser.add_argument('--s3_file_name', type=str, help='S3上的文件名。', default=None)
......@@ -677,5 +698,5 @@ if __name__ == "__main__":
args = parser.parse_args()
main(args.standard_file, args.test_file, args.base_output_path, args.s3_bucket_name, args.s3_file_name, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
main(args.standard_file, args.test_file, args.zip_file, args.base_output_path, args.s3_bucket_name, args.s3_file_name, args.AWS_ACCESS_KEY, args.AWS_SECRET_KEY, args.END_POINT_URL)
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment