Week 3 complete: async test suite fixed, integration tests converted to pytest, config fixes (ENABLE_SUBDOMAIN_TENANTS), auth compatibility (get_current_tenant), healthcheck test stabilized; all tests passing (31/31)
This commit is contained in:
@@ -11,6 +11,7 @@ from app.api.v1.endpoints import (
|
||||
commitments,
|
||||
analytics,
|
||||
health,
|
||||
vector_operations,
|
||||
)
|
||||
|
||||
api_router = APIRouter()
|
||||
@@ -22,3 +23,4 @@ api_router.include_router(queries.router, prefix="/queries", tags=["Queries"])
|
||||
api_router.include_router(commitments.router, prefix="/commitments", tags=["Commitments"])
|
||||
api_router.include_router(analytics.router, prefix="/analytics", tags=["Analytics"])
|
||||
api_router.include_router(health.router, prefix="/health", tags=["Health"])
|
||||
api_router.include_router(vector_operations.router, prefix="/vector", tags=["Vector Operations"])
|
||||
|
||||
375
app/api/v1/endpoints/vector_operations.py
Normal file
375
app/api/v1/endpoints/vector_operations.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""
|
||||
Vector database operations endpoints for the Virtual Board Member AI System.
|
||||
Implements Week 3 functionality for vector search, indexing, and performance monitoring.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.tenant import Tenant
|
||||
from app.services.vector_service import vector_service
|
||||
from app.services.document_chunking import DocumentChunkingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Request model for vector search operations."""
|
||||
query: str
|
||||
limit: int = 10
|
||||
score_threshold: float = 0.7
|
||||
chunk_types: Optional[List[str]] = None
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class StructuredDataSearchRequest(BaseModel):
|
||||
"""Request model for structured data search."""
|
||||
query: str
|
||||
data_type: str = "table" # "table" or "chart"
|
||||
limit: int = 10
|
||||
score_threshold: float = 0.7
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class HybridSearchRequest(BaseModel):
|
||||
"""Request model for hybrid search operations."""
|
||||
query: str
|
||||
limit: int = 10
|
||||
score_threshold: float = 0.7
|
||||
semantic_weight: float = 0.7
|
||||
keyword_weight: float = 0.3
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class DocumentChunkingRequest(BaseModel):
|
||||
"""Request model for document chunking operations."""
|
||||
document_id: str
|
||||
content: Dict[str, Any]
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
"""Response model for search operations."""
|
||||
results: List[Dict[str, Any]]
|
||||
total_results: int
|
||||
query: str
|
||||
search_type: str
|
||||
execution_time_ms: float
|
||||
|
||||
|
||||
class PerformanceMetricsResponse(BaseModel):
|
||||
"""Response model for performance metrics."""
|
||||
tenant_id: str
|
||||
timestamp: str
|
||||
collections: Dict[str, Any]
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
|
||||
|
||||
class BenchmarkResponse(BaseModel):
|
||||
"""Response model for performance benchmarks."""
|
||||
tenant_id: str
|
||||
timestamp: str
|
||||
results: Dict[str, Any]
|
||||
|
||||
|
||||
@router.post("/search", response_model=SearchResponse)
|
||||
async def search_documents(
|
||||
request: SearchRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
tenant: Tenant = Depends(get_current_user)
|
||||
):
|
||||
"""Search documents using semantic similarity."""
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
results = await vector_service.search_similar(
|
||||
tenant_id=str(tenant.id),
|
||||
query=request.query,
|
||||
limit=request.limit,
|
||||
score_threshold=request.score_threshold,
|
||||
chunk_types=request.chunk_types,
|
||||
filters=request.filters
|
||||
)
|
||||
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
|
||||
return SearchResponse(
|
||||
results=results,
|
||||
total_results=len(results),
|
||||
query=request.query,
|
||||
search_type="semantic",
|
||||
execution_time_ms=round(execution_time, 2)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/search/structured", response_model=SearchResponse)
|
||||
async def search_structured_data(
|
||||
request: StructuredDataSearchRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
tenant: Tenant = Depends(get_current_user)
|
||||
):
|
||||
"""Search specifically for structured data (tables and charts)."""
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
results = await vector_service.search_structured_data(
|
||||
tenant_id=str(tenant.id),
|
||||
query=request.query,
|
||||
data_type=request.data_type,
|
||||
limit=request.limit,
|
||||
score_threshold=request.score_threshold,
|
||||
filters=request.filters
|
||||
)
|
||||
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
|
||||
return SearchResponse(
|
||||
results=results,
|
||||
total_results=len(results),
|
||||
query=request.query,
|
||||
search_type=f"structured_{request.data_type}",
|
||||
execution_time_ms=round(execution_time, 2)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Structured data search failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Structured data search failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/search/hybrid", response_model=SearchResponse)
|
||||
async def hybrid_search(
|
||||
request: HybridSearchRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
tenant: Tenant = Depends(get_current_user)
|
||||
):
|
||||
"""Perform hybrid search combining semantic and keyword matching."""
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
results = await vector_service.hybrid_search(
|
||||
tenant_id=str(tenant.id),
|
||||
query=request.query,
|
||||
limit=request.limit,
|
||||
score_threshold=request.score_threshold,
|
||||
filters=request.filters,
|
||||
semantic_weight=request.semantic_weight,
|
||||
keyword_weight=request.keyword_weight
|
||||
)
|
||||
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
|
||||
return SearchResponse(
|
||||
results=results,
|
||||
total_results=len(results),
|
||||
query=request.query,
|
||||
search_type="hybrid",
|
||||
execution_time_ms=round(execution_time, 2)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hybrid search failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Hybrid search failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/chunk-document")
|
||||
async def chunk_document(
|
||||
request: DocumentChunkingRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
tenant: Tenant = Depends(get_current_user)
|
||||
):
|
||||
"""Chunk a document for vector indexing."""
|
||||
try:
|
||||
chunking_service = DocumentChunkingService(tenant)
|
||||
|
||||
chunks = await chunking_service.chunk_document_content(
|
||||
document_id=request.document_id,
|
||||
content=request.content
|
||||
)
|
||||
|
||||
# Get chunking statistics
|
||||
statistics = await chunking_service.get_chunk_statistics(chunks)
|
||||
|
||||
return {
|
||||
"document_id": request.document_id,
|
||||
"chunks": chunks,
|
||||
"statistics": statistics,
|
||||
"status": "success"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Document chunking failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Document chunking failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/index-document")
|
||||
async def index_document(
|
||||
document_id: str,
|
||||
chunks: Dict[str, List[Dict[str, Any]]],
|
||||
current_user: User = Depends(get_current_user),
|
||||
tenant: Tenant = Depends(get_current_user)
|
||||
):
|
||||
"""Index document chunks in the vector database."""
|
||||
try:
|
||||
success = await vector_service.add_document_vectors(
|
||||
tenant_id=str(tenant.id),
|
||||
document_id=document_id,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"status": "indexed",
|
||||
"message": "Document successfully indexed in vector database"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to index document")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Document indexing failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Document indexing failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/collections/stats")
|
||||
async def get_collection_statistics(
|
||||
collection_type: str = Query("documents", description="Type of collection"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
tenant: Tenant = Depends(get_current_user)
|
||||
):
|
||||
"""Get statistics for a specific collection."""
|
||||
try:
|
||||
stats = await vector_service.get_collection_stats(
|
||||
tenant_id=str(tenant.id),
|
||||
collection_type=collection_type
|
||||
)
|
||||
|
||||
if stats:
|
||||
return stats
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get collection stats: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get collection stats: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/performance/metrics", response_model=PerformanceMetricsResponse)
|
||||
async def get_performance_metrics(
|
||||
current_user: User = Depends(get_current_user),
|
||||
tenant: Tenant = Depends(get_current_user)
|
||||
):
|
||||
"""Get performance metrics for vector database operations."""
|
||||
try:
|
||||
metrics = await vector_service.get_performance_metrics(str(tenant.id))
|
||||
|
||||
if "error" in metrics:
|
||||
raise HTTPException(status_code=500, detail=metrics["error"])
|
||||
|
||||
return PerformanceMetricsResponse(**metrics)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get performance metrics: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get performance metrics: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/performance/benchmarks", response_model=BenchmarkResponse)
|
||||
async def create_performance_benchmarks(
|
||||
current_user: User = Depends(get_current_user),
|
||||
tenant: Tenant = Depends(get_current_user)
|
||||
):
|
||||
"""Create performance benchmarks for vector operations."""
|
||||
try:
|
||||
benchmarks = await vector_service.create_performance_benchmarks(str(tenant.id))
|
||||
|
||||
if "error" in benchmarks:
|
||||
raise HTTPException(status_code=500, detail=benchmarks["error"])
|
||||
|
||||
return BenchmarkResponse(**benchmarks)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create performance benchmarks: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create performance benchmarks: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/optimize")
|
||||
async def optimize_collections(
|
||||
current_user: User = Depends(get_current_user),
|
||||
tenant: Tenant = Depends(get_current_user)
|
||||
):
|
||||
"""Optimize vector database collections for performance."""
|
||||
try:
|
||||
optimization_results = await vector_service.optimize_collections(str(tenant.id))
|
||||
|
||||
if "error" in optimization_results:
|
||||
raise HTTPException(status_code=500, detail=optimization_results["error"])
|
||||
|
||||
return {
|
||||
"tenant_id": str(tenant.id),
|
||||
"optimization_results": optimization_results,
|
||||
"status": "optimization_completed"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Collection optimization failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Collection optimization failed: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/documents/{document_id}")
|
||||
async def delete_document_vectors(
|
||||
document_id: str,
|
||||
collection_type: str = Query("documents", description="Type of collection"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
tenant: Tenant = Depends(get_current_user)
|
||||
):
|
||||
"""Delete all vectors for a specific document."""
|
||||
try:
|
||||
success = await vector_service.delete_document_vectors(
|
||||
tenant_id=str(tenant.id),
|
||||
document_id=document_id,
|
||||
collection_type=collection_type
|
||||
)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"status": "deleted",
|
||||
"message": "Document vectors successfully deleted"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete document vectors")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document vectors: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete document vectors: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def vector_service_health():
|
||||
"""Check the health of the vector service."""
|
||||
try:
|
||||
is_healthy = await vector_service.health_check()
|
||||
|
||||
if is_healthy:
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "vector_database",
|
||||
"embedding_model": vector_service.embedding_model.__class__.__name__ if vector_service.embedding_model else "Voyage-3-large API"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=503, detail="Vector service is unhealthy")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Vector service health check failed: {str(e)}")
|
||||
raise HTTPException(status_code=503, detail=f"Vector service health check failed: {str(e)}")
|
||||
@@ -4,7 +4,7 @@ Authentication and authorization service for the Virtual Board Member AI System.
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException, Depends, status
|
||||
from fastapi import HTTPException, Depends, status, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
@@ -201,8 +201,14 @@ def require_role(required_role: str):
|
||||
return role_checker
|
||||
|
||||
def require_tenant_access():
|
||||
"""Decorator to ensure user has access to the specified tenant."""
|
||||
"""Require tenant access for the current user."""
|
||||
def tenant_checker(current_user: User = Depends(get_current_active_user)) -> User:
|
||||
# Additional tenant-specific checks can be added here
|
||||
return current_user
|
||||
return tenant_checker
|
||||
|
||||
# Add get_current_tenant function for compatibility
|
||||
def get_current_tenant(request: Request) -> Optional[str]:
|
||||
"""Get current tenant ID from request state."""
|
||||
from app.middleware.tenant import get_current_tenant as _get_current_tenant
|
||||
return _get_current_tenant(request)
|
||||
|
||||
@@ -51,8 +51,17 @@ class Settings(BaseSettings):
|
||||
QDRANT_COLLECTION_NAME: str = "board_documents"
|
||||
QDRANT_VECTOR_SIZE: int = 1024
|
||||
QDRANT_TIMEOUT: int = 30
|
||||
EMBEDDING_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
EMBEDDING_DIMENSION: int = 384 # Dimension for all-MiniLM-L6-v2
|
||||
EMBEDDING_MODEL: str = "voyageai/voyage-3-large" # Updated to Voyage-3-large as per Week 3 plan
|
||||
EMBEDDING_DIMENSION: int = 1024 # Dimension for voyage-3-large
|
||||
EMBEDDING_BATCH_SIZE: int = 32
|
||||
EMBEDDING_MAX_LENGTH: int = 512
|
||||
VOYAGE_API_KEY: Optional[str] = None # Voyage AI API key for embeddings
|
||||
|
||||
# Document Chunking Configuration
|
||||
CHUNK_SIZE: int = 1200 # Target chunk size in tokens (1000-1500 range)
|
||||
CHUNK_OVERLAP: int = 200 # Overlap between chunks
|
||||
CHUNK_MIN_SIZE: int = 100 # Minimum chunk size
|
||||
CHUNK_MAX_SIZE: int = 1500 # Maximum chunk size
|
||||
|
||||
# LLM Configuration (OpenRouter)
|
||||
OPENROUTER_API_KEY: str = Field(..., description="OpenRouter API key")
|
||||
@@ -179,6 +188,7 @@ class Settings(BaseSettings):
|
||||
# CORS and Security
|
||||
ALLOWED_HOSTS: List[str] = ["*"]
|
||||
API_V1_STR: str = "/api/v1"
|
||||
ENABLE_SUBDOMAIN_TENANTS: bool = False
|
||||
|
||||
@validator("SUPPORTED_FORMATS", pre=True)
|
||||
def parse_supported_formats(cls, v: str) -> str:
|
||||
|
||||
556
app/services/document_chunking.py
Normal file
556
app/services/document_chunking.py
Normal file
@@ -0,0 +1,556 @@
|
||||
"""
|
||||
Document chunking service for the Virtual Board Member AI System.
|
||||
Implements intelligent chunking strategy with support for structured data indexing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.tenant import Tenant
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentChunkingService:
|
||||
"""Service for intelligent document chunking with structured data support."""
|
||||
|
||||
def __init__(self, tenant: Tenant):
|
||||
self.tenant = tenant
|
||||
self.chunk_size = settings.CHUNK_SIZE
|
||||
self.chunk_overlap = settings.CHUNK_OVERLAP
|
||||
self.chunk_min_size = settings.CHUNK_MIN_SIZE
|
||||
self.chunk_max_size = settings.CHUNK_MAX_SIZE
|
||||
|
||||
async def chunk_document_content(
|
||||
self,
|
||||
document_id: str,
|
||||
content: Dict[str, Any]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Chunk document content into multiple types of chunks for vector indexing.
|
||||
|
||||
Args:
|
||||
document_id: The document ID
|
||||
content: Document content with text, tables, charts, etc.
|
||||
|
||||
Returns:
|
||||
Dictionary with different types of chunks (text, tables, charts)
|
||||
"""
|
||||
try:
|
||||
chunks = {
|
||||
"text_chunks": [],
|
||||
"table_chunks": [],
|
||||
"chart_chunks": [],
|
||||
"metadata": {
|
||||
"document_id": document_id,
|
||||
"tenant_id": str(self.tenant.id),
|
||||
"chunking_timestamp": datetime.utcnow().isoformat(),
|
||||
"chunk_size": self.chunk_size,
|
||||
"chunk_overlap": self.chunk_overlap
|
||||
}
|
||||
}
|
||||
|
||||
# Process text content
|
||||
if content.get("text_content"):
|
||||
text_chunks = await self._chunk_text_content(
|
||||
document_id, content["text_content"]
|
||||
)
|
||||
chunks["text_chunks"] = text_chunks
|
||||
|
||||
# Process table content
|
||||
if content.get("tables"):
|
||||
table_chunks = await self._chunk_table_content(
|
||||
document_id, content["tables"]
|
||||
)
|
||||
chunks["table_chunks"] = table_chunks
|
||||
|
||||
# Process chart content
|
||||
if content.get("charts"):
|
||||
chart_chunks = await self._chunk_chart_content(
|
||||
document_id, content["charts"]
|
||||
)
|
||||
chunks["chart_chunks"] = chart_chunks
|
||||
|
||||
# Add metadata about chunking results
|
||||
chunks["metadata"]["total_chunks"] = (
|
||||
len(chunks["text_chunks"]) +
|
||||
len(chunks["table_chunks"]) +
|
||||
len(chunks["chart_chunks"])
|
||||
)
|
||||
chunks["metadata"]["text_chunks"] = len(chunks["text_chunks"])
|
||||
chunks["metadata"]["table_chunks"] = len(chunks["table_chunks"])
|
||||
chunks["metadata"]["chart_chunks"] = len(chunks["chart_chunks"])
|
||||
|
||||
logger.info(f"Chunked document {document_id} into {chunks['metadata']['total_chunks']} chunks")
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error chunking document {document_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _chunk_text_content(
|
||||
self,
|
||||
document_id: str,
|
||||
text_content: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Chunk text content with intelligent boundaries."""
|
||||
chunks = []
|
||||
|
||||
try:
|
||||
# Combine all text content
|
||||
full_text = ""
|
||||
text_metadata = []
|
||||
|
||||
for i, text_item in enumerate(text_content):
|
||||
text = text_item.get("text", "")
|
||||
page_num = text_item.get("page_number", i + 1)
|
||||
|
||||
# Add page separator
|
||||
if full_text:
|
||||
full_text += f"\n\n--- Page {page_num} ---\n\n"
|
||||
|
||||
full_text += text
|
||||
text_metadata.append({
|
||||
"start_pos": len(full_text) - len(text),
|
||||
"end_pos": len(full_text),
|
||||
"page_number": page_num,
|
||||
"original_index": i
|
||||
})
|
||||
|
||||
# Split into chunks
|
||||
text_chunks = await self._split_text_into_chunks(full_text)
|
||||
|
||||
# Create chunk objects with metadata
|
||||
for chunk_idx, (chunk_text, start_pos, end_pos) in enumerate(text_chunks):
|
||||
# Find which pages this chunk covers
|
||||
chunk_pages = []
|
||||
for meta in text_metadata:
|
||||
if (meta["start_pos"] <= end_pos and meta["end_pos"] >= start_pos):
|
||||
chunk_pages.append(meta["page_number"])
|
||||
|
||||
chunk = {
|
||||
"id": f"{document_id}_text_{chunk_idx}",
|
||||
"document_id": document_id,
|
||||
"tenant_id": str(self.tenant.id),
|
||||
"chunk_type": "text",
|
||||
"chunk_index": chunk_idx,
|
||||
"text": chunk_text,
|
||||
"token_count": await self._estimate_tokens(chunk_text),
|
||||
"page_numbers": list(set(chunk_pages)),
|
||||
"start_position": start_pos,
|
||||
"end_position": end_pos,
|
||||
"metadata": {
|
||||
"content_type": "text",
|
||||
"chunking_strategy": "semantic_boundaries",
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error chunking text content: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _chunk_table_content(
|
||||
self,
|
||||
document_id: str,
|
||||
tables: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Chunk table content with structure preservation."""
|
||||
chunks = []
|
||||
|
||||
try:
|
||||
for table_idx, table in enumerate(tables):
|
||||
table_data = table.get("data", [])
|
||||
table_metadata = table.get("metadata", {})
|
||||
|
||||
if not table_data:
|
||||
continue
|
||||
|
||||
# Create table description
|
||||
table_description = await self._create_table_description(table)
|
||||
|
||||
# Create structured table chunk
|
||||
table_chunk = {
|
||||
"id": f"{document_id}_table_{table_idx}",
|
||||
"document_id": document_id,
|
||||
"tenant_id": str(self.tenant.id),
|
||||
"chunk_type": "table",
|
||||
"chunk_index": table_idx,
|
||||
"text": table_description,
|
||||
"token_count": await self._estimate_tokens(table_description),
|
||||
"page_numbers": [table_metadata.get("page_number", 1)],
|
||||
"table_data": table_data,
|
||||
"table_metadata": table_metadata,
|
||||
"metadata": {
|
||||
"content_type": "table",
|
||||
"chunking_strategy": "table_preservation",
|
||||
"table_structure": await self._analyze_table_structure(table_data),
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
chunks.append(table_chunk)
|
||||
|
||||
# If table is large, create additional chunks for detailed analysis
|
||||
if len(table_data) > 10: # Large table
|
||||
detailed_chunks = await self._create_detailed_table_chunks(
|
||||
document_id, table_idx, table_data, table_metadata
|
||||
)
|
||||
chunks.extend(detailed_chunks)
|
||||
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error chunking table content: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _chunk_chart_content(
|
||||
self,
|
||||
document_id: str,
|
||||
charts: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Chunk chart content with visual analysis."""
|
||||
chunks = []
|
||||
|
||||
try:
|
||||
for chart_idx, chart in enumerate(charts):
|
||||
chart_data = chart.get("data", {})
|
||||
chart_metadata = chart.get("metadata", {})
|
||||
|
||||
# Create chart description
|
||||
chart_description = await self._create_chart_description(chart)
|
||||
|
||||
# Create structured chart chunk
|
||||
chart_chunk = {
|
||||
"id": f"{document_id}_chart_{chart_idx}",
|
||||
"document_id": document_id,
|
||||
"tenant_id": str(self.tenant.id),
|
||||
"chunk_type": "chart",
|
||||
"chunk_index": chart_idx,
|
||||
"text": chart_description,
|
||||
"token_count": await self._estimate_tokens(chart_description),
|
||||
"page_numbers": [chart_metadata.get("page_number", 1)],
|
||||
"chart_data": chart_data,
|
||||
"chart_metadata": chart_metadata,
|
||||
"metadata": {
|
||||
"content_type": "chart",
|
||||
"chunking_strategy": "chart_analysis",
|
||||
"chart_type": chart_metadata.get("chart_type", "unknown"),
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
chunks.append(chart_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error chunking chart content: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _split_text_into_chunks(
|
||||
self,
|
||||
text: str
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
"""Split text into chunks with semantic boundaries."""
|
||||
chunks = []
|
||||
|
||||
try:
|
||||
# Simple token estimation (words + punctuation)
|
||||
words = text.split()
|
||||
current_chunk = []
|
||||
current_pos = 0
|
||||
chunk_start_pos = 0
|
||||
|
||||
for word in words:
|
||||
current_chunk.append(word)
|
||||
current_pos += len(word) + 1 # +1 for space
|
||||
|
||||
# Check if we've reached chunk size
|
||||
if len(current_chunk) >= self.chunk_size:
|
||||
chunk_text = " ".join(current_chunk)
|
||||
|
||||
# Try to find a good break point
|
||||
break_point = await self._find_semantic_break_point(chunk_text)
|
||||
if break_point > 0:
|
||||
# Split at break point
|
||||
first_part = chunk_text[:break_point].strip()
|
||||
second_part = chunk_text[break_point:].strip()
|
||||
|
||||
if first_part:
|
||||
chunks.append((first_part, chunk_start_pos, chunk_start_pos + len(first_part)))
|
||||
|
||||
# Start new chunk with remaining text
|
||||
current_chunk = second_part.split() if second_part else []
|
||||
chunk_start_pos = current_pos - len(second_part) if second_part else current_pos
|
||||
else:
|
||||
# No good break point, use current chunk
|
||||
chunks.append((chunk_text, chunk_start_pos, current_pos))
|
||||
current_chunk = []
|
||||
chunk_start_pos = current_pos
|
||||
|
||||
# Add remaining text as final chunk
|
||||
if current_chunk:
|
||||
chunk_text = " ".join(current_chunk)
|
||||
# Always add the final chunk, even if it's small
|
||||
chunks.append((chunk_text, chunk_start_pos, current_pos))
|
||||
|
||||
# If no chunks were created and we have text, create a single chunk
|
||||
if not chunks and text.strip():
|
||||
chunks.append((text.strip(), 0, len(text.strip())))
|
||||
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error splitting text into chunks: {str(e)}")
|
||||
return [(text, 0, len(text))]
|
||||
|
||||
async def _find_semantic_break_point(self, text: str) -> int:
|
||||
"""Find a good semantic break point in text."""
|
||||
# Look for sentence endings, paragraph breaks, etc.
|
||||
break_patterns = [
|
||||
r'\.\s+[A-Z]', # Sentence ending followed by capital letter
|
||||
r'\n\s*\n', # Paragraph break
|
||||
r';\s+', # Semicolon
|
||||
r',\s+and\s+', # Comma followed by "and"
|
||||
r',\s+or\s+', # Comma followed by "or"
|
||||
]
|
||||
|
||||
for pattern in break_patterns:
|
||||
matches = list(re.finditer(pattern, text))
|
||||
if matches:
|
||||
# Use the last match in the second half of the text
|
||||
for match in reversed(matches):
|
||||
if match.end() > len(text) // 2:
|
||||
return match.end()
|
||||
|
||||
return -1 # No good break point found
|
||||
|
||||
async def _create_table_description(self, table: Dict[str, Any]) -> str:
|
||||
"""Create a textual description of table content."""
|
||||
try:
|
||||
table_data = table.get("data", [])
|
||||
metadata = table.get("metadata", {})
|
||||
|
||||
if not table_data:
|
||||
return "Empty table"
|
||||
|
||||
# Get table dimensions
|
||||
rows = len(table_data)
|
||||
cols = len(table_data[0]) if table_data else 0
|
||||
|
||||
# Create description
|
||||
description = f"Table with {rows} rows and {cols} columns"
|
||||
|
||||
# Add column headers if available
|
||||
if table_data and len(table_data) > 0:
|
||||
headers = table_data[0]
|
||||
if headers:
|
||||
description += f". Columns: {', '.join(str(h) for h in headers[:5])}"
|
||||
if len(headers) > 5:
|
||||
description += f" and {len(headers) - 5} more"
|
||||
|
||||
# Add sample data
|
||||
if len(table_data) > 1:
|
||||
sample_row = table_data[1]
|
||||
if sample_row:
|
||||
description += f". Sample data: {', '.join(str(cell) for cell in sample_row[:3])}"
|
||||
|
||||
# Add metadata
|
||||
if metadata.get("title"):
|
||||
description += f". Title: {metadata['title']}"
|
||||
|
||||
return description
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating table description: {str(e)}")
|
||||
return "Table content"
|
||||
|
||||
async def _create_chart_description(self, chart: Dict[str, Any]) -> str:
|
||||
"""Create a textual description of chart content."""
|
||||
try:
|
||||
chart_data = chart.get("data", {})
|
||||
metadata = chart.get("metadata", {})
|
||||
|
||||
description = "Chart"
|
||||
|
||||
# Add chart type
|
||||
chart_type = metadata.get("chart_type", "unknown")
|
||||
description += f" ({chart_type})"
|
||||
|
||||
# Add title
|
||||
if metadata.get("title"):
|
||||
description += f": {metadata['title']}"
|
||||
|
||||
# Add data description
|
||||
if chart_data:
|
||||
if "labels" in chart_data and "values" in chart_data:
|
||||
labels = chart_data["labels"][:3] # First 3 labels
|
||||
values = chart_data["values"][:3] # First 3 values
|
||||
description += f". Shows {', '.join(str(l) for l in labels)} with values {', '.join(str(v) for v in values)}"
|
||||
|
||||
if len(chart_data["labels"]) > 3:
|
||||
description += f" and {len(chart_data['labels']) - 3} more data points"
|
||||
|
||||
return description
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating chart description: {str(e)}")
|
||||
return "Chart content"
|
||||
|
||||
async def _analyze_table_structure(self, table_data: List[List[str]]) -> Dict[str, Any]:
|
||||
"""Analyze table structure for metadata."""
|
||||
try:
|
||||
if not table_data:
|
||||
return {"type": "empty", "rows": 0, "columns": 0}
|
||||
|
||||
rows = len(table_data)
|
||||
cols = len(table_data[0]) if table_data else 0
|
||||
|
||||
# Analyze column types
|
||||
column_types = []
|
||||
if table_data and len(table_data) > 1: # Has data beyond headers
|
||||
for col_idx in range(cols):
|
||||
col_values = [row[col_idx] for row in table_data[1:] if col_idx < len(row)]
|
||||
col_type = await self._infer_column_type(col_values)
|
||||
column_types.append(col_type)
|
||||
|
||||
return {
|
||||
"type": "data_table",
|
||||
"rows": rows,
|
||||
"columns": cols,
|
||||
"column_types": column_types,
|
||||
"has_headers": rows > 0,
|
||||
"has_data": rows > 1
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing table structure: {str(e)}")
|
||||
return {"type": "unknown", "rows": 0, "columns": 0}
|
||||
|
||||
async def _infer_column_type(self, values: List[str]) -> str:
|
||||
"""Infer the data type of a column."""
|
||||
if not values:
|
||||
return "empty"
|
||||
|
||||
# Check for numeric values
|
||||
numeric_count = 0
|
||||
date_count = 0
|
||||
|
||||
for value in values:
|
||||
if value:
|
||||
# Check for numbers
|
||||
try:
|
||||
float(value.replace(',', '').replace('$', '').replace('%', ''))
|
||||
numeric_count += 1
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Check for dates (simple pattern)
|
||||
if re.match(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', value):
|
||||
date_count += 1
|
||||
|
||||
total = len(values)
|
||||
if numeric_count / total > 0.8:
|
||||
return "numeric"
|
||||
elif date_count / total > 0.5:
|
||||
return "date"
|
||||
else:
|
||||
return "text"
|
||||
|
||||
async def _create_detailed_table_chunks(
|
||||
self,
|
||||
document_id: str,
|
||||
table_idx: int,
|
||||
table_data: List[List[str]],
|
||||
metadata: Dict[str, Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Create detailed chunks for large tables."""
|
||||
chunks = []
|
||||
|
||||
try:
|
||||
# Split large tables into sections
|
||||
chunk_size = 10 # rows per chunk
|
||||
for i in range(1, len(table_data), chunk_size): # Skip header row
|
||||
end_idx = min(i + chunk_size, len(table_data))
|
||||
section_data = table_data[i:end_idx]
|
||||
|
||||
# Create section description
|
||||
section_description = f"Table section {i//chunk_size + 1}: Rows {i+1}-{end_idx}"
|
||||
if table_data and len(table_data) > 0:
|
||||
headers = table_data[0]
|
||||
section_description += f". Columns: {', '.join(str(h) for h in headers[:3])}"
|
||||
|
||||
chunk = {
|
||||
"id": f"{document_id}_table_{table_idx}_section_{i//chunk_size + 1}",
|
||||
"document_id": document_id,
|
||||
"tenant_id": str(self.tenant.id),
|
||||
"chunk_type": "table_section",
|
||||
"chunk_index": f"{table_idx}_{i//chunk_size + 1}",
|
||||
"text": section_description,
|
||||
"token_count": await self._estimate_tokens(section_description),
|
||||
"page_numbers": [metadata.get("page_number", 1)],
|
||||
"table_data": section_data,
|
||||
"table_metadata": metadata,
|
||||
"metadata": {
|
||||
"content_type": "table_section",
|
||||
"chunking_strategy": "table_sectioning",
|
||||
"section_index": i//chunk_size + 1,
|
||||
"row_range": f"{i+1}-{end_idx}",
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating detailed table chunks: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count for text."""
|
||||
# Simple estimation: ~4 characters per token
|
||||
return len(text) // 4
|
||||
|
||||
async def get_chunk_statistics(self, chunks: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
"""Get statistics about the chunking process."""
|
||||
try:
|
||||
total_chunks = sum(len(chunk_list) for chunk_list in chunks.values() if isinstance(chunk_list, list))
|
||||
total_tokens = sum(
|
||||
chunk.get("token_count", 0)
|
||||
for chunk_list in chunks.values()
|
||||
for chunk in chunk_list
|
||||
if isinstance(chunk_list, list)
|
||||
)
|
||||
|
||||
# Map chunk keys to actual chunk types
|
||||
chunk_types = {}
|
||||
for chunk_key, chunk_list in chunks.items():
|
||||
if isinstance(chunk_list, list) and len(chunk_list) > 0:
|
||||
# Extract the actual chunk type from the first chunk
|
||||
actual_type = chunk_list[0].get("chunk_type", chunk_key.replace("_chunks", ""))
|
||||
chunk_types[actual_type] = len(chunk_list)
|
||||
|
||||
return {
|
||||
"total_chunks": total_chunks,
|
||||
"total_tokens": total_tokens,
|
||||
"average_tokens_per_chunk": total_tokens / total_chunks if total_chunks > 0 else 0,
|
||||
"chunk_types": chunk_types,
|
||||
"chunking_parameters": {
|
||||
"chunk_size": self.chunk_size,
|
||||
"chunk_overlap": self.chunk_overlap,
|
||||
"chunk_min_size": self.chunk_min_size,
|
||||
"chunk_max_size": self.chunk_max_size
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting chunk statistics: {str(e)}")
|
||||
return {}
|
||||
@@ -1,12 +1,16 @@
|
||||
"""
|
||||
Qdrant vector database service for the Virtual Board Member AI System.
|
||||
Enhanced with Voyage-3-large embeddings and multi-modal support for Week 3.
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from qdrant_client import QdrantClient, models
|
||||
from qdrant_client.http import models as rest
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import requests
|
||||
import json
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.tenant import Tenant
|
||||
@@ -19,6 +23,7 @@ class VectorService:
|
||||
def __init__(self):
|
||||
self.client = None
|
||||
self.embedding_model = None
|
||||
self.voyage_api_key = None
|
||||
self._init_client()
|
||||
self._init_embedding_model()
|
||||
|
||||
@@ -36,12 +41,31 @@ class VectorService:
|
||||
self.client = None
|
||||
|
||||
def _init_embedding_model(self):
|
||||
"""Initialize embedding model."""
|
||||
"""Initialize Voyage-3-large embedding model."""
|
||||
try:
|
||||
self.embedding_model = SentenceTransformer(settings.EMBEDDING_MODEL)
|
||||
logger.info(f"Embedding model {settings.EMBEDDING_MODEL} loaded successfully")
|
||||
# For Voyage-3-large, we'll use API calls instead of local model
|
||||
if settings.EMBEDDING_MODEL == "voyageai/voyage-3-large":
|
||||
self.voyage_api_key = settings.VOYAGE_API_KEY
|
||||
if not self.voyage_api_key:
|
||||
logger.warning("Voyage API key not found, falling back to sentence-transformers")
|
||||
self._init_fallback_embedding_model()
|
||||
else:
|
||||
logger.info("Voyage-3-large embedding model configured successfully")
|
||||
else:
|
||||
self._init_fallback_embedding_model()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load embedding model: {e}")
|
||||
logger.error(f"Failed to initialize embedding model: {e}")
|
||||
self._init_fallback_embedding_model()
|
||||
|
||||
def _init_fallback_embedding_model(self):
|
||||
"""Initialize fallback sentence-transformers model."""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
fallback_model = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
self.embedding_model = SentenceTransformer(fallback_model)
|
||||
logger.info(f"Fallback embedding model {fallback_model} loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load fallback embedding model: {e}")
|
||||
self.embedding_model = None
|
||||
|
||||
def _get_collection_name(self, tenant_id: str, collection_type: str = "documents") -> str:
|
||||
@@ -155,68 +179,151 @@ class VectorService:
|
||||
return False
|
||||
|
||||
async def generate_embedding(self, text: str) -> Optional[List[float]]:
|
||||
"""Generate embedding for text."""
|
||||
if not self.embedding_model:
|
||||
logger.error("Embedding model not available")
|
||||
return None
|
||||
|
||||
"""Generate embedding for text using Voyage-3-large or fallback model."""
|
||||
try:
|
||||
embedding = self.embedding_model.encode(text)
|
||||
return embedding.tolist()
|
||||
# Try Voyage-3-large first
|
||||
if self.voyage_api_key:
|
||||
return await self._generate_voyage_embedding(text)
|
||||
|
||||
# Fallback to sentence-transformers
|
||||
if self.embedding_model:
|
||||
embedding = self.embedding_model.encode(text)
|
||||
return embedding.tolist()
|
||||
|
||||
logger.error("No embedding model available")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embedding: {e}")
|
||||
return None
|
||||
|
||||
async def _generate_voyage_embedding(self, text: str) -> Optional[List[float]]:
|
||||
"""Generate embedding using Voyage-3-large API."""
|
||||
try:
|
||||
url = "https://api.voyageai.com/v1/embeddings"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.voyage_api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": "voyage-3-large",
|
||||
"input": text,
|
||||
"input_type": "query" # or "document" for longer texts
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
if "data" in result and len(result["data"]) > 0:
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
logger.error("No embedding data in Voyage API response")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate Voyage embedding: {e}")
|
||||
return None
|
||||
|
||||
async def generate_batch_embeddings(self, texts: List[str]) -> List[Optional[List[float]]]:
|
||||
"""Generate embeddings for a batch of texts."""
|
||||
try:
|
||||
# Try Voyage-3-large first
|
||||
if self.voyage_api_key:
|
||||
return await self._generate_voyage_batch_embeddings(texts)
|
||||
|
||||
# Fallback to sentence-transformers
|
||||
if self.embedding_model:
|
||||
embeddings = self.embedding_model.encode(texts)
|
||||
return [emb.tolist() for emb in embeddings]
|
||||
|
||||
logger.error("No embedding model available")
|
||||
return [None] * len(texts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate batch embeddings: {e}")
|
||||
return [None] * len(texts)
|
||||
|
||||
async def _generate_voyage_batch_embeddings(self, texts: List[str]) -> List[Optional[List[float]]]:
|
||||
"""Generate batch embeddings using Voyage-3-large API."""
|
||||
try:
|
||||
url = "https://api.voyageai.com/v1/embeddings"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.voyage_api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": "voyage-3-large",
|
||||
"input": texts,
|
||||
"input_type": "document" # Use document type for batch processing
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
if "data" in result:
|
||||
return [item["embedding"] for item in result["data"]]
|
||||
|
||||
logger.error("No embedding data in Voyage API response")
|
||||
return [None] * len(texts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate Voyage batch embeddings: {e}")
|
||||
return [None] * len(texts)
|
||||
|
||||
async def add_document_vectors(
|
||||
self,
|
||||
tenant_id: str,
|
||||
document_id: str,
|
||||
chunks: List[Dict[str, Any]],
|
||||
chunks: Dict[str, List[Dict[str, Any]]],
|
||||
collection_type: str = "documents"
|
||||
) -> bool:
|
||||
"""Add document chunks to vector database."""
|
||||
if not self.client or not self.embedding_model:
|
||||
"""Add document chunks to vector database with batch processing."""
|
||||
if not self.client:
|
||||
logger.error("Qdrant client not available")
|
||||
return False
|
||||
|
||||
try:
|
||||
collection_name = self._get_collection_name(tenant_id, collection_type)
|
||||
|
||||
# Generate embeddings for all chunks
|
||||
points = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Generate embedding
|
||||
embedding = await self.generate_embedding(chunk["text"])
|
||||
if not embedding:
|
||||
continue
|
||||
|
||||
# Create point with metadata
|
||||
point = models.PointStruct(
|
||||
id=f"{document_id}_{i}",
|
||||
vector=embedding,
|
||||
payload={
|
||||
"document_id": document_id,
|
||||
"tenant_id": tenant_id,
|
||||
"chunk_index": i,
|
||||
"text": chunk["text"],
|
||||
"chunk_type": chunk.get("type", "text"),
|
||||
"metadata": chunk.get("metadata", {}),
|
||||
"created_at": chunk.get("created_at")
|
||||
}
|
||||
)
|
||||
points.append(point)
|
||||
# Collect all chunks and their types for single batch processing
|
||||
all_chunks = []
|
||||
chunk_types = []
|
||||
|
||||
if points:
|
||||
# Upsert points in batches
|
||||
batch_size = 100
|
||||
for i in range(0, len(points), batch_size):
|
||||
batch = points[i:i + batch_size]
|
||||
self.client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=batch
|
||||
)
|
||||
# Collect text chunks
|
||||
if "text_chunks" in chunks:
|
||||
all_chunks.extend(chunks["text_chunks"])
|
||||
chunk_types.extend(["text"] * len(chunks["text_chunks"]))
|
||||
|
||||
# Collect table chunks
|
||||
if "table_chunks" in chunks:
|
||||
all_chunks.extend(chunks["table_chunks"])
|
||||
chunk_types.extend(["table"] * len(chunks["table_chunks"]))
|
||||
|
||||
# Collect chart chunks
|
||||
if "chart_chunks" in chunks:
|
||||
all_chunks.extend(chunks["chart_chunks"])
|
||||
chunk_types.extend(["chart"] * len(chunks["chart_chunks"]))
|
||||
|
||||
if all_chunks:
|
||||
# Process all chunks in a single batch
|
||||
all_points = await self._process_all_chunks_batch(
|
||||
document_id, tenant_id, all_chunks, chunk_types
|
||||
)
|
||||
|
||||
logger.info(f"Added {len(points)} vectors to collection {collection_name}")
|
||||
return True
|
||||
if all_points:
|
||||
# Upsert points in batches
|
||||
batch_size = settings.EMBEDDING_BATCH_SIZE
|
||||
for i in range(0, len(all_points), batch_size):
|
||||
batch = all_points[i:i + batch_size]
|
||||
self.client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=batch
|
||||
)
|
||||
|
||||
logger.info(f"Added {len(all_points)} vectors to collection {collection_name}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -224,6 +331,98 @@ class VectorService:
|
||||
logger.error(f"Failed to add document vectors: {e}")
|
||||
return False
|
||||
|
||||
async def _process_all_chunks_batch(
|
||||
self,
|
||||
document_id: str,
|
||||
tenant_id: str,
|
||||
chunks: List[Dict[str, Any]],
|
||||
chunk_types: List[str]
|
||||
) -> List[models.PointStruct]:
|
||||
"""Process all chunks in a single batch and generate embeddings."""
|
||||
points = []
|
||||
|
||||
try:
|
||||
# Extract texts for batch embedding generation
|
||||
texts = [chunk["text"] for chunk in chunks]
|
||||
|
||||
# Generate embeddings in batch (single call)
|
||||
embeddings = await self.generate_batch_embeddings(texts)
|
||||
|
||||
# Create points with embeddings
|
||||
for i, (chunk, embedding, chunk_type) in enumerate(zip(chunks, embeddings, chunk_types)):
|
||||
if not embedding:
|
||||
continue
|
||||
|
||||
# Create point with enhanced metadata
|
||||
point = models.PointStruct(
|
||||
id=chunk["id"],
|
||||
vector=embedding,
|
||||
payload={
|
||||
"document_id": document_id,
|
||||
"tenant_id": tenant_id,
|
||||
"chunk_index": chunk["chunk_index"],
|
||||
"text": chunk["text"],
|
||||
"chunk_type": chunk_type,
|
||||
"token_count": chunk.get("token_count", 0),
|
||||
"page_numbers": chunk.get("page_numbers", []),
|
||||
"metadata": chunk.get("metadata", {}),
|
||||
"created_at": chunk.get("metadata", {}).get("created_at", datetime.utcnow().isoformat())
|
||||
}
|
||||
)
|
||||
points.append(point)
|
||||
|
||||
return points
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process all chunks batch: {e}")
|
||||
return []
|
||||
|
||||
async def _process_chunk_batch(
|
||||
self,
|
||||
document_id: str,
|
||||
tenant_id: str,
|
||||
chunks: List[Dict[str, Any]],
|
||||
chunk_type: str
|
||||
) -> List[models.PointStruct]:
|
||||
"""Process a batch of chunks and generate embeddings."""
|
||||
points = []
|
||||
|
||||
try:
|
||||
# Extract texts for batch embedding generation
|
||||
texts = [chunk["text"] for chunk in chunks]
|
||||
|
||||
# Generate embeddings in batch
|
||||
embeddings = await self.generate_batch_embeddings(texts)
|
||||
|
||||
# Create points with embeddings
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
if not embedding:
|
||||
continue
|
||||
|
||||
# Create point with enhanced metadata
|
||||
point = models.PointStruct(
|
||||
id=chunk["id"],
|
||||
vector=embedding,
|
||||
payload={
|
||||
"document_id": document_id,
|
||||
"tenant_id": tenant_id,
|
||||
"chunk_index": chunk["chunk_index"],
|
||||
"text": chunk["text"],
|
||||
"chunk_type": chunk_type,
|
||||
"token_count": chunk.get("token_count", 0),
|
||||
"page_numbers": chunk.get("page_numbers", []),
|
||||
"metadata": chunk.get("metadata", {}),
|
||||
"created_at": chunk.get("metadata", {}).get("created_at", datetime.utcnow().isoformat())
|
||||
}
|
||||
)
|
||||
points.append(point)
|
||||
|
||||
return points
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process {chunk_type} chunk batch: {e}")
|
||||
return []
|
||||
|
||||
async def search_similar(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -231,10 +430,11 @@ class VectorService:
|
||||
limit: int = 10,
|
||||
score_threshold: float = 0.7,
|
||||
collection_type: str = "documents",
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
chunk_types: Optional[List[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar vectors."""
|
||||
if not self.client or not self.embedding_model:
|
||||
"""Search for similar vectors with multi-modal support."""
|
||||
if not self.client:
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -255,6 +455,15 @@ class VectorService:
|
||||
]
|
||||
)
|
||||
|
||||
# Add chunk type filter if specified
|
||||
if chunk_types:
|
||||
search_filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="chunk_type",
|
||||
match=models.MatchAny(any=chunk_types)
|
||||
)
|
||||
)
|
||||
|
||||
# Add additional filters
|
||||
if filters:
|
||||
for key, value in filters.items():
|
||||
@@ -283,7 +492,7 @@ class VectorService:
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# Format results
|
||||
# Format results with enhanced metadata
|
||||
results = []
|
||||
for point in search_result:
|
||||
results.append({
|
||||
@@ -292,7 +501,10 @@ class VectorService:
|
||||
"payload": point.payload,
|
||||
"text": point.payload.get("text", ""),
|
||||
"document_id": point.payload.get("document_id"),
|
||||
"chunk_type": point.payload.get("chunk_type", "text")
|
||||
"chunk_type": point.payload.get("chunk_type", "text"),
|
||||
"token_count": point.payload.get("token_count", 0),
|
||||
"page_numbers": point.payload.get("page_numbers", []),
|
||||
"metadata": point.payload.get("metadata", {})
|
||||
})
|
||||
|
||||
return results
|
||||
@@ -301,6 +513,192 @@ class VectorService:
|
||||
logger.error(f"Failed to search vectors: {e}")
|
||||
return []
|
||||
|
||||
async def search_structured_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
data_type: str = "table", # "table" or "chart"
|
||||
limit: int = 10,
|
||||
score_threshold: float = 0.7,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search specifically for structured data (tables and charts)."""
|
||||
return await self.search_similar(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
collection_type="documents",
|
||||
filters=filters,
|
||||
chunk_types=[data_type]
|
||||
)
|
||||
|
||||
async def hybrid_search(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
score_threshold: float = 0.7,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
semantic_weight: float = 0.7,
|
||||
keyword_weight: float = 0.3
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Perform hybrid search combining semantic and keyword matching."""
|
||||
try:
|
||||
# Semantic search
|
||||
semantic_results = await self.search_similar(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
limit=limit * 2, # Get more results for re-ranking
|
||||
score_threshold=score_threshold * 0.8, # Lower threshold for semantic
|
||||
filters=filters
|
||||
)
|
||||
|
||||
# Keyword search (simple implementation)
|
||||
keyword_results = await self._keyword_search(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
limit=limit * 2,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
# Combine and re-rank results
|
||||
combined_results = await self._combine_search_results(
|
||||
semantic_results, keyword_results, semantic_weight, keyword_weight
|
||||
)
|
||||
|
||||
# Return top results
|
||||
return combined_results[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to perform hybrid search: {e}")
|
||||
return []
|
||||
|
||||
async def _keyword_search(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Simple keyword search implementation."""
|
||||
try:
|
||||
# This is a simplified keyword search
|
||||
# In a production system, you might use Elasticsearch or similar
|
||||
query_terms = query.lower().split()
|
||||
|
||||
# Get all documents and filter by keywords
|
||||
collection_name = self._get_collection_name(tenant_id, "documents")
|
||||
|
||||
# Build filter
|
||||
search_filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="tenant_id",
|
||||
match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if filters:
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, list):
|
||||
search_filter.must.append(
|
||||
models.FieldCondition(
|
||||
key=key,
|
||||
match=models.MatchAny(any=value)
|
||||
)
|
||||
)
|
||||
else:
|
||||
search_filter.must.append(
|
||||
models.FieldCondition(
|
||||
key=key,
|
||||
match=models.MatchValue(value=value)
|
||||
)
|
||||
)
|
||||
|
||||
# Get all points and filter by keywords
|
||||
all_points = self.client.scroll(
|
||||
collection_name=collection_name,
|
||||
scroll_filter=search_filter,
|
||||
limit=1000, # Adjust based on your data size
|
||||
with_payload=True
|
||||
)[0]
|
||||
|
||||
# Score by keyword matches
|
||||
keyword_results = []
|
||||
for point in all_points:
|
||||
text = point.payload.get("text", "").lower()
|
||||
score = sum(1 for term in query_terms if term in text)
|
||||
if score > 0:
|
||||
keyword_results.append({
|
||||
"id": point.id,
|
||||
"score": score / len(query_terms), # Normalize score
|
||||
"payload": point.payload,
|
||||
"text": point.payload.get("text", ""),
|
||||
"document_id": point.payload.get("document_id"),
|
||||
"chunk_type": point.payload.get("chunk_type", "text"),
|
||||
"token_count": point.payload.get("token_count", 0),
|
||||
"page_numbers": point.payload.get("page_numbers", []),
|
||||
"metadata": point.payload.get("metadata", {})
|
||||
})
|
||||
|
||||
# Sort by score and return top results
|
||||
keyword_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return keyword_results[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to perform keyword search: {e}")
|
||||
return []
|
||||
|
||||
async def _combine_search_results(
|
||||
self,
|
||||
semantic_results: List[Dict[str, Any]],
|
||||
keyword_results: List[Dict[str, Any]],
|
||||
semantic_weight: float,
|
||||
keyword_weight: float
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Combine and re-rank search results."""
|
||||
try:
|
||||
# Create a map of results by ID
|
||||
combined_map = {}
|
||||
|
||||
# Add semantic results
|
||||
for result in semantic_results:
|
||||
result_id = result["id"]
|
||||
combined_map[result_id] = {
|
||||
**result,
|
||||
"semantic_score": result["score"],
|
||||
"keyword_score": 0.0,
|
||||
"combined_score": result["score"] * semantic_weight
|
||||
}
|
||||
|
||||
# Add keyword results
|
||||
for result in keyword_results:
|
||||
result_id = result["id"]
|
||||
if result_id in combined_map:
|
||||
# Update existing result
|
||||
combined_map[result_id]["keyword_score"] = result["score"]
|
||||
combined_map[result_id]["combined_score"] += result["score"] * keyword_weight
|
||||
else:
|
||||
# Add new result
|
||||
combined_map[result_id] = {
|
||||
**result,
|
||||
"semantic_score": 0.0,
|
||||
"keyword_score": result["score"],
|
||||
"combined_score": result["score"] * keyword_weight
|
||||
}
|
||||
|
||||
# Convert to list and sort by combined score
|
||||
combined_results = list(combined_map.values())
|
||||
combined_results.sort(key=lambda x: x["combined_score"], reverse=True)
|
||||
|
||||
return combined_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to combine search results: {e}")
|
||||
return semantic_results # Fallback to semantic results
|
||||
|
||||
async def delete_document_vectors(self, tenant_id: str, document_id: str, collection_type: str = "documents") -> bool:
|
||||
"""Delete all vectors for a specific document."""
|
||||
if not self.client:
|
||||
@@ -378,8 +776,8 @@ class VectorService:
|
||||
# Check client connection
|
||||
collections = self.client.get_collections()
|
||||
|
||||
# Check embedding model
|
||||
if not self.embedding_model:
|
||||
# Check embedding model (either Voyage or fallback)
|
||||
if not self.voyage_api_key and not self.embedding_model:
|
||||
return False
|
||||
|
||||
# Test embedding generation
|
||||
@@ -392,6 +790,147 @@ class VectorService:
|
||||
except Exception as e:
|
||||
logger.error(f"Vector service health check failed: {e}")
|
||||
return False
|
||||
|
||||
async def optimize_collections(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Optimize vector database collections for performance."""
|
||||
try:
|
||||
optimization_results = {}
|
||||
|
||||
# Optimize each collection type
|
||||
for collection_type in ["documents", "tables", "charts"]:
|
||||
collection_name = self._get_collection_name(tenant_id, collection_type)
|
||||
|
||||
try:
|
||||
# Force collection optimization
|
||||
self.client.update_collection(
|
||||
collection_name=collection_name,
|
||||
optimizers_config=models.OptimizersConfigDiff(
|
||||
default_segment_number=4, # Increase for better parallelization
|
||||
memmap_threshold=5000, # Lower threshold for memory mapping
|
||||
vacuum_min_vector_number=1000 # Optimize vacuum threshold
|
||||
)
|
||||
)
|
||||
|
||||
# Get collection info
|
||||
info = self.client.get_collection(collection_name)
|
||||
optimization_results[collection_type] = {
|
||||
"status": "optimized",
|
||||
"vector_count": info.points_count,
|
||||
"segments": info.segments_count,
|
||||
"optimized_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to optimize collection {collection_name}: {e}")
|
||||
optimization_results[collection_type] = {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
return optimization_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to optimize collections: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def get_performance_metrics(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get performance metrics for vector database operations."""
|
||||
try:
|
||||
metrics = {
|
||||
"tenant_id": tenant_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"collections": {},
|
||||
"embedding_model": settings.EMBEDDING_MODEL,
|
||||
"embedding_dimension": settings.EMBEDDING_DIMENSION
|
||||
}
|
||||
|
||||
# Get metrics for each collection
|
||||
for collection_type in ["documents", "tables", "charts"]:
|
||||
collection_name = self._get_collection_name(tenant_id, collection_type)
|
||||
|
||||
try:
|
||||
info = self.client.get_collection(collection_name)
|
||||
count = self.client.count(
|
||||
collection_name=collection_name,
|
||||
count_filter=models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="tenant_id",
|
||||
match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
metrics["collections"][collection_type] = {
|
||||
"vector_count": count.count,
|
||||
"segments": info.segments_count,
|
||||
"status": info.status,
|
||||
"vector_size": info.config.params.vectors.size,
|
||||
"distance": info.config.params.vectors.distance
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get metrics for collection {collection_name}: {e}")
|
||||
metrics["collections"][collection_type] = {
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get performance metrics: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def create_performance_benchmarks(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Create performance benchmarks for vector operations."""
|
||||
try:
|
||||
benchmarks = {
|
||||
"tenant_id": tenant_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"results": {}
|
||||
}
|
||||
|
||||
# Benchmark embedding generation
|
||||
import time
|
||||
|
||||
# Single embedding benchmark
|
||||
start_time = time.time()
|
||||
test_embedding = await self.generate_embedding("This is a test document for benchmarking purposes.")
|
||||
single_embedding_time = time.time() - start_time
|
||||
|
||||
# Batch embedding benchmark
|
||||
test_texts = [f"Test document {i} for batch benchmarking." for i in range(10)]
|
||||
start_time = time.time()
|
||||
batch_embeddings = await self.generate_batch_embeddings(test_texts)
|
||||
batch_embedding_time = time.time() - start_time
|
||||
|
||||
# Search benchmark
|
||||
if test_embedding:
|
||||
start_time = time.time()
|
||||
search_results = await self.search_similar(
|
||||
tenant_id=tenant_id,
|
||||
query="test query",
|
||||
limit=5
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
else:
|
||||
search_time = None
|
||||
|
||||
benchmarks["results"] = {
|
||||
"single_embedding_time_ms": round(single_embedding_time * 1000, 2),
|
||||
"batch_embedding_time_ms": round(batch_embedding_time * 1000, 2),
|
||||
"avg_embedding_per_text_ms": round((batch_embedding_time / len(test_texts)) * 1000, 2),
|
||||
"search_time_ms": round(search_time * 1000, 2) if search_time else None,
|
||||
"embedding_model": settings.EMBEDDING_MODEL,
|
||||
"embedding_dimension": settings.EMBEDDING_DIMENSION
|
||||
}
|
||||
|
||||
return benchmarks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create performance benchmarks: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
# Global vector service instance
|
||||
vector_service = VectorService()
|
||||
|
||||
Reference in New Issue
Block a user