0


从AI推理性能优化角度看LLaMA的模型结构和源码

本篇文章讲讲LLaMA的结构,已经有很多文章已经对LLaMA在一些结构上任务表现上做了一些解析,本文主要从优化的角度、实现kernel的角度解析一下LLaMA,读者事先对transformer的结构有基本认识最好。本文首发于我的公众号“AI不止算法”,文章链接在此

LLaMA简单介绍

几个月前,FB开源了LLAMA,LLAMA1包括三个参数量的模型7B、13B、65B, 证明了完全可以通过公开数据集来训练最先进的模型,而无需使用专有和不可获取的数据集,同时LLaMA-13B 在大多数benchmark优于 GPT-3,尽管大小只有后者的1/10。在更大规模上,LLaMA-65B 参数模型也与可以与Chinchilla或PaLM-540B相竞争,这是之前bloom、OPT等没有做到的。本文不谈LLaMA的预训练数据多么多么怎么样,也不谈LLaMA在各个任务上的表现如何,重点从性能优化的角度谈谈LLaMA的模型结构。
图片

模型结构

LLaMA主体结构依然是transformer组成,和其它LLM不同的是:

  • 使用RMSNorm(即Root Mean square Layer Normalization)对每个Transformer子层的input进行Pre Norm
  • 使用激活函数SwiGLU
  • 使用RoPE进行相对位置编码
  • 使用了AdamW优化器,并使用cosine learning rate schedule (AdamW和Adam的区别我不是特别清楚,先放着不讲)

RMSNorm为layerNorm的变体,在分子分母都省去了Mean,同时少了beta参数,虽然不用再计算variance了,但我觉得Welford依然是Normlization类算子性能的最优解

在这里插入图片描述

# RMSNormclassRMSNorm(torch.nn.Module):def__init__(self, dim:int, eps:float=1e-6):super().__init__()
        self.eps = eps # ε
        self.weight = nn.Parameter(torch.ones(dim))def_norm(self, x):# RMSNormreturn x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True)+ self.eps)defforward(self, x):
        output = self._norm(x.float()).type_as(x)return output * self.weight

激活函数SwiGLU整合了Swish和GLU这两个函数,网上的文章对这一块讲的似懂非懂,不如直接看CUDA源码,我去翻了一下SwiGLU的实现,得出SwiGLU可以理解为SiLU和mul的fused kernel,前者为x * sigmoid(x),本质上来看依然是一个element wise kernel

在这里插入图片描述
对于RoPE,这是一个新鲜的玩意,我们要做的就是实现这样一个rotary_embedding kernel , 它作用与QK矩阵上,在QK的batch GEMM之前,采用绝对位置编码来达到相对位置编码的效果,绝对位置编码的优点是计算简单高效,缺点是一般效果不如相对位置编码。相对位置编码的优点是效果较好,缺点是计算效率不如绝对位置编码。在相对位置编码中,注意力权重的结果仅仅和参与注意力计算的token向量的相对位置有关,不和绝对位置直接关联。这符合NLP领域在序列长度方向上具有平移不变性的特点,所以相对位置编码一般效果会优于绝对位置编码。

RoPE公式推导我个人有点看不下去,直接看公式吧,将旋转位置编码过程由GEMM简化成两次向量的哈达玛积求和,这也是一个element wise kernel,要把x给索引好,送给cos和sin相乘

在这里插入图片描述
python源代码,还是比较straightforward
在这里插入图片描述

LLaMA Attention

和普遍的attention结构没有太大区别,除了把上面的那些新增结构RMS norm,RoPE给添加到各个transformer layer开头和QK之后。想谈论的是Tensor Parallel 版本的attention,这里对qkv的weight采用了列切分,output linear采用了行切分,这循序了NV megatron的张量并行切分思想,有助于最小化多卡通讯开销。
在这里插入图片描述

LLaMA MLP

同理对于MLP,也采用了linear的列切分行切分版本,同时把SwiGLU给加了进去
在这里插入图片描述

LLaMA TransformerLayer

对于每个layer,把attention和MLP叠起来就完事
在这里插入图片描述

Llama generate

transformerlayer出来后的经过LMhead(其实就是个linear)+ softmax得到probs,然后就开始sample,可以topP,可以贪心,可以beam search,主要就看怎么设计了,在这份代码里,采用了topP或贪心,最后再detokenize,吐出token到构造的buffer tokens = torch.full((bsz,total_len), self.tokenizer.pad_id).cuda().long()

classLLaMA:def__init__(self, model: Transformer, tokenizer: Tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    defgenerate(
        self,
        prompts: List[str],
        max_gen_len:int,
        temperature:float=0.8,
        top_p:float=0.95,)-> List[str]:
        bsz =len(prompts)
        params = self.model.params
        assert bsz <= params.max_batch_size,(bsz, params.max_batch_size)

        prompt_tokens =[self.tokenizer.encode(x, bos=True, eos=False)for x in prompts]

        min_prompt_size =min([len(t)for t in prompt_tokens])
        max_prompt_size =max([len(t)for t in prompt_tokens])

        total_len =min(params.max_seq_len, max_gen_len + max_prompt_size)

        tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()for k, t inenumerate(prompt_tokens):

            tokens[k,:len(t)]= torch.tensor(t).long()
        input_text_mask = tokens != self.tokenizer.pad_id
        start_pos = min_prompt_size
        prev_pos =0# start generatefor cur_pos inrange(start_pos, total_len):
            logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)if temperature >0:
                probs = torch.softmax(logits / temperature, dim=-1)# sample by top P
                next_token = sample_top_p(probs, top_p)else:# greedy search
                next_token = torch.argmax(logits, dim=-1)
            next_token = next_token.reshape(-1)# only replace token if prompt has already been generated
            next_token = torch.where(
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos]= next_token
            prev_pos = cur_pos
        # detokenize
        decoded =[]for i, t inenumerate(tokens.tolist()):# cut to max gen len
            t = t[:len(prompt_tokens[i])+ max_gen_len]# cut to eos tok if anytry:
                t = t[: t.index(self.tokenizer.eos_id)]except ValueError:pass
            decoded.append(self.tokenizer.decode(t))return decoded

# sample the one which is the cum prob < pdefsample_top_p(probs, p):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask]=0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))# extract a sample

    next_token = torch.multinomial(probs_sort, num_samples=1)# find next_token's id

        next_token = torch.gather(probs_idx,-1, next_token)return next_token

LLaMA 2

LLaMA2在1的基础上又做了一些改进,在模型结构上引入了GQA来降低KV cache的显存占用,以此来增大batch size,获得更高的吞吐量,后面单独开篇文章讲讲MQA和GQA

另外

1、attention mask的构造上面也有一些要点:

_make_causal_mask用于构造下三角这种mask结构以实现语言模型的单向注意力。

_expand_mask用于将mask信息展开成和attention矩阵相同的张量结构。

2、对优化器AdamW的具体实现不是很了解,后续补补课再来聊聊

3、LLM的inference本身并不像general的inference engine或者framework那么有太大的复杂度,主要还是实现那几个kernel,整体我个人感觉在性能优化的角度,还是不会带来太大的额外工作量,多数kernel都可以reuse已有实现


本文转载自: https://blog.csdn.net/weixin_43568400/article/details/136313606
版权归原作者 不归牛顿管的熊猫 所有, 如有侵权,请联系我们删除。

“从AI推理性能优化角度看LLaMA的模型结构和源码”的评论:

还没有评论