聊聊Spring AI的EmbeddingModel

本文主要研究一下Spring AI的EmbeddingModel

EmbeddingModel

spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java

public interface EmbeddingModel extends Model<EmbeddingRequest, EmbeddingResponse> {

    @Override
    EmbeddingResponse call(EmbeddingRequest request);

    /**
     * Embeds the given text into a vector.
     * @param text the text to embed.
     * @return the embedded vector.
     */
    default float[] embed(String text) {
        Assert.notNull(text, "Text must not be null");
        List<float[]> response = this.embed(List.of(text));
        return response.iterator().next();
    }

    /**
     * Embeds the given document's content into a vector.
     * @param document the document to embed.
     * @return the embedded vector.
     */
    float[] embed(Document document);

    /**
     * Embeds a batch of texts into vectors.
     * @param texts list of texts to embed.
     * @return list of embedded vectors.
     */
    default List<float[]> embed(List<String> texts) {
        Assert.notNull(texts, "Texts must not be null");
        return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()))
            .getResults()
            .stream()
            .map(Embedding::getOutput)
            .toList();
    }

    /**
     * Embeds a batch of {@link Document}s into vectors based on a
     * {@link BatchingStrategy}.
     * @param documents list of {@link Document}s.
     * @param options {@link EmbeddingOptions}.
     * @param batchingStrategy {@link BatchingStrategy}.
     * @return a list of float[] that represents the vectors for the incoming
     * {@link Document}s. The returned list is expected to be in the same order of the
     * {@link Document} list.
     */
    default List<float[]> embed(List<Document> documents, EmbeddingOptions options, BatchingStrategy batchingStrategy) {
        Assert.notNull(documents, "Documents must not be null");
        List<float[]> embeddings = new ArrayList<>(documents.size());
        List<List<Document>> batch = batchingStrategy.batch(documents);
        for (List<Document> subBatch : batch) {
            List<String> texts = subBatch.stream().map(Document::getText).toList();
            EmbeddingRequest request = new EmbeddingRequest(texts, options);
            EmbeddingResponse response = this.call(request);
            for (int i = 0; i < subBatch.size(); i++) {
                embeddings.add(response.getResults().get(i).getOutput());
            }
        }
        Assert.isTrue(embeddings.size() == documents.size(),
                "Embeddings must have the same number as that of the documents");
        return embeddings;
    }

    /**
     * Embeds a batch of texts into vectors and returns the {@link EmbeddingResponse}.
     * @param texts list of texts to embed.
     * @return the embedding response.
     */
    default EmbeddingResponse embedForResponse(List<String> texts) {
        Assert.notNull(texts, "Texts must not be null");
        return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()));
    }

    /**
     * Get the number of dimensions of the embedded vectors. Note that by default, this
     * method will call the remote Embedding endpoint to get the dimensions of the
     * embedded vectors. If the dimensions are known ahead of time, it is recommended to
     * override this method.
     * @return the number of dimensions of the embedded vectors.
     */
    default int dimensions() {
        return embed("Test String").length;
    }

}

EmbeddingModel繼承了Model接口,其入參類型為EmbeddingRequest,返回類型為EmbeddingResponse,它定義了call、embed接口,提供了embed、embedForResponse、dimensions的默認實現

EmbeddingRequest

spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java

public class EmbeddingRequest implements ModelRequest<List<String>> {

    private final List<String> inputs;

    private final EmbeddingOptions options;

    public EmbeddingRequest(List<String> inputs, EmbeddingOptions options) {
        this.inputs = inputs;
        this.options = options;
    }

    @Override
    public List<String> getInstructions() {
        return this.inputs;
    }

    @Override
    public EmbeddingOptions getOptions() {
        return this.options;
    }

}

EmbeddingRequest實現了ModelRequest接口,其getInstructions返回的是List<String>

EmbeddingResponse

spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java

public class EmbeddingResponse implements ModelResponse<Embedding> {

    /**
     * Embedding data.
     */
    private final List<Embedding> embeddings;

    /**
     * Embedding metadata.
     */
    private final EmbeddingResponseMetadata metadata;

    /**
     * Creates a new {@link EmbeddingResponse} instance with empty metadata.
     * @param embeddings the embedding data.
     */
    public EmbeddingResponse(List<Embedding> embeddings) {
        this(embeddings, new EmbeddingResponseMetadata());
    }

    /**
     * Creates a new {@link EmbeddingResponse} instance.
     * @param embeddings the embedding data.
     * @param metadata the embedding metadata.
     */
    public EmbeddingResponse(List<Embedding> embeddings, EmbeddingResponseMetadata metadata) {
        this.embeddings = embeddings;
        this.metadata = metadata;
    }

    /**
     * @return Get the embedding metadata.
     */
    public EmbeddingResponseMetadata getMetadata() {
        return this.metadata;
    }

    @Override
    public Embedding getResult() {
        Assert.notEmpty(this.embeddings, "No embedding data available.");
        return this.embeddings.get(0);
    }

    /**
     * @return Get the embedding data.
     */
    @Override
    public List<Embedding> getResults() {
        return this.embeddings;
    }

    //......
}   

EmbeddingResponse實現了ModelResponse接口,其result為Embedding類型

AbstractEmbeddingModel

spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java

public abstract class AbstractEmbeddingModel implements EmbeddingModel {

    private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions();

    /**
     * Default constructor.
     */
    public AbstractEmbeddingModel() {
    }

    /**
     * Cached embedding dimensions.
     */
    protected final AtomicInteger embeddingDimensions = new AtomicInteger(-1);

    /**
     * Return the dimension of the requested embedding generative name. If the generative
     * name is unknown uses the EmbeddingModel to perform a dummy EmbeddingModel#embed and
     * count the response dimensions.
     * @param embeddingModel Fall-back client to determine, empirically the dimensions.
     * @param modelName Embedding generative name to retrieve the dimensions for.
     * @param dummyContent Dummy content to use for the empirical dimension calculation.
     * @return Returns the embedding dimensions for the modelName.
     */
    public static int dimensions(EmbeddingModel embeddingModel, String modelName, String dummyContent) {

        if (KNOWN_EMBEDDING_DIMENSIONS.containsKey(modelName)) {
            // Retrieve the dimension from a pre-configured file.
            return KNOWN_EMBEDDING_DIMENSIONS.get(modelName);
        }
        else {
            // Determine the dimensions empirically.
            // Generate an embedding and count the dimension size;
            return embeddingModel.embed(dummyContent).length;
        }
    }

    private static Map<String, Integer> loadKnownModelDimensions() {
        try {
            Properties properties = new Properties();
            properties.load(new DefaultResourceLoader()
                .getResource("classpath:/embedding/embedding-model-dimensions.properties")
                .getInputStream());
            return properties.entrySet()
                .stream()
                .collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())));
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public int dimensions() {
        if (this.embeddingDimensions.get() < 0) {
            this.embeddingDimensions.set(dimensions(this, "Test", "Hello World"));
        }
        return this.embeddingDimensions.get();
    }

}

AbstractEmbeddingModel實現了EmbeddingModel接口定義的dimensions方法,它在不同模塊有不同的實現子類,比如spring-ai-openai的OpenAiEmbeddingModel、spring-ai-ollama的OllamaEmbeddingModel、spring-ai-minimax的MiniMaxEmbeddingModel

OllamaEmbeddingAutoConfiguration

org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfiguration.java

@AutoConfiguration(after = RestClientAutoConfiguration.class)
@ConditionalOnClass(OllamaEmbeddingModel.class)
@ConditionalOnProperty(name = SpringAIModelProperties.EMBEDDING_MODEL, havingValue = SpringAIModels.OLLAMA,
        matchIfMissing = true)
@EnableConfigurationProperties({ OllamaEmbeddingProperties.class, OllamaInitializationProperties.class })
@ImportAutoConfiguration(classes = { OllamaApiAutoConfiguration.class, RestClientAutoConfiguration.class,
        WebClientAutoConfiguration.class })
public class OllamaEmbeddingAutoConfiguration {

    @Bean
    @ConditionalOnMissingBean
    public OllamaEmbeddingModel ollamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingProperties properties,
            OllamaInitializationProperties initProperties, ObjectProvider<ObservationRegistry> observationRegistry,
            ObjectProvider<EmbeddingModelObservationConvention> observationConvention) {
        var embeddingModelPullStrategy = initProperties.getEmbedding().isInclude()
                ? initProperties.getPullModelStrategy() : PullModelStrategy.NEVER;

        var embeddingModel = OllamaEmbeddingModel.builder()
            .ollamaApi(ollamaApi)
            .defaultOptions(properties.getOptions())
            .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
            .modelManagementOptions(new ModelManagementOptions(embeddingModelPullStrategy,
                    initProperties.getEmbedding().getAdditionalModels(), initProperties.getTimeout(),
                    initProperties.getMaxRetries()))
            .build();

        observationConvention.ifAvailable(embeddingModel::setObservationConvention);

        return embeddingModel;
    }

}

OllamaEmbeddingAutoConfiguration在spring.ai.model.embeddingollama時啟用,它自動配置了OllamaEmbeddingModel

OllamaEmbeddingProperties

org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingProperties.java

@ConfigurationProperties(OllamaEmbeddingProperties.CONFIG_PREFIX)
public class OllamaEmbeddingProperties {

    public static final String CONFIG_PREFIX = "spring.ai.ollama.embedding";

    /**
     * Client lever Ollama options. Use this property to configure generative temperature,
     * topK and topP and alike parameters. The null values are ignored defaulting to the
     * generative's defaults.
     */
    @NestedConfigurationProperty
    private OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MXBAI_EMBED_LARGE.id()).build();

    public String getModel() {
        return this.options.getModel();
    }

    public void setModel(String model) {
        this.options.setModel(model);
    }

    public OllamaOptions getOptions() {
        return this.options;
    }

}

OllamaEmbeddingProperties主要是提供了OllamaOptions屬性配置,具體可以參考https://github.com/ggerganov/llama.cpp/blob/master/examples/main/README.md

OllamaInitializationProperties

org/springframework/ai/model/ollama/autoconfigure/OllamaInitializationProperties.java

@ConfigurationProperties(OllamaInitializationProperties.CONFIG_PREFIX)
public class OllamaInitializationProperties {

    public static final String CONFIG_PREFIX = "spring.ai.ollama.init";

    /**
     * Chat models initialization settings.
     */
    private final ModelTypeInit chat = new ModelTypeInit();

    /**
     * Embedding models initialization settings.
     */
    private final ModelTypeInit embedding = new ModelTypeInit();

    /**
     * Whether to pull models at startup-time and how.
     */
    private PullModelStrategy pullModelStrategy = PullModelStrategy.NEVER;

    /**
     * How long to wait for a model to be pulled.
     */
    private Duration timeout = Duration.ofMinutes(5);

    /**
     * Maximum number of retries for the model pull operation.
     */
    private int maxRetries = 0;

    public PullModelStrategy getPullModelStrategy() {
        return this.pullModelStrategy;
    }

    public void setPullModelStrategy(PullModelStrategy pullModelStrategy) {
        this.pullModelStrategy = pullModelStrategy;
    }

    public ModelTypeInit getChat() {
        return this.chat;
    }

    public ModelTypeInit getEmbedding() {
        return this.embedding;
    }

    public Duration getTimeout() {
        return this.timeout;
    }

    public void setTimeout(Duration timeout) {
        this.timeout = timeout;
    }

    public int getMaxRetries() {
        return this.maxRetries;
    }

    public void setMaxRetries(int maxRetries) {
        this.maxRetries = maxRetries;
    }

    public static class ModelTypeInit {

        /**
         * Include this type of models in the initialization task.
         */
        private boolean include = true;

        /**
         * Additional models to initialize besides the ones configured via default
         * properties.
         */
        private List<String> additionalModels = List.of();

        public boolean isInclude() {
            return this.include;
        }

        public void setInclude(boolean include) {
            this.include = include;
        }

        public List<String> getAdditionalModels() {
            return this.additionalModels;
        }

        public void setAdditionalModels(List<String> additionalModels) {
            this.additionalModels = additionalModels;
        }

    }

}

OllamaInitializationProperties提供了spring.ai.ollama.init即ollama初始化的相關配置,其中ModelTypeInit可以指定初始化哪些額外的model

示例

pom.xml

<dependency>
   <groupId>org.springframework.ai</groupId>
   <artifactId>spring-ai-starter-model-ollama</artifactId>
</dependency>

配置

spring:
  ai:
    model:
      embedding: ollama
    ollama:
      init:
        timeout: 5m
        max-retries: 0
        embedding:
          include: true
          additional-models: []
      base-url: http://localhost:11434
      embedding:
        enabled: true
        options:
          model: bge-m3:latest
          truncate: true

example

    @Test
    public void testCall() {
        EmbeddingRequest request = new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"),
                OllamaOptions.builder()
                        .model("bge-m3:latest")
                        .truncate(false)
                        .build());
        EmbeddingResponse embeddingResponse = embeddingModel.call(request);
        log.info("resp:{}", JSON.toJSONString(embeddingResponse));
    }

小結

Spring AI定義了EmbeddingModel接口,它繼承了Model接口,其入參類型為EmbeddingRequest,返回類型為EmbeddingResponse,它定義了call、embed接口,提供了embed、embedForResponse、dimensions的默認實現;AbstractEmbeddingModel實現了EmbeddingModel接口定義的dimensions方法,它在不同模塊有不同的實現子類,比如spring-ai-openai的OpenAiEmbeddingModel、spring-ai-ollama的OllamaEmbeddingModel、spring-ai-minimax的MiniMaxEmbeddingModel等;OllamaEmbeddingAutoConfiguration在spring.ai.model.embeddingollama時啟用,它自動配置了OllamaEmbeddingModel。

doc

?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 228,197評論 6 531
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 98,415評論 3 415
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 176,104評論 0 373
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 62,884評論 1 309
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 71,647評論 6 408
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 55,130評論 1 323
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,208評論 3 441
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 42,366評論 0 288
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 48,887評論 1 334
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 40,737評論 3 354
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 42,939評論 1 369
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,478評論 5 358
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,174評論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,586評論 0 26
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 35,827評論 1 283
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 51,608評論 3 390
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 47,914評論 2 372

推薦閱讀更多精彩內容