Embed Jina Embeddings V4

Source examples/offline_inference/embed_jina_embeddings_v4.py.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of offline inference with Jina Embeddings V4 multimodal model.

This example demonstrates:
1. Text-only embeddings
2. Image-only embeddings
3. Cross-modal embeddings (text-to-image similarity)

The model supports both text and vision inputs through a unified architecture.
"""

import torch

from vllm import LLM
from vllm.config import PoolerConfig
from vllm.inputs.data import TextPrompt
from vllm.multimodal.utils import fetch_image


def get_embeddings(outputs):
    """Extract and normalize embeddings from model outputs."""
    VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653

    embeddings = []
    for output in outputs:
        if VISION_START_TOKEN_ID in output.prompt_token_ids:
            # For vision inputs, extract only vision token embeddings
            img_start_pos = output.prompt_token_ids.index(VISION_START_TOKEN_ID)
            img_end_pos = output.prompt_token_ids.index(VISION_END_TOKEN_ID)
            embeddings_tensor = output.outputs.data.detach().clone()[
                img_start_pos : img_end_pos + 1
            ]
        else:
            # For text-only inputs, use all token embeddings
            embeddings_tensor = output.outputs.data.detach().clone()

        # Pool and normalize embeddings
        pooled_output = embeddings_tensor.mean(dim=0, dtype=torch.float32)
        embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
    return embeddings


def main():
    # Initialize the model
    model = LLM(
        model="jinaai/jina-embeddings-v4-vllm-retrieval",
        task="embed",
        override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False),
        dtype="float16",
    )

    # Example 1: Text-only embeddings
    print("=== Text Embeddings ===")
    query = "Overview of climate change impacts on coastal cities"
    query_prompt = TextPrompt(prompt=f"Query: {query}")

    passage = """The impacts of climate change on coastal cities are significant
    and multifaceted. Rising sea levels threaten infrastructure, while increased
    storm intensity poses risks to populations and economies."""
    passage_prompt = TextPrompt(prompt=f"Passage: {passage}")

    # Generate embeddings
    text_outputs = model.encode([query_prompt, passage_prompt])
    text_embeddings = get_embeddings(text_outputs)

    # Calculate similarity
    similarity = torch.dot(text_embeddings[0], text_embeddings[1]).item()
    print(f"Query: {query[:50]}...")
    print(f"Passage: {passage[:50]}...")
    print(f"Similarity: {similarity:.4f}\n")

    # Example 2: Image embeddings
    print("=== Image Embeddings ===")
    # Fetch sample images
    image1_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
    image2_url = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"

    image1 = fetch_image(image1_url)
    image2 = fetch_image(image2_url)

    # Create image prompts with the required format
    image1_prompt = TextPrompt(
        prompt="<|im_start|>user\n<|vision_start|><|image_pad|>"
        "<|vision_end|>Describe the image.<|im_end|>\n",
        multi_modal_data={"image": image1},
    )

    image2_prompt = TextPrompt(
        prompt="<|im_start|>user\n<|vision_start|><|image_pad|>"
        "<|vision_end|>Describe the image.<|im_end|>\n",
        multi_modal_data={"image": image2},
    )

    # Generate embeddings
    image_outputs = model.encode([image1_prompt, image2_prompt])
    image_embeddings = get_embeddings(image_outputs)

    # Calculate similarity
    similarity = torch.dot(image_embeddings[0], image_embeddings[1]).item()
    print(f"Image 1: {image1_url.split('/')[-1]}")
    print(f"Image 2: {image2_url.split('/')[-1]}")
    print(f"Similarity: {similarity:.4f}\n")

    # Example 3: Cross-modal similarity (text vs image)
    print("=== Cross-modal Similarity ===")
    query = "scientific paper with markdown formatting"
    query_prompt = TextPrompt(prompt=f"Query: {query}")

    # Generate embeddings for text query and second image
    cross_outputs = model.encode([query_prompt, image2_prompt])
    cross_embeddings = get_embeddings(cross_outputs)

    # Calculate cross-modal similarity
    similarity = torch.dot(cross_embeddings[0], cross_embeddings[1]).item()
    print(f"Text query: {query}")
    print(f"Image: {image2_url.split('/')[-1]}")
    print(f"Cross-modal similarity: {similarity:.4f}")


if __name__ == "__main__":
    main()