一、前言
通过参与“开源模型应用落地-业务整合系列篇”的学习,我们已经成功建立了基本的业务流程。然而,这只是迈出了万里长征的第一步。现在我们要对整个项目进行优化,以提高效率。我们计划利用线程池来加快处理速度,使用redis来实现排队需求,以及通过多级环境来减轻负载压力。这些优化措施将有助于我们进一步改进项目的性能和效果。
二、术语
2.1. 线程池
是一种用于线程管理的技术,它包含一组预先创建的线程,用于执行任务。线程池维护着一个任务队列,当有任务到达时,线程池中的线程会自动分配任务并执行。
线程池的主要目的是重用线程,避免频繁地创建和销毁线程带来的开销。通过使用线程池,可以在程序初始化时创建一组线程,并将任务提交给线程池进行处理,而不需要为每个任务都创建一个新的线程。这样可以有效地管理系统中的线程数量,控制并发度,提高系统的性能和资源利用率。
线程池通常包含以下几个关键组件:
- 任务队列(Task Queue):用于存储待执行的任务,通常是一个队列结构。当有新的任务到达时,会被添加到任务队列中。
- 线程池管理器(Thread Pool Manager):负责管理线程池的创建、销毁和线程的调度。它会监视任务队列的状态,并根据需要动态地创建或回收线程。
- 工作线程(Worker Threads):线程池中的线程,用于执行任务。它们会从任务队列中获取任务,并执行任务的处理逻辑。
三、前置条件
3.1. 已搭建WebSocket与AI服务调用链路
四、技术实现
4.1. 调整业务逻辑处理类
对于每次交互的chat对话,都需要经过以下步骤,包括但不限于:
- 对用户输入的内容进行自定义违规词检测
- 对用户输入的内容进行第三方在线违规词检测
- 对用户输入的内容进行组装成Prompt
- 对Prompt根据业务进行增强(完善prompt的内容)
- 对history进行裁剪或总结(检测history是否操作模型支持的上下文长度,例如qwen-7b支持的上下文长度为8192)
特别是调用第三方在线违规词检测,例如:某某云的内容安全审核服务,是非常耗时,会阻塞正常线程的执行,导致吞吐量的下降。
所以,我们就要对下面这块的处理逻辑进行调整,通过自定义线程池的方式,去处理核心的Chat交互流程
调整后:
4.2. 新增线程处理类
import io.netty.channel.ChannelHandlerContext;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@Component
@Slf4j
public class TaskUtils{
private static ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
@Autowired
private AIChatUtils aiChatUtils;
public void execute(AITaskReqMessage aiTaskReqMessage) {
executorService.execute(() -> {
Long userId = aiTaskReqMessage.getUserId();
if (null == userId || (long) userId < 10000) {
log.warn("用户身份标识有误!");
return;
}
ChannelHandlerContext channelHandlerContext = AbstractBusinessLogicHandler.getContextByUserId(userId);
if (channelHandlerContext != null) {
try {
aiChatUtils.chatStream(aiTaskReqMessage);
} catch (Throwable exception) {
exception.printStackTrace();
}
}
});
}
public static void destory(){
executorService.shutdownNow();
executorService = null;
}
}
4.3. 新增线程处理实体类
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import java.util.List;
@Builder
@Setter
@Getter
public class AITaskReqMessage {
private String messageId;
private Long userId;
private String contents;
private List<ChatContext> history;
}
五、测试
在线测试方式:WebSocket在线测试工具
5.1. 建立连接
5.2. 业务初始化
服务端输出:
5.3. 业务对话
服务端输出
5.4. 关闭连接
六、附带说明
6.1. 可以使用jmeter进行websocket压测,以评估各项性能指标是否符合预期(下一篇)
6.2. BusinessHandler完整代码
import com.alibaba.fastjson.JSON;
import io.netty.channel.ChannelHandler;
import lombok.extern.slf4j.Slf4j;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.List;
/**
* @Description: 处理消息的handler
*/
@Slf4j
@ChannelHandler.Sharable
@Component
public class BusinessHandler extends AbstractBusinessLogicHandler<TextWebSocketFrame> {
@Autowired
private TaskUtils taskExecuteUtils;
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
String channelId = ctx.channel().id().asShortText();
log.info("add client,channelId:{}", channelId);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
String channelId = ctx.channel().id().asShortText();
log.info("remove client,channelId:{}", channelId);
}
@Override
protected void channelRead0(ChannelHandlerContext channelHandlerContext, TextWebSocketFrame textWebSocketFrame)
throws Exception {
// 获取客户端传输过来的消息
String content = textWebSocketFrame.text();
// 兼容在线测试
if (StringUtils.equals(content, "PING")) {
buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
.respTime(String.valueOf(System.currentTimeMillis()))
.msgType(String.valueOf(MsgType.HEARTBEAT.getCode()))
.contents("心跳测试,很高兴收到你的心跳包")
.build());
return;
}
log.info("接收到客户端发送的信息: {}", content);
Long userIdForReq;
String msgType = "";
String contents = "";
try {
ApiReqMessage apiReqMessage = JSON.parseObject(content, ApiReqMessage.class);
msgType = apiReqMessage.getMsgType();
contents = apiReqMessage.getContents();
userIdForReq = apiReqMessage.getUserId();
// 用户身份标识校验
if (null == userIdForReq || (long) userIdForReq <= 10000) {
ApiRespMessage apiRespMessage = ApiRespMessage.builder().code(String.valueOf(StatusCode.SYSTEM_ERROR.getCode()))
.respTime(String.valueOf(System.currentTimeMillis()))
.contents("用户身份标识有误!")
.msgType(String.valueOf(MsgType.SYSTEM.getCode()))
.build();
buildResponseAndClose(channelHandlerContext, apiRespMessage);
return;
}
if (StringUtils.equals(msgType, String.valueOf(MsgType.CHAT.getCode()))) {
// 对用户输入的内容进行自定义违规词检测
// 对用户输入的内容进行第三方在线违规词检测
// 对用户输入的内容进行组装成Prompt
// 对Prompt根据业务进行增强(完善prompt的内容)
// 对history进行裁剪或总结(检测history是否操作模型支持的上下文长度,例如qwen-7b支持的上下文长度为8192)
// ...
String messageId = apiReqMessage.getMessageId();
List<ChatContext> history = apiReqMessage.getHistory();
AITaskReqMessage aiTaskReqMessage = AITaskReqMessage.builder().messageId(messageId).userId(userIdForReq).contents(contents).history(history).build();
taskExecuteUtils.execute(aiTaskReqMessage);
} else if (StringUtils.equals(msgType, String.valueOf(MsgType.INIT.getCode()))) {
//一、业务黑名单检测(多次违规,永久锁定)
//二、账户锁定检测(临时锁定)
//三、多设备登录检测
//四、剩余对话次数检测
//检测通过,绑定用户与channel之间关系
addChannel(channelHandlerContext, userIdForReq);
String respMessage = "用户标识: " + userIdForReq + " 登录成功";
buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
.respTime(String.valueOf(System.currentTimeMillis()))
.msgType(String.valueOf(MsgType.INIT.getCode()))
.contents(respMessage)
.build());
} else if (StringUtils.equals(msgType, String.valueOf(MsgType.HEARTBEAT.getCode()))) {
buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
.respTime(String.valueOf(System.currentTimeMillis()))
.msgType(String.valueOf(MsgType.HEARTBEAT.getCode()))
.contents("心跳测试,很高兴收到你的心跳包")
.build());
}
else {
log.info("用户标识: {}, 消息类型有误,不支持类型: {}", userIdForReq, msgType);
}
} catch (Exception e) {
log.warn("【BusinessHandler】接收到请求内容:{},异常信息:{}", content, e.getMessage(), e);
// 异常返回
return;
}
}
}
6.3. AIChatUtils完整代码
import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.util.List;
import java.util.Objects;
@Slf4j
@Component
public class AIChatUtils {
@Autowired
private AIConfig aiConfig;
private Request buildRequest(Long userId, String prompt) throws Exception {
//创建一个请求体对象(body)
MediaType mediaType = MediaType.parse("application/json");
RequestBody requestBody = RequestBody.create(mediaType, prompt);
return buildHeader(userId, new Request.Builder().post(requestBody))
.url(aiConfig.getUrl()).build();
}
private Request.Builder buildHeader(Long userId, Request.Builder builder) throws Exception {
return builder
.addHeader("Content-Type", "application/json")
.addHeader("userId", String.valueOf(userId))
.addHeader("secret",generateSecret(userId))
}
/**
* 生成请求密钥
*
* @param userId 用户ID
* @return
*/
private String generateSecret(Long userId) throws Exception {
String key = aiConfig.getServerKey();
String content = key + userId + key;
MessageDigest digest = MessageDigest.getInstance("SHA-256");
byte[] hash = digest.digest(content.getBytes(StandardCharsets.UTF_8));
StringBuilder hexString = new StringBuilder();
for (byte b : hash) {
String hex = Integer.toHexString(0xff & b);
if (hex.length() == 1) {
hexString.append('0');
}
hexString.append(hex);
}
return hexString.toString();
}
public String chatStream(AITaskReqMessage aiTaskReqMessage) throws Exception {
String messageId = aiTaskReqMessage.getMessageId();
Long userId = aiTaskReqMessage.getUserId();
String contents = aiTaskReqMessage.getContents();
List<ChatContext> history = aiTaskReqMessage.getHistory();
if(StringUtils.isEmpty(contents) || StringUtils.isBlank(contents)){
log.warn("用户输入内容不能为空!");
return null;
}
//定义请求的参数
String prompt = JSON.toJSONString(AIChatReqVO.init(contents, history));
log.info("【AIChatUtils】调用AI聊天,用户({}),prompt:{}", userId, prompt);
//创建一个请求对象
Request request = buildRequest(userId, prompt);
InputStream is = null;
try {
// 从线程池获取http请求并执行
Response response =OkHttpUtils.getInstance(aiConfig).getOkHttpClient().newCall(request).execute();
// 响应结果
StringBuffer resultBuff = new StringBuffer();
//正常返回
if (response.code() == 200) {
//打印返回的字符数据
is = response.body().byteStream();
byte[] bytes = new byte[1024];
int len = is.read(bytes);
while (len != -1) {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
outputStream.write(bytes, 0, len);
outputStream.flush();
// 本轮读取到的数据
String result = new String(outputStream.toByteArray(), StandardCharsets.UTF_8);
resultBuff.append(result);
len = is.read(bytes);
// 将数据逐个传输给用户
AbstractBusinessLogicHandler.pushChatMessageForUser(userId, result);
}
// 正常响应
return resultBuff.toString();
}
else {
String result = response.body().string();
log.warn("处理异常,异常描述:{}",result);
}
} catch (Throwable e) {
log.error("【AIChatUtils】消息({})调用AI聊天 chatStream 异常,异常消息:{}", messageId, e.getMessage(), e);
} finally {
if (!Objects.isNull(is)) {
try {
is.close();
} catch (Exception e) {
e.printStackTrace();
}
}
}
return null;
}
}
版权归原作者 开源技术探险家 所有, 如有侵权,请联系我们删除。