diff --git a/CHANGELOG.md b/CHANGELOG.md index a38724ece..627e89591 100755 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,10 @@ # 🚀Changelog ------------------------------------------------------------------------------------------------------------- -# 5.8.39(2025-05-13) +# 5.8.39(2025-05-19) ### 🐣新特性 +* 【ai 】 增加SSE流式返回函数参数callback,豆包、grok新增文生图接口,豆包生成视频支持使用model ### 🐞Bug修复 * 【core 】 修复`NumberUtil`isNumber方法以L结尾没有小数点判断问题(issue#3938@Github) * 【core 】 修复`CharsequenceUtil`toLowerCase方法拼写错误(issue#3941@Github) diff --git a/hutool-ai/pom.xml b/hutool-ai/pom.xml index 5efa669bc..fb1cc9a87 100644 --- a/hutool-ai/pom.xml +++ b/hutool-ai/pom.xml @@ -40,6 +40,11 @@ ${project.parent.version} compile + + com.fasterxml.jackson.core + jackson-databind + 2.13.5 + diff --git a/hutool-ai/src/main/java/cn/hutool/ai/Models.java b/hutool-ai/src/main/java/cn/hutool/ai/Models.java index dd3ac7364..e11a7f2a7 100644 --- a/hutool-ai/src/main/java/cn/hutool/ai/Models.java +++ b/hutool-ai/src/main/java/cn/hutool/ai/Models.java @@ -123,7 +123,12 @@ public class Models { DOUBAO_VISION_LITE_32K("doubao-vision-lite-32k-241015"), DOUBAO_EMBEDDING_LARGE("doubao-embedding-large-text-240915"), DOUBAO_EMBEDDING_TEXT_240715("doubao-embedding-text-240715"), - DOUBAO_EMBEDDING_VISION("doubao-embedding-vision-241215"); + DOUBAO_EMBEDDING_VISION("doubao-embedding-vision-241215"), + DOUBAO_SEEDREAM_3_0_T2I("doubao-seedream-3-0-t2i-250415"), + Doubao_Seedance_1_0_lite_t2v("doubao-seedance-1-0-lite-t2v-250428"), + Doubao_Seedance_1_0_lite_i2v("doubao-seedance-1-0-lite-i2v-250428"), + Wan2_1_14B_t2v("wan2-1-14b-t2v-250225"), + Wan2_1_14B_i2v("wan2-1-14b-i2v-250225"); private final String model; @@ -138,6 +143,23 @@ public class Models { // Grok的模型 public enum Grok { + GROK_3_BETA_LATEST("grok-3-beta"), + GROK_3_BETA("grok-3-beta"), + GROK_3("grok-3-beta"), + GROK_3_MINI_FAST_LATEST("grok-3-mini-fast-beta"), + GROK_3_MINI_FAST_BETA("grok-3-mini-fast-beta"), + GROK_3_MINI_FAST("grok-3-mini-fast-beta"), + GROK_3_FAST_LATEST("grok-3-fast-beta"), + GROK_3_FAST_BETA("grok-3-fast-beta"), + GROK_3_FAST("grok-3-fast-beta"), + GROK_3_MINI_LATEST("grok-3-mini-beta"), + GROK_3_MINI_BETA("grok-3-mini-beta"), + GROK_3_MINI("grok-3-mini-beta"), + GROK_2_IMAGE_LATEST("grok-2-image-1212"), + GROK_2_IMAGE("grok-2-image-1212"), + GROK_2_IMAGE_1212("grok-2-image-1212"), + grok_2_latest("grok-2-1212"), + GROK_2("grok-2-1212"), GROK_2_1212("grok-2-1212"), GROK_2_VISION_1212("grok-2-vision-1212"), GROK_BETA("grok-beta"), diff --git a/hutool-ai/src/main/java/cn/hutool/ai/core/AIService.java b/hutool-ai/src/main/java/cn/hutool/ai/core/AIService.java index 8bff076f8..3805f0e4a 100644 --- a/hutool-ai/src/main/java/cn/hutool/ai/core/AIService.java +++ b/hutool-ai/src/main/java/cn/hutool/ai/core/AIService.java @@ -16,7 +16,9 @@ package cn.hutool.ai.core; +import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; /** * 模型公共的API功能,特有的功能在model.xx.XXService下定义 @@ -33,7 +35,25 @@ public interface AIService { * @return AI回答 * @since 5.8.38 */ - String chat(String prompt); + default String chat(String prompt){ + final List messages = new ArrayList<>(); + messages.add(new Message("system", "You are a helpful assistant")); + messages.add(new Message("user", prompt)); + return chat(messages); + } + + /** + * 对话-SSE流式输出 + * @param prompt user题词 + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + default void chat(String prompt, final Consumer callback){ + final List messages = new ArrayList<>(); + messages.add(new Message("system", "You are a helpful assistant")); + messages.add(new Message("user", prompt)); + chat(messages, callback); + } /** * 对话 @@ -44,4 +64,12 @@ public interface AIService { */ String chat(final List messages); + + /** + * 对话-SSE流式输出 + * @param messages 由目前为止的对话组成的消息列表,可以设置role,content。详细参考官方文档 + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + void chat(final List messages, final Consumer callback); } diff --git a/hutool-ai/src/main/java/cn/hutool/ai/core/BaseAIService.java b/hutool-ai/src/main/java/cn/hutool/ai/core/BaseAIService.java index 218db8012..a1944394a 100644 --- a/hutool-ai/src/main/java/cn/hutool/ai/core/BaseAIService.java +++ b/hutool-ai/src/main/java/cn/hutool/ai/core/BaseAIService.java @@ -18,8 +18,15 @@ package cn.hutool.ai.core; import cn.hutool.ai.AIException; import cn.hutool.http.*; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URL; import java.util.Map; +import java.util.function.Consumer; /** * 基础AIService,包含基公共参数和公共方法 @@ -102,4 +109,50 @@ public class BaseAIService { throw new AIException("Failed to send POST request:" + e.getMessage(), e); } } + + /** + * 支持流式返回的 POST 请求 + * + * @param endpoint 请求地址 + * @param paramMap 请求参数 + * @param callback 流式数据回调函数 + */ + protected void sendPostStream(final String endpoint, final Map paramMap, Consumer callback) { + HttpURLConnection connection = null; + try { + // 创建连接 + URL apiUrl = new URL(config.getApiUrl() + endpoint); + connection = (HttpURLConnection) apiUrl.openConnection(); + connection.setRequestMethod(Method.POST.name()); + connection.setRequestProperty(Header.CONTENT_TYPE.getValue(), "application/json"); + connection.setRequestProperty(Header.AUTHORIZATION.getValue(), "Bearer " + config.getApiKey()); + connection.setDoOutput(true); + //5分钟 + connection.setReadTimeout(300000); + //3分钟 + connection.setConnectTimeout(180000); + // 发送请求体 + try (OutputStream os = connection.getOutputStream()) { + String jsonInputString = new ObjectMapper().writeValueAsString(paramMap); + os.write(jsonInputString.getBytes()); + os.flush(); + } + + // 读取流式响应 + try (BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream()))) { + String line; + while ((line = reader.readLine()) != null) { + // 调用回调函数处理每一行数据 + callback.accept(line); + } + } + } catch (Exception e) { + callback.accept("{\"error\": \"" + e.getMessage() + "\"}"); + } finally { + // 关闭连接 + if (connection != null) { + connection.disconnect(); + } + } + } } diff --git a/hutool-ai/src/main/java/cn/hutool/ai/model/deepseek/DeepSeekService.java b/hutool-ai/src/main/java/cn/hutool/ai/model/deepseek/DeepSeekService.java index b601d5934..d537a078d 100644 --- a/hutool-ai/src/main/java/cn/hutool/ai/model/deepseek/DeepSeekService.java +++ b/hutool-ai/src/main/java/cn/hutool/ai/model/deepseek/DeepSeekService.java @@ -17,6 +17,7 @@ package cn.hutool.ai.model.deepseek; import cn.hutool.ai.core.AIService; +import java.util.function.Consumer; /** * deepSeek支持的扩展接口 @@ -35,6 +36,14 @@ public interface DeepSeekService extends AIService { */ String beta(String prompt); + /** + * 模型beta功能-SSE流式输出 + * @param prompt 题词 + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + void beta(String prompt, final Consumer callback); + /** * 列出所有模型列表 * diff --git a/hutool-ai/src/main/java/cn/hutool/ai/model/deepseek/DeepSeekServiceImpl.java b/hutool-ai/src/main/java/cn/hutool/ai/model/deepseek/DeepSeekServiceImpl.java index 54c38f9dc..624b97aab 100644 --- a/hutool-ai/src/main/java/cn/hutool/ai/model/deepseek/DeepSeekServiceImpl.java +++ b/hutool-ai/src/main/java/cn/hutool/ai/model/deepseek/DeepSeekServiceImpl.java @@ -19,6 +19,7 @@ package cn.hutool.ai.model.deepseek; import cn.hutool.ai.core.AIConfig; import cn.hutool.ai.core.BaseAIService; import cn.hutool.ai.core.Message; +import cn.hutool.core.thread.ThreadUtil; import cn.hutool.http.HttpResponse; import cn.hutool.json.JSONUtil; @@ -26,6 +27,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; /** * DeepSeek服务,AI具体功能的实现 @@ -54,15 +56,6 @@ public class DeepSeekServiceImpl extends BaseAIService implements DeepSeekServic super(config); } - @Override - public String chat(final String prompt) { - // 定义消息结构 - final List messages = new ArrayList<>(); - messages.add(new Message("system", "You are a helpful assistant")); - messages.add(new Message("user", prompt)); - return chat(messages); - } - @Override public String chat(final List messages) { final String paramJson = buildChatRequestBody(messages); @@ -70,6 +63,12 @@ public class DeepSeekServiceImpl extends BaseAIService implements DeepSeekServic return response.body(); } + @Override + public void chat(final List messages, final Consumer callback) { + Map paramMap = buildChatStreamRequestBody(messages); + ThreadUtil.newThread(() -> sendPostStream(CHAT_ENDPOINT, paramMap, callback::accept), "deepseek-chat-sse").start(); + } + @Override public String beta(final String prompt) { final String paramJson = buildBetaRequestBody(prompt); @@ -77,6 +76,12 @@ public class DeepSeekServiceImpl extends BaseAIService implements DeepSeekServic return response.body(); } + @Override + public void beta(final String prompt, final Consumer callback) { + Map paramMap = buildBetaStreamRequestBody(prompt); + ThreadUtil.newThread(() -> sendPostStream(BETA_ENDPOINT, paramMap, callback::accept), "deepseek-beta-sse").start(); + } + @Override public String models() { final HttpResponse response = sendGet(MODELS_ENDPOINT); @@ -101,6 +106,19 @@ public class DeepSeekServiceImpl extends BaseAIService implements DeepSeekServic return JSONUtil.toJsonStr(paramMap); } + // 构建chatStream请求体 + private Map buildChatStreamRequestBody(final List messages) { + //使用JSON工具 + final Map paramMap = new HashMap<>(); + paramMap.put("stream", true); + paramMap.put("model", config.getModel()); + paramMap.put("messages", messages); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); + + return paramMap; + } + // 构建beta请求体 private String buildBetaRequestBody(final String prompt) { // 定义消息结构 @@ -108,10 +126,23 @@ public class DeepSeekServiceImpl extends BaseAIService implements DeepSeekServic final Map paramMap = new HashMap<>(); paramMap.put("model", config.getModel()); paramMap.put("prompt", prompt); -// //合并其他参数 + //合并其他参数 paramMap.putAll(config.getAdditionalConfigMap()); return JSONUtil.toJsonStr(paramMap); } + // 构建betaStream请求体 + private Map buildBetaStreamRequestBody(final String prompt) { + //使用JSON工具 + final Map paramMap = new HashMap<>(); + paramMap.put("stream", true); + paramMap.put("model", config.getModel()); + paramMap.put("prompt", prompt); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); + + return paramMap; + } + } diff --git a/hutool-ai/src/main/java/cn/hutool/ai/model/doubao/DoubaoService.java b/hutool-ai/src/main/java/cn/hutool/ai/model/doubao/DoubaoService.java index 14d087115..e94ae258f 100644 --- a/hutool-ai/src/main/java/cn/hutool/ai/model/doubao/DoubaoService.java +++ b/hutool-ai/src/main/java/cn/hutool/ai/model/doubao/DoubaoService.java @@ -19,7 +19,9 @@ package cn.hutool.ai.model.doubao; import cn.hutool.ai.core.AIService; import cn.hutool.ai.core.Message; +import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; /** * doubao支持的扩展接口 @@ -29,6 +31,30 @@ import java.util.List; */ public interface DoubaoService extends AIService { + /** + * 图像理解:模型会依据传入的图片信息以及问题,给出回复。 + * + * @param prompt 提问 + * @param images 传入的图片列表地址/或者图片Base64编码图片列表(URI形式) + * @return AI回答 + * @since 5.8.38 + */ + default String chatVision(String prompt, final List images) { + return chatVision(prompt, images, DoubaoCommon.DoubaoVision.AUTO.getDetail()); + } + + /** + * 图像理解-SSE流式输出 + * + * @param prompt 提问 + * @param images 图片列表/或者图片Base64编码图片列表(URI形式) + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + default void chatVision(String prompt, final List images, final Consumer callback) { + chatVision(prompt, images, DoubaoCommon.DoubaoVision.AUTO.getDetail(), callback); + } + /** * 图像理解:模型会依据传入的图片信息以及问题,给出回复。 * @@ -41,16 +67,15 @@ public interface DoubaoService extends AIService { String chatVision(String prompt, final List images, String detail); /** - * 图像理解:模型会依据传入的图片信息以及问题,给出回复。 + * 图像理解-SSE流式输出 * * @param prompt 提问 * @param images 传入的图片列表地址/或者图片Base64编码图片列表(URI形式) - * @return AI回答 - * @since 5.8.38 + * @param detail 手动设置图片的质量,取值范围high、low、auto,默认为auto + * @param callback 流式数据回调函数 + * @since 5.8.39 */ - default String chatVision(String prompt, final List images) { - return chatVision(prompt, images, DoubaoCommon.DoubaoVision.AUTO.getDetail()); - } + void chatVision(String prompt, final List images, String detail, final Consumer callback); /** * 创建视频生成任务 @@ -66,7 +91,7 @@ public interface DoubaoService extends AIService { /** * 创建视频生成任务 - * 注意:调用该方法时,配置config中的model为您创建的推理接入点(Endpoint)ID。详细参考官方文档 + * 注意:调用该方法时,配置config中的model为生成视频的模型或者您创建的推理接入点(Endpoint)ID。详细参考官方文档 * * @param text 文本提示词 * @param image 图片/或者图片Base64编码图片(URI形式) @@ -114,6 +139,15 @@ public interface DoubaoService extends AIService { */ String botsChat(final List messages); + /** + * 应用(Bot)-SSE流式输出 config中model设置为您创建的应用ID + * + * @param messages 由对话组成的消息列表。如系统人设,背景信息等,用户自定义的信息 + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + void botsChat(final List messages, final Consumer callback); + /** * 分词:可以将文本转换为模型可理解的 token id,并返回文本的 tokens 数量、token id、 token 在原始文本中的偏移量等信息 * @@ -132,7 +166,12 @@ public interface DoubaoService extends AIService { * @return AI回答 * @since 5.8.38 */ - String batchChat(String prompt); + default String batchChat(String prompt){ + final List messages = new ArrayList<>(); + messages.add(new Message("system", "You are a helpful assistant")); + messages.add(new Message("user", prompt)); + return batchChat(messages); + } /** * 批量推理 Chat @@ -179,7 +218,26 @@ public interface DoubaoService extends AIService { * @return AI的回答 * @since 5.8.38 */ - String chatContext(String prompt, String contextId); + default String chatContext(String prompt, String contextId){ + final List messages = new ArrayList<>(); + messages.add(new Message("user", prompt)); + return chatContext(messages, contextId); + } + + /** + * 上下文缓存对话-SSE流式输出 + * 注意:配置config中的model可以为您创建的推理接入点(Endpoint)ID,也可以是支持chat的model + * + * @param prompt 对话的内容题词 + * @param contextId 创建上下文缓存后获取的缓存id + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + default void chatContext(String prompt, String contextId, final Consumer callback){ + final List messages = new ArrayList<>(); + messages.add(new Message("user", prompt)); + chatContext(messages, contextId, callback); + } /** * 上下文缓存对话: 向大模型发起带上下文缓存的请求 @@ -192,4 +250,25 @@ public interface DoubaoService extends AIService { */ String chatContext(final List messages, String contextId); + /** + * 上下文缓存对话-SSE流式输出 + * 注意:配置config中的model可以为您创建的推理接入点(Endpoint)ID,也可以是支持chat的model + * + * @param messages 对话的信息 不支持最后一个元素的role设置为assistant。如使用session 缓存(mode设置为session)传入最新一轮对话的信息,无需传入历史信息 + * @param contextId 创建上下文缓存后获取的缓存id + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + void chatContext(final List messages, String contextId, final Consumer callback); + + /** + * 文生图 + * 请设置config中model为支持图片功能的模型,目前支持Doubao-Seedream-3.0-t2i + * + * @param prompt 题词 + * @return 包含生成图片的url + * @since 5.8.39 + */ + String imagesGenerations(String prompt); + } diff --git a/hutool-ai/src/main/java/cn/hutool/ai/model/doubao/DoubaoServiceImpl.java b/hutool-ai/src/main/java/cn/hutool/ai/model/doubao/DoubaoServiceImpl.java index 88a24d218..c1836fec4 100644 --- a/hutool-ai/src/main/java/cn/hutool/ai/model/doubao/DoubaoServiceImpl.java +++ b/hutool-ai/src/main/java/cn/hutool/ai/model/doubao/DoubaoServiceImpl.java @@ -19,6 +19,7 @@ package cn.hutool.ai.model.doubao; import cn.hutool.ai.core.AIConfig; import cn.hutool.ai.core.BaseAIService; import cn.hutool.ai.core.Message; +import cn.hutool.core.thread.ThreadUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.http.HttpResponse; import cn.hutool.json.JSONUtil; @@ -27,6 +28,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; /** * Doubao服务,AI具体功能的实现 @@ -54,21 +56,14 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { private final String CHAT_CONTEXT = "/context/chat/completions"; //创建视频生成任务 private final String CREATE_VIDEO = "/contents/generations/tasks"; + //文生图 + private final String IMAGES_GENERATIONS = "/images/generations"; public DoubaoServiceImpl(final AIConfig config) { //初始化doubao客户端 super(config); } - @Override - public String chat(String prompt) { - // 定义消息结构 - final List messages = new ArrayList<>(); - messages.add(new Message("system", "You are a helpful assistant")); - messages.add(new Message("user", prompt)); - return chat(messages); - } - @Override public String chat(final List messages) { String paramJson = buildChatRequestBody(messages); @@ -76,6 +71,12 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { return response.body(); } + @Override + public void chat(final List messages, final Consumer callback) { + Map paramMap = buildChatStreamRequestBody(messages); + ThreadUtil.newThread(() -> sendPostStream(CHAT_ENDPOINT, paramMap, callback::accept), "doubao-chat-sse").start(); + } + @Override public String chatVision(String prompt, final List images, String detail) { String paramJson = buildChatVisionRequestBody(prompt, images, detail); @@ -83,6 +84,12 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { return response.body(); } + @Override + public void chatVision(String prompt, List images, String detail, Consumer callback) { + Map paramMap = buildChatVisionStreamRequestBody(prompt, images, detail); + ThreadUtil.newThread(() -> sendPostStream(CHAT_ENDPOINT, paramMap, callback::accept), "doubao-chatVision-sse").start(); + } + @Override public String videoTasks(String text, String image, final List videoParams) { String paramJson = buildGenerationsTasksRequestBody(text, image, videoParams); @@ -118,6 +125,12 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { return response.body(); } + @Override + public void botsChat(List messages, Consumer callback) { + Map paramMap = buildBotsChatStreamRequestBody(messages); + ThreadUtil.newThread(() -> sendPostStream(BOTS_CHAT, paramMap, callback::accept), "doubao-botsChat-sse").start(); + } + @Override public String tokenization(String[] text) { String paramJson = buildTokenizationRequestBody(text); @@ -125,14 +138,6 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { return response.body(); } - @Override - public String batchChat(String prompt) { - // 定义消息结构 - final List messages = new ArrayList<>(); - messages.add(new Message("system", "You are a helpful assistant")); - messages.add(new Message("user", prompt)); - return batchChat(messages); - } @Override public String batchChat(final List messages) { @@ -148,14 +153,6 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { return response.body(); } - @Override - public String chatContext(String prompt, String contextId) { - // 定义消息结构 - final List messages = new ArrayList<>(); - messages.add(new Message("user", prompt)); - return chatContext(messages, contextId); - } - @Override public String chatContext(final List messages, String contextId) { String paramJson = buildChatContentRequestBody(messages, contextId); @@ -163,6 +160,19 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { return response.body(); } + @Override + public void chatContext(final List messages, String contextId, final Consumer callback) { + Map paramMap = buildChatContentStreamRequestBody(messages, contextId); + ThreadUtil.newThread(() -> sendPostStream(CHAT_CONTEXT, paramMap, callback::accept), "doubao-chatContext-sse").start(); + } + + @Override + public String imagesGenerations(String prompt) { + String paramJson = buildImagesGenerationsRequestBody(prompt); + final HttpResponse response = sendPost(IMAGES_GENERATIONS, paramJson); + return response.body(); + } + // 构建chat请求体 private String buildChatRequestBody(final List messages) { //使用JSON工具 @@ -175,6 +185,19 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { return JSONUtil.toJsonStr(paramMap); } + // 构建chatStream请求体 + private Map buildChatStreamRequestBody(final List messages) { + //使用JSON工具 + final Map paramMap = new HashMap<>(); + paramMap.put("stream", true); + paramMap.put("model", config.getModel()); + paramMap.put("messages", messages); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); + + return paramMap; + } + //构建chatVision请求体 private String buildChatVisionRequestBody(String prompt, final List images, String detail) { // 定义消息结构 @@ -206,6 +229,37 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { return JSONUtil.toJsonStr(paramMap); } + private Map buildChatVisionStreamRequestBody(String prompt, final List images, String detail) { + // 定义消息结构 + final List messages = new ArrayList<>(); + final List content = new ArrayList<>(); + + final Map contentMap = new HashMap<>(); + contentMap.put("type", "text"); + contentMap.put("text", prompt); + content.add(contentMap); + for (String img : images) { + HashMap imgUrlMap = new HashMap<>(); + imgUrlMap.put("type", "image_url"); + HashMap urlMap = new HashMap<>(); + urlMap.put("url", img); + urlMap.put("detail", detail); + imgUrlMap.put("image_url", urlMap); + content.add(imgUrlMap); + } + + messages.add(new Message("user", content)); + + //使用JSON工具 + final Map paramMap = new HashMap<>(); + paramMap.put("stream", true); + paramMap.put("model", config.getModel()); + paramMap.put("messages", messages); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); + return paramMap; + } + //构建文本向量化请求体 private String buildEmbeddingTextRequestBody(String[] input) { //使用JSON工具 @@ -253,6 +307,10 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { return buildChatRequestBody(messages); } + private Map buildBotsChatStreamRequestBody(final List messages) { + return buildChatStreamRequestBody(messages); + } + //构建分词请求体 private String buildTokenizationRequestBody(String[] text) { final Map paramMap = new HashMap<>(); @@ -266,6 +324,10 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { return buildChatRequestBody(messages); } + private Map buildBatchChatStreamRequestBody(final List messages) { + return buildChatStreamRequestBody(messages); + } + //构建创建上下文缓存请求体 private String buildCreateContextRequest(final List messages, String mode) { final Map paramMap = new HashMap<>(); @@ -291,6 +353,19 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { return JSONUtil.toJsonStr(paramMap); } + private Map buildChatContentStreamRequestBody(final List messages, String contextId) { + //使用JSON工具 + final Map paramMap = new HashMap<>(); + paramMap.put("stream", true); + paramMap.put("model", config.getModel()); + paramMap.put("messages", messages); + paramMap.put("context_id", contextId); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); + + return paramMap; + } + //构建创建视频任务请求体 private String buildGenerationsTasksRequestBody(String text, String image, final List videoParams) { //使用JSON工具 @@ -306,7 +381,7 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { content.add(textMap); } //添加图片参数 - if (!StrUtil.isNotBlank(image)) { + if (!StrUtil.isBlank(image)) { final Map imgUrlMap = new HashMap<>(); imgUrlMap.put("type", "image_url"); final Map urlMap = new HashMap<>(); @@ -347,6 +422,16 @@ public class DoubaoServiceImpl extends BaseAIService implements DoubaoService { paramMap.put("content", content); //合并其他参数 paramMap.putAll(config.getAdditionalConfigMap()); + return JSONUtil.toJsonStr(paramMap); + } + + //构建文生图请求体 + private String buildImagesGenerationsRequestBody(String prompt) { + final Map paramMap = new HashMap<>(); + paramMap.put("model", config.getModel()); + paramMap.put("prompt", prompt); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); return JSONUtil.toJsonStr(paramMap); } diff --git a/hutool-ai/src/main/java/cn/hutool/ai/model/grok/GrokService.java b/hutool-ai/src/main/java/cn/hutool/ai/model/grok/GrokService.java index 391c07a41..4658f766b 100644 --- a/hutool-ai/src/main/java/cn/hutool/ai/model/grok/GrokService.java +++ b/hutool-ai/src/main/java/cn/hutool/ai/model/grok/GrokService.java @@ -17,8 +17,11 @@ package cn.hutool.ai.model.grok; import cn.hutool.ai.core.AIService; +import cn.hutool.ai.core.Message; +import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; /** * grok支持的扩展接口 @@ -36,7 +39,48 @@ public interface GrokService extends AIService { * @return AI回答 * @since 5.8.38 */ - String message(String prompt, int maxToken); + default String message(String prompt, int maxToken){ + // 定义消息结构 + final List messages = new ArrayList<>(); + messages.add(new Message("system", "You are a helpful assistant")); + messages.add(new Message("user", prompt)); + return message(messages, maxToken); + } + + /** + * 创建消息回复-SSE流式输出 + * + * @param prompt 题词 + * @param maxToken 最大token + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + default void message(String prompt, int maxToken, final Consumer callback){ + final List messages = new ArrayList<>(); + messages.add(new Message("system", "You are a helpful assistant")); + messages.add(new Message("user", prompt)); + message(messages, maxToken, callback); + } + + /** + * 创建消息回复 + * + * @param messages messages 由对话组成的消息列表。如系统人设,背景信息等,用户自定义的信息 + * @param maxToken 最大token + * @return AI回答 + * @since 5.8.39 + */ + String message(List messages, int maxToken); + + /** + * 创建消息回复-SSE流式输出 + * + * @param messages messages 由对话组成的消息列表。如系统人设,背景信息等,用户自定义的信息 + * @param maxToken 最大token + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + void message(List messages, int maxToken, final Consumer callback); /** * 图像理解:模型会依据传入的图片信息以及问题,给出回复。 @@ -49,6 +93,17 @@ public interface GrokService extends AIService { */ String chatVision(String prompt, final List images, String detail); + /** + * 图像理解-SSE流式输出 + * + * @param prompt 题词 + * @param images 图片列表/或者图片Base64编码图片列表(URI形式) + * @param detail 手动设置图片的质量,取值范围high、low、auto,默认为auto + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + void chatVision(String prompt, final List images, String detail,final Consumer callback); + /** * 图像理解:模型会依据传入的图片信息以及问题,给出回复。 * @@ -61,6 +116,18 @@ public interface GrokService extends AIService { return chatVision(prompt, images, GrokCommon.GrokVision.AUTO.getDetail()); } + /** + * 图像理解:模型会依据传入的图片信息以及问题,给出回复。 + * + * @param prompt 题词 + * @param images 传入|的图片列表地址/或者图片Base64编码图片列表(URI形式) + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + default void chatVision(String prompt, final List images, final Consumer callback){ + chatVision(prompt, images, GrokCommon.GrokVision.AUTO.getDetail(), callback); + } + /** * 列出所有model列表 * @@ -112,4 +179,14 @@ public interface GrokService extends AIService { * @since 5.8.38 */ String deferredCompletion(String requestId); + + /** + * 文生图 + * 请设置config中model为支持图片功能的模型,目前支持GROK_2_IMAGE + * + * @param prompt 题词 + * @return 包含生成图片的url + * @since 5.8.39 + */ + String imagesGenerations(String prompt); } diff --git a/hutool-ai/src/main/java/cn/hutool/ai/model/grok/GrokServiceImpl.java b/hutool-ai/src/main/java/cn/hutool/ai/model/grok/GrokServiceImpl.java index a87debce7..1c53932e5 100644 --- a/hutool-ai/src/main/java/cn/hutool/ai/model/grok/GrokServiceImpl.java +++ b/hutool-ai/src/main/java/cn/hutool/ai/model/grok/GrokServiceImpl.java @@ -19,6 +19,7 @@ package cn.hutool.ai.model.grok; import cn.hutool.ai.core.AIConfig; import cn.hutool.ai.core.BaseAIService; import cn.hutool.ai.core.Message; +import cn.hutool.core.thread.ThreadUtil; import cn.hutool.http.HttpResponse; import cn.hutool.json.JSONUtil; @@ -26,6 +27,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; /** * Grok服务,AI具体功能的实现 @@ -47,21 +49,14 @@ public class GrokServiceImpl extends BaseAIService implements GrokService { private final String TOKENIZE_TEXT = "/tokenize-text"; //获取延迟对话 private final String DEFERRED_COMPLETION = "/chat/deferred-completion"; + //文生图 + private final String IMAGES_GENERATIONS = "/images/generations"; public GrokServiceImpl(final AIConfig config) { //初始化grok客户端 super(config); } - @Override - public String chat(String prompt) { - // 定义消息结构 - final List messages = new ArrayList<>(); - messages.add(new Message("system", "You are a helpful assistant")); - messages.add(new Message("user", prompt)); - return chat(messages); - } - @Override public String chat(final List messages) { String paramJson = buildChatRequestBody(messages); @@ -70,16 +65,24 @@ public class GrokServiceImpl extends BaseAIService implements GrokService { } @Override - public String message(String prompt, int maxToken) { - // 定义消息结构 - final List messages = new ArrayList<>(); - messages.add(new Message("system", "You are a helpful assistant")); - messages.add(new Message("user", prompt)); + public void chat(List messages,Consumer callback) { + Map paramMap = buildChatStreamRequestBody(messages); + ThreadUtil.newThread(() -> sendPostStream(CHAT_ENDPOINT, paramMap, callback::accept), "grok-chat-sse").start(); + } + + @Override + public String message(final List messages, int maxToken) { String paramJson = buildMessageRequestBody(messages, maxToken); final HttpResponse response = sendPost(MESSAGES, paramJson); return response.body(); } + @Override + public void message(List messages, int maxToken, final Consumer callback) { + Map paramMap = buildMessageStreamRequestBody(messages, maxToken); + ThreadUtil.newThread(() -> sendPostStream(MESSAGES, paramMap, callback::accept), "grok-message-sse").start(); + } + @Override public String chatVision(String prompt, final List images, String detail) { String paramJson = buildChatVisionRequestBody(prompt, images, detail); @@ -87,6 +90,12 @@ public class GrokServiceImpl extends BaseAIService implements GrokService { return response.body(); } + @Override + public void chatVision(String prompt, List images, String detail, Consumer callback) { + Map paramMap = buildChatVisionStreamRequestBody(prompt, images, detail); + ThreadUtil.newThread(() -> sendPostStream(CHAT_ENDPOINT, paramMap, callback::accept), "grok-chatVision-sse").start(); + } + @Override public String models() { final HttpResponse response = sendGet(MODELS_ENDPOINT); @@ -124,6 +133,13 @@ public class GrokServiceImpl extends BaseAIService implements GrokService { return response.body(); } + @Override + public String imagesGenerations(String prompt) { + String paramJson = buildImagesGenerationsRequestBody(prompt); + final HttpResponse response = sendPost(IMAGES_GENERATIONS, paramJson); + return response.body(); + } + // 构建chat请求体 private String buildChatRequestBody(final List messages) { //使用JSON工具 @@ -136,6 +152,18 @@ public class GrokServiceImpl extends BaseAIService implements GrokService { return JSONUtil.toJsonStr(paramMap); } + private Map buildChatStreamRequestBody(final List messages) { + //使用JSON工具 + final Map paramMap = new HashMap<>(); + paramMap.put("stream", true); + paramMap.put("model", config.getModel()); + paramMap.put("messages", messages); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); + + return paramMap; + } + //构建chatVision请求体 private String buildChatVisionRequestBody(String prompt, final List images, String detail) { // 定义消息结构 @@ -167,6 +195,37 @@ public class GrokServiceImpl extends BaseAIService implements GrokService { return JSONUtil.toJsonStr(paramMap); } + private Map buildChatVisionStreamRequestBody(String prompt, final List images, String detail) { + // 定义消息结构 + final List messages = new ArrayList<>(); + final List content = new ArrayList<>(); + + final Map contentMap = new HashMap<>(); + contentMap.put("type", "text"); + contentMap.put("text", prompt); + content.add(contentMap); + for (String img : images) { + HashMap imgUrlMap = new HashMap<>(); + imgUrlMap.put("type", "image_url"); + HashMap urlMap = new HashMap<>(); + urlMap.put("url", img); + urlMap.put("detail", detail); + imgUrlMap.put("image_url", urlMap); + content.add(imgUrlMap); + } + + messages.add(new Message("user", content)); + + //使用JSON工具 + final Map paramMap = new HashMap<>(); + paramMap.put("stream", true); + paramMap.put("model", config.getModel()); + paramMap.put("messages", messages); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); + return paramMap; + } + //构建消息回复请求体 private String buildMessageRequestBody(final List messages, int maxToken) { final Map paramMap = new HashMap<>(); @@ -179,6 +238,18 @@ public class GrokServiceImpl extends BaseAIService implements GrokService { return JSONUtil.toJsonStr(paramMap); } + private Map buildMessageStreamRequestBody(final List messages, int maxToken) { + final Map paramMap = new HashMap<>(); + paramMap.put("stream", true); + paramMap.put("model", config.getModel()); + paramMap.put("messages", messages); + paramMap.put("max_tokens", maxToken); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); + + return paramMap; + } + //构建分词请求体 private String buildTokenizeRequestBody(String text) { //使用JSON工具 @@ -190,4 +261,15 @@ public class GrokServiceImpl extends BaseAIService implements GrokService { return JSONUtil.toJsonStr(paramMap); } + + //构建文生图请求体 + private String buildImagesGenerationsRequestBody(String prompt) { + final Map paramMap = new HashMap<>(); + paramMap.put("model", config.getModel()); + paramMap.put("prompt", prompt); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); + + return JSONUtil.toJsonStr(paramMap); + } } diff --git a/hutool-ai/src/main/java/cn/hutool/ai/model/openai/OpenaiService.java b/hutool-ai/src/main/java/cn/hutool/ai/model/openai/OpenaiService.java index de8ff5869..4e7edad7f 100644 --- a/hutool-ai/src/main/java/cn/hutool/ai/model/openai/OpenaiService.java +++ b/hutool-ai/src/main/java/cn/hutool/ai/model/openai/OpenaiService.java @@ -21,7 +21,9 @@ import cn.hutool.ai.core.Message; import java.io.File; import java.io.InputStream; +import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; /** * openai支持的扩展接口 @@ -42,6 +44,18 @@ public interface OpenaiService extends AIService { */ String chatVision(String prompt, final List images, String detail); + /** + * 图像理解-SSE流式输出 + * + * @param prompt 题词 + * @param images 图片列表/或者图片Base64编码图片列表(URI形式) + * @param detail 手动设置图片的质量,取值范围high、low、auto,默认为auto + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + void chatVision(String prompt, final List images, String detail,final Consumer callback); + + /** * 图像理解:模型会依据传入的图片信息以及问题,给出回复。 * @@ -54,6 +68,18 @@ public interface OpenaiService extends AIService { return chatVision(prompt, images, OpenaiCommon.OpenaiVision.AUTO.getDetail()); } + /** + * 图像理解-SSE流式输出 + * + * @param prompt 题词 + * @param images 传入的图片列表地址/或者图片Base64编码图片列表(URI形式) + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + default void chatVision(String prompt, final List images, final Consumer callback){ + chatVision(prompt, images, OpenaiCommon.OpenaiVision.AUTO.getDetail(), callback); + } + /** * 文生图 请设置config中model为支持图片功能的模型 DALL·E系列 * @@ -166,7 +192,28 @@ public interface OpenaiService extends AIService { * @return AI回答 * @since 5.8.38 */ - String chatReasoning(String prompt, String reasoningEffort); + default String chatReasoning(String prompt, String reasoningEffort){ + final List messages = new ArrayList<>(); + messages.add(new Message("system", "You are a helpful assistant")); + messages.add(new Message("user", prompt)); + return chatReasoning(messages, reasoningEffort); + } + + /** + * 推理chat-SSE流式输出 + * 支持o3-mini和o1 + * + * @param prompt 对话题词 + * @param reasoningEffort 推理程度 + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + default void chatReasoning(String prompt, String reasoningEffort, final Consumer callback){ + final List messages = new ArrayList<>(); + messages.add(new Message("system", "You are a helpful assistant")); + messages.add(new Message("user", prompt)); + chatReasoning(messages, reasoningEffort, callback); + } /** * 推理chat @@ -180,6 +227,18 @@ public interface OpenaiService extends AIService { return chatReasoning(prompt, OpenaiCommon.OpenaiReasoning.MEDIUM.getEffort()); } + /** + * 推理chat-SSE流式输出 + * 支持o3-mini和o1 + * + * @param prompt 对话题词 + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + default void chatReasoning(String prompt, final Consumer callback) { + chatReasoning(prompt, OpenaiCommon.OpenaiReasoning.MEDIUM.getEffort(), callback); + } + /** * 推理chat * 支持o3-mini和o1 @@ -191,6 +250,17 @@ public interface OpenaiService extends AIService { */ String chatReasoning(final List messages, String reasoningEffort); + /** + * 推理chat-SSE流式输出 + * 支持o3-mini和o1 + * + * @param messages 消息列表 + * @param reasoningEffort 推理程度 + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + void chatReasoning(final List messages, String reasoningEffort, final Consumer callback); + /** * 推理chat * 支持o3-mini和o1 @@ -203,4 +273,16 @@ public interface OpenaiService extends AIService { return chatReasoning(messages, OpenaiCommon.OpenaiReasoning.MEDIUM.getEffort()); } + /** + * 推理chat-SSE流式输出 + * 支持o3-mini和o1 + * + * @param messages 消息列表 + * @param callback 流式数据回调函数 + * @since 5.8.39 + */ + default void chatReasoning(final List messages, final Consumer callback) { + chatReasoning(messages, OpenaiCommon.OpenaiReasoning.MEDIUM.getEffort(), callback); + } + } diff --git a/hutool-ai/src/main/java/cn/hutool/ai/model/openai/OpenaiServiceImpl.java b/hutool-ai/src/main/java/cn/hutool/ai/model/openai/OpenaiServiceImpl.java index 02c835b57..67f7d2555 100644 --- a/hutool-ai/src/main/java/cn/hutool/ai/model/openai/OpenaiServiceImpl.java +++ b/hutool-ai/src/main/java/cn/hutool/ai/model/openai/OpenaiServiceImpl.java @@ -19,6 +19,7 @@ package cn.hutool.ai.model.openai; import cn.hutool.ai.core.AIConfig; import cn.hutool.ai.core.BaseAIService; import cn.hutool.ai.core.Message; +import cn.hutool.core.thread.ThreadUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.http.HttpResponse; import cn.hutool.json.JSONUtil; @@ -29,6 +30,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; /** * openai服务,AI具体功能的实现 @@ -60,15 +62,6 @@ public class OpenaiServiceImpl extends BaseAIService implements OpenaiService { super(config); } - @Override - public String chat(String prompt) { - // 定义消息结构 - final List messages = new ArrayList<>(); - messages.add(new Message("system", "You are a helpful assistant")); - messages.add(new Message("user", prompt)); - return chat(messages); - } - @Override public String chat(final List messages) { String paramJson = buildChatRequestBody(messages); @@ -76,6 +69,12 @@ public class OpenaiServiceImpl extends BaseAIService implements OpenaiService { return response.body(); } + @Override + public void chat(List messages,Consumer callback) { + Map paramMap = buildChatStreamRequestBody(messages); + ThreadUtil.newThread(() -> sendPostStream(CHAT_ENDPOINT, paramMap, callback::accept), "openai-chat-sse").start(); + } + @Override public String chatVision(String prompt, final List images, String detail) { String paramJson = buildChatVisionRequestBody(prompt, images, detail); @@ -83,6 +82,12 @@ public class OpenaiServiceImpl extends BaseAIService implements OpenaiService { return response.body(); } + @Override + public void chatVision(String prompt, List images, String detail, Consumer callback) { + Map paramMap = buildChatVisionStreamRequestBody(prompt, images, detail); + ThreadUtil.newThread(() -> sendPostStream(CHAT_ENDPOINT, paramMap, callback::accept), "openai-chatVision-sse").start(); + } + @Override public String imagesGenerations(String prompt) { String paramJson = buildImagesGenerationsRequestBody(prompt); @@ -132,15 +137,6 @@ public class OpenaiServiceImpl extends BaseAIService implements OpenaiService { return response.body(); } - @Override - public String chatReasoning(String prompt, String reasoningEffort) { - // 定义消息结构 - final List messages = new ArrayList<>(); - messages.add(new Message("system", "You are a helpful assistant")); - messages.add(new Message("user", prompt)); - return chat(messages); - } - @Override public String chatReasoning(final List messages, String reasoningEffort) { String paramJson = buildChatReasoningRequestBody(messages, reasoningEffort); @@ -148,6 +144,12 @@ public class OpenaiServiceImpl extends BaseAIService implements OpenaiService { return response.body(); } + @Override + public void chatReasoning(List messages, String reasoningEffort, Consumer callback) { + Map paramMap = buildChatReasoningStreamRequestBody(messages, reasoningEffort); + ThreadUtil.newThread(() -> sendPostStream(CHAT_ENDPOINT, paramMap, callback::accept), "openai-chatReasoning-sse").start(); + } + // 构建chat请求体 private String buildChatRequestBody(final List messages) { //使用JSON工具 @@ -160,6 +162,18 @@ public class OpenaiServiceImpl extends BaseAIService implements OpenaiService { return JSONUtil.toJsonStr(paramMap); } + private Map buildChatStreamRequestBody(final List messages) { + //使用JSON工具 + final Map paramMap = new HashMap<>(); + paramMap.put("stream", true); + paramMap.put("model", config.getModel()); + paramMap.put("messages", messages); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); + + return paramMap; + } + //构建chatVision请求体 private String buildChatVisionRequestBody(String prompt, final List images, String detail) { // 定义消息结构 @@ -191,6 +205,37 @@ public class OpenaiServiceImpl extends BaseAIService implements OpenaiService { return JSONUtil.toJsonStr(paramMap); } + private Map buildChatVisionStreamRequestBody(String prompt, final List images, String detail) { + // 定义消息结构 + final List messages = new ArrayList<>(); + final List content = new ArrayList<>(); + + final Map contentMap = new HashMap<>(); + contentMap.put("type", "text"); + contentMap.put("text", prompt); + content.add(contentMap); + for (String img : images) { + HashMap imgUrlMap = new HashMap<>(); + imgUrlMap.put("type", "image_url"); + HashMap urlMap = new HashMap<>(); + urlMap.put("url", img); + urlMap.put("detail", detail); + imgUrlMap.put("image_url", urlMap); + content.add(imgUrlMap); + } + + messages.add(new Message("user", content)); + + //使用JSON工具 + final Map paramMap = new HashMap<>(); + paramMap.put("stream", true); + paramMap.put("model", config.getModel()); + paramMap.put("messages", messages); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); + return paramMap; + } + //构建文生图请求体 private String buildImagesGenerationsRequestBody(String prompt) { final Map paramMap = new HashMap<>(); @@ -305,4 +350,16 @@ public class OpenaiServiceImpl extends BaseAIService implements OpenaiService { return JSONUtil.toJsonStr(paramMap); } + private Map buildChatReasoningStreamRequestBody(final List messages, String reasoningEffort) { + final Map paramMap = new HashMap<>(); + paramMap.put("stream", true); + paramMap.put("model", config.getModel()); + paramMap.put("messages", messages); + paramMap.put("reasoning_effort", reasoningEffort); + //合并其他参数 + paramMap.putAll(config.getAdditionalConfigMap()); + + return paramMap; + } + } diff --git a/hutool-ai/src/test/java/cn/hutool/ai/model/deepseek/DeepSeekServiceTest.java b/hutool-ai/src/test/java/cn/hutool/ai/model/deepseek/DeepSeekServiceTest.java index 7dd486c08..d658265d3 100644 --- a/hutool-ai/src/test/java/cn/hutool/ai/model/deepseek/DeepSeekServiceTest.java +++ b/hutool-ai/src/test/java/cn/hutool/ai/model/deepseek/DeepSeekServiceTest.java @@ -20,11 +20,13 @@ import cn.hutool.ai.AIServiceFactory; import cn.hutool.ai.ModelName; import cn.hutool.ai.core.AIConfigBuilder; import cn.hutool.ai.core.Message; +import cn.hutool.core.thread.ThreadUtil; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -40,6 +42,29 @@ class DeepSeekServiceTest { assertNotNull(chat); } + @Test + @Disabled + void chatStream() { + String prompt = "写一个疯狂星期四广告词"; + // 使用AtomicBoolean作为结束标志 + AtomicBoolean isDone = new AtomicBoolean(false); + + deepSeekService.chat(prompt, data -> { + assertNotNull(data); + if (data.equals("data: [DONE]")) { + // 设置结束标志 + isDone.set(true); + } else if (data.contains("\"error\"")) { + isDone.set(true); + } + + }); + // 轮询检查结束标志 + while (!isDone.get()) { + ThreadUtil.sleep(100); + } + } + @Test @Disabled void testChat(){ @@ -54,7 +79,31 @@ class DeepSeekServiceTest { @Disabled void beta() { final String beta = deepSeekService.beta("写一个疯狂星期四广告词"); - System.out.println(beta); + assertNotNull(beta); + + } + + @Test + @Disabled + void betaStream() { + String beta = "写一个疯狂星期四广告词"; + // 使用AtomicBoolean作为结束标志 + AtomicBoolean isDone = new AtomicBoolean(false); + + deepSeekService.beta(beta, data -> { + assertNotNull(data); + if (data.equals("data: [DONE]")) { + // 设置结束标志 + isDone.set(true); + } else if (data.contains("\"error\"")) { + isDone.set(true); + } + + }); + // 轮询检查结束标志 + while (!isDone.get()) { + ThreadUtil.sleep(100); + } } @Test diff --git a/hutool-ai/src/test/java/cn/hutool/ai/model/doubao/DoubaoServiceTest.java b/hutool-ai/src/test/java/cn/hutool/ai/model/doubao/DoubaoServiceTest.java index 0aeb9b9b1..90195ee02 100644 --- a/hutool-ai/src/test/java/cn/hutool/ai/model/doubao/DoubaoServiceTest.java +++ b/hutool-ai/src/test/java/cn/hutool/ai/model/doubao/DoubaoServiceTest.java @@ -22,6 +22,7 @@ import cn.hutool.ai.Models; import cn.hutool.ai.core.AIConfigBuilder; import cn.hutool.ai.core.Message; import cn.hutool.core.img.ImgUtil; +import cn.hutool.core.thread.ThreadUtil; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -29,6 +30,7 @@ import java.awt.*; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -44,6 +46,29 @@ class DoubaoServiceTest { assertNotNull(chat); } + @Test + @Disabled + void chatStream() { + String prompt = "写一个疯狂星期四广告词"; + // 使用AtomicBoolean作为结束标志 + AtomicBoolean isDone = new AtomicBoolean(false); + + doubaoService.chat(prompt, data -> { + assertNotNull(data); + if (data.equals("data: [DONE]")) { + // 设置结束标志 + isDone.set(true); + } else if (data.contains("\"error\"")) { + isDone.set(true); + } + + }); + // 轮询检查结束标志 + while (!isDone.get()) { + ThreadUtil.sleep(100); + } + } + @Test @Disabled void testChat(){ @@ -73,14 +98,41 @@ class DoubaoServiceTest { assertNotNull(chatVision); } + @Test + @Disabled + void testChatVisionStream() { + final DoubaoService doubaoService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.DOUBAO.getValue()) + .setApiKey(key).setModel(Models.Doubao.DOUBAO_1_5_VISION_PRO_32K.getModel()).build(), DoubaoService.class); + + String prompt = "图片上有些什么?"; + List images = Arrays.asList("https://img2.baidu.com/it/u=862000265,4064861820&fm=253&fmt=auto&app=138&f=JPEG?w=800&h=1544"); + + // 使用AtomicBoolean作为结束标志 + AtomicBoolean isDone = new AtomicBoolean(false); + doubaoService.chatVision(prompt,images, data -> { + assertNotNull(data); + if (data.equals("data: [DONE]")) { + // 设置结束标志 + isDone.set(true); + } else if (data.contains("\"error\"")) { + isDone.set(true); + } + + }); + // 轮询检查结束标志 + while (!isDone.get()) { + ThreadUtil.sleep(100); + } + } + @Test @Disabled void videoTasks() { final DoubaoService doubaoService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.DOUBAO.getValue()) - .setApiKey(key).setModel("your Endpoint ID").build(), DoubaoService.class); + .setApiKey(key).setModel(Models.Doubao.Doubao_Seedance_1_0_lite_i2v.getModel()).build(), DoubaoService.class); final String videoTasks = doubaoService.videoTasks("生成一段动画视频,主角是大耳朵图图,一个活泼可爱的小男孩。视频中图图在公园里玩耍," + "画面采用明亮温暖的卡通风格,色彩鲜艳,动作流畅。背景音乐轻快活泼,带有冒险感,音效包括鸟叫声、欢笑声和山洞回声。", "https://img2.baidu.com/it/u=862000265,4064861820&fm=253&fmt=auto&app=138&f=JPEG?w=800&h=1544"); - assertNotNull(videoTasks);//cgt-20250306170051-6r9gk + assertNotNull(videoTasks); } @Test @@ -123,6 +175,33 @@ class DoubaoServiceTest { assertNotNull(botsChat); } + @Test + @Disabled + void botsChatStream() { + final DoubaoService doubaoService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.DOUBAO.getValue()) + .setApiKey(key).setModel("your bots id").build(), DoubaoService.class); + final ArrayList messages = new ArrayList<>(); + messages.add(new Message("system","你是什么都可以")); + messages.add(new Message("user","你想做些什么")); + + // 使用AtomicBoolean作为结束标志 + AtomicBoolean isDone = new AtomicBoolean(false); + doubaoService.botsChat(messages, data -> { + assertNotNull(data); + if (data.equals("data: [DONE]")) { + // 设置结束标志 + isDone.set(true); + } else if (data.contains("\"error\"")) { + isDone.set(true); + } + + }); + // 轮询检查结束标志 + while (!isDone.get()) { + ThreadUtil.sleep(100); + } + } + @Test @Disabled void tokenization() { @@ -166,7 +245,7 @@ class DoubaoServiceTest { @Disabled void testCreateContext() { final DoubaoService doubaoService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.DOUBAO.getValue()) - .setApiKey(key).setModel("your Endpoint ID").build(), DoubaoService.class); + .setApiKey(key).setModel("ep-20250305100610-bvbpc").build(), DoubaoService.class); final List messages = new ArrayList<>(); messages.add(new Message("system","你是个抽象大师,你真的很抽象")); final String context = doubaoService.createContext(messages,DoubaoCommon.DoubaoContext.COMMON_PREFIX.getMode()); @@ -178,8 +257,8 @@ class DoubaoServiceTest { void chatContext() { //ctx-20250307092153-cvslm final DoubaoService doubaoService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.DOUBAO.getValue()) - .setApiKey(key).setModel("eyour Endpoint ID").build(), DoubaoService.class); - final String chatContext = doubaoService.chatContext("你是谁?", "ctx-20250307092153-cvslm"); + .setApiKey(key).setModel("your Endpoint ID").build(), DoubaoService.class); + final String chatContext = doubaoService.chatContext("你是谁?", "your contextId"); assertNotNull(chatContext); } @@ -190,7 +269,43 @@ class DoubaoServiceTest { .setApiKey(key).setModel("your Endpoint ID").build(), DoubaoService.class); final List messages = new ArrayList<>(); messages.add(new Message("user","你怎么看待意大利面拌水泥?")); - final String chatContext = doubaoService.chatContext(messages, "ctx-20250307092153-cvslm"); + final String chatContext = doubaoService.chatContext(messages, "your contextId"); assertNotNull(chatContext); } + + @Test + @Disabled + void testChatContextStream() { + final DoubaoService doubaoService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.DOUBAO.getValue()) + .setApiKey(key).setModel("your Endpoint ID").build(), DoubaoService.class); + final List messages = new ArrayList<>(); + messages.add(new Message("user","你怎么看待意大利面拌水泥?")); + String contextId = "your contextId"; + + // 使用AtomicBoolean作为结束标志 + AtomicBoolean isDone = new AtomicBoolean(false); + doubaoService.chatContext(messages,contextId, data -> { + assertNotNull(data); + if (data.equals("data: [DONE]")) { + // 设置结束标志 + isDone.set(true); + } else if (data.contains("\"error\"")) { + isDone.set(true); + } + + }); + // 轮询检查结束标志 + while (!isDone.get()) { + ThreadUtil.sleep(100); + } + } + + @Test + @Disabled + void imagesGenerations() { + final DoubaoService doubaoService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.DOUBAO.getValue()) + .setApiKey(key).setModel(Models.Doubao.DOUBAO_SEEDREAM_3_0_T2I.getModel()).build(), DoubaoService.class); + final String imagesGenerations = doubaoService.imagesGenerations("一位年轻的宇航员站在未来感十足的太空站内,透过巨大的弧形落地窗凝望浩瀚宇宙。窗外,璀璨的星河与五彩斑斓的星云交织,远处隐约可见未知星球的轮廓,仿佛在召唤着探索的脚步。宇航服上的呼吸灯与透明显示屏上的星图交相辉映,象征着人类科技与宇宙奥秘的碰撞。画面深邃而神秘,充满对未知的渴望与无限可能的想象。"); + assertNotNull(imagesGenerations); + } } diff --git a/hutool-ai/src/test/java/cn/hutool/ai/model/grok/GrokServiceTest.java b/hutool-ai/src/test/java/cn/hutool/ai/model/grok/GrokServiceTest.java index bd2aac911..c6df69a25 100644 --- a/hutool-ai/src/test/java/cn/hutool/ai/model/grok/GrokServiceTest.java +++ b/hutool-ai/src/test/java/cn/hutool/ai/model/grok/GrokServiceTest.java @@ -22,6 +22,7 @@ import cn.hutool.ai.Models; import cn.hutool.ai.core.AIConfigBuilder; import cn.hutool.ai.core.Message; import cn.hutool.core.img.ImgUtil; +import cn.hutool.core.thread.ThreadUtil; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -29,6 +30,7 @@ import java.awt.*; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -45,6 +47,29 @@ class GrokServiceTest { assertNotNull(chat); } + @Test + @Disabled + void chatStream() { + String prompt = "写一个疯狂星期四广告词"; + // 使用AtomicBoolean作为结束标志 + AtomicBoolean isDone = new AtomicBoolean(false); + + grokService.chat(prompt, data -> { + assertNotNull(data); + if (data.equals("data: [DONE]")) { + // 设置结束标志 + isDone.set(true); + } else if (data.contains("\"error\"")) { + isDone.set(true); + } + + }); + // 轮询检查结束标志 + while (!isDone.get()) { + ThreadUtil.sleep(100); + } + } + @Test @Disabled void testChat(){ @@ -62,6 +87,29 @@ class GrokServiceTest { assertNotNull(message); } + @Test + @Disabled + void messageStream() { + String prompt = "给我一个KFC的广告词"; + + // 使用AtomicBoolean作为结束标志 + AtomicBoolean isDone = new AtomicBoolean(false); + grokService.message(prompt, 4096, data -> { + assertNotNull(data); + if (data.equals("data: [DONE]")) { + // 设置结束标志 + isDone.set(true); + } else if (data.contains("\"error\"")) { + isDone.set(true); + } + + }); + // 轮询检查结束标志 + while (!isDone.get()) { + ThreadUtil.sleep(100); + } + } + @Test @Disabled void chatVision() { @@ -71,6 +119,31 @@ class GrokServiceTest { assertNotNull(chatVision); } + @Test + @Disabled + void testChatVisionStream() { + final GrokService grokService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.GROK.getValue()).setModel(Models.Grok.GROK_2_VISION_1212.getModel()).setApiKey(key).build(), GrokService.class); + String prompt = "图片上有些什么?"; + List images = Arrays.asList("https://img2.baidu.com/it/u=862000265,4064861820&fm=253&fmt=auto&app=138&f=JPEG?w=800&h=1544"); + + // 使用AtomicBoolean作为结束标志 + AtomicBoolean isDone = new AtomicBoolean(false); + grokService.chatVision(prompt,images, data -> { + assertNotNull(data); + if (data.equals("data: [DONE]")) { + // 设置结束标志 + isDone.set(true); + } else if (data.contains("\"error\"")) { + isDone.set(true); + } + + }); + // 轮询检查结束标志 + while (!isDone.get()) { + ThreadUtil.sleep(100); + } + } + @Test @Disabled void testChatVision() { @@ -120,4 +193,13 @@ class GrokServiceTest { final String deferred = grokService.deferredCompletion(key); assertNotNull(deferred); } + + @Test + @Disabled + void imagesGenerations() { + final GrokService grokService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.GROK.getValue()) + .setApiKey(key).setModel(Models.Grok.GROK_2_IMAGE.getModel()).build(), GrokService.class); + final String imagesGenerations = grokService.imagesGenerations("一位年轻的宇航员站在未来感十足的太空站内,透过巨大的弧形落地窗凝望浩瀚宇宙。窗外,璀璨的星河与五彩斑斓的星云交织,远处隐约可见未知星球的轮廓,仿佛在召唤着探索的脚步。宇航服上的呼吸灯与透明显示屏上的星图交相辉映,象征着人类科技与宇宙奥秘的碰撞。画面深邃而神秘,充满对未知的渴望与无限可能的想象。"); + assertNotNull(imagesGenerations); + } } diff --git a/hutool-ai/src/test/java/cn/hutool/ai/model/openai/OpenaiServiceTest.java b/hutool-ai/src/test/java/cn/hutool/ai/model/openai/OpenaiServiceTest.java index 9e5e8f98f..7948ab495 100644 --- a/hutool-ai/src/test/java/cn/hutool/ai/model/openai/OpenaiServiceTest.java +++ b/hutool-ai/src/test/java/cn/hutool/ai/model/openai/OpenaiServiceTest.java @@ -22,6 +22,7 @@ import cn.hutool.ai.Models; import cn.hutool.ai.core.AIConfigBuilder; import cn.hutool.ai.core.Message; import cn.hutool.core.io.FileUtil; +import cn.hutool.core.thread.ThreadUtil; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -35,6 +36,7 @@ import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -51,6 +53,29 @@ class OpenaiServiceTest { assertNotNull(chat); } + @Test + @Disabled + void chatStream() { + String prompt = "写一个疯狂星期四广告词"; + // 使用AtomicBoolean作为结束标志 + AtomicBoolean isDone = new AtomicBoolean(false); + + openaiService.chat(prompt, data -> { + assertNotNull(data); + if (data.equals("data: [DONE]")) { + // 设置结束标志 + isDone.set(true); + } else if (data.contains("\"error\"")) { + isDone.set(true); + } + + }); + // 轮询检查结束标志 + while (!isDone.get()) { + ThreadUtil.sleep(100); + } + } + @Test @Disabled void testChat(){ @@ -70,6 +95,32 @@ class OpenaiServiceTest { assertNotNull(chatVision); } + @Test + @Disabled + void testChatVisionStream() { + final OpenaiService openaiService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.OPENAI.getValue()) + .setApiKey(key).setModel(Models.Openai.GPT_4O_MINI.getModel()).build(), OpenaiService.class); + String prompt = "图片上有些什么?"; + List images = Arrays.asList("https://img2.baidu.com/it/u=862000265,4064861820&fm=253&fmt=auto&app=138&f=JPEG?w=800&h=1544\",\"https://img2.baidu.com/it/u=1682510685,1244554634&fm=253&fmt=auto&app=138&f=JPEG?w=803&h=800"); + + // 使用AtomicBoolean作为结束标志 + AtomicBoolean isDone = new AtomicBoolean(false); + openaiService.chatVision(prompt,images, data -> { + assertNotNull(data); + if (data.equals("data: [DONE]")) { + // 设置结束标志 + isDone.set(true); + } else if (data.contains("\"error\"")) { + isDone.set(true); + } + + }); + // 轮询检查结束标志 + while (!isDone.get()) { + ThreadUtil.sleep(100); + } + } + @Test @Disabled void imagesGenerations() { @@ -132,7 +183,6 @@ class OpenaiServiceTest { .setApiKey(key).setModel(Models.Openai.WHISPER_1.getModel()).build(), OpenaiService.class); final File file = FileUtil.file("your filePath"); final String speechToText = openaiService.speechToText(file); - System.out.println(speechToText); assertNotNull(speechToText); } @@ -165,4 +215,31 @@ class OpenaiServiceTest { final String chatReasoning = openaiService.chatReasoning(messages, OpenaiCommon.OpenaiReasoning.HIGH.getEffort()); assertNotNull(chatReasoning); } + + @Test + @Disabled + void chatReasoningStream() { + final OpenaiService openaiService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.OPENAI.getValue()) + .setApiKey(key).setModel(Models.Openai.O3_MINI.getModel()).build(), OpenaiService.class); + final List messages = new ArrayList<>(); + messages.add(new Message("system","你是现代抽象家")); + messages.add(new Message("user","给我一个KFC疯狂星期四的文案")); + + // 使用AtomicBoolean作为结束标志 + AtomicBoolean isDone = new AtomicBoolean(false); + openaiService.chatReasoning(messages,OpenaiCommon.OpenaiReasoning.HIGH.getEffort(), data -> { + assertNotNull(data); + if (data.equals("data: [DONE]")) { + // 设置结束标志 + isDone.set(true); + } else if (data.contains("\"error\"")) { + isDone.set(true); + } + + }); + // 轮询检查结束标志 + while (!isDone.get()) { + ThreadUtil.sleep(100); + } + } }