基于SpringAI搭建系统,依靠线程池\负载均衡等技术进行请求优化,用于解决科研&开发过程中对GPT接口进行批量化接口请求中出现的问题。
github地址:https://github.com/linkcao/springai-wave
大语言模型接口以OpenAI的GPT 3.5为例,JDK版本为17,其他依赖版本可见仓库pom.xml
pom.xml
在处理大量提示文本时,存在以下挑战:
为了解决上述问题,本文提出了一种基于Spring框架的批量化提示访问方案,如下图所示:
其中具体包括以下步骤:
// 线程池初始化 private static final ExecutorService executor = Executors.newFixedThreadPool(10); /** * 多线程请求提示 * @param prompts * @param user * @param task * @return */ @Async public CompletableFuture<Void> processPrompts(List<String> prompts, Users user, Task task) { for (int i = 0; i < prompts.size();i++) { int finalI = i; // 提交任务 executor.submit(() -> processPrompt(prompts.get(finalI), user, finalI)); } // 设置批量任务状态 task.setStatus(TaskStatus.COMPLETED); taskService.setTask(task); return CompletableFuture.completedFuture(null); }
// 线程池初始化
private static final ExecutorService executor = Executors.newFixedThreadPool(10);
/**
* 多线程请求提示
* @param prompts
* @param user
* @param task
* @return
*/
@Async
public CompletableFuture<Void> processPrompts(List<String> prompts, Users user, Task task) {
for (int i = 0; i < prompts.size();i++) {
int finalI = i;
// 提交任务
executor.submit(() -> processPrompt(prompts.get(finalI), user, finalI));
}
// 设置批量任务状态
task.setStatus(TaskStatus.COMPLETED);
taskService.setTask(task);
return CompletableFuture.completedFuture(null);
如上所示,利用了Spring框架的@Async注解和线程池的功能,实现了多线程异步处理提示信息。
线程池
首先,使用了ExecutorService创建了一个固定大小的线程池,以便同时处理多个提示文本。
ExecutorService
然后,通过CompletableFuture来实现异步任务的管理。
CompletableFuture
在处理每个提示文本时,通过executor.submit()方法提交一个任务给线程池,让线程池来处理。
executor.submit()
处理完成后,将批量任务的状态设置为已完成,并更新任务状态。
一个线程任务需要绑定请求的用户以及所在的批量任务,当前任务所分配的key由任务所在队列的下标决定。
/** * 处理单条提示文本 * @param prompt 提示文本 * @param user 用户 * @param index 所在队列下标 */ public void processPrompt(String prompt, Users user, int index) { // 获取Api Key OpenAiApi openAiApi = getApiByIndex(user, index); assert openAiApi != null; ChatClient client = new OpenAiChatClient(openAiApi); // 提示文本请求 String response = client.call(prompt); // 日志记录 log.info("提示信息" + prompt ); log.info("输出" + response ); // 回答保存数据库 saveQuestionAndAnswer(user, prompt, response); }
* 处理单条提示文本
* @param prompt 提示文本
* @param user 用户
* @param index 所在队列下标
public void processPrompt(String prompt, Users user, int index) {
// 获取Api Key
OpenAiApi openAiApi = getApiByIndex(user, index);
assert openAiApi != null;
ChatClient client = new OpenAiChatClient(openAiApi);
// 提示文本请求
String response = client.call(prompt);
// 日志记录
log.info("提示信息" + prompt );
log.info("输出" + response );
// 回答保存数据库
saveQuestionAndAnswer(user, prompt, response);
/** * 采用任务下标分配key的方式进行负载均衡 * @param index 任务下标 * @return OpenAiApi */ private OpenAiApi getApiByIndex(int index){ List<KeyInfo> keyInfoList = keyRepository.findAll(); if (keyInfoList.isEmpty()) { return null; } // 根据任务队列下标分配 Key KeyInfo keyInfo = keyInfoList.get(index % keyInfoList.size()); return new OpenAiApi(keyInfo.getApi(),keyInfo.getKeyValue()); }
* 采用任务下标分配key的方式进行负载均衡
* @param index 任务下标
* @return OpenAiApi
private OpenAiApi getApiByIndex(int index){
List<KeyInfo> keyInfoList = keyRepository.findAll();
if (keyInfoList.isEmpty()) {
return null;
// 根据任务队列下标分配 Key
KeyInfo keyInfo = keyInfoList.get(index % keyInfoList.size());
return new OpenAiApi(keyInfo.getApi(),keyInfo.getKeyValue());
/** * 依靠线程池批量请求GPT * @param promptFile 传入的批量提示文件,每一行为一个提示语句 * @param username 调用的用户 * @return 处理状态 */ @PostMapping("/batch") public String batchPrompt(MultipartFile promptFile, String username){ if (promptFile.isEmpty()) { return "上传的文件为空"; } // 批量请求任务 Task task = new Task(); try { BufferedReader reader = new BufferedReader(new InputStreamReader(promptFile.getInputStream())); List<String> prompts = new ArrayList<>(); String line; while ((line = reader.readLine()) != null) { prompts.add(line); } // 用户信息请求 Users user = userService.findByUsername(username); // 任务状态设置 task.setFileName(promptFile.getName()); task.setStartTime(LocalDateTime.now()); task.setUserId(user.getUserId()); task.setStatus(TaskStatus.PROCESSING); // 线程池处理 chatService.processPrompts(prompts, user, task); return "文件上传成功,已开始批量处理提示"; } catch ( IOException e) { // 处理失败 e.printStackTrace(); task.setStatus(TaskStatus.FAILED); return "上传文件时出错:" + e.getMessage(); } finally { // 任务状态保存 taskService.setTask(task); } }
* 依靠线程池批量请求GPT
* @param promptFile 传入的批量提示文件,每一行为一个提示语句
* @param username 调用的用户
* @return 处理状态
@PostMapping("/batch")
public String batchPrompt(MultipartFile promptFile, String username){
if (promptFile.isEmpty()) {
return "上传的文件为空";
// 批量请求任务
Task task = new Task();
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(promptFile.getInputStream()));
List<String> prompts = new ArrayList<>();
String line;
while ((line = reader.readLine()) != null) {
prompts.add(line);
// 用户信息请求
Users user = userService.findByUsername(username);
// 任务状态设置
task.setFileName(promptFile.getName());
task.setStartTime(LocalDateTime.now());
task.setUserId(user.getUserId());
task.setStatus(TaskStatus.PROCESSING);
// 线程池处理
chatService.processPrompts(prompts, user, task);
return "文件上传成功,已开始批量处理提示";
} catch ( IOException e) {
// 处理失败
e.printStackTrace();
task.setStatus(TaskStatus.FAILED);
return "上传文件时出错:" + e.getMessage();
} finally {
// 任务状态保存
ChatService
processPrompts()
所有信息都与用户ID强绑定,便于管理和查询,ER图如下所示:
批量请求文件
username
localhost:8080/batch
请回答1+2=?请回答8*12=?请回答12*9=?请回答321-12=?请回答12/4=?请回答32%2=?
请回答1+2=?
请回答8*12=?
请回答12*9=?
请回答321-12=?
请回答12/4=?
请回答32%2=?
question_id
user_id
原文链接:https://www.cnblogs.com/linkcxt/p/18187163
本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728