Commit 3cbcf2de authored by myhloli's avatar myhloli

feat(draw_bbox): add layout sorting visualization

Implement a new function `draw_layout_sort_bbox` in `draw_bbox.py` to visualize the
layout sorting results using the `LayoutLMv3ForTokenClassification` model. This function
predicts the order of layout elements and draws them in the sorted sequence on the PDF pages.
parent 24c143fe
import time
from magic_pdf.libs.commons import fitz # PyMuPDF from magic_pdf.libs.commons import fitz # PyMuPDF
from magic_pdf.libs.Constants import CROSS_PAGE from magic_pdf.libs.Constants import CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
...@@ -211,9 +213,9 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -211,9 +213,9 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
# 构造其余useful_list # 构造其余useful_list
for block in page['para_blocks']: for block in page['para_blocks']:
if block['type'] in [ if block['type'] in [
BlockType.Text, BlockType.Text,
BlockType.Title, BlockType.Title,
BlockType.InterlineEquation, BlockType.InterlineEquation,
]: ]:
for line in block['lines']: for line in block['lines']:
for span in line['spans']: for span in line['spans']:
...@@ -244,7 +246,7 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -244,7 +246,7 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
pdf_docs.save(f'{out_path}/{filename}_spans.pdf') pdf_docs.save(f'{out_path}/{filename}_spans.pdf')
def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
dropped_bbox_list = [] dropped_bbox_list = []
tables_body_list, tables_caption_list, tables_footnote_list = [], [], [] tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], [] imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
...@@ -279,7 +281,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -279,7 +281,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
elif layout_det['category_id'] == CategoryId.ImageCaption: elif layout_det['category_id'] == CategoryId.ImageCaption:
imgs_caption.append(bbox) imgs_caption.append(bbox)
elif layout_det[ elif layout_det[
'category_id'] == CategoryId.InterlineEquation_YOLO: 'category_id'] == CategoryId.InterlineEquation_YOLO:
interequations.append(bbox) interequations.append(bbox)
elif layout_det['category_id'] == CategoryId.Abandon: elif layout_det['category_id'] == CategoryId.Abandon:
page_dropped_list.append(bbox) page_dropped_list.append(bbox)
...@@ -316,3 +318,55 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -316,3 +318,55 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_model.pdf') pdf_docs.save(f'{out_path}/{filename}_model.pdf')
from typing import List
def do_predict(boxes: List[List[int]]) -> List[int]:
from transformers import LayoutLMv3ForTokenClassification
from magic_pdf.v3.helpers import prepare_inputs, boxes2inputs, parse_logits
model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, model)
logits = model(**inputs).logits.cpu().squeeze(0)
return parse_logits(logits, len(boxes))
def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []
from loguru import logger
for page in pdf_info:
page_layout_list = []
for block in page['para_blocks']:
bbox = block['bbox']
page_layout_list.append(bbox)
# 使用layoutreader排序
page_size = page['page_size']
x_scale = 1000.0 / page_size[0]
y_scale = 1000.0 / page_size[1]
boxes = []
logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_layout_list)}")
for left, top, right, bottom in page_layout_list:
left = round(left * x_scale)
top = round(top * y_scale)
right = round(right * x_scale)
bottom = round(bottom * y_scale)
assert (
1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}"
boxes.append([left, top, right, bottom])
logger.info("layoutreader start")
start = time.time()
orders = do_predict(boxes)
print(orders)
logger.info(f"layoutreader end, cos time{time.time() - start}")
sorted_bboxes = [page_layout_list[i] for i in orders]
layout_bbox_list.append(sorted_bboxes)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [102, 102, 255], False)
pdf_docs.save(f'{out_path}/{filename}_layout_sort.pdf')
...@@ -7,7 +7,7 @@ from loguru import logger ...@@ -7,7 +7,7 @@ from loguru import logger
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox, from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox,
drow_model_bbox) draw_model_bbox, draw_layout_sort_bbox)
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.pipe.OCRPipe import OCRPipe from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe from magic_pdf.pipe.TXTPipe import TXTPipe
...@@ -90,7 +90,9 @@ def do_parse( ...@@ -90,7 +90,9 @@ def do_parse(
if f_draw_span_bbox: if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name) draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
if f_draw_model_bbox: if f_draw_model_bbox:
drow_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name) draw_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name)
draw_layout_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
md_content = pipe.pipe_mk_markdown(image_dir, md_content = pipe.pipe_mk_markdown(image_dir,
drop_mode=DropMode.NONE, drop_mode=DropMode.NONE,
......
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
\ No newline at end of file
import gzip
import json
import torch
import typer
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from tqdm import tqdm
from transformers import LayoutLMv3ForTokenClassification
from helpers import (
DataCollator,
check_duplicate,
MAX_LEN,
parse_logits,
prepare_inputs,
)
app = typer.Typer()
chen_cherry = SmoothingFunction()
@app.command()
def main(
input_file: str = typer.Argument(..., help="input file"),
model_path: str = typer.Argument(..., help="model path"),
batch_size: int = typer.Option(16, help="batch size"),
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = (
LayoutLMv3ForTokenClassification.from_pretrained(model_path, num_labels=MAX_LEN)
.bfloat16()
.to(device)
.eval()
)
data_collator = DataCollator()
if torch.cuda.is_available():
torch.cuda.empty_cache()
datasets = []
with gzip.open(input_file, "rt") as f:
for line in tqdm(f):
datasets.append(json.loads(line))
# make batch faster
datasets.sort(key=lambda x: len(x["source_boxes"]), reverse=True)
total = 0
total_out_idx = 0.0
total_out_token = 0.0
for i in tqdm(range(0, len(datasets), batch_size)):
batch = datasets[i : i + batch_size]
model_inputs = data_collator(batch)
model_inputs = prepare_inputs(model_inputs, model)
# forward
with torch.no_grad():
model_outputs = model(**model_inputs)
logits = model_outputs.logits.cpu()
for data, logit in zip(batch, logits):
target_index = data["target_index"][:MAX_LEN]
pred_index = parse_logits(logit, len(target_index))
assert len(pred_index) == len(target_index)
assert not check_duplicate(pred_index)
target_texts = data["target_texts"][:MAX_LEN]
source_texts = data["source_texts"][:MAX_LEN]
pred_texts = []
for idx in pred_index:
pred_texts.append(source_texts[idx])
total += 1
total_out_idx += sentence_bleu(
[target_index],
[i + 1 for i in pred_index],
smoothing_function=chen_cherry.method2,
)
total_out_token += sentence_bleu(
[" ".join(target_texts).split()],
" ".join(pred_texts).split(),
smoothing_function=chen_cherry.method2,
)
print("total: ", total)
print("out_idx: ", round(100 * total_out_idx / total, 1))
print("out_token: ", round(100 * total_out_token / total, 1))
if __name__ == "__main__":
app()
from collections import defaultdict
from typing import List, Dict
import torch
from transformers import LayoutLMv3ForTokenClassification
MAX_LEN = 510
CLS_TOKEN_ID = 0
UNK_TOKEN_ID = 3
EOS_TOKEN_ID = 2
class DataCollator:
def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
bbox = []
labels = []
input_ids = []
attention_mask = []
# clip bbox and labels to max length, build input_ids and attention_mask
for feature in features:
_bbox = feature["source_boxes"]
if len(_bbox) > MAX_LEN:
_bbox = _bbox[:MAX_LEN]
_labels = feature["target_index"]
if len(_labels) > MAX_LEN:
_labels = _labels[:MAX_LEN]
_input_ids = [UNK_TOKEN_ID] * len(_bbox)
_attention_mask = [1] * len(_bbox)
assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
bbox.append(_bbox)
labels.append(_labels)
input_ids.append(_input_ids)
attention_mask.append(_attention_mask)
# add CLS and EOS tokens
for i in range(len(bbox)):
bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
labels[i] = [-100] + labels[i] + [-100]
input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
attention_mask[i] = [1] + attention_mask[i] + [1]
# padding to max length
max_len = max(len(x) for x in bbox)
for i in range(len(bbox)):
bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
attention_mask[i] = attention_mask[i] + [0] * (
max_len - len(attention_mask[i])
)
ret = {
"bbox": torch.tensor(bbox),
"attention_mask": torch.tensor(attention_mask),
"labels": torch.tensor(labels),
"input_ids": torch.tensor(input_ids),
}
# set label > MAX_LEN to -100, because original labels may be > MAX_LEN
ret["labels"][ret["labels"] > MAX_LEN] = -100
# set label > 0 to label-1, because original labels are 1-indexed
ret["labels"][ret["labels"] > 0] -= 1
return ret
def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
attention_mask = [1] + [1] * len(boxes) + [1]
return {
"bbox": torch.tensor([bbox]),
"attention_mask": torch.tensor([attention_mask]),
"input_ids": torch.tensor([input_ids]),
}
def prepare_inputs(
inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
) -> Dict[str, torch.Tensor]:
ret = {}
for k, v in inputs.items():
v = v.to(model.device)
if torch.is_floating_point(v):
v = v.to(model.dtype)
ret[k] = v
return ret
def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
"""
parse logits to orders
:param logits: logits from model
:param length: input length
:return: orders
"""
logits = logits[1 : length + 1, :length]
orders = logits.argsort(descending=False).tolist()
ret = [o.pop() for o in orders]
while True:
order_to_idxes = defaultdict(list)
for idx, order in enumerate(ret):
order_to_idxes[order].append(idx)
# filter idxes len > 1
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
if not order_to_idxes:
break
# filter
for order, idxes in order_to_idxes.items():
# find original logits of idxes
idxes_to_logit = {}
for idx in idxes:
idxes_to_logit[idx] = logits[idx, order]
idxes_to_logit = sorted(
idxes_to_logit.items(), key=lambda x: x[1], reverse=True
)
# keep the highest logit as order, set others to next candidate
for idx, _ in idxes_to_logit[1:]:
ret[idx] = orders[idx].pop()
return ret
def check_duplicate(a: List[int]) -> bool:
return len(a) != len(set(a))
import os
from dataclasses import dataclass, field
from datasets import load_dataset, Dataset
from loguru import logger
from transformers import (
TrainingArguments,
HfArgumentParser,
LayoutLMv3ForTokenClassification,
set_seed,
)
from transformers.trainer import Trainer
from helpers import DataCollator, MAX_LEN
@dataclass
class Arguments(TrainingArguments):
model_dir: str = field(
default=None,
metadata={"help": "Path to model, based on `microsoft/layoutlmv3-base`"},
)
dataset_dir: str = field(
default=None,
metadata={"help": "Path to dataset"},
)
def load_train_and_dev_dataset(path: str) -> (Dataset, Dataset):
datasets = load_dataset(
"json",
data_files={
"train": os.path.join(path, "train.jsonl.gz"),
"dev": os.path.join(path, "dev.jsonl.gz"),
},
)
return datasets["train"], datasets["dev"]
def main():
parser = HfArgumentParser((Arguments,))
args: Arguments = parser.parse_args_into_dataclasses()[0]
set_seed(args.seed)
train_dataset, dev_dataset = load_train_and_dev_dataset(args.dataset_dir)
logger.info(
"Train dataset size: {}, Dev dataset size: {}".format(
len(train_dataset), len(dev_dataset)
)
)
model = LayoutLMv3ForTokenClassification.from_pretrained(
args.model_dir, num_labels=MAX_LEN, visual_embed=False
)
data_collator = DataCollator()
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=dev_dataset,
data_collator=data_collator,
)
trainer.train()
if __name__ == "__main__":
main()
#!/usr/bin/env bash
set -x
set -e
DIR="$( cd "$( dirname "$0" )" && cd .. && pwd )"
OUTPUT_DIR="${DIR}/checkpoint/v3/$(date +%F-%H)"
DATA_DIR="${DIR}/ReadingBank/"
mkdir -p "${OUTPUT_DIR}"
deepspeed train.py \
--model_dir 'microsoft/layoutlmv3-large' \
--dataset_dir "${DATA_DIR}" \
--dataloader_num_workers 1 \
--deepspeed ds_config.json \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 64 \
--do_train \
--do_eval \
--logging_steps 100 \
--bf16 \
--seed 42 \
--num_train_epochs 10 \
--learning_rate 5e-5 \
--warmup_steps 1000 \
--save_strategy epoch \
--evaluation_strategy epoch \
--remove_unused_columns False \
--output_dir "${OUTPUT_DIR}" \
--overwrite_output_dir \
"$@"
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment