0


在springboot中调用openai Api并实现流式响应

之前在《在springboot项目中调用openai API及我遇到的问题》这篇博客中,我实现了在springboot中调用openai接口,但是在这里的返回的信息是一次性全部返回的,如果返回的文字比较多,我们可能需要等很久。 所以需要考虑将请求接口响应方式改为流式响应。

openai api文档

查阅openai的api文档,文档中说我们只需要在请求体中添加"stream":true就可以实现流式响应了。

openai api文档流式响应参数

文档中还说当返回值为

  1. data: [DONE]

时,标识响应结束。

码代码!!!

跟之前一样,为了缩减篇幅,set、get、构造器都省略

配置

properties

  1. openai.key=你的key
  2. openai.chatgtp.model=gpt-3.5-turbo
  3. openai.gpt4.model=gpt-4-turbo-preview
  4. openai.chatgtp.api.url=/v1/chat/completions

pom文件

我们在项目中引入websocket和webflux 之前使用的RestTemplate并不擅长处理异步流式的请求。所以我们改用web flux。

  1. <!-- websocket依赖-->
  2. <dependency>
  3. <groupId>org.springframework.boot</groupId>
  4. <artifactId>spring-boot-starter-websocket</artifactId>
  5. </dependency>
  6. <!-- 流式异步响应客户端-->
  7. <dependency>
  8. <groupId>org.springframework.boot</groupId>
  9. <artifactId>spring-boot-starter-webflux</artifactId>
  10. </dependency>

请求体类

  1. public class ChatRequest {
  2. // 使用的模型
  3. private String model;
  4. // 历史对话记录
  5. private List<ChatMessage> messages;
  6. private Boolean stream = Boolean.TRUE;
  7. @Override
  8. public String toString() {
  9. try {
  10. return ConstValuePool.OBJECT_MAPPER.writeValueAsString(this);
  11. } catch (JsonProcessingException e) {
  12. throw new RuntimeException(e);
  13. }
  14. }
  15. }

请求体中的信息类

  1. public class ChatMessage {
  2. // 角色
  3. private String role;
  4. // 消息内容
  5. private String content;
  6. }

响应类

响应类先看接口的返回格式的示例吧。下面json中的content就是本次响应数据

  1. {
  2. "id": "chatcmpl-8uk7ofAZnSJhsHlsQ9mSYwFInuSFq",
  3. "object": "chat.completion.chunk",
  4. "created": 1708534364,
  5. "model": "gpt-3.5-turbo-0125",
  6. "system_fingerprint": "fp_cbdb91ce3f",
  7. "choices": [
  8. {
  9. "index": 0,
  10. "delta": {
  11. "content": "吗"
  12. },
  13. "logprobs": null,
  14. "finish_reason": null
  15. }
  16. ]
  17. }

根据json格式,我们构造响应体类如下

1)响应体主体类

  1. public class ChatResponse {
  2. private String id;
  3. private String object;
  4. private Long created;
  5. private String model;
  6. private String system_fingerprint;
  7. // GPT返回的对话列表
  8. private List<Choice> choices;
  9. public static class Choice {
  10. private int index;
  11. private Delta delta;
  12. private Object logprobs;
  13. private Object finish_reason;
  14. }
  15. }

2)Delta类

  1. public class Delta {
  2. private String role;
  3. private String content;
  4. }

常量池类

  1. public class ConstValuePool {
  2. // openai代理客户端
  3. public static WebClient PROXY_OPENAI_CLIENT = null;
  4. }

客户端类

客户端一样还是在钩子函数中生成。

  1. @Component
  2. public class ApiCodeLoadAware implements EnvironmentAware, ApplicationContextAware {
  3. Environment environment;
  4. @Override
  5. public void setEnvironment(Environment environment) {
  6. this.environment = environment;
  7. }
  8. @Override
  9. public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
  10. // chatgpt、gpt4
  11. HttpClient httpClient = HttpClient.create().proxy(clientProxy ->
  12. clientProxy.type(ProxyProvider.Proxy.HTTP) // 设置代理类型
  13. .host("127.0.0.1") // 代理主机
  14. .port(7890)); // 代理端口
  15. ConstValuePool.PROXY_OPENAI_CLIENT = WebClient.builder()
  16. .clientConnector(new ReactorClientHttpConnector(httpClient))
  17. .baseUrl("https://api.openai.com")
  18. .defaultHeader("Authorization", "Bearer " + environment.getProperty("openai.key"))
  19. .build();
  20. }
  21. }

websocket后端配置

webscoekt具体可以看我之前的博客使用websocket实现服务端主动发送消息到客户端

1)websocket配置类

  1. @Configuration
  2. public class WebsocketConfig {
  3. @Bean
  4. public ServerEndpointExporter getServerEndpointExporter() {
  5. return new ServerEndpointExporter();
  6. }
  7. }

2)websocket类

这里的参数id是为了区分具体是那个websocket需要推送消息,可以通过登录等方式提供给用户

  1. @Component
  2. @ServerEndpoint("/aiWebsocket/{id}")
  3. public class AiWebsocketService {
  4. private final Logger logger = LoggerFactory.getLogger(AiWebsocketService.class);
  5. private Session session;
  6. //存放所有的websocket连接
  7. private static Map<String,AiWebsocketService> aiWebSocketServicesMap = new ConcurrentHashMap<>();
  8. //建立websocket连接时自动调用
  9. @OnOpen
  10. public void onOpen(Session session,@PathParam("id") String id){
  11. this.session = session;
  12. aiWebSocketServicesMap.put(id, this);
  13. logger.debug("有新的websocket连接进入,当前连接总数为" + aiWebSocketServicesMap.size());
  14. }
  15. //关闭websocket连接时自动调用
  16. @OnClose
  17. public void onClose(){
  18. aiWebSocketServicesMap.remove(this);
  19. logger.debug("连接断开,当前连接总数为" + aiWebSocketServicesMap.size());
  20. }
  21. //websocket接收到消息时自动调用
  22. @OnMessage
  23. public void onMessage(String message){
  24. logger.debug("this:" + message);
  25. }
  26. //通过websocket发送消息
  27. public void sendMessage(String message, String id){
  28. AiWebsocketService aiWebsocketService = aiWebSocketServicesMap.get(id);
  29. if (aiWebsocketService == null) {
  30. return;
  31. }
  32. try {
  33. aiWebsocketService.session.getBasicRemote().sendText(message);
  34. } catch (IOException e) {
  35. logger.debug(this + "发送消息错误:" + e.getClass() + ":" + e.getMessage());
  36. }
  37. }
  38. }

ai消息工具类

  1. @Component
  2. public class ChatGptModelService implements AiModelService{
  3. private static final Logger logger = LoggerFactory.getLogger(ChatGptModelService.class);
  4. @Value("${openai.chatgtp.api.url}")
  5. private String uri;
  6. @Value(("${openai.chatgtp.model}"))
  7. private String model;
  8. @Resource
  9. private AiWebsocketService aiWebsocketService;
  10. @Override
  11. public String answer(String prompt, HttpServletRequest request) throws InterruptedException {
  12. HttpSession session = request.getSession();
  13. String identity = AiIdentityFlagUtil.getAiIdentity(request);
  14. // 获取历史对话列表,chatMessages实现连续对话、chatDialogues便于页面显示
  15. List<ChatMessage> chatMessages = (List<ChatMessage>) session.getAttribute(ConstValuePool.CHAT_MESSAGE_DIALOGUES);
  16. List<AiDialogue> chatDialogues = (List<AiDialogue>) session.getAttribute(ConstValuePool.CHAT_DIALOGUES);
  17. if (chatMessages == null) {
  18. chatMessages = new ArrayList<>();
  19. chatMessages.add(ChatMessage.createSystemDialogue("You are a helpful assistant."));
  20. chatDialogues = new ArrayList<>();
  21. session.setAttribute(ConstValuePool.CHAT_DIALOGUES, chatDialogues);
  22. session.setAttribute(ConstValuePool.CHAT_MESSAGE_DIALOGUES, chatMessages);
  23. }
  24. chatMessages.add(new ChatMessage("user", prompt));
  25. chatDialogues.add(AiDialogue.createUserDialogue(prompt));
  26. ChatRequest chatRequest = new ChatRequest(this.model, chatMessages);
  27. logger.debug("发送的请求为:{}",chatRequest);
  28. Flux<String> chatResponseFlux = ConstValuePool.PROXY_OPENAI_CLIENT
  29. .post()
  30. .uri(uri)
  31. .contentType(MediaType.APPLICATION_JSON)
  32. .bodyValue(chatRequest.toString())
  33. .retrieve()
  34. .bodyToFlux(String.class);// 得到string返回,便于查看结束标志
  35. StringBuilder resultBuilder = new StringBuilder();
  36. // 设置同步信号量
  37. Semaphore semaphore = new Semaphore(0);
  38. chatResponseFlux.subscribe(
  39. value -> {
  40. logger.debug("返回结果:{}", value);
  41. if ("[DONE]".equals(value)) {
  42. return;
  43. }
  44. try {
  45. ChatResponse chatResponse = ConstValuePool.OBJECT_MAPPER.readValue(value, ChatResponse.class);
  46. List<ChatResponse.Choice> choices = chatResponse.getChoices();
  47. ChatResponse.Choice choice = choices.get(choices.size() - 1);
  48. Delta delta = choice.getDelta();
  49. String res = delta.getContent();
  50. if (res != null) {
  51. resultBuilder.append(res);
  52. aiWebsocketService.sendMessage(resultBuilder.toString(), identity);
  53. }
  54. } catch (JsonProcessingException e) {
  55. throw new AiException("chatgpt运行出错",e);
  56. }
  57. }, // 获得数据,拼接结果,发送给前端
  58. error -> {
  59. semaphore.release();
  60. throw new AiException("chatpgt执行出错",error);
  61. }, // 失败释放信号量,并报错
  62. semaphore::release// 成功释放信号量
  63. );
  64. semaphore.acquire();
  65. String resString = resultBuilder.toString();
  66. logger.debug(resString);
  67. chatDialogues.add(AiDialogue.createAssistantDialogue(resString));
  68. chatMessages.add(ChatMessage.createAssistantDialogue(resString));
  69. // 对话轮数过多删除最早的历史对话,避免大量消耗tokens
  70. while (chatMessages.size() > ConstValuePool.CHAT_MAX_MESSAGE) {
  71. chatMessages.remove(0);
  72. }
  73. return "";
  74. }
  75. }

页面

因为我的前端写的不太好,就不展示前端代码了

看结果

能够实现

openai api流式调用结果1

openai api流式调用结果2

标签: spring boot java spring

本文转载自: https://blog.csdn.net/qq_56460466/article/details/136235175
版权归原作者 写做四月一日的四月一日 所有, 如有侵权,请联系我们删除。

“在springboot中调用openai Api并实现流式响应”的评论:

还没有评论