Unverified Commit b72d4ebd authored by icecraft's avatar icecraft Committed by GitHub

Feat/support rag (#510)

* Create requirements-docker.txt

* feat: update deps to support rag

* feat: add support to rag, add rag_data_reader api for rag integration

* feat: let user retrieve the filename of the processed file

* feat: add projects demo for rag integrations

---------
Co-authored-by: 's avatarXiaomeng Zhao <moe@myhloli.com>
Co-authored-by: 's avataricecraft <xurui1@pjlab.org.cn>
parent 0f91fcf6
import os
from pathlib import Path
from loguru import logger
from magic_pdf.integrations.rag.type import (ElementRelation, LayoutElements,
Node)
from magic_pdf.integrations.rag.utils import inference
class RagPageReader:
def __init__(self, pagedata: LayoutElements):
self.o = [
Node(
category_type=v.category_type,
text=v.text,
image_path=v.image_path,
anno_id=v.anno_id,
latex=v.latex,
html=v.html,
) for v in pagedata.layout_dets
]
self.pagedata = pagedata
def __iter__(self):
return iter(self.o)
def get_rel_map(self) -> list[ElementRelation]:
return self.pagedata.extra.element_relation
class RagDocumentReader:
def __init__(self, ragdata: list[LayoutElements]):
self.o = [RagPageReader(v) for v in ragdata]
def __iter__(self):
return iter(self.o)
class DataReader:
def __init__(self, path_or_directory: str, method: str, output_dir: str):
self.path_or_directory = path_or_directory
self.method = method
self.output_dir = output_dir
self.pdfs = []
if os.path.isdir(path_or_directory):
for doc_path in Path(path_or_directory).glob('*.pdf'):
self.pdfs.append(doc_path)
else:
assert path_or_directory.endswith('.pdf')
self.pdfs.append(Path(path_or_directory))
def get_documents_count(self) -> int:
"""Returns the number of documents in the directory."""
return len(self.pdfs)
def get_document_result(self, idx: int) -> RagDocumentReader | None:
"""
Args:
idx (int): the index of documents under the
directory path_or_directory
Returns:
RagDocumentReader | None: RagDocumentReader is an iterable object,
more details @RagDocumentReader
"""
if idx >= self.get_documents_count() or idx < 0:
logger.error(f'invalid idx: {idx}')
return None
res = inference(str(self.pdfs[idx]), self.output_dir, self.method)
if res is None:
logger.warning(f'failed to inference pdf {self.pdfs[idx]}')
return None
return RagDocumentReader(res)
def get_document_filename(self, idx: int) -> Path:
"""get the filename of the document."""
return self.pdfs[idx]
from enum import Enum
from pydantic import BaseModel, Field
# rag
class CategoryType(Enum): # py310 not support StrEnum
text = 'text'
title = 'title'
interline_equation = 'interline_equation'
image = 'image'
image_body = 'image_body'
image_caption = 'image_caption'
table = 'table'
table_body = 'table_body'
table_caption = 'table_caption'
table_footnote = 'table_footnote'
class ElementRelType(Enum):
sibling = 'sibling'
class PageInfo(BaseModel):
page_no: int = Field(description='the index of page, start from zero',
ge=0)
height: int = Field(description='the height of page', gt=0)
width: int = Field(description='the width of page', ge=0)
image_path: str | None = Field(description='the image of this page',
default=None)
class ContentObject(BaseModel):
category_type: CategoryType = Field(description='类别')
poly: list[float] = Field(
description=('Coordinates, need to convert back to PDF coordinates,'
' order is top-left, top-right, bottom-right, bottom-left'
' x,y coordinates'))
ignore: bool = Field(description='whether ignore this object',
default=False)
text: str | None = Field(description='text content of the object',
default=None)
image_path: str | None = Field(description='path of embedded image',
default=None)
order: int = Field(description='the order of this object within a page',
default=-1)
anno_id: int = Field(description='unique id', default=-1)
latex: str | None = Field(description='latex result', default=None)
html: str | None = Field(description='html result', default=None)
class ElementRelation(BaseModel):
source_anno_id: int = Field(description='unique id of the source object',
default=-1)
target_anno_id: int = Field(description='unique id of the target object',
default=-1)
relation: ElementRelType = Field(
description='the relation between source and target element')
class LayoutElementsExtra(BaseModel):
element_relation: list[ElementRelation] = Field(
description='the relation between source and target element')
class LayoutElements(BaseModel):
layout_dets: list[ContentObject] = Field(
description='layout element details')
page_info: PageInfo = Field(description='page info')
extra: LayoutElementsExtra = Field(description='extra information')
# iter data format
class Node(BaseModel):
category_type: CategoryType = Field(description='类别')
text: str | None = Field(description='text content of the object',
default=None)
image_path: str | None = Field(description='path of embedded image',
default=None)
anno_id: int = Field(description='unique id', default=-1)
latex: str | None = Field(description='latex result', default=None)
html: str | None = Field(description='html result', default=None)
This diff is collapsed.
...@@ -86,6 +86,7 @@ def jsonl(jsonl, method, output_dir): ...@@ -86,6 +86,7 @@ def jsonl(jsonl, method, output_dir):
pdf_data, pdf_data,
jso['doc_layout_result'], jso['doc_layout_result'],
method, method,
False,
f_dump_content_list=True, f_dump_content_list=True,
f_draw_model_bbox=True, f_draw_model_bbox=True,
) )
...@@ -141,6 +142,7 @@ def pdf(pdf, json_data, output_dir, method): ...@@ -141,6 +142,7 @@ def pdf(pdf, json_data, output_dir, method):
pdf_data, pdf_data,
model_json_list, model_json_list,
method, method,
False,
f_dump_content_list=True, f_dump_content_list=True,
f_draw_model_bbox=True, f_draw_model_bbox=True,
) )
......
## 安装
MinerU
```bash
git clone https://github.com/opendatalab/MinerU.git
cd MinerU
conda create -n MinerU python=3.10
conda activate MinerU
pip install .[full] --extra-index-url https://wheels.myhloli.com
```
第三方软件
```bash
# install
pip install llama-index-vector-stores-elasticsearch==0.2.0
pip install llama-index-embeddings-dashscope==0.2.0
pip install llama-index-core==0.10.68
pip install einops==0.7.0
pip install transformers-stream-generator==0.0.5
pip install accelerate==0.33.0
# uninstall
pip uninstall transformer-engine
```
## 环境配置
```
export DASHSCOPE_API_KEY={some_key}
export ES_USER={some_es_user}
export ES_PASSWORD={some_es_password}
export ES_URL=http://{es_url}:9200
```
DASHSCOPE_API_KEY 的开通参考[文档](https://help.aliyun.com/zh/dashscope/opening-service)
## 使用
### 导入数据
```bash
python data_ingestion.py -p some.pdf # load data from pdf
or
python data_ingestion.py -p /opt/data/some_pdf_directory/ # load data from multiples pdf which under the directory of {some_pdf_directory}
```
### 查询
```bash
python query.py --question '{the_question_you_want_to_ask}'
```
## 示例
````bash
# 启动 es 服务
docker compose up -d
or
docker-compose up -d
# 配置环境变量
export ES_USER=elastic
export ES_PASSWORD=llama_index
export ES_URL=http://127.0.0.1:9200
# 导入数据
python data_ingestion.py example/data/declaration_of_the_rights_of_man_1789.pdf
# 查询问题
python query.py -q 'how about the rights of men'
## outputs
请基于```内的内容回答问题。"
```
I. Men are born, and always continue, free and equal in respect of their rights. Civil distinctions, therefore, can be founded only on public utility.
```
我的问题是:how about the rights of men。
question: how about the rights of men
answer: The statement implies that men are born free and equal in terms of their rights. Civil distinctions should only be based on public utility. However, it does not specify what those rights are. It is up to society and individual countries to determine and protect the specific rights of their citizens.
````
## 开发
`MinerU` 提供了 `RAG` 集成接口,用户可以通过指定输入单个 `pdf` 文件或者某个目录。`MinerU` 会自动解析输入文件并返回可以迭代的接口用于获取数据
### API 接口
```python
from magic_pdf.integrations.rag.type import Node
class RagPageReader:
def get_rel_map(self) -> list[ElementRelation]:
# 获取节点的间的关系
pass
...
class RagDocumentReader:
...
class DataReader:
def __init__(self, path_or_directory: str, method: str, output_dir: str):
pass
def get_documents_count(self) -> int:
"""获取 pdf 文档数量"""
pass
def get_document_result(self, idx: int) -> RagDocumentReader | None:
"""获取某个 pdf 的解析内容"""
pass
def get_document_filename(self, idx: int) -> Path:
"""获取某个 pdf 的具体的路径"""
pass
```
类型定义
```python
class Node(BaseModel):
category_type: CategoryType = Field(description='类别') # 类别
text: str | None = Field(description='文本内容',
default=None)
image_path: str | None = Field(description='图或者表格(表可能用图片形式存储)的存储路径',
default=None)
anno_id: int = Field(description='unique id', default=-1)
latex: str | None = Field(description='公式或表格 latex 解析结果', default=None)
html: str | None = Field(description='表格的 html 解析结果', default=None)
```
表格存储形式可能会是 图片、latex、html 三种形式之一。
anno_id 是该 Node 的在全局唯一ID。后续可以用于匹配该 Node 和其他 Node 的关系。节点的关系可以通过方法 `get_rel_map` 获取。用户可以用 `anno_id` 匹配节点之间的关系,并用于构建具备节点的关系的 rag index。
### 节点类型关系矩阵
| | image_body | table_body |
| -------------- | ---------- | ---------- |
| image_caption | sibling | |
| table_caption | | sibling |
| table_footnote | | sibling |
import os
import click
from llama_index.core.schema import TextNode
from llama_index.embeddings.dashscope import (DashScopeEmbedding,
DashScopeTextEmbeddingModels,
DashScopeTextEmbeddingType)
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from magic_pdf.integrations.rag.api import DataReader
es_vec_store = ElasticsearchStore(
index_name='rag_index',
es_url=os.getenv('ES_URL', 'http://127.0.0.1:9200'),
es_user=os.getenv('ES_USER', 'elastic'),
es_password=os.getenv('ES_PASSWORD', 'llama_index'),
)
# Create embeddings
# text_type=`document` to build index
def embed_node(node):
embedder = DashScopeEmbedding(
model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_DOCUMENT,
)
result_embeddings = embedder.get_text_embedding(node.text)
node.embedding = result_embeddings
return node
@click.command()
@click.option(
'-p',
'--path',
'path',
type=click.Path(exists=True),
required=True,
help='local pdf filepath or directory',
)
def cli(path):
output_dir = '/tmp/magic_pdf/integrations/rag/'
os.makedirs(output_dir, exist_ok=True)
documents = DataReader(path, 'ocr', output_dir)
# build nodes
nodes = []
for idx in range(documents.get_documents_count()):
doc = documents.get_document_result(idx)
if doc is None: # something wrong happens when parse pdf !
continue
for page in iter(
doc): # iterate documents from initial page to last page !
for element in iter(page): # iterate the element from all page !
if element.text is None:
continue
nodes.append(
embed_node(
TextNode(text=element.text,
metadata={'purpose': 'demo'})))
es_vec_store.add(nodes)
if __name__ == '__main__':
cli()
services:
es:
container_name: es
image: docker.elastic.co/elasticsearch/elasticsearch:8.11.3
volumes:
- esdata01:/usr/share/elasticsearch/data
ports:
- 9200:9200
environment:
- node.name=es
- ELASTIC_PASSWORD=llama_index
- bootstrap.memory_lock=false
- discovery.type=single-node
- xpack.security.enabled=true
- xpack.security.http.ssl.enabled=false
- xpack.security.transport.ssl.enabled=false
ulimits:
memlock:
soft: -1
hard: -1
restart: always
volumes:
esdata01:
driver: local
import os
import click
from llama_index.core.vector_stores.types import VectorStoreQuery
from llama_index.embeddings.dashscope import (DashScopeEmbedding,
DashScopeTextEmbeddingModels,
DashScopeTextEmbeddingType)
from llama_index.vector_stores.elasticsearch import (AsyncDenseVectorStrategy,
ElasticsearchStore)
# initialize qwen 7B model
from modelscope import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
es_vector_store = ElasticsearchStore(
index_name='rag_index',
es_url=os.getenv('ES_URL', 'http://127.0.0.1:9200'),
es_user=os.getenv('ES_USER', 'elastic'),
es_password=os.getenv('ES_PASSWORD', 'llama_index'),
retrieval_strategy=AsyncDenseVectorStrategy(),
)
def embed_text(text):
embedder = DashScopeEmbedding(
model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_DOCUMENT,
)
return embedder.get_text_embedding(text)
def search(vector_store: ElasticsearchStore, query: str):
query_vec = VectorStoreQuery(query_embedding=embed_text(query))
result = vector_store.query(query_vec)
return '\n'.join([node.text for node in result.nodes])
@click.command()
@click.option(
'-q',
'--question',
'question',
required=True,
help='ask what you want to know!',
)
def cli(question):
tokenizer = AutoTokenizer.from_pretrained('qwen/Qwen-7B-Chat',
revision='v1.0.5',
trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('qwen/Qwen-7B-Chat',
revision='v1.0.5',
device_map='auto',
trust_remote_code=True,
fp32=True).eval()
model.generation_config = GenerationConfig.from_pretrained(
'Qwen/Qwen-7B-Chat', revision='v1.0.5', trust_remote_code=True)
# define a prompt template for the vectorDB-enhanced LLM generation
def answer_question(question, context, model):
if context == '':
prompt = question
else:
prompt = f'''请基于```内的内容回答问题。"
```
{context}
```
我的问题是:{question}。
'''
history = None
print(prompt)
response, history = model.chat(tokenizer, prompt, history=None)
return response
answer = answer_question(question, search(es_vector_store, question),
model)
print(f'question: {question}\n'
f'answer: {answer}')
"""
python query.py -q 'how about the rights of men'
"""
if __name__ == '__main__':
cli()
...@@ -15,4 +15,4 @@ paddleocr==2.7.3 ...@@ -15,4 +15,4 @@ paddleocr==2.7.3
paddlepaddle==3.0.0b1 paddlepaddle==3.0.0b1
pypandoc pypandoc
struct-eqtable==0.1.0 struct-eqtable==0.1.0
detectron2 detectron2
\ No newline at end of file
boto3>=1.28.43 boto3>=1.28.43
Brotli>=1.1.0 Brotli>=1.1.0
click>=8.1.7 click>=8.1.7
PyMuPDF>=1.24.9 fast-langdetect==0.2.0
loguru>=0.6.0 loguru>=0.6.0
numpy>=1.21.6,<2.0.0 numpy>=1.21.6,<2.0.0
fast-langdetect==0.2.0
wordninja>=2.0.0
scikit-learn>=1.0.2
pdfminer.six==20231228 pdfminer.six==20231228
pydantic>=2.7.2,<2.8.0
PyMuPDF>=1.24.9
scikit-learn>=1.0.2
wordninja>=2.0.0
# The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator. # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.
This diff is collapsed.
import json
import os
import shutil
import tempfile
from magic_pdf.integrations.rag.api import DataReader, RagDocumentReader
from magic_pdf.integrations.rag.type import CategoryType
from magic_pdf.integrations.rag.utils import \
convert_middle_json_to_layout_elements
def test_rag_document_reader():
# setup
unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
os.makedirs(unitest_dir, exist_ok=True)
temp_output_dir = tempfile.mkdtemp(dir=unitest_dir)
os.makedirs(temp_output_dir, exist_ok=True)
# test
with open('tests/test_integrations/test_rag/assets/middle.json') as f:
json_data = json.load(f)
res = convert_middle_json_to_layout_elements(json_data, temp_output_dir)
doc = RagDocumentReader(res)
assert len(list(iter(doc))) == 1
page = list(iter(doc))[0]
assert len(list(iter(page))) == 10
assert len(page.get_rel_map()) == 3
item = list(iter(page))[0]
assert item.category_type == CategoryType.text
# teardown
shutil.rmtree(temp_output_dir)
def test_data_reader():
# setup
unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
os.makedirs(unitest_dir, exist_ok=True)
temp_output_dir = tempfile.mkdtemp(dir=unitest_dir)
os.makedirs(temp_output_dir, exist_ok=True)
# test
data_reader = DataReader('tests/test_integrations/test_rag/assets', 'ocr',
temp_output_dir)
assert data_reader.get_documents_count() == 2
for idx in range(data_reader.get_documents_count()):
document = data_reader.get_document_result(idx)
assert document is not None
# teardown
shutil.rmtree(temp_output_dir)
import json
import os
import shutil
import tempfile
from magic_pdf.integrations.rag.type import CategoryType
from magic_pdf.integrations.rag.utils import (
convert_middle_json_to_layout_elements, inference)
def test_convert_middle_json_to_layout_elements():
# setup
unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
os.makedirs(unitest_dir, exist_ok=True)
temp_output_dir = tempfile.mkdtemp(dir=unitest_dir)
os.makedirs(temp_output_dir, exist_ok=True)
# test
with open('tests/test_integrations/test_rag/assets/middle.json') as f:
json_data = json.load(f)
res = convert_middle_json_to_layout_elements(json_data, temp_output_dir)
assert len(res) == 1
assert len(res[0].layout_dets) == 10
assert res[0].layout_dets[0].anno_id == 0
assert res[0].layout_dets[0].category_type == CategoryType.text
assert len(res[0].extra.element_relation) == 3
# teardown
shutil.rmtree(temp_output_dir)
def test_inference():
asset_dir = 'tests/test_integrations/test_rag/assets'
# setup
unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
os.makedirs(unitest_dir, exist_ok=True)
temp_output_dir = tempfile.mkdtemp(dir=unitest_dir)
os.makedirs(temp_output_dir, exist_ok=True)
# test
res = inference(
asset_dir + '/one_page_with_table_image.pdf',
temp_output_dir,
'ocr',
)
assert res is not None
assert len(res) == 1
assert len(res[0].layout_dets) == 10
assert res[0].layout_dets[0].anno_id == 0
assert res[0].layout_dets[0].category_type == CategoryType.text
assert len(res[0].extra.element_relation) == 3
# teardown
shutil.rmtree(temp_output_dir)
...@@ -19,7 +19,12 @@ def test_common_do_parse(method): ...@@ -19,7 +19,12 @@ def test_common_do_parse(method):
# run # run
with open("tests/test_tools/assets/common/cli_test_01.pdf", "rb") as f: with open("tests/test_tools/assets/common/cli_test_01.pdf", "rb") as f:
bits = f.read() bits = f.read()
do_parse(temp_output_dir, filename, bits, [], method, f_dump_content_list=True) do_parse(temp_output_dir,
filename,
bits, [],
method,
False,
f_dump_content_list=True)
# check # check
base_output_dir = os.path.join(temp_output_dir, f"fake/{method}") base_output_dir = os.path.join(temp_output_dir, f"fake/{method}")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment