0


开源模型应用落地-业务优化篇(一)

一、前言

通过参与“开源模型应用落地-业务整合系列篇”的学习,我们已经成功建立了基本的业务流程。然而,这只是迈出了万里长征的第一步。现在我们要对整个项目进行优化,以提高效率。我们计划利用线程池来加快处理速度,使用redis来实现排队需求,以及通过多级环境来减轻负载压力。这些优化措施将有助于我们进一步改进项目的性能和效果。

二、术语

2.1. 线程池

是一种用于线程管理的技术,它包含一组预先创建的线程,用于执行任务。线程池维护着一个任务队列,当有任务到达时,线程池中的线程会自动分配任务并执行。

线程池的主要目的是重用线程,避免频繁地创建和销毁线程带来的开销。通过使用线程池,可以在程序初始化时创建一组线程,并将任务提交给线程池进行处理,而不需要为每个任务都创建一个新的线程。这样可以有效地管理系统中的线程数量,控制并发度,提高系统的性能和资源利用率。

线程池通常包含以下几个关键组件:

  1. 任务队列(Task Queue):用于存储待执行的任务,通常是一个队列结构。当有新的任务到达时,会被添加到任务队列中。
  2. 线程池管理器(Thread Pool Manager):负责管理线程池的创建、销毁和线程的调度。它会监视任务队列的状态,并根据需要动态地创建或回收线程。
  3. 工作线程(Worker Threads):线程池中的线程,用于执行任务。它们会从任务队列中获取任务,并执行任务的处理逻辑。

三、前置条件

3.1. 已搭建WebSocket与AI服务调用链路


四、技术实现

4.1. 调整业务逻辑处理类

 对于每次交互的chat对话,都需要经过以下步骤,包括但不限于:
  1. 对用户输入的内容进行自定义违规词检测
  2. 对用户输入的内容进行第三方在线违规词检测
  3. 对用户输入的内容进行组装成Prompt
  4. 对Prompt根据业务进行增强(完善prompt的内容)
  5. 对history进行裁剪或总结(检测history是否操作模型支持的上下文长度,例如qwen-7b支持的上下文长度为8192)

特别是调用第三方在线违规词检测,例如:某某云的内容安全审核服务,是非常耗时,会阻塞正常线程的执行,导致吞吐量的下降。

所以,我们就要对下面这块的处理逻辑进行调整,通过自定义线程池的方式,去处理核心的Chat交互流程

调整后:

4.2. 新增线程处理类

import io.netty.channel.ChannelHandlerContext;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

@Component
@Slf4j
public class TaskUtils{
    private static ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
    @Autowired
    private AIChatUtils aiChatUtils;

     public void execute(AITaskReqMessage aiTaskReqMessage) {

        executorService.execute(() -> {
            Long userId = aiTaskReqMessage.getUserId();

            if (null == userId || (long) userId < 10000) {
                log.warn("用户身份标识有误!");
                return;
            }

            ChannelHandlerContext channelHandlerContext = AbstractBusinessLogicHandler.getContextByUserId(userId);

            if (channelHandlerContext != null) {
                try {
                    aiChatUtils.chatStream(aiTaskReqMessage);

                } catch (Throwable exception) {
                    exception.printStackTrace();
                }
            }
        });
    }

    public static void destory(){
        executorService.shutdownNow();
        executorService = null;
    }

}

4.3. 新增线程处理实体类

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

import java.util.List;

@Builder
@Setter
@Getter
public class AITaskReqMessage {

    private String messageId;
    private Long userId;
    private String contents;
    private List<ChatContext> history;
}

五、测试

在线测试方式:WebSocket在线测试工具

5.1. 建立连接

5.2. 业务初始化

服务端输出:

5.3. 业务对话

服务端输出

5.4. 关闭连接


六、附带说明

6.1. 可以使用jmeter进行websocket压测,以评估各项性能指标是否符合预期(下一篇)

6.2. BusinessHandler完整代码

import com.alibaba.fastjson.JSON;
import io.netty.channel.ChannelHandler;
import lombok.extern.slf4j.Slf4j;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.List;

/**
 * @Description: 处理消息的handler
 */
@Slf4j
@ChannelHandler.Sharable
@Component
public class BusinessHandler extends AbstractBusinessLogicHandler<TextWebSocketFrame> {
    @Autowired
    private TaskUtils taskExecuteUtils;

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
        String channelId = ctx.channel().id().asShortText();
        log.info("add client,channelId:{}", channelId);
    }

    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        String channelId = ctx.channel().id().asShortText();
        log.info("remove client,channelId:{}", channelId);
    }

    @Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext, TextWebSocketFrame textWebSocketFrame)
            throws Exception {
        // 获取客户端传输过来的消息
        String content = textWebSocketFrame.text();
        // 兼容在线测试
        if (StringUtils.equals(content, "PING")) {
            buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
                    .respTime(String.valueOf(System.currentTimeMillis()))
                    .msgType(String.valueOf(MsgType.HEARTBEAT.getCode()))
                    .contents("心跳测试,很高兴收到你的心跳包")
                    .build());

            return;
        }
        log.info("接收到客户端发送的信息: {}", content);

        Long userIdForReq;
        String msgType = "";
        String contents = "";

        try {
            ApiReqMessage apiReqMessage = JSON.parseObject(content, ApiReqMessage.class);
            msgType = apiReqMessage.getMsgType();
            contents = apiReqMessage.getContents();

            userIdForReq = apiReqMessage.getUserId();
            // 用户身份标识校验
            if (null == userIdForReq || (long) userIdForReq <= 10000) {
                ApiRespMessage apiRespMessage = ApiRespMessage.builder().code(String.valueOf(StatusCode.SYSTEM_ERROR.getCode()))
                        .respTime(String.valueOf(System.currentTimeMillis()))
                        .contents("用户身份标识有误!")
                        .msgType(String.valueOf(MsgType.SYSTEM.getCode()))
                        .build();
                buildResponseAndClose(channelHandlerContext, apiRespMessage);
                return;
            }

            if (StringUtils.equals(msgType, String.valueOf(MsgType.CHAT.getCode()))) {
                // 对用户输入的内容进行自定义违规词检测
                // 对用户输入的内容进行第三方在线违规词检测
                // 对用户输入的内容进行组装成Prompt
                // 对Prompt根据业务进行增强(完善prompt的内容)
                // 对history进行裁剪或总结(检测history是否操作模型支持的上下文长度,例如qwen-7b支持的上下文长度为8192)
                // ...
                String messageId = apiReqMessage.getMessageId();
                List<ChatContext> history = apiReqMessage.getHistory();
                AITaskReqMessage aiTaskReqMessage = AITaskReqMessage.builder().messageId(messageId).userId(userIdForReq).contents(contents).history(history).build();
                taskExecuteUtils.execute(aiTaskReqMessage);

            } else if (StringUtils.equals(msgType, String.valueOf(MsgType.INIT.getCode()))) {
                //一、业务黑名单检测(多次违规,永久锁定)

                //二、账户锁定检测(临时锁定)

                //三、多设备登录检测

                //四、剩余对话次数检测

                //检测通过,绑定用户与channel之间关系
                addChannel(channelHandlerContext, userIdForReq);
                String respMessage = "用户标识: " + userIdForReq + " 登录成功";

                buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
                        .respTime(String.valueOf(System.currentTimeMillis()))
                        .msgType(String.valueOf(MsgType.INIT.getCode()))
                        .contents(respMessage)
                        .build());

            } else if (StringUtils.equals(msgType, String.valueOf(MsgType.HEARTBEAT.getCode()))) {

                buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
                        .respTime(String.valueOf(System.currentTimeMillis()))
                        .msgType(String.valueOf(MsgType.HEARTBEAT.getCode()))
                        .contents("心跳测试,很高兴收到你的心跳包")
                        .build());
            }
            else {
                log.info("用户标识: {}, 消息类型有误,不支持类型: {}", userIdForReq, msgType);
            }

        } catch (Exception e) {
            log.warn("【BusinessHandler】接收到请求内容:{},异常信息:{}", content, e.getMessage(), e);
            // 异常返回
            return;
        }

    }

}

6.3. AIChatUtils完整代码

import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.util.List;
import java.util.Objects;

@Slf4j
@Component
public class AIChatUtils {
    @Autowired
    private AIConfig aiConfig;

    private Request buildRequest(Long userId, String prompt) throws Exception {
        //创建一个请求体对象(body)
        MediaType mediaType = MediaType.parse("application/json");
        RequestBody requestBody = RequestBody.create(mediaType, prompt);

        return buildHeader(userId, new Request.Builder().post(requestBody))
                .url(aiConfig.getUrl()).build();
    }

    private Request.Builder buildHeader(Long userId, Request.Builder builder) throws Exception {
        return builder
                .addHeader("Content-Type", "application/json")
                .addHeader("userId", String.valueOf(userId))
                .addHeader("secret",generateSecret(userId))
    }

    /**
     * 生成请求密钥
     *
     * @param userId 用户ID
     * @return
     */
    private String generateSecret(Long userId) throws Exception {
        String key = aiConfig.getServerKey();
        String content = key + userId + key;

        MessageDigest digest = MessageDigest.getInstance("SHA-256");
        byte[] hash = digest.digest(content.getBytes(StandardCharsets.UTF_8));

        StringBuilder hexString = new StringBuilder();
        for (byte b : hash) {
            String hex = Integer.toHexString(0xff & b);
            if (hex.length() == 1) {
                hexString.append('0');
            }
            hexString.append(hex);
        }
        return hexString.toString();
    }

    public String chatStream(AITaskReqMessage aiTaskReqMessage) throws Exception {

        String messageId = aiTaskReqMessage.getMessageId();
        Long userId = aiTaskReqMessage.getUserId();
        String contents = aiTaskReqMessage.getContents();
        List<ChatContext> history = aiTaskReqMessage.getHistory();

        if(StringUtils.isEmpty(contents) || StringUtils.isBlank(contents)){
            log.warn("用户输入内容不能为空!");
            return null;
        }

        //定义请求的参数
        String prompt = JSON.toJSONString(AIChatReqVO.init(contents, history));
        log.info("【AIChatUtils】调用AI聊天,用户({}),prompt:{}",  userId, prompt);

        //创建一个请求对象
        Request request = buildRequest(userId, prompt);

        InputStream is = null;
        try {

            // 从线程池获取http请求并执行
            Response response =OkHttpUtils.getInstance(aiConfig).getOkHttpClient().newCall(request).execute();

            // 响应结果
            StringBuffer resultBuff = new StringBuffer();
            //正常返回
            if (response.code() == 200) {
                //打印返回的字符数据
                is = response.body().byteStream();
                byte[] bytes = new byte[1024];

                int len = is.read(bytes);
                while (len != -1) {
                    ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
                    outputStream.write(bytes, 0, len);
                    outputStream.flush();
                    // 本轮读取到的数据
                    String result = new String(outputStream.toByteArray(), StandardCharsets.UTF_8);
                    resultBuff.append(result);

                    len = is.read(bytes);

                    // 将数据逐个传输给用户
                    AbstractBusinessLogicHandler.pushChatMessageForUser(userId, result);
                }

                // 正常响应
                return resultBuff.toString();
            }
            else {
                String result = response.body().string();
                log.warn("处理异常,异常描述:{}",result);
            }
        } catch (Throwable e) {
            log.error("【AIChatUtils】消息({})调用AI聊天 chatStream 异常,异常消息:{}", messageId, e.getMessage(), e);

        } finally {
            if (!Objects.isNull(is)) {
                try {
                    is.close();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
        return null;
    }

}

本文转载自: https://blog.csdn.net/qq839019311/article/details/135851421
版权归原作者 开源技术探险家 所有, 如有侵权,请联系我们删除。

“开源模型应用落地-业务优化篇(一)”的评论:

还没有评论