Commit a3358878 authored by liukaiwen's avatar liukaiwen

feat: merge formula update

parent 763688c0
from typing import Optional, Tuple, Union
import torch
from torch import nn
import os
from unimernet.common.config import Config
import unimernet.tasks as tasks
import argparse
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
class PatchedMBartLearnedPositionalEmbedding(nn.Module):
def __init__(self, origin: nn.Module):
super().__init__()
self.offset = origin.offset
self.embedding = nn.Embedding(origin.num_embeddings, origin.embedding_dim)
self.embedding.weight.data = origin.weight.data
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(0, seq_len, dtype=torch.long, device=self.embedding.weight.device
)
positions += past_key_values_length
positions = positions.expand(bsz, -1)
return self.embedding(positions + self.offset)
class PatchedMBartDecoder(nn.Module):
def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
super().__init__()
self.origin = origin
self.kvlen = kvlen
self.config = origin.config
self.embed_tokens = origin.embed_tokens
self.embed_scale = origin.embed_scale
self._use_flash_attention_2 = origin._use_flash_attention_2
self.embed_positions = origin.embed_positions
self.counting_context_weight = getattr(origin, 'counting_context_weight', None)
self.layernorm_embedding = origin.layernorm_embedding
self.layers = origin.layers
self.layer_norm = origin.layer_norm
self.patched_embed_positions = PatchedMBartLearnedPositionalEmbedding(self.embed_positions)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
count_pred: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
run_origin = False
if past_key_values is None:
run_origin = True
elif past_key_values[0][0].size(-2) < attention_mask.size(-1):
run_origin = True
if run_origin:
return self.origin(
input_ids=input_ids,
attention_mask=attention_mask,
count_pred=count_pred,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input = input_ids
input_shape = input.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self._use_flash_attention_2:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions
positions = self.patched_embed_positions(input, self.kvlen)
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
# TODO: add counting context weight to hidden_states
if count_pred is not None:
count_context_weight = self.counting_context_weight(count_pred)
hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
hidden_states = self.layernorm_embedding(hidden_states)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {attn_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class PatchedMBartAttention(nn.Module):
def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
super().__init__()
self.embed_dim = origin.embed_dim
self.num_heads = origin.num_heads
self.dropout = origin.dropout
self.head_dim = origin.head_dim
self.config = origin.config
self.scaling = origin.scaling
self.is_decoder = origin.is_decoder
self.is_causal = origin.is_causal
self.k_proj = origin.k_proj
self.v_proj = origin.v_proj
self.q_proj = origin.q_proj
self.out_proj = origin.out_proj
self.kvlen = kvlen
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if past_key_value[0].size(-2) < attention_mask.size(-1):
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
past_key_value[0][:, :, self.kvlen[None]] = key_states
past_key_value[1][:, :, self.kvlen[None]] = value_states
key_states = past_key_value[0]
value_states = past_key_value[1]
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = attn_weights
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
# attn_output = self.out_proj(attn_output)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class PatchedMBartSqueezeAttention(nn.Module):
def __init__(self, origin: nn.Module, kvlen: torch.LongTensor):
super().__init__()
self.embed_dim = origin.embed_dim
self.num_heads = origin.num_heads
self.dropout = origin.dropout
self.head_dim = origin.head_dim
self.squeeze_head_dim=origin.squeeze_head_dim
self.config = origin.config
self.scaling = origin.scaling
self.is_decoder = origin.is_decoder
self.scaling = origin.scaling
self.q_proj = origin.q_proj
self.k_proj = origin.k_proj
self.v_proj = origin.v_proj
self.out_proj = origin.out_proj
self.kvlen = kvlen
def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim).transpose(1, 2).contiguous()
def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
if past_key_value[0].size(-2) < attention_mask.size(-1):
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
past_key_value[0][:, :, self.kvlen[None]] = key_states
past_key_value[1][:, :, self.kvlen[None]] = value_states
key_states = past_key_value[0]
value_states = past_key_value[1]
else:
# self_attention
key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.squeeze_head_dim)
value_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape_qk(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*value_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
def patch_model(model: nn.Module, kvlen: torch.LongTensor):
for name, child in model.named_children():
cls_name = type(child).__name__
if cls_name == 'MBartAttention':
patched_child = PatchedMBartAttention(child, kvlen)
model.register_module(name, patched_child)
elif cls_name == 'MBartSqueezeAttention':
patched_child = PatchedMBartSqueezeAttention(child, kvlen)
model.register_module(name, patched_child)
else:
patch_model(child, kvlen)
cls_name = type(model).__name__
if cls_name == 'CustomMBartDecoder':
model = PatchedMBartDecoder(model, kvlen)
return model
def next_power_of_2(n: int):
"""Return the smallest power of 2 greater than or equal to n."""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n
def get_graph_key(batch_size: int, kvlens: int):
batch_size = next_power_of_2(batch_size)
kvlens = next_power_of_2(kvlens)
batch_size = max(8, batch_size)
kvlens = max(32, kvlens)
return batch_size, kvlens
class GraphRunnerImpl:
def __init__(self, model: nn.Module, graph: torch.cuda.CUDAGraph, input_buffers: dict, output_buffers: dict):
self.model = model
self.graph = graph
self.input_buffers = input_buffers
self.output_buffers = output_buffers
@staticmethod
def extract_input_buffers(input_buffers: dict, batch_size: int, kvlens: int):
input_ids = input_buffers['input_ids'][:batch_size]
attention_mask = input_buffers['attention_mask'][:batch_size, :kvlens]
encoder_hidden_states = input_buffers['encoder_hidden_states'][:batch_size]
kvlen=input_buffers['kvlen']
past_key_values = []
for past_key_value in input_buffers['past_key_values']:
k0 = past_key_value[0][:batch_size, :, :kvlens]
v0 = past_key_value[1][:batch_size, :, :kvlens]
k1 = past_key_value[2][:batch_size]
v1 = past_key_value[3][:batch_size]
past_key_values.append((k0, v0, k1, v1))
input_buffers = dict(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
kvlen=kvlen,
)
return input_buffers
@staticmethod
def fill_input_buffers(
input_buffer: dict,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
):
batch_size = input_ids.size(0)
kvlens = attention_mask.size(1)
input_buffer['input_ids'][:batch_size] = input_ids
if input_buffer['attention_mask'].data_ptr() != attention_mask.data_ptr():
input_buffer['attention_mask'].fill_(0)
input_buffer['attention_mask'][:batch_size, :kvlens] = attention_mask
input_buffer['encoder_hidden_states'][:batch_size] = encoder_hidden_states
if past_key_values is not None:
for buf_kv, kv in zip(input_buffer['past_key_values'], past_key_values):
idx = 0
if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
buf_kv[idx].fill_(0)
buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx]
idx = 1
if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
buf_kv[idx].fill_(0)
buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx]
idx = 2
if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
buf_kv[idx].fill_(0)
buf_kv[idx][:batch_size] = kv[idx]
idx = 3
if buf_kv[idx].data_ptr() != kv[idx].data_ptr():
buf_kv[idx].fill_(0)
buf_kv[idx][:batch_size] = kv[idx]
input_buffer['kvlen'].fill_(kvlens - 1)
@classmethod
@torch.inference_mode()
def capture(cls,
model: nn.Module,
input_buffers: dict,
pool,
warmup: bool = False,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
count_pred: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,):
batch_size = input_ids.size(0)
kvlens = attention_mask.size(1)
graph_key = get_graph_key(batch_size, kvlens)
batch_size = graph_key[0]
kvlens = graph_key[1]
input_buffers = cls.extract_input_buffers(input_buffers,
batch_size=batch_size,
kvlens=kvlens)
cls.fill_input_buffers(input_buffers,
input_ids,
attention_mask,
encoder_hidden_states,
past_key_values)
input_ids = input_buffers['input_ids']
attention_mask = input_buffers['attention_mask']
encoder_hidden_states = input_buffers['encoder_hidden_states']
past_key_values = input_buffers['past_key_values']
if warmup:
# warmup
model(
input_ids=input_ids,
attention_mask=attention_mask,
count_pred=count_pred,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph,
pool=pool):
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
count_pred=count_pred,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
output_buffers = dict(
last_hidden_state=outputs['last_hidden_state'],
past_key_values=outputs['past_key_values'],
)
return GraphRunnerImpl(model, graph, input_buffers, output_buffers)
def __call__(self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
count_pred: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
batch_size = input_ids.size(0)
kvlens = attention_mask.size(1)
self.fill_input_buffers(self.input_buffers,
input_ids,
attention_mask,
encoder_hidden_states,
past_key_values)
self.graph.replay()
last_hidden_state = self.output_buffers['last_hidden_state'][:batch_size]
past_key_values = []
for past_key_value in self.output_buffers['past_key_values']:
k0 = past_key_value[0][:batch_size, :, :kvlens]
v0 = past_key_value[1][:batch_size, :, :kvlens]
k1 = past_key_value[2][:batch_size]
v1 = past_key_value[3][:batch_size]
past_key_values.append((k0, v0, k1, v1))
outputs = BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_state,
past_key_values=past_key_values,
)
return outputs
class GraphRunner(nn.Module):
def __init__(self, model: nn.Module, max_batchs: int, max_kvlens: int, dtype:torch.dtype = torch.float16, device: torch.device = 'cuda'):
super().__init__()
self.kvlen = torch.tensor(0, dtype=torch.long, device=device)
model = patch_model(model.to(dtype), self.kvlen)
self.model = model
self.max_batchs = max_batchs
self.max_kvlens = max_kvlens
self.device = device
self.input_buffers = None
self.impl_map = dict()
self.graph_pool_handle = torch.cuda.graph_pool_handle()
self.warmuped = False
def create_buffers(self, encoder_kvlens: int, dtype: torch.dtype):
max_batchs = self.max_batchs
max_kvlens = self.max_kvlens
device = self.device
config = self.model.config
d_model = config.d_model
decoder_layers = config.decoder_layers
num_heads = config.decoder_attention_heads
head_dim = d_model // num_heads
self_attn = self.model.layers[0].self_attn
qk_head_dim = getattr(self_attn, 'squeeze_head_dim', head_dim)
input_ids = torch.ones((max_batchs, 1), dtype=torch.int64, device=device)
attention_mask = torch.zeros((max_batchs, max_kvlens), dtype=torch.int64, device=device)
encoder_hidden_states = torch.zeros((max_batchs, encoder_kvlens, d_model), dtype=dtype, device=device)
past_key_values = []
for _ in range(decoder_layers):
k0 = torch.zeros((max_batchs, num_heads, max_kvlens, qk_head_dim), dtype=dtype, device=device)
v0 = torch.zeros((max_batchs, num_heads, max_kvlens, head_dim), dtype=dtype, device=device)
k1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, qk_head_dim), dtype=dtype, device=device)
v1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, head_dim), dtype=dtype, device=device)
past_key_values.append((k0, v0, k1, v1))
self.input_buffers = dict(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
kvlen=self.kvlen
)
@torch.inference_mode()
def forward(self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
count_pred: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
batch_size, qlens = input_ids.size()
kvlens = attention_mask.size(1)
eager_mode = False
if qlens != 1:
eager_mode = True
if past_key_values is None:
eager_mode = True
else:
for past_key_value in past_key_values:
if past_key_value is None:
eager_mode = True
break
if batch_size >= self.max_batchs or kvlens >= self.max_kvlens:
eager_mode = True
if eager_mode:
return self.model(
input_ids=input_ids,
attention_mask=attention_mask,
count_pred=count_pred,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,)
# create buffer if not exists.
if self.input_buffers is None:
encoder_kvlens = encoder_hidden_states.size(1)
self.create_buffers(encoder_kvlens=encoder_kvlens, dtype=encoder_hidden_states.dtype)
graph_key = get_graph_key(batch_size, kvlens)
if graph_key not in self.impl_map:
warmup = False
if not self.warmuped:
warmup = True
self.warmuped = True
impl = GraphRunnerImpl.capture(
self.model,
self.input_buffers,
self.graph_pool_handle,
warmup=warmup,
input_ids=input_ids,
attention_mask=attention_mask,
count_pred=count_pred,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
self.impl_map[graph_key] = impl
impl = self.impl_map[graph_key]
ret = impl(
input_ids=input_ids,
attention_mask=attention_mask,
count_pred=count_pred,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return ret
\ No newline at end of file
......@@ -5,6 +5,7 @@ import time
from magic_pdf.libs.Constants import *
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.model.model_list import AtomicModel
from .mfr_cudagraph import GraphRunner
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
......@@ -67,6 +68,11 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
model = task.build_model(cfg)
model.to(_device_)
model.eval()
model = model.to(_device_)
if 'cuda' in _device_:
decoder_runner = GraphRunner(model.model.model.decoder.model.decoder, max_batchs=128, max_kvlens=256,
device=_device_)
model.model.model.decoder.model.decoder = decoder_runner
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
mfr_transform = transforms.Compose([vis_processor, ])
return [model, mfr_transform]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment