Unverified Commit b75ee676 authored by drunkpig's avatar drunkpig Committed by GitHub

Merge pull request #11 from magicpdf/dev-xm

fix logic
parents b4fb6a68 4b87a571
import json
import os
import sys
from pathlib import Path
......@@ -6,8 +7,8 @@ import click
from loguru import logger
from magic_pdf.libs.commons import join_path, read_file
from magic_pdf.dict2md.mkcontent import mk_mm_markdown
from magic_pdf.pipeline import parse_pdf_by_model
from magic_pdf.dict2md.mkcontent import mk_mm_markdown, mk_universal_format
from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
......@@ -24,7 +25,7 @@ def main(s3_pdf_path: str, s3_pdf_profile: str, pdf_model_path: str, pdf_model_p
pdf_bytes = read_file(s3_pdf_path, s3_pdf_profile)
try:
paras_dict = parse_pdf_by_model(
paras_dict = parse_pdf_by_txt(
pdf_bytes, pdf_model_path, save_path, book_name, pdf_model_profile, start_page_num, debug_mode=debug_mode
)
parent_dir = os.path.dirname(text_content_save_path)
......@@ -32,7 +33,8 @@ def main(s3_pdf_path: str, s3_pdf_profile: str, pdf_model_path: str, pdf_model_p
os.makedirs(parent_dir)
if not paras_dict.get('need_drop'):
markdown_content = mk_mm_markdown(paras_dict)
content_list = mk_universal_format(paras_dict)
markdown_content = mk_mm_markdown(content_list)
else:
markdown_content = paras_dict['drop_reason']
......@@ -70,8 +72,8 @@ def main_shell(pdf_file_path: str, save_path: str):
@click.command()
@click.option("--pdf-dir", help="s3上pdf文件的路径")
@click.option("--model-dir", help="s3上pdf文件的路径")
@click.option("--pdf-dir", help="本地pdf文件的路径")
@click.option("--model-dir", help="本地模型文件的路径")
@click.option("--start-page-num", default=0, help="从第几页开始解析")
def main_shell2(pdf_dir: str, model_dir: str,start_page_num: int):
# 先扫描所有的pdf目录里的文件名字
......@@ -86,8 +88,10 @@ def main_shell2(pdf_dir: str, model_dir: str,start_page_num: int):
for pdf_file in pdf_file_names:
pdf_file_path = os.path.join(pdf_dir, pdf_file)
model_file_path = os.path.join(model_dir, pdf_file)
main(pdf_file_path, None, model_file_path, None, start_page_num)
model_file_path = os.path.join(model_dir, pdf_file).rstrip(".pdf") + ".json"
with open(model_file_path, "r") as json_file:
model_list = json.load(json_file)
main(pdf_file_path, None, model_list, None, start_page_num)
......
from magic_pdf.libs.commons import s3_image_save_path, join_path
from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import ContentType
import wordninja
......@@ -72,7 +73,7 @@ def ocr_mk_mm_markdown_with_para(pdf_info_dict: dict):
markdown = []
for _, page_info in pdf_info_dict.items():
paras_of_layout = page_info.get("para_blocks")
page_markdown = ocr_mk_mm_markdown_with_para_core(paras_of_layout, "mm")
page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "mm")
markdown.extend(page_markdown)
return '\n\n'.join(markdown)
......@@ -81,7 +82,7 @@ def ocr_mk_nlp_markdown_with_para(pdf_info_dict: dict):
markdown = []
for _, page_info in pdf_info_dict.items():
paras_of_layout = page_info.get("para_blocks")
page_markdown = ocr_mk_mm_markdown_with_para_core(paras_of_layout, "nlp")
page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "nlp")
markdown.extend(page_markdown)
return '\n\n'.join(markdown)
......@@ -91,7 +92,7 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: dict):
paras_of_layout = page_info.get("para_blocks")
if not paras_of_layout:
continue
page_markdown = ocr_mk_mm_markdown_with_para_core(paras_of_layout, "mm")
page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "mm")
markdown_with_para_and_pagination.append({
'page_no': page_no,
'md_content': '\n\n'.join(page_markdown)
......@@ -99,7 +100,7 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: dict):
return markdown_with_para_and_pagination
def ocr_mk_mm_markdown_with_para_core(paras_of_layout, mode):
def ocr_mk_markdown_with_para_core(paras_of_layout, mode):
page_markdown = []
for paras in paras_of_layout:
for para in paras:
......@@ -108,19 +109,28 @@ def ocr_mk_mm_markdown_with_para_core(paras_of_layout, mode):
for span in line['spans']:
span_type = span.get('type')
content = ''
language = ''
if span_type == ContentType.Text:
content = ocr_escape_special_markdown_char(split_long_words(span['content']))
content = span['content']
language = detect_lang(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
content = f"${ocr_escape_special_markdown_char(span['content'])}$"
content = f"${span['content']}$"
elif span_type == ContentType.InterlineEquation:
content = f"\n$$\n{ocr_escape_special_markdown_char(span['content'])}\n$$\n"
content = f"\n$$\n{span['content']}\n$$\n"
elif span_type in [ContentType.Image, ContentType.Table]:
if mode == 'mm':
content = f"\n![]({join_path(s3_image_save_path, span['image_path'])})\n"
elif mode == 'nlp':
pass
if content != '':
para_text += content + ' '
if language == 'en': # 英文语境下 content间需要空格分隔
para_text += content + ' '
else: # 中文语境下,content间不需要空格分隔
para_text += content
if para_text.strip() == '':
continue
else:
......@@ -137,13 +147,23 @@ def para_to_standard_format(para):
inline_equation_num = 0
for line in para:
for span in line['spans']:
language = ''
span_type = span.get('type')
if span_type == ContentType.Text:
content = ocr_escape_special_markdown_char(split_long_words(span['content']))
content = span['content']
language = detect_lang(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
content = f"${ocr_escape_special_markdown_char(span['content'])}$"
content = f"${span['content']}$"
inline_equation_num += 1
para_text += content + ' '
if language == 'en': # 英文语境下 content间需要空格分隔
para_text += content + ' '
else: # 中文语境下,content间不需要空格分隔
para_text += content
para_content = {
'type': 'text',
'text': para_text,
......@@ -186,14 +206,14 @@ def line_to_standard_format(line):
return content
else:
if span['type'] == ContentType.InterlineEquation:
interline_equation = ocr_escape_special_markdown_char(span['content']) # 转义特殊符号
interline_equation = span['content']
content = {
'type': 'equation',
'latex': f"$$\n{interline_equation}\n$$"
}
return content
elif span['type'] == ContentType.InlineEquation:
inline_equation = ocr_escape_special_markdown_char(span['content']) # 转义特殊符号
inline_equation = span['content']
line_text += f"${inline_equation}$"
inline_equation_num += 1
elif span['type'] == ContentType.Text:
......
"""
根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
"""
import json
import os
from loguru import logger
def get_s3_config(bucket_name: str):
"""
~/magic-pdf.json 读出来
"""
ak , sk, endpoint = "", "", ""
# TODO 请实现这个函数
return ak, sk, endpoint
home_dir = os.path.expanduser("~")
config_file = os.path.join(home_dir, "magic-pdf.json")
if not os.path.exists(config_file):
raise Exception("magic-pdf.json not found")
with open(config_file, "r") as f:
config = json.load(f)
bucket_info = config.get("bucket_info")
if bucket_name not in bucket_info:
raise Exception("bucket_name not found in magic-pdf.json")
access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
if access_key is None or secret_key is None or storage_endpoint is None:
raise Exception("ak, sk or endpoint not found in magic-pdf.json")
# logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
return access_key, secret_key, storage_endpoint
if __name__ == '__main__':
ak, sk, endpoint = get_s3_config("llm-raw")
......@@ -70,7 +70,7 @@ paraMergeException_msg = ParaMergeException().message
def parse_pdf_by_model(
def parse_pdf_by_txt(
pdf_bytes,
pdf_model_output,
save_path,
......
......@@ -13,7 +13,7 @@ from magic_pdf.libs.commons import (
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.json_compressor import JsonCompressor
from magic_pdf.dict2md.mkcontent import mk_universal_format
from magic_pdf.pdf_parse_by_model import parse_pdf_by_model
from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
from magic_pdf.filter.pdf_classify_by_type import classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
from loguru import logger
......@@ -130,6 +130,7 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
classify_time = int(time.time() - start_time) # 计算执行时间
if is_text_pdf:
pdf_meta["is_text_pdf"] = is_text_pdf
jso["_pdf_type"] = "TXT"
jso["pdf_meta"] = pdf_meta
jso["classify_time"] = classify_time
# print(json.dumps(pdf_meta, ensure_ascii=False))
......@@ -144,10 +145,11 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
else:
# 先不drop
pdf_meta["is_text_pdf"] = is_text_pdf
jso["_pdf_type"] = "OCR"
jso["pdf_meta"] = pdf_meta
jso["classify_time"] = classify_time
jso["need_drop"] = True
jso["drop_reason"] = DropReason.NOT_IS_TEXT_PDF
# jso["need_drop"] = True
# jso["drop_reason"] = DropReason.NOT_IS_TEXT_PDF
extra_info = {"classify_rules": []}
for condition, result in results.items():
if not result:
......@@ -310,7 +312,7 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
f"book_name is:{book_name},start_time is:{formatted_time(start_time)}",
file=sys.stderr,
)
pdf_info_dict = parse_pdf_by_model(
pdf_info_dict = parse_pdf_by_txt(
pdf_bytes,
model_output_json_list,
save_path,
......
from loguru import logger
import json
import os
from magic_pdf.config import s3_buckets, s3_clusters, s3_users
def get_bucket_configs_dict(buckets, clusters, users):
bucket_configs = {}
for s3_bucket in buckets.items():
bucket_name = s3_bucket[0]
bucket_config = s3_bucket[1]
cluster, user = bucket_config
cluster_config = clusters[cluster]
endpoint_key = "outside"
endpoints = cluster_config[endpoint_key]
endpoint = endpoints[0]
user_config = users[user]
# logger.info(bucket_name)
# logger.info(endpoint)
# logger.info(user_config)
bucket_config = [user_config["ak"], user_config["sk"], endpoint]
bucket_configs[bucket_name] = bucket_config
return bucket_configs
def write_json_to_home(my_dict):
# Convert dictionary to JSON
json_data = json.dumps(my_dict, indent=4, ensure_ascii=False)
home_dir = os.path.expanduser("~")
# Define the output file path
output_file = os.path.join(home_dir, "magic-pdf.json")
# Write JSON data to the output file
with open(output_file, "w") as f:
f.write(json_data)
# Print a success message
print(f"Dictionary converted to JSON and saved to {output_file}")
if __name__ == '__main__':
bucket_configs_dict = get_bucket_configs_dict(s3_buckets, s3_clusters, s3_users)
logger.info(bucket_configs_dict)
config_dict = {
"bucket_info": bucket_configs_dict,
"temp-output-dir": "/tmp"
}
write_json_to_home(config_dict)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment