聊聊Spring AI的Tool Calling

本文主要研究一下Spring AI的Tool Calling

ToolCallback

org/springframework/ai/tool/ToolCallback.java

public interface ToolCallback extends FunctionCallback {

    /**
     * Definition used by the AI model to determine when and how to call the tool.
     */
    ToolDefinition getToolDefinition();

    /**
     * Metadata providing additional information on how to handle the tool.
     */
    default ToolMetadata getToolMetadata() {
        return ToolMetadata.builder().build();
    }

    /**
     * Execute tool with the given input and return the result to send back to the AI
     * model.
     */
    String call(String toolInput);

    /**
     * Execute tool with the given input and context, and return the result to send back
     * to the AI model.
     */
    default String call(String toolInput, @Nullable ToolContext tooContext) {
        if (tooContext != null && !tooContext.getContext().isEmpty()) {
            throw new UnsupportedOperationException("Tool context is not supported!");
        }
        return call(toolInput);
    }

    @Override
    @Deprecated // Call getToolDefinition().name() instead
    default String getName() {
        return getToolDefinition().name();
    }

    @Override
    @Deprecated // Call getToolDefinition().description() instead
    default String getDescription() {
        return getToolDefinition().description();
    }

    @Override
    @Deprecated // Call getToolDefinition().inputTypeSchema() instead
    default String getInputTypeSchema() {
        return getToolDefinition().inputSchema();
    }

}

ToolCallback繼承了FunctionCallback接口,不過FunctionCallback接口即將被廢棄,它主要定義了getToolDefinition、getToolMetadata、call方法,它兩個基本實現(xiàn),分別是MethodToolCallback、FunctionToolCallback

MethodToolCallback

org/springframework/ai/tool/method/MethodToolCallback.java

public class MethodToolCallback implements ToolCallback {

    private static final Logger logger = LoggerFactory.getLogger(MethodToolCallback.class);

    private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter();

    private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();

    private final ToolDefinition toolDefinition;

    private final ToolMetadata toolMetadata;

    private final Method toolMethod;

    @Nullable
    private final Object toolObject;

    private final ToolCallResultConverter toolCallResultConverter;

    public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Method toolMethod,
            @Nullable Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) {
        Assert.notNull(toolDefinition, "toolDefinition cannot be null");
        Assert.notNull(toolMethod, "toolMethod cannot be null");
        Assert.isTrue(Modifier.isStatic(toolMethod.getModifiers()) || toolObject != null,
                "toolObject cannot be null for non-static methods");
        this.toolDefinition = toolDefinition;
        this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
        this.toolMethod = toolMethod;
        this.toolObject = toolObject;
        this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter
                : DEFAULT_RESULT_CONVERTER;
    }

    @Override
    public ToolDefinition getToolDefinition() {
        return toolDefinition;
    }

    @Override
    public ToolMetadata getToolMetadata() {
        return toolMetadata;
    }

    @Override
    public String call(String toolInput) {
        return call(toolInput, null);
    }

    @Override
    public String call(String toolInput, @Nullable ToolContext toolContext) {
        Assert.hasText(toolInput, "toolInput cannot be null or empty");

        logger.debug("Starting execution of tool: {}", toolDefinition.name());

        validateToolContextSupport(toolContext);

        Map<String, Object> toolArguments = extractToolArguments(toolInput);

        Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);

        Object result = callMethod(methodArguments);

        logger.debug("Successful execution of tool: {}", toolDefinition.name());

        Type returnType = toolMethod.getGenericReturnType();

        return toolCallResultConverter.convert(result, returnType);
    }

    @Nullable
    private Object callMethod(Object[] methodArguments) {
        if (isObjectNotPublic() || isMethodNotPublic()) {
            toolMethod.setAccessible(true);
        }

        Object result;
        try {
            result = toolMethod.invoke(toolObject, methodArguments);
        }
        catch (IllegalAccessException ex) {
            throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex);
        }
        catch (InvocationTargetException ex) {
            throw new ToolExecutionException(toolDefinition, ex.getCause());
        }
        return result;
    }

    //......
}   

MethodToolCallback實現(xiàn)了ToolCallback接口,其call方法通過buildMethodArguments構(gòu)建參數(shù),再通過callMethod獲取返回值,最后通過toolCallResultConverter.convert來轉(zhuǎn)換返回值類型;callMethod主要是通過反射調(diào)用執(zhí)行
目前如下幾個類型作為參數(shù)或者返回類型不支持

  • Optional
  • Asynchronous types (e.g. CompletableFuture, Future)
  • Reactive types (e.g. Flow, Mono, Flux)
  • Functional types (e.g. Function, Supplier, Consumer).

FunctionToolCallback

org/springframework/ai/tool/function/FunctionToolCallback.java

public class FunctionToolCallback<I, O> implements ToolCallback {

    private static final Logger logger = LoggerFactory.getLogger(FunctionToolCallback.class);

    private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter();

    private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();

    private final ToolDefinition toolDefinition;

    private final ToolMetadata toolMetadata;

    private final Type toolInputType;

    private final BiFunction<I, ToolContext, O> toolFunction;

    private final ToolCallResultConverter toolCallResultConverter;

    public FunctionToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Type toolInputType,
            BiFunction<I, ToolContext, O> toolFunction, @Nullable ToolCallResultConverter toolCallResultConverter) {
        Assert.notNull(toolDefinition, "toolDefinition cannot be null");
        Assert.notNull(toolInputType, "toolInputType cannot be null");
        Assert.notNull(toolFunction, "toolFunction cannot be null");
        this.toolDefinition = toolDefinition;
        this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
        this.toolFunction = toolFunction;
        this.toolInputType = toolInputType;
        this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter
                : DEFAULT_RESULT_CONVERTER;
    }

    @Override
    public ToolDefinition getToolDefinition() {
        return toolDefinition;
    }

    @Override
    public ToolMetadata getToolMetadata() {
        return toolMetadata;
    }

    @Override
    public String call(String toolInput) {
        return call(toolInput, null);
    }

    @Override
    public String call(String toolInput, @Nullable ToolContext toolContext) {
        Assert.hasText(toolInput, "toolInput cannot be null or empty");

        logger.debug("Starting execution of tool: {}", toolDefinition.name());

        I request = JsonParser.fromJson(toolInput, toolInputType);
        O response = toolFunction.apply(request, toolContext);

        logger.debug("Successful execution of tool: {}", toolDefinition.name());

        return toolCallResultConverter.convert(response, null);
    }

    @Override
    public String toString() {
        return "FunctionToolCallback{" + "toolDefinition=" + toolDefinition + ", toolMetadata=" + toolMetadata + '}';
    }

    //......
}   

FunctionToolCallback實現(xiàn)了ToolCallback接口,其call方法通過JsonParser.fromJson(toolInput, toolInputType)轉(zhuǎn)換請求參數(shù),再通過toolFunction.apply(request, toolContext)獲取返回結(jié)果,最后通過toolCallResultConverter.convert(response, null)來轉(zhuǎn)換結(jié)果
目前如下類型不支持作為參數(shù)或者返回類型

  • Primitive types
  • Optional
  • Collection types (e.g. List, Map, Array, Set)
  • Asynchronous types (e.g. CompletableFuture, Future)
  • Reactive types (e.g. Flow, Mono, Flux).

示例

class DateTimeTools {

    String getCurrentDateTime() {
        return LocalDateTime.now().atZone(LocaleContextHolder.getTimeZone().toZoneId()).toString();
    }

}

MethodToolCallback

Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime");
ToolCallback toolCallback = MethodToolCallback.builder()
    .toolDefinition(ToolDefinition.builder(method)
            .description("Get the current date and time in the user's timezone")
            .build())
    .toolMethod(method)
    .toolObject(new DateTimeTools())
    .build();

亦或是使用@Tool注解

class DateTimeTools {

    @Tool(description = "Get the current date and time in the user's timezone")
    String getCurrentDateTime() {
        return LocalDateTime.now().atZone(LocaleContextHolder.getTimeZone().toZoneId()).toString();
    }

}

亦或是通過ToolCallbacks.from方法

ToolCallback[] dateTimeTools = ToolCallbacks.from(new DateTimeTools());

FunctionToolCallback

public class WeatherService implements Function<WeatherRequest, WeatherResponse> {
    public WeatherResponse apply(WeatherRequest request) {
        return new WeatherResponse(30.0, Unit.C);
    }
}

ToolCallback toolCallback = FunctionToolCallback
    .builder("currentWeather", new WeatherService())
    .description("Get the weather in location")
    .inputType(WeatherRequest.class)
    .build();

ChatClient.create(chatModel)
    .prompt("What's the weather like in Copenhagen?")
    .tools(toolCallback)
    .call()
    .content();    

亦或設(shè)置到chatOptions

ChatOptions chatOptions = ToolCallingChatOptions.builder()
    .toolCallbacks(toolCallback)
    .build():
Prompt prompt = new Prompt("What's the weather like in Copenhagen?", chatOptions);
chatModel.call(prompt);

亦或是注冊到spring中

@Configuration(proxyBeanMethods = false)
class WeatherTools {

    WeatherService weatherService = new WeatherService();

    @Bean
    @Description("Get the weather in location")
    Function<WeatherRequest, WeatherResponse> currentWeather() {
        return weatherService;
    }

}

ChatClient.create(chatModel)
    .prompt("What's the weather like in Copenhagen?")
    .tools("currentWeather")
    .call()
    .content();

Tool Specification

ToolDefinition

org/springframework/ai/tool/definition/ToolDefinition.java

public interface ToolDefinition {

    /**
     * The tool name. Unique within the tool set provided to a model.
     */
    String name();

    /**
     * The tool description, used by the AI model to determine what the tool does.
     */
    String description();

    /**
     * The schema of the parameters used to call the tool.
     */
    String inputSchema();

    /**
     * Create a default {@link ToolDefinition} builder.
     */
    static DefaultToolDefinition.Builder builder() {
        return DefaultToolDefinition.builder();
    }

    /**
     * Create a default {@link ToolDefinition} builder from a {@link Method}.
     */
    static DefaultToolDefinition.Builder builder(Method method) {
        Assert.notNull(method, "method cannot be null");
        return DefaultToolDefinition.builder()
            .name(ToolUtils.getToolName(method))
            .description(ToolUtils.getToolDescription(method))
            .inputSchema(JsonSchemaGenerator.generateForMethodInput(method));
    }

    /**
     * Create a default {@link ToolDefinition} instance from a {@link Method}.
     */
    static ToolDefinition from(Method method) {
        return ToolDefinition.builder(method).build();
    }

}

ToolDefinition定義了name、description、inputSchema屬性,它提供了builder方法可以基于Method來構(gòu)建DefaultToolDefinition

示例

Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime");
ToolDefinition toolDefinition = ToolDefinition.builder(method)
    .name("currentDateTime")
    .description("Get the current date and time in the user's timezone")
    .inputSchema(JsonSchemaGenerator.generateForMethodInput(method))
    .build();

JSON Schema

Spring AI提供了JsonSchemaGenerator用于生成指定method或者function的請求參數(shù)的json schema,對于參數(shù)描述可以使用如下注解:

@ToolParam(description = "…") from Spring AI
@JsonClassDescription(description = "…") from Jackson
@JsonPropertyDescription(description = "…") from Jackson
@Schema(description = "…") from Swagger.

示例

import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.context.i18n.LocaleContextHolder;

class DateTimeTools {

    @Tool(description = "Set a user alarm for the given time")
    void setAlarm(@ToolParam(description = "Time in ISO-8601 format") String time) {
        LocalDateTime alarmTime = LocalDateTime.parse(time, DateTimeFormatter.ISO_DATE_TIME);
        System.out.println("Alarm set for " + alarmTime);
    }

}

對于是否必填,可以使用如下注解:

@ToolParam(required = false) from Spring AI
@JsonProperty(required = false) from Jackson
@Schema(required = false) from Swagger
@Nullable from Spring Framework.

示例:

class CustomerTools {

    @Tool(description = "Update customer information")
    void updateCustomerInfo(Long id, String name, @ToolParam(required = false) String email) {
        System.out.println("Updated info for customer with id: " + id);
    }

}

Result Conversion

Spring AI提供了ToolCallResultConverter用于將tool calling的返回數(shù)據(jù)進(jìn)行轉(zhuǎn)換再發(fā)送給AI模型
org/springframework/ai/tool/execution/ToolCallResultConverter.java

@FunctionalInterface
public interface ToolCallResultConverter {

    /**
     * Given an Object returned by a tool, convert it to a String compatible with the
     * given class type.
     */
    String convert(@Nullable Object result, @Nullable Type returnType);

}

它有一個默認(rèn)實現(xiàn)DefaultToolCallResultConverter

public final class DefaultToolCallResultConverter implements ToolCallResultConverter {

    private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallResultConverter.class);

    @Override
    public String convert(@Nullable Object result, @Nullable Type returnType) {
        if (returnType == Void.TYPE) {
            logger.debug("The tool has no return type. Converting to conventional response.");
            return "Done";
        }
        else {
            logger.debug("Converting tool result to JSON.");
            return JsonParser.toJson(result);
        }
    }

}

DefaultToolCallResultConverter采用的是JsonParser.toJson(result),將返回類型轉(zhuǎn)換為json字符串

也可以自己指定,比如

class CustomerTools {

    @Tool(description = "Retrieve customer information", resultConverter = CustomToolCallResultConverter.class)
    Customer getCustomerInfo(Long id) {
        return customerRepository.findById(id);
    }

}

Tool Context

Spring AI提供了ToolContext,可以將附加的上下文信息傳遞給工具。這一功能允許開發(fā)者提供額外的、由用戶提供的數(shù)據(jù),這些數(shù)據(jù)可以在工具執(zhí)行過程中與AI模型傳遞的工具參數(shù)一起使用。使用示例如下:

class CustomerTools {

    @Tool(description = "Retrieve customer information")
    Customer getCustomerInfo(Long id, ToolContext toolContext) {
        return customerRepository.findById(id, toolContext.get("tenantId"));
    }

}

對于chatClient:

ChatModel chatModel = ...

String response = ChatClient.create(chatModel)
        .prompt("Tell me more about the customer with ID 42")
        .tools(new CustomerTools())
        .toolContext(Map.of("tenantId", "acme"))
        .call()
        .content();

System.out.println(response);

對于chatModel:

ChatModel chatModel = ...
ToolCallback[] customerTools = ToolCallbacks.from(new CustomerTools());
ChatOptions chatOptions = ToolCallingChatOptions.builder()
    .toolCallbacks(customerTools)
    .toolContext(Map.of("tenantId", "acme"))
    .build();
Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOptions);
chatModel.call(prompt);

Return Direct

Spring AI提供了returnDirect參數(shù),設(shè)置為true則會將tool calling的返回直接返回,而不是經(jīng)過大模型再返回。默認(rèn)是返回給AI模型,AI模型處理之后再返回給用戶。
示例如下:

class CustomerTools {

    @Tool(description = "Retrieve customer information", returnDirect = true)
    Customer getCustomerInfo(Long id) {
        return customerRepository.findById(id);
    }

}

亦或是

ToolMetadata toolMetadata = ToolMetadata.builder()
    .returnDirect(true)
    .build();

ToolCallingManager

org/springframework/ai/model/tool/ToolCallingManager.java

public interface ToolCallingManager {

    /**
     * Resolve the tool definitions from the model's tool calling options.
     */
    List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions);

    /**
     * Execute the tool calls requested by the model.
     */
    ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse);

    /**
     * Create a default {@link ToolCallingManager} builder.
     */
    static DefaultToolCallingManager.Builder builder() {
        return DefaultToolCallingManager.builder();
    }

}

ToolCallingManager定義了resolveToolDefinitions、executeToolCalls方法,默認(rèn)實現(xiàn)是DefaultToolCallingManager

DefaultToolCallingManager

org/springframework/ai/model/tool/DefaultToolCallingManager.java

public class DefaultToolCallingManager implements ToolCallingManager {

    private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallingManager.class);

    // @formatter:off

    private static final ObservationRegistry DEFAULT_OBSERVATION_REGISTRY
            = ObservationRegistry.NOOP;

    private static final ToolCallbackResolver DEFAULT_TOOL_CALLBACK_RESOLVER
            = new DelegatingToolCallbackResolver(List.of());

    private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR
            = DefaultToolExecutionExceptionProcessor.builder().build();

    // @formatter:on

    private final ObservationRegistry observationRegistry;

    private final ToolCallbackResolver toolCallbackResolver;

    private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;

    public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver,
            ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
        Assert.notNull(observationRegistry, "observationRegistry cannot be null");
        Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null");
        Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null");

        this.observationRegistry = observationRegistry;
        this.toolCallbackResolver = toolCallbackResolver;
        this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
    }

    @Override
    public List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions) {
        Assert.notNull(chatOptions, "chatOptions cannot be null");

        List<FunctionCallback> toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks());
        for (String toolName : chatOptions.getToolNames()) {
            // Skip the tool if it is already present in the request toolCallbacks.
            // That might happen if a tool is defined in the options
            // both as a ToolCallback and as a tool name.
            if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) {
                continue;
            }
            FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName);
            if (toolCallback == null) {
                throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
            }
            toolCallbacks.add(toolCallback);
        }

        return toolCallbacks.stream().map(functionCallback -> {
            if (functionCallback instanceof ToolCallback toolCallback) {
                return toolCallback.getToolDefinition();
            }
            else {
                return ToolDefinition.builder()
                    .name(functionCallback.getName())
                    .description(functionCallback.getDescription())
                    .inputSchema(functionCallback.getInputTypeSchema())
                    .build();
            }
        }).toList();
    }

    @Override
    public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) {
        Assert.notNull(prompt, "prompt cannot be null");
        Assert.notNull(chatResponse, "chatResponse cannot be null");

        Optional<Generation> toolCallGeneration = chatResponse.getResults()
            .stream()
            .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
            .findFirst();

        if (toolCallGeneration.isEmpty()) {
            throw new IllegalStateException("No tool call requested by the chat model");
        }

        AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();

        ToolContext toolContext = buildToolContext(prompt, assistantMessage);

        InternalToolExecutionResult internalToolExecutionResult = executeToolCall(prompt, assistantMessage,
                toolContext);

        List<Message> conversationHistory = buildConversationHistoryAfterToolExecution(prompt.getInstructions(),
                assistantMessage, internalToolExecutionResult.toolResponseMessage());

        return ToolExecutionResult.builder()
            .conversationHistory(conversationHistory)
            .returnDirect(internalToolExecutionResult.returnDirect())
            .build();
    }

    //......

    /**
     * Execute the tool call and return the response message. To ensure backward
     * compatibility, both {@link ToolCallback} and {@link FunctionCallback} are
     * supported.
     */
    private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage,
            ToolContext toolContext) {
        List<FunctionCallback> toolCallbacks = List.of();
        if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
            toolCallbacks = toolCallingChatOptions.getToolCallbacks();
        }
        else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) {
            toolCallbacks = functionOptions.getFunctionCallbacks();
        }

        List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();

        Boolean returnDirect = null;

        for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {

            logger.debug("Executing tool call: {}", toolCall.name());

            String toolName = toolCall.name();
            String toolInputArguments = toolCall.arguments();

            FunctionCallback toolCallback = toolCallbacks.stream()
                .filter(tool -> toolName.equals(tool.getName()))
                .findFirst()
                .orElseGet(() -> toolCallbackResolver.resolve(toolName));

            if (toolCallback == null) {
                throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
            }

            if (returnDirect == null && toolCallback instanceof ToolCallback callback) {
                returnDirect = callback.getToolMetadata().returnDirect();
            }
            else if (toolCallback instanceof ToolCallback callback) {
                returnDirect = returnDirect && callback.getToolMetadata().returnDirect();
            }
            else if (returnDirect == null) {
                // This is a temporary solution to ensure backward compatibility with
                // FunctionCallback.
                // TODO: remove this block when FunctionCallback is removed.
                returnDirect = false;
            }

            String toolResult;
            try {
                toolResult = toolCallback.call(toolInputArguments, toolContext);
            }
            catch (ToolExecutionException ex) {
                toolResult = toolExecutionExceptionProcessor.process(ex);
            }

            toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolResult));
        }

        return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
    }

    private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages,
            AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) {
        List<Message> messages = new ArrayList<>(previousMessages);
        messages.add(assistantMessage);
        messages.add(toolResponseMessage);
        return messages;
    }   
}   

DefaultToolCallingManager的resolveToolDefinitions方法會通過toolCallbackResolver來解析chatOptions.getToolCallbacks(),executeToolCalls方法先篩選出需要toolCall支持的assistantMessage,然后構(gòu)建toolContext,再執(zhí)行executeToolCall獲取執(zhí)行結(jié)構(gòu),再基于此構(gòu)建conversationHistory。
executeToolCall方法遍歷assistantMessage.getToolCalls(),通過toolCallbackResolver.resolve(toolName)解析成toolCallback,最后通過toolCallback.call(toolInputArguments, toolContext)獲取結(jié)果,如果出現(xiàn)ToolExecutionException,則通過toolExecutionExceptionProcessor.process(ex)去做兜底操作

ToolExecutionExceptionProcessor

org/springframework/ai/tool/execution/ToolExecutionExceptionProcessor.java

@FunctionalInterface
public interface ToolExecutionExceptionProcessor {

    /**
     * Convert an exception thrown by a tool to a String that can be sent back to the AI
     * model or throw an exception to be handled by the caller.
     */
    String process(ToolExecutionException exception);

}

ToolExecutionExceptionProcessor定義process

DefaultToolExecutionExceptionProcessor

public class DefaultToolExecutionExceptionProcessor implements ToolExecutionExceptionProcessor {

    private final static Logger logger = LoggerFactory.getLogger(DefaultToolExecutionExceptionProcessor.class);

    private static final boolean DEFAULT_ALWAYS_THROW = false;

    private final boolean alwaysThrow;

    public DefaultToolExecutionExceptionProcessor(boolean alwaysThrow) {
        this.alwaysThrow = alwaysThrow;
    }

    @Override
    public String process(ToolExecutionException exception) {
        Assert.notNull(exception, "exception cannot be null");
        if (alwaysThrow) {
            throw exception;
        }
        logger.debug("Exception thrown by tool: {}. Message: {}", exception.getToolDefinition().name(),
                exception.getMessage());
        return exception.getMessage();
    }

    //......
}   

DefaultToolExecutionExceptionProcessor對于alwaysThrow為true的(默認(rèn)為false)直接拋出該異常,否則返回異常的信息

User-Controlled Tool Execution

ToolCallingChatOptions提供了internalToolExecutionEnabled屬性,設(shè)置為false可以自行控制對tool的調(diào)用過程(也可以自己實現(xiàn)ToolExecutionEligibilityPredicate去控制),示例如下:

ChatModel chatModel = ...
ToolCallingManager toolCallingManager = ToolCallingManager.builder().build();

ChatOptions chatOptions = ToolCallingChatOptions.builder()
    .toolCallbacks(new CustomerTools())
    .internalToolExecutionEnabled(false)
    .build();
Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOptions);

ChatResponse chatResponse = chatModel.call(prompt);

while (chatResponse.hasToolCalls()) {
    ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);

    prompt = new Prompt(toolExecutionResult.conversationHistory(), chatOptions);

    chatResponse = chatModel.call(prompt);
}

System.out.println(chatResponse.getResult().getOutput().getText());

這里自己通過toolCallingManager.executeToolCalls去執(zhí)行,再傳遞給chatModel

ToolCallbackResolver

spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java

public interface ToolCallbackResolver {

    /**
     * Resolve the {@link FunctionCallback} for the given tool name.
     */
    @Nullable
    FunctionCallback resolve(String toolName);

}

ToolCallbackResolver定義了resolve方法,用于根據(jù)toolName來獲取對應(yīng)的FunctionCallback,它有三種實現(xiàn),分別是StaticToolCallbackResolver、SpringBeanToolCallbackResolver、DelegatingToolCallbackResolver

StaticToolCallbackResolver

spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java

public class StaticToolCallbackResolver implements ToolCallbackResolver {

    private static final Logger logger = LoggerFactory.getLogger(StaticToolCallbackResolver.class);

    private final Map<String, FunctionCallback> toolCallbacks = new HashMap<>();

    public StaticToolCallbackResolver(List<FunctionCallback> toolCallbacks) {
        Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
        Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");

        toolCallbacks.forEach(callback -> {
            if (callback instanceof ToolCallback toolCallback) {
                this.toolCallbacks.put(toolCallback.getToolDefinition().name(), toolCallback);
            }
            this.toolCallbacks.put(callback.getName(), callback);
        });
    }

    @Override
    public FunctionCallback resolve(String toolName) {
        Assert.hasText(toolName, "toolName cannot be null or empty");
        logger.debug("ToolCallback resolution attempt from static registry");
        return toolCallbacks.get(toolName);
    }

}

StaticToolCallbackResolver依據(jù)構(gòu)造器傳入的List<FunctionCallback>來尋找

SpringBeanToolCallbackResolver

spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java

public class SpringBeanToolCallbackResolver implements ToolCallbackResolver {

    private static final Logger logger = LoggerFactory.getLogger(SpringBeanToolCallbackResolver.class);

    private static final Map<String, ToolCallback> toolCallbacksCache = new HashMap<>();

    private static final SchemaType DEFAULT_SCHEMA_TYPE = SchemaType.JSON_SCHEMA;

    private final GenericApplicationContext applicationContext;

    private final SchemaType schemaType;

    public SpringBeanToolCallbackResolver(GenericApplicationContext applicationContext,
            @Nullable SchemaType schemaType) {
        Assert.notNull(applicationContext, "applicationContext cannot be null");

        this.applicationContext = applicationContext;
        this.schemaType = schemaType != null ? schemaType : DEFAULT_SCHEMA_TYPE;
    }

    @Override
    public ToolCallback resolve(String toolName) {
        Assert.hasText(toolName, "toolName cannot be null or empty");

        logger.debug("ToolCallback resolution attempt from Spring application context");

        ToolCallback resolvedToolCallback = toolCallbacksCache.get(toolName);

        if (resolvedToolCallback != null) {
            return resolvedToolCallback;
        }

        ResolvableType toolType = TypeResolverHelper.resolveBeanType(applicationContext, toolName);
        ResolvableType toolInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(toolType))
                ? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(toolType, 0);

        String toolDescription = resolveToolDescription(toolName, toolInputType.toClass());
        Object bean = applicationContext.getBean(toolName);

        resolvedToolCallback = buildToolCallback(toolName, toolType, toolInputType, toolDescription, bean);

        toolCallbacksCache.put(toolName, resolvedToolCallback);

        return resolvedToolCallback;
    }

    //......
}   

SpringBeanToolCallbackResolver使用GenericApplicationContext根據(jù)toolName去spring容器查找,找到的話會放到toolCallbacksCache中

DelegatingToolCallbackResolver

spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java

public class DelegatingToolCallbackResolver implements ToolCallbackResolver {

    private final List<ToolCallbackResolver> toolCallbackResolvers;

    public DelegatingToolCallbackResolver(List<ToolCallbackResolver> toolCallbackResolvers) {
        Assert.notNull(toolCallbackResolvers, "toolCallbackResolvers cannot be null");
        Assert.noNullElements(toolCallbackResolvers, "toolCallbackResolvers cannot contain null elements");
        this.toolCallbackResolvers = toolCallbackResolvers;
    }

    @Override
    @Nullable
    public FunctionCallback resolve(String toolName) {
        Assert.hasText(toolName, "toolName cannot be null or empty");

        for (ToolCallbackResolver toolCallbackResolver : toolCallbackResolvers) {
            FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName);
            if (toolCallback != null) {
                return toolCallback;
            }
        }
        return null;
    }

}

DelegatingToolCallbackResolver把resolve方法委托給了構(gòu)造器傳入的其他toolCallbackResolvers

小結(jié)

Spring AI提供了ToolCallback來實現(xiàn)Tool Calling,它繼承了FunctionCallback接口,不過FunctionCallback接口即將被廢棄,它主要定義了getToolDefinition、getToolMetadata、call方法,它兩個基本實現(xiàn),分別是MethodToolCallback、FunctionToolCallback。

整個Tool Specification包含了Tool Callback、Tool Definition、JSON Schema、Result Conversion、Tool Context、Return Direct
整個Tool Execution包含了Framework-Controlled Tool Execution、User-Controlled Tool Execution、Exception Handling

doc

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 227,702評論 6 531
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 98,143評論 3 415
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 175,553評論 0 373
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經(jīng)常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 62,620評論 1 307
  • 正文 為了忘掉前任,我火速辦了婚禮,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 71,416評論 6 405
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 54,940評論 1 321
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,024評論 3 440
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 42,170評論 0 287
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 48,709評論 1 333
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 40,597評論 3 354
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 42,784評論 1 369
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,291評論 5 357
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 44,029評論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,407評論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 35,663評論 1 280
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 51,403評論 3 390
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 47,746評論 2 370

推薦閱讀更多精彩內(nèi)容