Skip to content

Commit 2b5858e

Browse files
committed
Adapt ChatConfig for Spring AI M5
1 parent f496d71 commit 2b5858e

1 file changed

Lines changed: 73 additions & 7 deletions

File tree

SpringAI/src/main/java/cloud/cleo/squareup/config/ChatConfig.java

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package cloud.cleo.squareup.config;
22

33
import cloud.cleo.squareup.memory.DynamoDbChatMemoryRepository;
4+
import com.openai.client.OpenAIClient;
5+
import com.openai.client.OpenAIClientAsync;
6+
import io.micrometer.observation.ObservationRegistry;
47
import java.time.Duration;
58
import java.util.Comparator;
69
import java.util.List;
10+
import java.util.Map;
711
import org.springframework.ai.bedrock.converse.BedrockChatOptions;
812
import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
913
import org.springframework.ai.chat.client.ChatClient;
@@ -20,11 +24,20 @@
2024
import org.springframework.ai.chat.messages.Message;
2125
import org.springframework.ai.chat.messages.MessageType;
2226
import org.springframework.ai.chat.model.ChatModel;
27+
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
28+
import org.springframework.ai.model.openai.autoconfigure.OpenAiAutoConfigurationUtil;
29+
import org.springframework.ai.model.openai.autoconfigure.OpenAiConnectionProperties;
30+
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
31+
import org.springframework.ai.model.tool.ToolCallingManager;
32+
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
33+
import org.springframework.ai.openai.AbstractOpenAiOptions;
2334
import org.springframework.ai.openai.OpenAiChatModel;
2435
import org.springframework.ai.openai.OpenAiChatOptions;
25-
import org.springframework.ai.openai.api.OpenAiApi;
36+
import org.springframework.ai.openai.setup.OpenAiSetup;
2637
import org.springframework.beans.factory.annotation.Qualifier;
2738
import org.springframework.beans.factory.annotation.Value;
39+
import org.springframework.beans.factory.ObjectProvider;
40+
import org.springframework.boot.context.properties.EnableConfigurationProperties;
2841
import org.springframework.context.annotation.Bean;
2942
import org.springframework.context.annotation.Configuration;
3043
import org.springframework.context.annotation.Primary;
@@ -40,6 +53,7 @@
4053
* @author sjensen
4154
*/
4255
@Configuration
56+
@EnableConfigurationProperties(OpenAiConnectionProperties.class)
4357
public class ChatConfig {
4458

4559
/**
@@ -78,8 +92,8 @@ public OpenAiChatOptions openAiChatOptions(@Value("${spring.ai.openai.chat.optio
7892
var builder = OpenAiChatOptions.builder()
7993
.model(model)
8094
.parallelToolCalls(true)
81-
// We want cache to work across different Lambda IPs(AZ) and across regions
82-
.promptCacheKey("cloud-cleo-squareup-spring-ai")
95+
// We want cache to work across different Lambda IPs(AZ) and across regions.
96+
.extraBody(Map.of("prompt_cache_key", "cloud-cleo-squareup-spring-ai"))
8397
.N(1); // We only ever want 1 response
8498

8599
if (model.startsWith("gpt-4")) {
@@ -99,11 +113,63 @@ public OpenAiChatOptions openAiChatOptions(@Value("${spring.ai.openai.chat.optio
99113
}
100114

101115
@Bean(name = "customOpenAiChatModel")
102-
public ChatModel chatModel(OpenAiApi api, OpenAiChatOptions options) {
103-
return OpenAiChatModel.builder()
104-
.openAiApi(api)
105-
.defaultOptions(options)
116+
public ChatModel chatModel(OpenAiConnectionProperties connectionProperties,
117+
OpenAiChatOptions options,
118+
ToolCallingManager toolCallingManager,
119+
ObjectProvider<ObservationRegistry> observationRegistry,
120+
ObjectProvider<ChatModelObservationConvention> observationConvention,
121+
ObjectProvider<ToolExecutionEligibilityPredicate> toolExecutionEligibilityPredicate
122+
) {
123+
AbstractOpenAiOptions resolvedConnectionProperties =
124+
OpenAiAutoConfigurationUtil.resolveConnectionProperties(connectionProperties, options);
125+
126+
OpenAiChatModel chatModel = OpenAiChatModel.builder()
127+
.openAiClient(openAiClient(resolvedConnectionProperties))
128+
.openAiClientAsync(openAiClientAsync(resolvedConnectionProperties))
129+
.options(options)
130+
.toolCallingManager(toolCallingManager)
131+
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
132+
.toolExecutionEligibilityPredicate(toolExecutionEligibilityPredicate.getIfUnique(DefaultToolExecutionEligibilityPredicate::new))
106133
.build();
134+
135+
observationConvention.ifAvailable(chatModel::setObservationConvention);
136+
return chatModel;
137+
}
138+
139+
private static OpenAIClient openAiClient(AbstractOpenAiOptions options) {
140+
return OpenAiSetup.setupSyncClient(
141+
options.getBaseUrl(),
142+
options.getApiKey(),
143+
options.getCredential(),
144+
options.getMicrosoftDeploymentName(),
145+
options.getMicrosoftFoundryServiceVersion(),
146+
options.getOrganizationId(),
147+
options.isMicrosoftFoundry(),
148+
options.isGitHubModels(),
149+
options.getModel(),
150+
options.getTimeout(),
151+
options.getMaxRetries(),
152+
options.getProxy(),
153+
options.getCustomHeaders()
154+
);
155+
}
156+
157+
private static OpenAIClientAsync openAiClientAsync(AbstractOpenAiOptions options) {
158+
return OpenAiSetup.setupAsyncClient(
159+
options.getBaseUrl(),
160+
options.getApiKey(),
161+
options.getCredential(),
162+
options.getMicrosoftDeploymentName(),
163+
options.getMicrosoftFoundryServiceVersion(),
164+
options.getOrganizationId(),
165+
options.isMicrosoftFoundry(),
166+
options.isGitHubModels(),
167+
options.getModel(),
168+
options.getTimeout(),
169+
options.getMaxRetries(),
170+
options.getProxy(),
171+
options.getCustomHeaders()
172+
);
107173
}
108174

109175
@Bean

0 commit comments

Comments
 (0)