Unverified Commit f07c2673 authored by linfeng's avatar linfeng Committed by GitHub

feat: mineru_web (#555)

parent c5474c93
......@@ -3,3 +3,5 @@
## 项目列表
- [llama_index_rag](./llama_index_rag/README.md): 基于 llama_index 构建轻量级 RAG 系统
- [web_api](./web_api/README.md): PDF解析的restful api服务
## 安装
MinerU
```bash
# mineru已安装则跳过此步骤
git clone https://github.com/opendatalab/MinerU.git
cd MinerU
conda create -n MinerU python=3.10
conda activate MinerU
pip install .[full] --extra-index-url https://wheels.myhloli.com
```
第三方软件
```bash
cd projects/web_api
pip install poetry
portey install
```
接口文档
```
在浏览器打开 mineru-web接口文档.html
```
This diff is collapsed.
This diff is collapsed.
[tool.poetry]
name = "web-api"
version = "0.1.0"
description = ""
authors = ["houlinfeng <m15237195947@163.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.10"
flask = "^3.0.3"
flask-restful = "^0.3.10"
flask-cors = "^5.0.0"
flask-sqlalchemy = "^3.1.1"
flask-migrate = "^4.0.7"
flask-jwt-extended = "^4.6.0"
flask-marshmallow = "^1.2.1"
pyyaml = "^6.0.2"
loguru = "^0.7.2"
marshmallow-sqlalchemy = "^1.1.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
__all__ = ["common", "api"]
\ No newline at end of file
import os
from .extentions import app, db, migrate, jwt, ma
from common.web_hook import before_request
from common.logger import setup_log
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
print("root_dir", root_dir)
def _register_db(flask_app):
from common import import_models
db.init_app(flask_app)
with app.app_context():
db.create_all()
def create_app(config):
"""
Create and configure an instance of the Flask application
:param config:
:return:
"""
app.static_folder = os.path.join(root_dir, "static")
if config is None:
config = {}
app.config.update(config)
setup_log(config)
_register_db(app)
migrate.init_app(app=app, db=db)
jwt.init_app(app=app)
ma.init_app(app=app)
from .analysis import analysis_blue
app.register_blueprint(analysis_blue)
app.before_request(before_request)
return app
from flask import Blueprint
from ..extentions import Api
from .upload_view import UploadPdfView
from .analysis_view import AnalysisTaskView, AnalysisTaskProgressView
from .img_md_view import ImgView, MdView
from .task_view import TaskView, HistoricalTasksView, DeleteTaskView
analysis_blue = Blueprint('analysis', __name__)
api_v2 = Api(analysis_blue, prefix='/api/v2')
api_v2.add_resource(UploadPdfView, '/analysis/upload_pdf')
api_v2.add_resource(AnalysisTaskView, '/extract/task/submit')
api_v2.add_resource(AnalysisTaskProgressView, '/extract/task/progress')
api_v2.add_resource(ImgView, '/analysis/pdf_img')
api_v2.add_resource(MdView, '/analysis/pdf_md')
api_v2.add_resource(TaskView, '/extract/taskQueue')
api_v2.add_resource(HistoricalTasksView, '/extract/list')
api_v2.add_resource(DeleteTaskView, '/extract/task')
\ No newline at end of file
This diff is collapsed.
import os
task_state_map = {
0: "running",
1: "finished",
2: "pending",
}
def find_file(file_key, file_dir):
"""
查询文件
:param file_key: 文件哈希
:param file_dir: 文件目录
:return:
"""
pdf_path = ""
for root, subDirs, files in os.walk(file_dir):
for fileName in files:
if fileName.startswith(file_key):
pdf_path = os.path.join(root, fileName)
break
if pdf_path:
break
return pdf_path
import os
import pkgutil
import numpy as np
import yaml
import argparse
import cv2
from pathlib import Path
from ultralytics import YOLO
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from torchvision import transforms
from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_low_confidence_spans, remove_overlaps_min_spans
from PIL import Image
from common.ext import singleton_func
from common.custom_response import generate_response
def mfd_model_init(weight):
mfd_model = YOLO(weight)
return mfd_model
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args)
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
cfg.config.model.model_config.model_name = weight_dir
cfg.config.model.tokenizer_config.path = weight_dir
task = tasks.setup_task(cfg)
model = task.build_model(cfg)
model = model.to(_device_)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
return model, vis_processor
@singleton_func
class CustomPEKModel:
def __init__(self):
# PDF-Extract-Kit/models
models_dir = get_local_models_dir()
self.device = get_device()
loader = pkgutil.get_loader("magic_pdf")
root_dir = Path(loader.path).parent
# model_config目录
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
# 构建 model_configs.yaml 文件的完整路径
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, "r", encoding='utf-8') as f:
configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化公式检测模型
self.mfd_model = mfd_model_init(str(os.path.join(models_dir, configs["weights"]["mfd"])))
# 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, configs["weights"]["mfr"]))
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
def get_all_spans(layout_dets) -> list:
def remove_duplicate_spans(spans):
new_spans = []
for span in spans:
if not any(span == existing_span for existing_span in new_spans):
new_spans.append(span)
return new_spans
all_spans = []
# allow_category_id_list = [3, 5, 13, 14, 15]
"""当成span拼接的"""
# 3: 'image', # 图片
# 5: 'table', # 表格
# 13: 'inline_equation', # 行内公式
# 14: 'interline_equation', # 行间公式
# 15: 'text', # ocr识别文本
for layout_det in layout_dets:
if layout_det.get("bbox") is not None:
# 兼容直接输出bbox的模型数据,如paddle
x0, y0, x1, y1 = layout_det["bbox"]
else:
# 兼容直接输出poly的模型数据,如xxx
x0, y0, _, _, x1, y1, _, _ = layout_det["poly"]
bbox = [x0, y0, x1, y1]
layout_det["bbox"] = bbox
all_spans.append(layout_det)
return remove_duplicate_spans(all_spans)
def formula_predict(mfd_model, image):
"""
公式检测
:param mfd_model:
:param image:
:return:
"""
latex_filling_list = []
# 公式检测
mfd_res = mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
}
latex_filling_list.append(new_item)
return latex_filling_list
def formula_detection(file_path, upload_dir):
"""
公式检测
:param file_path: 文件路径
:param upload_dir: 上传文件夹
:return:
"""
try:
image_open = Image.open(file_path)
except IOError:
return generate_response(code=400, msg="params is not valid", msgZh="参数类型不是图片,无效参数")
filename = Path(file_path).name
# 获取图片宽高
width, height = image_open.size
# 转换为RGB,忽略透明度通道
rgb_image = image_open.convert('RGB')
# 保存转换后的图片
rgb_image.save(file_path)
# 初始化模型
cpm = CustomPEKModel()
# 初始化公式检测模型
mfd_model = cpm.mfd_model
image_conv = Image.open(file_path)
image_array = np.array(image_conv)
pdf_width = 1416
pdf_height = 1888
# 重置图片大小
scale = min(pdf_width // 2 / width, pdf_height // 2 / height) # 缩放比例
nw = int(width * scale)
nh = int(height * scale)
image_resize = cv2.resize(image_array, (nw, nh), interpolation=cv2.INTER_LINEAR)
resize_image_path = f"{upload_dir}/resize_{filename}"
cv2.imwrite(resize_image_path, image_resize)
# 将重置的图片贴到pdf白纸中
x = (pdf_width - nw) // 2
y = (pdf_height - nh) // 2
new_img = Image.new('RGB', (pdf_width, pdf_height), 'white')
image_scale = Image.open(resize_image_path)
new_img.paste(image_scale, (x, y))
# 公式检测
latex_filling_list = formula_predict(mfd_model, new_img)
os.remove(resize_image_path)
# 将缩放图公式检测的坐标还原为原图公式检测的坐标
for item in latex_filling_list:
item_poly = item["poly"]
item["poly"] = [
(item_poly[0] - x) / scale,
(item_poly[1] - y) / scale,
(item_poly[2] - x) / scale,
(item_poly[3] - y) / scale,
(item_poly[4] - x) / scale,
(item_poly[5] - y) / scale,
(item_poly[6] - x) / scale,
(item_poly[7] - y) / scale,
]
if not latex_filling_list:
return generate_response(code=1001, msg="detection fail", msgZh="公式检测失败,图片过小,无法检测")
spans = get_all_spans(latex_filling_list)
'''删除重叠spans中置信度较低的那些'''
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
'''删除重叠spans中较小的那些'''
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
return generate_response(data={
'layout': spans,
})
def formula_recognition(file_path, upload_dir):
"""
公式识别
:param file_path: 文件路径
:param upload_dir: 上传文件夹
:return:
"""
try:
image_open = Image.open(file_path)
except IOError:
return generate_response(code=400, msg="params is not valid", msgZh="参数类型不是图片,无效参数")
filename = Path(file_path).name
# 获取图片宽高
width, height = image_open.size
# 转换为RGB,忽略透明度通道
rgb_image = image_open.convert('RGB')
# 保存转换后的图片
rgb_image.save(file_path)
image_conv = Image.open(file_path)
image_array = np.array(image_conv)
pdf_width = 1416
pdf_height = 1888
# 重置图片大小
scale = min(pdf_width // 2 / width, pdf_height // 2 / height) # 缩放比例
nw = int(width * scale)
nh = int(height * scale)
image_resize = cv2.resize(image_array, (nw, nh), interpolation=cv2.INTER_LINEAR)
resize_image_path = f"{upload_dir}/resize_{filename}"
cv2.imwrite(resize_image_path, image_resize)
# 将重置的图片贴到pdf白纸中
x = (pdf_width - nw) // 2
y = (pdf_height - nh) // 2
new_img = Image.new('RGB', (pdf_width, pdf_height), 'white')
image_scale = Image.open(resize_image_path)
new_img.paste(image_scale, (x, y))
new_img_array = np.array(new_img)
# 初始化模型
cpm = CustomPEKModel()
# device
device = cpm.device
# 初始化公式检测模型
mfd_model = cpm.mfd_model
# 初始化公式解析模型
mfr_model = cpm.mfr_model
mfr_transform = cpm.mfr_transform
# 公式识别
latex_filling_list, mfr_res = formula_recognition(mfd_model, new_img_array, mfr_transform, device, mfr_model,
image_open)
os.remove(resize_image_path)
# 将缩放图公式检测的坐标还原为原图公式检测的坐标
for item in latex_filling_list:
item_poly = item["poly"]
item["poly"] = [
(item_poly[0] - x) / scale,
(item_poly[1] - y) / scale,
(item_poly[2] - x) / scale,
(item_poly[3] - y) / scale,
(item_poly[4] - x) / scale,
(item_poly[5] - y) / scale,
(item_poly[6] - x) / scale,
(item_poly[7] - y) / scale,
]
spans = get_all_spans(latex_filling_list)
'''删除重叠spans中置信度较低的那些'''
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
'''删除重叠spans中较小的那些'''
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
if not latex_filling_list:
width, height = image_open.size
latex_filling_list.append({
'category_id': 14,
'poly': [0, 0, width, 0, width, height, 0, height],
'score': 1,
'latex': mfr_res[0] if mfr_res else "",
})
return generate_response(data={
'layout': spans if spans else latex_filling_list,
"mfr_res": mfr_res
})
from pathlib import Path
from flask import request, current_app, send_from_directory
from flask_restful import Resource
class ImgView(Resource):
def get(self):
"""
获取pdf解析的图片
:return:
"""
params = request.args
pdf = params.get('pdf')
filename = params.get('filename')
as_attachment = params.get('as_attachment')
if str(as_attachment).lower() == "true":
as_attachment = True
else:
as_attachment = False
file_stem = Path(pdf).stem
pdf_analysis_folder = current_app.config['PDF_ANALYSIS_FOLDER']
pdf_dir = f"{current_app.static_folder}/{pdf_analysis_folder}/{file_stem}"
image_dir = f"{pdf_dir}/images"
response = send_from_directory(image_dir, filename, as_attachment=as_attachment)
return response
class MdView(Resource):
def get(self):
"""
获取pdf解析的markdown
:return:
"""
params = request.args
pdf = params.get('pdf')
filename = params.get('filename')
as_attachment = params.get('as_attachment')
if str(as_attachment).lower() == "true":
as_attachment = True
else:
as_attachment = False
file_stem = Path(pdf).stem
pdf_analysis_folder = current_app.config['PDF_ANALYSIS_FOLDER']
pdf_dir = f"{current_app.static_folder}/{pdf_analysis_folder}/{file_stem}"
response = send_from_directory(pdf_dir, filename, as_attachment=as_attachment)
return response
from datetime import datetime
from ..extentions import db
class AnalysisTask(db.Model):
__tablename__ = 'analysis_task'
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
file_key = db.Column(db.Text, comment="文件唯一哈希")
file_name = db.Column(db.Text, comment="文件名称")
task_type = db.Column(db.String(128), comment="任务类型")
is_ocr = db.Column(db.Boolean, default=False, comment="是否ocr")
status = db.Column(db.Integer, default=0, comment="状态") # 0 running 1 finished 2 pending
analysis_pdf_id = db.Column(db.Integer, comment="analysis_pdf的id")
create_date = db.Column(db.DateTime(), nullable=False, default=datetime.now)
update_date = db.Column(db.DateTime(), nullable=False, default=datetime.now, onupdate=datetime.now)
class AnalysisPdf(db.Model):
__tablename__ = 'analysis_pdf'
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
file_name = db.Column(db.Text, comment="文件名称")
file_url = db.Column(db.Text, comment="文件原路径")
file_path = db.Column(db.Text, comment="文件路径")
status = db.Column(db.Integer, default=3, comment="状态") # 0 转换中 1 已完成 2 转换失败 3 init
bbox_info = db.Column(db.Text, comment="坐标数据")
md_link_list = db.Column(db.Text, comment="markdown分页链接")
full_md_link = db.Column(db.Text, comment="markdown全文链接")
create_date = db.Column(db.DateTime(), nullable=False, default=datetime.now)
update_date = db.Column(db.DateTime(), nullable=False, default=datetime.now, onupdate=datetime.now)
\ No newline at end of file
import json
import re
import traceback
from pathlib import Path
from flask import current_app, url_for
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
from magic_pdf.pipe.UNIPipe import UNIPipe
import magic_pdf.model as model_config
from magic_pdf.libs.json_compressor import JsonCompressor
from magic_pdf.dict2md.ocr_mkcontent import ocr_mk_mm_markdown_with_para_and_pagination
from .ext import find_file
from ..extentions import app, db
from .models import AnalysisPdf, AnalysisTask
from common.error_types import ApiException
from loguru import logger
model_config.__use_inside_model__ = True
def analysis_pdf(image_dir, pdf_bytes, is_ocr=False):
try:
model_json = [] # model_json传空list使用内置模型解析
logger.info(f"is_ocr: {is_ocr}")
if not is_ocr:
jso_useful_key = {"_pdf_type": "", "model_list": model_json}
image_writer = DiskReaderWriter(image_dir)
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True)
pipe.pipe_classify()
else:
jso_useful_key = {"_pdf_type": "ocr", "model_list": model_json}
image_writer = DiskReaderWriter(image_dir)
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True)
"""如果没有传入有效的模型数据,则使用内置model解析"""
if len(model_json) == 0:
if model_config.__use_inside_model__:
pipe.pipe_analyze()
else:
logger.error("need model list input")
exit(1)
pipe.pipe_parse()
pdf_mid_data = JsonCompressor.decompress_json(pipe.get_compress_pdf_mid_data())
pdf_info_list = pdf_mid_data["pdf_info"]
md_content = json.dumps(ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_list, image_dir),
ensure_ascii=False)
bbox_info = get_bbox_info(pdf_info_list)
return md_content, bbox_info
except Exception as e:
logger.error(traceback.format_exc())
def get_bbox_info(data):
bbox_info = []
for page in data:
preproc_blocks = page.get("preproc_blocks", [])
discarded_blocks = page.get("discarded_blocks", [])
bbox_info.append({
"preproc_blocks": preproc_blocks,
"page_idx": page.get("page_idx"),
"page_size": page.get("page_size"),
"discarded_blocks": discarded_blocks,
})
return bbox_info
def analysis_pdf_task(pdf_dir, image_dir, pdf_path, is_ocr, analysis_pdf_id):
"""
解析pdf
:param pdf_dir: pdf解析目录
:param image_dir: 图片目录
:param pdf_path: pdf路径
:param is_ocr: 是否启用ocr
:param analysis_pdf_id: pdf解析表id
:return:
"""
try:
logger.info(f"start task: {pdf_path}")
logger.info(f"image_dir: {image_dir}")
if not Path(image_dir).exists():
Path(image_dir).mkdir(parents=True, exist_ok=True)
with open(pdf_path, 'rb') as file:
pdf_bytes = file.read()
md_content, bbox_info = analysis_pdf(image_dir, pdf_bytes, is_ocr)
img_list = Path(image_dir).glob('*') if Path(image_dir).exists() else []
pdf_name = Path(pdf_path).name
with app.app_context():
for img in img_list:
img_name = Path(img).name
regex = re.compile(fr'.*\((.*?{img_name})')
regex_result = regex.search(md_content)
img_url = url_for('analysis.imgview', filename=img_name, as_attachment=False)
md_content = md_content.replace(regex_result.group(1), f"{img_url}&pdf={pdf_name}")
full_md_content = ""
for item in json.loads(md_content):
full_md_content += item["md_content"] + "\n"
full_md_name = "full.md"
with open(f"{pdf_dir}/{full_md_name}", "w") as file:
file.write(full_md_content)
with app.app_context():
full_md_link = url_for('analysis.mdview', filename=full_md_name, as_attachment=False)
full_md_link = f"{full_md_link}&pdf={pdf_name}"
md_link_list = []
with app.app_context():
for n, md in enumerate(json.loads(md_content)):
md_content = md["md_content"]
md_name = f"{md.get('page_no', n)}.md"
with open(f"{pdf_dir}/{md_name}", "w") as file:
file.write(md_content)
md_url = url_for('analysis.mdview', filename=md_name, as_attachment=False)
md_link_list.append(f"{md_url}&pdf={pdf_name}")
with app.app_context():
with db.auto_commit():
analysis_pdf_object = AnalysisPdf.query.filter_by(id=analysis_pdf_id).first()
analysis_pdf_object.status = 1
analysis_pdf_object.bbox_info = json.dumps(bbox_info, ensure_ascii=False)
analysis_pdf_object.md_link_list = json.dumps(md_link_list, ensure_ascii=False)
analysis_pdf_object.full_md_link = full_md_link
db.session.add(analysis_pdf_object)
with db.auto_commit():
analysis_task_object = AnalysisTask.query.filter_by(analysis_pdf_id=analysis_pdf_id).first()
analysis_task_object.status = 1
db.session.add(analysis_task_object)
logger.info(f"finished!")
except Exception as e:
logger.error(traceback.format_exc())
with app.app_context():
with db.auto_commit():
analysis_pdf_object = AnalysisPdf.query.filter_by(id=analysis_pdf_id).first()
analysis_pdf_object.status = 2
db.session.add(analysis_pdf_object)
with db.auto_commit():
analysis_task_object = AnalysisTask.query.filter_by(analysis_pdf_id=analysis_pdf_id).first()
analysis_task_object.status = 1
db.session.add(analysis_task_object)
raise ApiException(code=500, msg="PDF parsing failed", msgZH="pdf解析失败")
finally:
# 执行pending
with app.app_context():
analysis_task_object = AnalysisTask.query.filter_by(status=2).order_by(
AnalysisTask.update_date.asc()).first()
if analysis_task_object:
pdf_upload_folder = current_app.config['PDF_UPLOAD_FOLDER']
upload_dir = f"{current_app.static_folder}/{pdf_upload_folder}"
file_path = find_file(analysis_task_object.file_key, upload_dir)
file_stem = Path(file_path).stem
pdf_analysis_folder = current_app.config['PDF_ANALYSIS_FOLDER']
pdf_dir = f"{current_app.static_folder}/{pdf_analysis_folder}/{file_stem}"
image_dir = f"{pdf_dir}/images"
with db.auto_commit():
analysis_pdf_object = AnalysisPdf.query.filter_by(id=analysis_task_object.analysis_pdf_id).first()
analysis_pdf_object.status = 0
db.session.add(analysis_pdf_object)
with db.auto_commit():
analysis_task_object.status = 0
db.session.add(analysis_task_object)
analysis_pdf_task(pdf_dir, image_dir, file_path, analysis_task_object.is_ocr, analysis_task_object.analysis_pdf_id)
else:
logger.info(f"all task finished!")
from marshmallow import Schema, fields, validates_schema, validates
from common.error_types import ApiException
from .models import AnalysisTask
class BooleanField(fields.Boolean):
def _deserialize(self, value, attr, data, **kwargs):
# 进行自定义验证
if not isinstance(value, bool):
raise ApiException(code=400, msg="isOcr not a valid boolean", msgZH="isOcr不是有效的布尔值")
return value
class AnalysisViewSchema(Schema):
fileKey = fields.Str(required=True)
fileName = fields.Str()
taskType = fields.Str(required=True)
isOcr = BooleanField()
@validates_schema(pass_many=True)
def validate_passwords(self, data, **kwargs):
task_type = data['taskType']
file_key = data['fileKey']
if not file_key:
raise ApiException(code=400, msg="fileKey cannot be empty", msgZH="fileKey不能为空")
if not task_type:
raise ApiException(code=400, msg="taskType cannot be empty", msgZH="taskType不能为空")
import json
from flask import url_for, request
from flask_restful import Resource
from sqlalchemy import func
from ..extentions import db
from .models import AnalysisTask, AnalysisPdf
from .ext import task_state_map
from common.custom_response import generate_response
class TaskView(Resource):
def get(self):
"""
查询正在进行的任务
:return:
"""
analysis_task_running = AnalysisTask.query.filter(AnalysisTask.status == 0).first()
analysis_task_pending = AnalysisTask.query.filter(AnalysisTask.status == 2).order_by(
AnalysisTask.create_date.asc()).all()
pending_total = db.session.query(func.count(AnalysisTask.id)).filter(AnalysisTask.status == 2).scalar()
task_nums = pending_total + 1
data = [
{
"queues": task_nums, # 正在排队的任务总数
"rank": 1,
"id": analysis_task_running.id,
"url": url_for('analysis.uploadpdfview', filename=analysis_task_running.file_name, as_attachment=False),
"fileName": analysis_task_running.file_name,
"type": analysis_task_running.task_type,
"state": task_state_map.get(analysis_task_running.status),
}
]
for n, task in enumerate(analysis_task_pending):
data.append({
"queues": task_nums, # 正在排队的任务总数
"rank": n + 2,
"id": task.id,
"url": url_for('analysis.uploadpdfview', filename=task.file_name, as_attachment=False),
"fileName": task.file_name,
"type": task.task_type,
"state": task_state_map.get(task.status),
})
data.reverse()
return generate_response(data=data, total=task_nums)
class HistoricalTasksView(Resource):
def get(self):
"""
获取任务历史记录
:return:
"""
params = request.args
page_no = params.get('pageNo', 1)
page_size = params.get('pageSize', 10)
total = db.session.query(func.count(AnalysisTask.id)).scalar()
analysis_task = AnalysisTask.query.order_by(AnalysisTask.create_date.desc()).paginate(page=int(page_no),
per_page=int(page_size),
error_out=False)
data = []
for n, task in enumerate(analysis_task):
data.append({
"fileName": task.file_name,
"id": task.id,
"type": task.task_type,
"state": task_state_map.get(task.status),
})
data = {
"list": data,
"total": total,
"pageNo": page_no,
"pageSize": page_size,
}
return generate_response(data=data)
class DeleteTaskView(Resource):
def delete(self):
"""
删除任务历史记录
:return:
"""
params = json.loads(request.data)
id = params.get('id')
analysis_task = AnalysisTask.query.filter(AnalysisTask.id == id, AnalysisTask.status != 0).first()
if analysis_task:
analysis_pdf = AnalysisPdf.query.filter(AnalysisPdf.id == AnalysisTask.analysis_pdf_id).first()
with db.auto_commit():
db.session.delete(analysis_pdf)
db.session.delete(analysis_task)
else:
return generate_response(code=400, msg="The ID is incorrect", msgZH="id不正确")
return generate_response(data={"id": id})
import json
import traceback
import requests
from flask import request, current_app, url_for, send_from_directory
from flask_restful import Resource
from werkzeug.utils import secure_filename
from pathlib import Path
from common.ext import is_pdf, calculate_file_hash, url_is_pdf
from io import BytesIO
from werkzeug.datastructures import FileStorage
from common.custom_response import generate_response
from loguru import logger
class UploadPdfView(Resource):
def get(self):
"""
获取pdf
:return:
"""
params = request.args
filename = params.get('filename')
as_attachment = params.get('as_attachment')
if str(as_attachment).lower() == "true":
as_attachment = True
else:
as_attachment = False
pdf_upload_folder = current_app.config['PDF_UPLOAD_FOLDER']
response = send_from_directory(f"{current_app.static_folder}/{pdf_upload_folder}", filename,
as_attachment=as_attachment)
return response
def post(self):
"""
上传pdf
:return:
"""
file_list = request.files.getlist("file")
if file_list:
file = file_list[0]
filename = secure_filename(file.filename)
if not file or file and not is_pdf(filename, file):
return generate_response(code=400, msg="Invalid PDF file", msgZH="PDF文件参数无效")
else:
params = json.loads(request.data)
pdf_url = params.get('pdfUrl')
try:
response = requests.get(pdf_url, stream=True)
except ConnectionError as e:
logger.error(traceback.format_exc())
return generate_response(code=400, msg="params is not valid", msgZh="参数错误,pdf链接无法访问")
if response.status_code != 200:
return generate_response(code=400, msg="params is not valid", msgZh="参数错误,pdf链接响应状态异常")
# 创建一个模拟的 FileStorage 对象
file_content = BytesIO(response.content)
filename = Path(pdf_url).name if ".pdf" in pdf_url else f"{Path(pdf_url).name}.pdf"
file = FileStorage(
stream=file_content,
filename=filename,
content_type=response.headers.get('Content-Type', 'application/octet-stream')
)
if not file or file and not url_is_pdf(file):
return generate_response(code=400, msg="Invalid PDF file", msgZH="PDF文件参数无效")
pdf_upload_folder = current_app.config['PDF_UPLOAD_FOLDER']
upload_dir = f"{current_app.static_folder}/{pdf_upload_folder}"
if not Path(upload_dir).exists():
Path(upload_dir).mkdir(parents=True, exist_ok=True)
file_key = calculate_file_hash(file)
# new_filename = f"{int(time.time())}_{filename}"
new_filename = f"{file_key}_{filename}"
file_path = f"{upload_dir}/{new_filename}"
# file.save(file_path)
chunk_size = 8192
with open(file_path, 'wb') as f:
while True:
chunk = file.stream.read(chunk_size)
if not chunk:
break
f.write(chunk)
# 生成文件的URL路径
file_url = url_for('analysis.uploadpdfview', filename=new_filename, as_attachment=False)
data = {
"url": file_url,
"file_key": file_key
}
return generate_response(data=data)
from flask import Flask, jsonify
from flask_restful import Api as _Api
from flask_cors import CORS
from flask_sqlalchemy import SQLAlchemy as _SQLAlchemy
from flask_migrate import Migrate
from contextlib import contextmanager
from flask_jwt_extended import JWTManager
from flask_marshmallow import Marshmallow
from common.error_types import ApiException
from werkzeug.exceptions import HTTPException
from loguru import logger
class Api(_Api):
def handle_error(self, e):
if isinstance(e, ApiException):
code = e.code
msg = e.msg
msgZH = e.msgZH
error_code = e.error_code
elif isinstance(e, HTTPException):
code = e.code
msg = e.description
msgZH = "服务异常,详细信息请查看日志"
error_code = e.code
else:
code = 500
msg = str(e)
error_code = 500
msgZH = "服务异常,详细信息请查看日志"
# 使用 loguru 记录异常信息
logger.opt(exception=e).error(f"An error occurred: {msg}")
return jsonify({
"error": "Internal Server Error" if code == 500 else e.name,
"msg": msg,
"msgZH": msgZH,
"code": code,
"error_code": error_code
}), code
class SQLAlchemy(_SQLAlchemy):
@contextmanager
def auto_commit(self):
try:
yield
db.session.commit()
db.session.flush()
except Exception as e:
db.session.rollback()
raise e
app = Flask(__name__)
CORS(app, supports_credentials=True)
db = SQLAlchemy()
migrate = Migrate()
jwt = JWTManager()
ma = Marshmallow()
import socket
from api import create_app
from pathlib import Path
import yaml
def get_local_ip():
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.connect(('8.8.8.8', 80)) # Google DNS 服务器
ip_address = sock.getsockname()[0]
sock.close()
return ip_address
current_file_path = Path(__file__).resolve()
base_dir = current_file_path.parent
config_path = base_dir / "config/config.yaml"
class ConfigMap(dict):
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__
with open(str(config_path), mode='r', encoding='utf-8') as fd:
data = yaml.load(fd, Loader=yaml.FullLoader)
_config = data.get(data.get("CurrentConfig", "DevelopmentConfig"))
config = ConfigMap()
for k, v in _config.items():
config[k] = v
config['base_dir'] = base_dir
database = _config.get("database")
if database:
if database.get("type") == "sqlite":
database_uri = f'sqlite:///{base_dir}/{database.get("path")}'
elif database.get("type") == "mysql":
database_uri = f'mysql+pymysql://{database.get("user")}:{database.get("password")}@{database.get("host")}:{database.get("port")}/{database.get("database")}?'
else:
database_uri = ''
config['SQLALCHEMY_DATABASE_URI'] = database_uri
ip_address = get_local_ip()
port = config.get("PORT", 5559)
# 配置 SERVER_NAME
config['SERVER_NAME'] = f'{ip_address}:5559'
# 配置 APPLICATION_ROOT
config['APPLICATION_ROOT'] = '/'
# 配置 PREFERRED_URL_SCHEME
config['PREFERRED_URL_SCHEME'] = 'http'
app = create_app(config)
if __name__ == '__main__':
app.run(host="0.0.0.0", port=port, debug=config.get("DEBUG", False))
from flask import jsonify
class ResponseCode:
SUCCESS = 200
PARAM_WARING = 400
MESSAGE = "success"
def generate_response(data=None, code=ResponseCode.SUCCESS, msg=ResponseCode.MESSAGE, **kwargs):
"""
自定义响应
:param code:状态码
:param data:返回数据
:param msg:返回消息
:param kwargs:
:return:
"""
msg = msg or 'success' if code == 200 else msg or 'fail'
success = True if code == 200 else False
res = jsonify(dict(code=code, success=success, data=data, msg=msg, **kwargs))
res.status_code = 200
return res
import json
from flask import request
from werkzeug.exceptions import HTTPException
class ApiException(HTTPException):
"""API错误基类"""
code = 500
msg = 'Sorry, we made a mistake Σ(っ °Д °;)っ'
msgZH = ""
error_code = 999
def __init__(self, msg=None, msgZH=None, code=None, error_code=None, headers=None):
if code:
self.code = code
if msg:
self.msg = msg
if msgZH:
self.msgZH = msgZH
if error_code:
self.error_code = error_code
super(ApiException, self).__init__(msg, None)
@staticmethod
def get_error_url():
"""获取出错路由和请求方式"""
method = request.method
full_path = str(request.full_path)
main_path = full_path.split('?')[0]
res = method + ' ' + main_path
return res
def get_body(self, environ=None, scope=None):
"""异常返回信息"""
body = dict(
msg=self.msg,
error_code=self.error_code,
request=self.get_error_url()
)
text = json.dumps(body)
return text
def get_headers(self, environ=None, scope=None):
"""异常返回格式"""
return [("Content-Type", "application/json")]
\ No newline at end of file
import hashlib
import mimetypes
def is_pdf(filename, file):
"""
判断文件是否为PDF格式。
:param filename: 文件名
:param file: 文件对象
:return: 如果文件是PDF格式,则返回True,否则返回False
"""
# 检查文件扩展名 https://arxiv.org/pdf/2405.08702 pdf链接可能存在不带扩展名的情况,先注释
if not filename.endswith('.pdf'):
return False
# 检查MIME类型
mime_type, _ = mimetypes.guess_type(filename)
print(mime_type)
if mime_type != 'application/pdf':
return False
# 可选:读取文件的前几KB内容并检查MIME类型
# 这一步是可选的,用于更严格的检查
# if not mimetypes.guess_type(filename, strict=False)[0] == 'application/pdf':
# return False
# 检查文件内容
file_start = file.read(5)
file.seek(0)
if not file_start.startswith(b'%PDF-'):
return False
return True
def url_is_pdf(file):
"""
判断文件是否为PDF格式。
:param file: 文件对象
:return: 如果文件是PDF格式,则返回True,否则返回False
"""
# 检查文件内容
file_start = file.read(5)
file.seek(0)
if not file_start.startswith(b'%PDF-'):
return False
return True
def calculate_file_hash(file, algorithm='sha256'):
"""
计算给定文件的哈希值。
:param file: 文件对象
:param algorithm: 哈希算法的名字,如:'sha256', 'md5', 'sha1'等
:return: 文件的哈希值
"""
hash_func = getattr(hashlib, algorithm)()
block_size = 65536 # 64KB chunks
# with open(file_path, 'rb') as file:
buffer = file.read(block_size)
while len(buffer) > 0:
hash_func.update(buffer)
buffer = file.read(block_size)
file.seek(0)
return hash_func.hexdigest()
def singleton_func(cls):
instance = {}
def _singleton(*args, **kwargs):
if cls not in instance:
instance[cls] = cls(*args, **kwargs)
return instance[cls]
return _singleton
from api.analysis.models import *
\ No newline at end of file
import os
from loguru import logger
from pathlib import Path
from datetime import datetime
def setup_log(config):
"""
Setup logging
:param config: config file
:return:
"""
log_path = os.path.join(Path(__file__).parent.parent, "log")
if not Path(log_path).exists():
Path(log_path).mkdir(parents=True, exist_ok=True)
log_level = config.get("LOG_LEVEL")
log_name = f'log_{datetime.now().strftime("%Y-%m-%d")}.log'
log_file_path = os.path.join(log_path, log_name)
logger.add(str(log_file_path), rotation='00:00', encoding='utf-8', level=log_level, enqueue=True)
def before_request():
return None
def after_request(response):
response.headers.add('Access-Control-Allow-Origin', '*')
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
return response
# 基本配置
BaseConfig: &base
DEBUG: false
PORT: 5559
LOG_LEVEL: "DEBUG"
SQLALCHEMY_TRACK_MODIFICATIONS: true
SQLALCHEMY_DATABASE_URI: ""
PROPAGATE_EXCEPTIONS: true
SECRET_KEY: "#$%^&**$##*(*^%%$**((&"
JWT_SECRET_KEY: "#$%^&**$##*(*^%%$**((&"
JWT_ACCESS_TOKEN_EXPIRES: 3600
PDF_UPLOAD_FOLDER: "upload_pdf"
PDF_ANALYSIS_FOLDER: "analysis_pdf"
# 开发配置
DevelopmentConfig:
<<: *base
database:
type: sqlite
path: config/mineru_web.db
# 生产配置
ProductionConfig:
<<: *base
# 测试配置
TestingConfig:
<<: *base
# 当前使用配置
CurrentConfig: "DevelopmentConfig"
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment