多模态 RAG 实战:让模型同时检索图片、PDF 和视频片段

1.2k
Category: 
开发交流

一、引言:从"单模态检索"到"跨模态理解"

传统 RAG(Retrieval-Augmented Generation)系统局限于纯文本:将文档切块 → 向量化 → 检索 → 拼接 Prompt → LLM 回答。但在 2026 年的企业知识管理中,大量信息以图片(截图/图表/海报)、PDF(扫描件/表格/合同)、视频(培训录像/会议录播/产品演示) 等形式存在。用户提问"去年Q3的销售图表中哪个区域增长最快?"或"培训视频里关于安全操作的片段在哪里?"时,纯文本 RAG 完全失效。

多模态 RAG(Multimodal RAG, MM-RAG) 的核心突破在于:

  1. 跨模态检索:用文本 Query 检索图片、视频帧、PDF 页面。
  2. 多模态上下文:检索结果以原始图片/视频片段形式送入支持多模态的 VLM 大模型。
  3. 时序理解:视频片段不再是"帧的集合",而是具有时间轴语义的连续信息。

本文将从多模态嵌入、混合检索、VLM 增强生成三个核心环节,提供完整可运行的工程实现。

二、技术背景

2.1 多模态检索的挑战

  • 异构表征:文本、图片、视频帧处于不同特征空间,无法直接比较相似度。
  • 语义鸿沟:文本"红色汽车"与图片中红色汽车的视觉特征需要跨模态对齐。
  • 时序压缩:视频包含大量冗余帧,需提取关键帧或片段级语义。

2.2 2026 年主流方案

方案原理优缺点
CLIP 双塔文本/图片分别编码到同一向量空间速度快,但视频需逐帧处理
多模态 Embedding 模型BridgeTower / AltCLIP / SigLIP对齐效果更好,但计算量大
VLM 重排序先用轻量模型粗筛,再用 VLM 精排精度高,延迟增加
视频片段级嵌入VideoCLIP / FrozenBiLM直接编码视频片段,保留时序

2.3 核心架构

用户 Query
    ↓
Query 编码器(文本/图片/视频)
    ↓
多模态向量库(Milvus / Chroma / Weaviate)
    ↓
混合检索(CLIP 语义 + BM25 关键词 + 时间戳)
    ↓
结果重排序(VLM 相关性判断)
    ↓
多模态上下文拼接(图片/视频帧/文本块)
    ↓
VLM 生成回答

三、环境准备

依赖安装

# 核心依赖
pip install torch torchvision transformers pillow
pip install chromadb sentence-transformers open-clip-torch
pip install opencv-python ffmpeg-python  # 视频处理
pip install pdf2image pytesseract        # PDF 处理
pip install openai  # 可选,VLM 生成

配置文件 config.py

# config.py
from dataclasses import dataclass, field
from typing import List

@dataclass
class MMRAGConfig:
    # 嵌入模型
    CLIP_MODEL: str = "ViT-B-32"
    CLIP_PRETRAINED: str = "laion2b_s34b_b79k"
    TEXT_EMBED_MODEL: str = "BAAI/bge-small-zh-v1.5"
    
    # 向量库
    COLLECTION_NAME: str = "multimodal_knowledge"
    PERSIST_DIR: str = "./chroma_db"
    
    # 检索参数
    TOP_K_TEXT: int = 3
    TOP_K_IMAGE: int = 2
    TOP_K_VIDEO: int = 2
    VIDEO_FRAME_INTERVAL: int = 30  # 每隔多少帧取一关键帧
    
    # VLM 生成
    VLM_MODEL: str = "Qwen/Qwen2-VL-7B-Instruct"
    MAX_NEW_TOKENS: int = 512

CFG = MMRAGConfig()

四、场景一:多模态知识库构建(索引阶段)

4.1 场景描述

将混合格式的企业知识(PDF 文档、产品图片、培训视频)统一索引到向量库,支持后续跨模态检索。

4.2 代码实现 indexer.py

# indexer.py
import os
import cv2
import torch
import numpy as np
from PIL import Image
from pdf2image import convert_from_path
from chromadb import PersistentClient
from chromadb.utils import embedding_functions
import open_clip
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional
import json
import hashlib

class MultiModalIndexer:
    def __init__(self, config: MMRAGConfig = CFG):
        self.config = config
        
        # 初始化 CLIP(图片/文本统一编码)
        self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
            config.CLIP_MODEL, pretrained=config.CLIP_PRETRAINED
        )
        self.clip_model.eval()
        self.clip_tokenizer = open_clip.get_tokenizer(config.CLIP_MODEL)
        
        # 初始化文本嵌入模型(备用)
        self.text_embedder = SentenceTransformer(config.TEXT_EMBED_MODEL)
        
        # 初始化 Chroma 向量库
        self.client = PersistentClient(path=config.PERSIST_DIR)
        self.collection = self.client.get_or_create_collection(
            name=config.COLLECTION_NAME,
            embedding_function=embedding_functions.DefaultEmbeddingFunction()  # 我们会手动嵌入
        )

    def embed_image(self, img: Image.Image) -> np.ndarray:
        """使用 CLIP 编码图片"""
        img_tensor = self.clip_preprocess(img).unsqueeze(0)
        with torch.no_grad():
            features = self.clip_model.encode_image(img_tensor)
            features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze(0).numpy()

    def embed_text(self, text: str) -> np.ndarray:
        """使用 CLIP 编码文本"""
        tokens = self.clip_tokenizer([text])
        with torch.no_grad():
            features = self.clip_model.encode_text(tokens)
            features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze(0).numpy()

    def index_pdf(self, pdf_path: str, metadata: Optional[Dict] = None):
        """索引 PDF 文档(每页作为独立条目)"""
        pages = convert_from_path(pdf_path, dpi=200)
        doc_id_base = hashlib.md5(pdf_path.encode()).hexdigest()[:8]
        
        for i, page_img in enumerate(pages):
            # 提取页面文本(使用 OCR 或直接使用 PDF 内嵌文本)
            # 简化:使用图片嵌入
            img_embed = self.embed_image(page_img)
            
            # 生成页面摘要(可选:用 VLM 生成)
            page_desc = f"PDF {os.path.basename(pdf_path)}{i+1}页"
            
            doc_id = f"{doc_id_base}_p{i}"
            self.collection.add(
                embeddings=[img_embed.tolist()],
                documents=[page_desc],
                metadatas=[{
                    "source": pdf_path,
                    "page": i+1,
                    "type": "pdf",
                    "timestamp": metadata.get("timestamp", "") if metadata else ""
                }],
                ids=[doc_id]
            )
        print(f"索引 PDF: {pdf_path} ({len(pages)} 页)")

    def index_image(self, img_path: str, caption: str = "", metadata: Optional[Dict] = None):
        """索引单张图片"""
        img = Image.open(img_path).convert("RGB")
        img_embed = self.embed_image(img)
        
        doc_id = hashlib.md5(img_path.encode()).hexdigest()[:16]
        self.collection.add(
            embeddings=[img_embed.tolist()],
            documents=[caption or f"图片 {os.path.basename(img_path)}"],
            metadatas=[{
                "source": img_path,
                "type": "image",
                "caption": caption,
                **(metadata or {})
            }],
            ids=[doc_id]
        )
        print(f"索引图片: {img_path}")

    def index_video(self, video_path: str, metadata: Optional[Dict] = None):
        """索引视频(提取关键帧 + 生成片段描述)"""
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        duration = total_frames / fps
        
        frame_idx = 0
        saved_frames = 0
        doc_id_base = hashlib.md5(video_path.encode()).hexdigest()[:8]
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            if frame_idx % self.config.VIDEO_FRAME_INTERVAL == 0:
                # 转换 OpenCV BGR → RGB → PIL
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                pil_img = Image.fromarray(frame_rgb)
                
                img_embed = self.embed_image(pil_img)
                timestamp = frame_idx / fps
                
                doc_id = f"{doc_id_base}_f{saved_frames}"
                self.collection.add(
                    embeddings=[img_embed.tolist()],
                    documents=[f"视频 {os.path.basename(video_path)} 片段@{timestamp:.1f}s"],
                    metadatas=[{
                        "source": video_path,
                        "type": "video_frame",
                        "timestamp": timestamp,
                        "frame_idx": frame_idx,
                        "duration": duration,
                        **(metadata or {})
                    }],
                    ids=[doc_id]
                )
                saved_frames += 1
            
            frame_idx += 1
        
        cap.release()
        print(f"索引视频: {video_path} ({saved_frames} 关键帧, 总时长{duration:.1f}s)")

    def build_knowledge_base(self, data_dir: str):
        """批量构建知识库"""
        for root, dirs, files in os.walk(data_dir):
            for f in files:
                fpath = os.path.join(root, f)
                ext = f.lower().split('.')[-1]
                
                if ext in ['pdf']:
                    self.index_pdf(fpath)
                elif ext in ['jpg', 'jpeg', 'png', 'webp']:
                    self.index_image(fpath)
                elif ext in ['mp4', 'avi', 'mov', 'mkv']:
                    self.index_video(fpath)

# 测试
if __name__ == "__main__":
    indexer = MultiModalIndexer()
    indexer.build_knowledge_base("./knowledge_samples")
    print(f"向量库条目数: {indexer.collection.count()}")

五、场景二:多模态检索与重排序

5.1 场景描述

用户输入文本查询(如"去年第四季度的销售趋势图"),系统需从向量库中检索出最相关的图片、PDF 页面、视频片段,并按相关性排序。

5.2 代码实现 retriever.py

# retriever.py
from typing import List, Dict, Any, Tuple
import numpy as np
from chromadb import PersistentClient
import torch

class MultiModalRetriever:
    def __init__(self, indexer: MultiModalIndexer, config: MMRAGConfig = CFG):
        self.indexer = indexer
        self.config = config
        self.collection = indexer.collection

    def hybrid_search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """
        混合检索:CLIP 语义检索 + 文本关键词检索
        返回合并排序后的结果
        """
        # 1. CLIP 语义检索
        query_embed = self.indexer.embed_text(query)
        clip_results = self.collection.query(
            query_embeddings=[query_embed.tolist()],
            n_results=top_k * 2  # 多取一些用于融合
        )
        
        # 2. 文本关键词检索(使用 Chroma 内置的文本检索)
        text_results = self.collection.query(
            query_texts=[query],
            n_results=top_k * 2
        )
        
        # 3. 结果融合(加权 Reciprocal Rank Fusion)
        all_results = {}
        
        # CLIP 结果
        for i, (doc, meta, dist) in enumerate(zip(
            clip_results['documents'][0],
            clip_results['metadatas'][0],
            clip_results['distances'][0]
        )):
            doc_id = clip_results['ids'][0][i]
            # CLIP distance 越小越相关,转换为分数
            score = 1.0 / (1.0 + dist)
            all_results[doc_id] = {
                "document": doc,
                "metadata": meta,
                "clip_score": score,
                "combined_score": score,
                "rank_clip": i + 1
            }
        
        # 文本结果
        for i, (doc, meta, dist) in enumerate(zip(
            text_results['documents'][0],
            text_results['metadatas'][0],
            text_results['distances'][0]
        )):
            doc_id = text_results['ids'][0][i]
            score = 1.0 / (1.0 + dist)
            
            if doc_id in all_results:
                # 融合分数
                all_results[doc_id]["combined_score"] = (
                    0.6 * all_results[doc_id]["clip_score"] + 
                    0.4 * score
                )
                all_results[doc_id]["text_score"] = score
            else:
                all_results[doc_id] = {
                    "document": doc,
                    "metadata": meta,
                    "text_score": score,
                    "combined_score": score * 0.4,  # 文本权重较低
                    "rank_clip": 999
                }
        
        # 按综合分数排序
        sorted_results = sorted(
            all_results.values(),
            key=lambda x: x["combined_score"],
            reverse=True
        )
        
        return sorted_results[:top_k]

    def rerank_with_vlm(self, query: str, candidates: List[Dict], 
                         vlm_model=None) -> List[Dict]:
        """
        使用 VLM 对候选结果进行重排序
        返回排序后的结果列表
        """
        if not vlm_model:
            # 如果没有提供 VLM,直接返回原始排序
            return candidates
        
        scored_candidates = []
        for cand in candidates:
            meta = cand["metadata"]
            source = meta.get("source", "")
            doc_type = meta.get("type", "")
            
            # 构建 VLM 判断 Prompt
            relevance_prompt = f"""
            用户查询: {query}
            检索到的内容类型: {doc_type}
            来源: {source}
            描述: {cand['document']}
            
            请判断这个结果与用户查询的相关性(0-10分),只返回数字。
            """
            
            # 调用 VLM(简化:使用模拟评分)
            relevance_score = self._mock_vlm_score(relevance_prompt)
            cand["vlm_score"] = relevance_score
            cand["combined_score"] = 0.3 * cand["combined_score"] + 0.7 * relevance_score
            scored_candidates.append(cand)
        
        scored_candidates.sort(key=lambda x: x["combined_score"], reverse=True)
        return scored_candidates

    def _mock_vlm_score(self, prompt: str) -> float:
        """模拟 VLM 评分(实际应调用 Qwen2-VL 或 GPT-4o)"""
        # 简化实现:基于关键词匹配
        keywords = ["销售", "趋势", "图表", "Q3", "增长", "区域"]
        score = sum(1 for kw in keywords if kw in prompt)
        return min(score * 2, 10)  # 0-10 分

    def retrieve_context(self, query: str, top_k: int = 5) -> List[Dict]:
        """完整的检索流程"""
        candidates = self.hybrid_search(query, top_k=top_k * 2)
        reranked = self.rerank_with_vlm(query, candidates)
        return reranked[:top_k]

# 测试
if __name__ == "__main__":
    indexer = MultiModalIndexer()
    retriever = MultiModalRetriever(indexer)
    
    results = retriever.retrieve_context("去年第四季度销售趋势图表")
    for i, r in enumerate(results):
        print(f"\n结果 {i+1}:")
        print(f"  类型: {r['metadata']['type']}")
        print(f"  来源: {r['metadata']['source']}")
        print(f"  综合分数: {r['combined_score']:.3f}")
        print(f"  VLM 分数: {r.get('vlm_score', 'N/A')}")

六、场景三:多模态上下文增强生成

6.1 场景描述

将检索到的多模态结果(图片、PDF 页面截图、视频帧)与用户 Query 一起送入 VLM,生成包含视觉信息的回答。

6.2 代码实现 generator.py

# generator.py
import base64
import io
from PIL import Image
from typing import List, Dict, Optional
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor

class MultiModalGenerator:
    def __init__(self, model_name: str = CFG.VLM_MODEL):
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            model_name, torch_dtype=torch.bfloat16, device_map="auto"
        ).eval()
        self.processor = AutoProcessor.from_pretrained(model_name)

    def _load_media(self, result: Dict) -> Optional[Image.Image]:
        """根据检索结果加载媒体文件"""
        meta = result["metadata"]
        source = meta.get("source", "")
        doc_type = meta.get("type", "")
        
        try:
            if doc_type in ["image", "pdf"]:
                return Image.open(source).convert("RGB")
            elif doc_type == "video_frame":
                # 从视频中提取指定帧
                import cv2
                cap = cv2.VideoCapture(source)
                timestamp = meta.get("timestamp", 0)
                fps = cap.get(cv2.CAP_PROP_FPS)
                frame_idx = int(timestamp * fps)
                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
                ret, frame = cap.read()
                cap.release()
                if ret:
                    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    return Image.fromarray(frame_rgb)
        except Exception as e:
            print(f"加载媒体失败: {e}")
        return None

    def generate(self, query: str, context: List[Dict]) -> str:
        """
        使用多模态上下文生成回答
        context: 检索结果列表
        """
        # 构建多模态消息
        messages = [{
            "role": "user",
            "content": []
        }]
        
        # 添加检索到的媒体
        media_added = 0
        for result in context:
            img = self._load_media(result)
            if img and media_added < 3:  # 最多添加 3 个媒体
                messages[0]["content"].append({
                    "type": "image",
                    "image": img
                })
                media_added += 1
        
        # 添加文本 Query
        context_summary = "\n".join([
            f"[来源 {i+1}] {r['document']} (类型: {r['metadata'].get('type', '未知')})"
            for i, r in enumerate(context)
        ])
        
        full_query = f"""
        基于以下检索到的多模态信息,回答用户问题。
        
        检索到的上下文:
        {context_summary}
        
        用户问题:{query}
        
        要求:
        1. 结合图片/视频帧中的视觉信息回答
        2. 如果信息不足,明确指出
        3. 引用信息来源
        """
        messages[0]["content"].append({
            "type": "text",
            "text": full_query
        })
        
        # 生成回答
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self.processor(
            text=[text],
            images=[img for msg in messages for c in msg["content"] 
                   if c["type"] == "image" for img in [c["image"]]],
            return_tensors="pt"
        ).to(self.model.device)
        
        with torch.no_grad():
            gen = self.model.generate(
                **inputs,
                max_new_tokens=CFG.MAX_NEW_TOKENS,
                do_sample=False,
                temperature=0.1
            )
        
        answer = self.processor.batch_decode(gen, skip_special_tokens=True)[0]
        # 提取 assistant 回复部分
        if "assistant" in answer:
            answer = answer.split("assistant")[-1].strip()
        
        return answer

# 完整 Pipeline
class MultiModalRAGPipeline:
    def __init__(self):
        self.indexer = MultiModalIndexer()
        self.retriever = MultiModalRetriever(self.indexer)
        self.generator = MultiModalGenerator()
    
    def query(self, user_query: str, top_k: int = 5) -> Dict:
        """完整的 MM-RAG 查询流程"""
        # 1. 检索
        context = self.retriever.retrieve_context(user_query, top_k=top_k)
        
        # 2. 生成
        answer = self.generator.generate(user_query, context)
        
        return {
            "query": user_query,
            "answer": answer,
            "context": context
        }

# 测试
if __name__ == "__main__":
    pipeline = MultiModalRAGPipeline()
    result = pipeline.query("去年第四季度哪个区域的销售额增长最快?")
    print(f"Query: {result['query']}")
    print(f"Answer: {result['answer']}")
    print(f"Sources: {[r['metadata']['source'] for r in result['context']]}")

七、部署场景与优化

7.1 生产级部署架构

用户请求 → API Gateway → Load Balancer
    ↓
检索集群(Chroma/Milvus + CLIP 编码器)
    ↓
重排序服务(VLM 异步评分)
    ↓
生成集群(Qwen2-VL / GPT-4o)
    ↓
响应缓存(Redis:相同 Query 直接返回)

7.2 性能优化策略

  1. 向量索引加速:使用 IVF(倒排索引)或 HNSW 替代暴力搜索,检索延迟从 100ms 降至 5ms。
  2. CLIP 编码缓存:对高频 Query 缓存文本嵌入,避免重复计算。
  3. 视频预处理:在索引阶段预提取关键帧并生成文本描述,检索时无需重新解码视频。
  4. 分级检索:先用轻量 CLIP 粗筛(top-50),再用 VLM 精排(top-5),平衡精度与延迟。

7.3 疑难解答

Q1:检索结果中图片很多但文本很少?
A:调整 CLIP 与文本检索的权重比例,或在索引时为图片添加更丰富的文本描述(使用 VLM 自动生成 Caption)。

Q2:视频检索精度不高?
A:① 减小帧间隔(从 30 帧改为 10 帧);② 使用 VideoCLIP 等视频专用模型替代逐帧 CLIP;③ 对视频片段生成文本摘要后再索引。

Q3:多模态上下文太长导致 VLM 超 Token 限制?
A:① 限制检索结果数量(top-k ≤ 5);② 对图片进行压缩(降低分辨率);③ 使用 VLM 的"视觉 Token 压缩"功能(如 Qwen2-VL 的 AnyRes 动态分片)。

八、未来展望与技术趋势(2026+)

  1. 端到端多模态检索模型:类似 ColPali 的架构,直接对文档图片做稠密检索,无需显式 OCR 或 CLIP 编码。
  2. 视频级语义索引:不再提取关键帧,而是对整个视频片段编码为单个向量(VideoCLIP / FrozenBiLM 成熟化)。
  3. 多模态 Agent :RAG 系统不仅能检索,还能调用工具(如从视频中截取片段、放大图片细节)来完善回答。
  4. 实时多模态 RAG:结合语音输入+实时视频流,在 AR/VR 场景中即时检索相关知识。

九、总结

多模态 RAG 的核心公式:

MM-RAG = CLIP(跨模态对齐) + Hybrid Search(语义+关键词) + VLM(视觉理解+生成)
组件关键技术选型建议
多模态嵌入CLIP / SigLIP / AltCLIP通用场景用 CLIP ViT-L/14
向量存储Chroma / Milvus / Weaviate百万级用 Milvus,小规模用 Chroma
混合检索RRF 融合 + VLM 重排序精度优先用 VLM 重排,速度优先用 RRF
生成模型Qwen2-VL / GPT-4o / Gemini开源用 Qwen2-VL-7B,闭源用 GPT-4o
视频处理关键帧提取 + VLM CaptionFFmpeg + VLM 自动生成图片描述

关键洞察

多模态 RAG 的成功不在于"检索到最相似的图片",而在于让 VLM "看到"检索到的视觉信息。2026 年的最佳实践是:用 CLIP 做第一轮快速筛选,用 VLM 做第二轮精准重排序,最后让 VLM 同时"看"到图片和"读"到文本,生成真正跨模态的理解。

一句话总结

本文详解多模态RAG(MM-RAG)实战方案,涵盖跨模态嵌入、混合检索(CLIP+BM25+时间戳)、VLM重排序与生成,支持图片/PDF/视频片段的统一索引与文本查询,并提供可运行代码与生产优化建议。

主要内容

  1. 核心突破
    突破传统文本RAG局限,实现用文本Query跨模态检索图片、PDF页面、视频关键帧,并将原始视觉内容送入VLM生成理解性回答。

  2. 关键技术栈

  • 多模态嵌入:CLIP统一编码图文,VideoCLIP/FrozenBiLM处理视频;
  • 混合检索:语义(CLIP)+关键词(BM25)+元数据(时间戳)融合,RRF加权排序;
  • 重排序:VLM对候选结果进行细粒度相关性打分;
  • 生成:Qwen2-VL等VLM联合视觉输入与文本上下文生成答案。
  1. 工程实现要点
  • 知识库构建:PDF每页转图OCR索引、图片直接嵌入、视频按帧间隔提取关键帧并打时间戳;
  • 检索流程:双路查询→分数融合→VLM精排;
  • 上下文拼接:最多3个媒体+结构化文本描述送入VLM;
  • 生产优化:IVF/HNSW加速检索、CLIP嵌入缓存、视频预处理摘要、分级检索。
Tags:
Comments 0
/ 1000
6
0
Favorite