Unverified Commit b75ee676 authored by drunkpig's avatar drunkpig Committed by GitHub

Merge pull request #11 from magicpdf/dev-xm

fix logic
parents b4fb6a68 4b87a571
import json
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
...@@ -6,8 +7,8 @@ import click ...@@ -6,8 +7,8 @@ import click
from loguru import logger from loguru import logger
from magic_pdf.libs.commons import join_path, read_file from magic_pdf.libs.commons import join_path, read_file
from magic_pdf.dict2md.mkcontent import mk_mm_markdown from magic_pdf.dict2md.mkcontent import mk_mm_markdown, mk_universal_format
from magic_pdf.pipeline import parse_pdf_by_model from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
...@@ -24,7 +25,7 @@ def main(s3_pdf_path: str, s3_pdf_profile: str, pdf_model_path: str, pdf_model_p ...@@ -24,7 +25,7 @@ def main(s3_pdf_path: str, s3_pdf_profile: str, pdf_model_path: str, pdf_model_p
pdf_bytes = read_file(s3_pdf_path, s3_pdf_profile) pdf_bytes = read_file(s3_pdf_path, s3_pdf_profile)
try: try:
paras_dict = parse_pdf_by_model( paras_dict = parse_pdf_by_txt(
pdf_bytes, pdf_model_path, save_path, book_name, pdf_model_profile, start_page_num, debug_mode=debug_mode pdf_bytes, pdf_model_path, save_path, book_name, pdf_model_profile, start_page_num, debug_mode=debug_mode
) )
parent_dir = os.path.dirname(text_content_save_path) parent_dir = os.path.dirname(text_content_save_path)
...@@ -32,7 +33,8 @@ def main(s3_pdf_path: str, s3_pdf_profile: str, pdf_model_path: str, pdf_model_p ...@@ -32,7 +33,8 @@ def main(s3_pdf_path: str, s3_pdf_profile: str, pdf_model_path: str, pdf_model_p
os.makedirs(parent_dir) os.makedirs(parent_dir)
if not paras_dict.get('need_drop'): if not paras_dict.get('need_drop'):
markdown_content = mk_mm_markdown(paras_dict) content_list = mk_universal_format(paras_dict)
markdown_content = mk_mm_markdown(content_list)
else: else:
markdown_content = paras_dict['drop_reason'] markdown_content = paras_dict['drop_reason']
...@@ -70,8 +72,8 @@ def main_shell(pdf_file_path: str, save_path: str): ...@@ -70,8 +72,8 @@ def main_shell(pdf_file_path: str, save_path: str):
@click.command() @click.command()
@click.option("--pdf-dir", help="s3上pdf文件的路径") @click.option("--pdf-dir", help="本地pdf文件的路径")
@click.option("--model-dir", help="s3上pdf文件的路径") @click.option("--model-dir", help="本地模型文件的路径")
@click.option("--start-page-num", default=0, help="从第几页开始解析") @click.option("--start-page-num", default=0, help="从第几页开始解析")
def main_shell2(pdf_dir: str, model_dir: str,start_page_num: int): def main_shell2(pdf_dir: str, model_dir: str,start_page_num: int):
# 先扫描所有的pdf目录里的文件名字 # 先扫描所有的pdf目录里的文件名字
...@@ -86,8 +88,10 @@ def main_shell2(pdf_dir: str, model_dir: str,start_page_num: int): ...@@ -86,8 +88,10 @@ def main_shell2(pdf_dir: str, model_dir: str,start_page_num: int):
for pdf_file in pdf_file_names: for pdf_file in pdf_file_names:
pdf_file_path = os.path.join(pdf_dir, pdf_file) pdf_file_path = os.path.join(pdf_dir, pdf_file)
model_file_path = os.path.join(model_dir, pdf_file) model_file_path = os.path.join(model_dir, pdf_file).rstrip(".pdf") + ".json"
main(pdf_file_path, None, model_file_path, None, start_page_num) with open(model_file_path, "r") as json_file:
model_list = json.load(json_file)
main(pdf_file_path, None, model_list, None, start_page_num)
......
from magic_pdf.libs.commons import s3_image_save_path, join_path from magic_pdf.libs.commons import s3_image_save_path, join_path
from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import ContentType from magic_pdf.libs.ocr_content_type import ContentType
import wordninja import wordninja
...@@ -72,7 +73,7 @@ def ocr_mk_mm_markdown_with_para(pdf_info_dict: dict): ...@@ -72,7 +73,7 @@ def ocr_mk_mm_markdown_with_para(pdf_info_dict: dict):
markdown = [] markdown = []
for _, page_info in pdf_info_dict.items(): for _, page_info in pdf_info_dict.items():
paras_of_layout = page_info.get("para_blocks") paras_of_layout = page_info.get("para_blocks")
page_markdown = ocr_mk_mm_markdown_with_para_core(paras_of_layout, "mm") page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "mm")
markdown.extend(page_markdown) markdown.extend(page_markdown)
return '\n\n'.join(markdown) return '\n\n'.join(markdown)
...@@ -81,7 +82,7 @@ def ocr_mk_nlp_markdown_with_para(pdf_info_dict: dict): ...@@ -81,7 +82,7 @@ def ocr_mk_nlp_markdown_with_para(pdf_info_dict: dict):
markdown = [] markdown = []
for _, page_info in pdf_info_dict.items(): for _, page_info in pdf_info_dict.items():
paras_of_layout = page_info.get("para_blocks") paras_of_layout = page_info.get("para_blocks")
page_markdown = ocr_mk_mm_markdown_with_para_core(paras_of_layout, "nlp") page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "nlp")
markdown.extend(page_markdown) markdown.extend(page_markdown)
return '\n\n'.join(markdown) return '\n\n'.join(markdown)
...@@ -91,7 +92,7 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: dict): ...@@ -91,7 +92,7 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: dict):
paras_of_layout = page_info.get("para_blocks") paras_of_layout = page_info.get("para_blocks")
if not paras_of_layout: if not paras_of_layout:
continue continue
page_markdown = ocr_mk_mm_markdown_with_para_core(paras_of_layout, "mm") page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "mm")
markdown_with_para_and_pagination.append({ markdown_with_para_and_pagination.append({
'page_no': page_no, 'page_no': page_no,
'md_content': '\n\n'.join(page_markdown) 'md_content': '\n\n'.join(page_markdown)
...@@ -99,7 +100,7 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: dict): ...@@ -99,7 +100,7 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: dict):
return markdown_with_para_and_pagination return markdown_with_para_and_pagination
def ocr_mk_mm_markdown_with_para_core(paras_of_layout, mode): def ocr_mk_markdown_with_para_core(paras_of_layout, mode):
page_markdown = [] page_markdown = []
for paras in paras_of_layout: for paras in paras_of_layout:
for para in paras: for para in paras:
...@@ -108,19 +109,28 @@ def ocr_mk_mm_markdown_with_para_core(paras_of_layout, mode): ...@@ -108,19 +109,28 @@ def ocr_mk_mm_markdown_with_para_core(paras_of_layout, mode):
for span in line['spans']: for span in line['spans']:
span_type = span.get('type') span_type = span.get('type')
content = '' content = ''
language = ''
if span_type == ContentType.Text: if span_type == ContentType.Text:
content = ocr_escape_special_markdown_char(split_long_words(span['content'])) content = span['content']
language = detect_lang(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation: elif span_type == ContentType.InlineEquation:
content = f"${ocr_escape_special_markdown_char(span['content'])}$" content = f"${span['content']}$"
elif span_type == ContentType.InterlineEquation: elif span_type == ContentType.InterlineEquation:
content = f"\n$$\n{ocr_escape_special_markdown_char(span['content'])}\n$$\n" content = f"\n$$\n{span['content']}\n$$\n"
elif span_type in [ContentType.Image, ContentType.Table]: elif span_type in [ContentType.Image, ContentType.Table]:
if mode == 'mm': if mode == 'mm':
content = f"\n![]({join_path(s3_image_save_path, span['image_path'])})\n" content = f"\n![]({join_path(s3_image_save_path, span['image_path'])})\n"
elif mode == 'nlp': elif mode == 'nlp':
pass pass
if content != '': if content != '':
if language == 'en': # 英文语境下 content间需要空格分隔
para_text += content + ' ' para_text += content + ' '
else: # 中文语境下,content间不需要空格分隔
para_text += content
if para_text.strip() == '': if para_text.strip() == '':
continue continue
else: else:
...@@ -137,13 +147,23 @@ def para_to_standard_format(para): ...@@ -137,13 +147,23 @@ def para_to_standard_format(para):
inline_equation_num = 0 inline_equation_num = 0
for line in para: for line in para:
for span in line['spans']: for span in line['spans']:
language = ''
span_type = span.get('type') span_type = span.get('type')
if span_type == ContentType.Text: if span_type == ContentType.Text:
content = ocr_escape_special_markdown_char(split_long_words(span['content'])) content = span['content']
language = detect_lang(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation: elif span_type == ContentType.InlineEquation:
content = f"${ocr_escape_special_markdown_char(span['content'])}$" content = f"${span['content']}$"
inline_equation_num += 1 inline_equation_num += 1
if language == 'en': # 英文语境下 content间需要空格分隔
para_text += content + ' ' para_text += content + ' '
else: # 中文语境下,content间不需要空格分隔
para_text += content
para_content = { para_content = {
'type': 'text', 'type': 'text',
'text': para_text, 'text': para_text,
...@@ -186,14 +206,14 @@ def line_to_standard_format(line): ...@@ -186,14 +206,14 @@ def line_to_standard_format(line):
return content return content
else: else:
if span['type'] == ContentType.InterlineEquation: if span['type'] == ContentType.InterlineEquation:
interline_equation = ocr_escape_special_markdown_char(span['content']) # 转义特殊符号 interline_equation = span['content']
content = { content = {
'type': 'equation', 'type': 'equation',
'latex': f"$$\n{interline_equation}\n$$" 'latex': f"$$\n{interline_equation}\n$$"
} }
return content return content
elif span['type'] == ContentType.InlineEquation: elif span['type'] == ContentType.InlineEquation:
inline_equation = ocr_escape_special_markdown_char(span['content']) # 转义特殊符号 inline_equation = span['content']
line_text += f"${inline_equation}$" line_text += f"${inline_equation}$"
inline_equation_num += 1 inline_equation_num += 1
elif span['type'] == ContentType.Text: elif span['type'] == ContentType.Text:
......
""" """
根据bucket的名字返回对应的s3 AK, SK,endpoint三元组 根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
""" """
import json
import os
from loguru import logger
def get_s3_config(bucket_name: str): def get_s3_config(bucket_name: str):
""" """
~/magic-pdf.json 读出来 ~/magic-pdf.json 读出来
""" """
ak , sk, endpoint = "", "", ""
# TODO 请实现这个函数 home_dir = os.path.expanduser("~")
return ak, sk, endpoint
config_file = os.path.join(home_dir, "magic-pdf.json")
if not os.path.exists(config_file):
raise Exception("magic-pdf.json not found")
with open(config_file, "r") as f:
config = json.load(f)
bucket_info = config.get("bucket_info")
if bucket_name not in bucket_info:
raise Exception("bucket_name not found in magic-pdf.json")
access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
if access_key is None or secret_key is None or storage_endpoint is None:
raise Exception("ak, sk or endpoint not found in magic-pdf.json")
# logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
return access_key, secret_key, storage_endpoint
if __name__ == '__main__':
ak, sk, endpoint = get_s3_config("llm-raw")
...@@ -70,7 +70,7 @@ paraMergeException_msg = ParaMergeException().message ...@@ -70,7 +70,7 @@ paraMergeException_msg = ParaMergeException().message
def parse_pdf_by_model( def parse_pdf_by_txt(
pdf_bytes, pdf_bytes,
pdf_model_output, pdf_model_output,
save_path, save_path,
......
...@@ -13,7 +13,7 @@ from magic_pdf.libs.commons import ( ...@@ -13,7 +13,7 @@ from magic_pdf.libs.commons import (
from magic_pdf.libs.drop_reason import DropReason from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.json_compressor import JsonCompressor from magic_pdf.libs.json_compressor import JsonCompressor
from magic_pdf.dict2md.mkcontent import mk_universal_format from magic_pdf.dict2md.mkcontent import mk_universal_format
from magic_pdf.pdf_parse_by_model import parse_pdf_by_model from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
from magic_pdf.filter.pdf_classify_by_type import classify from magic_pdf.filter.pdf_classify_by_type import classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
from loguru import logger from loguru import logger
...@@ -130,6 +130,7 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict: ...@@ -130,6 +130,7 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
classify_time = int(time.time() - start_time) # 计算执行时间 classify_time = int(time.time() - start_time) # 计算执行时间
if is_text_pdf: if is_text_pdf:
pdf_meta["is_text_pdf"] = is_text_pdf pdf_meta["is_text_pdf"] = is_text_pdf
jso["_pdf_type"] = "TXT"
jso["pdf_meta"] = pdf_meta jso["pdf_meta"] = pdf_meta
jso["classify_time"] = classify_time jso["classify_time"] = classify_time
# print(json.dumps(pdf_meta, ensure_ascii=False)) # print(json.dumps(pdf_meta, ensure_ascii=False))
...@@ -144,10 +145,11 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict: ...@@ -144,10 +145,11 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
else: else:
# 先不drop # 先不drop
pdf_meta["is_text_pdf"] = is_text_pdf pdf_meta["is_text_pdf"] = is_text_pdf
jso["_pdf_type"] = "OCR"
jso["pdf_meta"] = pdf_meta jso["pdf_meta"] = pdf_meta
jso["classify_time"] = classify_time jso["classify_time"] = classify_time
jso["need_drop"] = True # jso["need_drop"] = True
jso["drop_reason"] = DropReason.NOT_IS_TEXT_PDF # jso["drop_reason"] = DropReason.NOT_IS_TEXT_PDF
extra_info = {"classify_rules": []} extra_info = {"classify_rules": []}
for condition, result in results.items(): for condition, result in results.items():
if not result: if not result:
...@@ -310,7 +312,7 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict: ...@@ -310,7 +312,7 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
f"book_name is:{book_name},start_time is:{formatted_time(start_time)}", f"book_name is:{book_name},start_time is:{formatted_time(start_time)}",
file=sys.stderr, file=sys.stderr,
) )
pdf_info_dict = parse_pdf_by_model( pdf_info_dict = parse_pdf_by_txt(
pdf_bytes, pdf_bytes,
model_output_json_list, model_output_json_list,
save_path, save_path,
......
from loguru import logger
import json
import os
from magic_pdf.config import s3_buckets, s3_clusters, s3_users
def get_bucket_configs_dict(buckets, clusters, users):
bucket_configs = {}
for s3_bucket in buckets.items():
bucket_name = s3_bucket[0]
bucket_config = s3_bucket[1]
cluster, user = bucket_config
cluster_config = clusters[cluster]
endpoint_key = "outside"
endpoints = cluster_config[endpoint_key]
endpoint = endpoints[0]
user_config = users[user]
# logger.info(bucket_name)
# logger.info(endpoint)
# logger.info(user_config)
bucket_config = [user_config["ak"], user_config["sk"], endpoint]
bucket_configs[bucket_name] = bucket_config
return bucket_configs
def write_json_to_home(my_dict):
# Convert dictionary to JSON
json_data = json.dumps(my_dict, indent=4, ensure_ascii=False)
home_dir = os.path.expanduser("~")
# Define the output file path
output_file = os.path.join(home_dir, "magic-pdf.json")
# Write JSON data to the output file
with open(output_file, "w") as f:
f.write(json_data)
# Print a success message
print(f"Dictionary converted to JSON and saved to {output_file}")
if __name__ == '__main__':
bucket_configs_dict = get_bucket_configs_dict(s3_buckets, s3_clusters, s3_users)
logger.info(bucket_configs_dict)
config_dict = {
"bucket_info": bucket_configs_dict,
"temp-output-dir": "/tmp"
}
write_json_to_home(config_dict)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment