Files
virtual_board_member/app/services/document_chunking.py
2025-08-08 17:17:56 -04:00

557 lines
22 KiB
Python

"""
Document chunking service for the Virtual Board Member AI System.
Implements intelligent chunking strategy with support for structured data indexing.
"""
import logging
import re
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
import uuid
import json
from app.core.config import settings
from app.models.tenant import Tenant
logger = logging.getLogger(__name__)
class DocumentChunkingService:
"""Service for intelligent document chunking with structured data support."""
def __init__(self, tenant: Tenant):
self.tenant = tenant
self.chunk_size = settings.CHUNK_SIZE
self.chunk_overlap = settings.CHUNK_OVERLAP
self.chunk_min_size = settings.CHUNK_MIN_SIZE
self.chunk_max_size = settings.CHUNK_MAX_SIZE
async def chunk_document_content(
self,
document_id: str,
content: Dict[str, Any]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Chunk document content into multiple types of chunks for vector indexing.
Args:
document_id: The document ID
content: Document content with text, tables, charts, etc.
Returns:
Dictionary with different types of chunks (text, tables, charts)
"""
try:
chunks = {
"text_chunks": [],
"table_chunks": [],
"chart_chunks": [],
"metadata": {
"document_id": document_id,
"tenant_id": str(self.tenant.id),
"chunking_timestamp": datetime.utcnow().isoformat(),
"chunk_size": self.chunk_size,
"chunk_overlap": self.chunk_overlap
}
}
# Process text content
if content.get("text_content"):
text_chunks = await self._chunk_text_content(
document_id, content["text_content"]
)
chunks["text_chunks"] = text_chunks
# Process table content
if content.get("tables"):
table_chunks = await self._chunk_table_content(
document_id, content["tables"]
)
chunks["table_chunks"] = table_chunks
# Process chart content
if content.get("charts"):
chart_chunks = await self._chunk_chart_content(
document_id, content["charts"]
)
chunks["chart_chunks"] = chart_chunks
# Add metadata about chunking results
chunks["metadata"]["total_chunks"] = (
len(chunks["text_chunks"]) +
len(chunks["table_chunks"]) +
len(chunks["chart_chunks"])
)
chunks["metadata"]["text_chunks"] = len(chunks["text_chunks"])
chunks["metadata"]["table_chunks"] = len(chunks["table_chunks"])
chunks["metadata"]["chart_chunks"] = len(chunks["chart_chunks"])
logger.info(f"Chunked document {document_id} into {chunks['metadata']['total_chunks']} chunks")
return chunks
except Exception as e:
logger.error(f"Error chunking document {document_id}: {str(e)}")
raise
async def _chunk_text_content(
self,
document_id: str,
text_content: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Chunk text content with intelligent boundaries."""
chunks = []
try:
# Combine all text content
full_text = ""
text_metadata = []
for i, text_item in enumerate(text_content):
text = text_item.get("text", "")
page_num = text_item.get("page_number", i + 1)
# Add page separator
if full_text:
full_text += f"\n\n--- Page {page_num} ---\n\n"
full_text += text
text_metadata.append({
"start_pos": len(full_text) - len(text),
"end_pos": len(full_text),
"page_number": page_num,
"original_index": i
})
# Split into chunks
text_chunks = await self._split_text_into_chunks(full_text)
# Create chunk objects with metadata
for chunk_idx, (chunk_text, start_pos, end_pos) in enumerate(text_chunks):
# Find which pages this chunk covers
chunk_pages = []
for meta in text_metadata:
if (meta["start_pos"] <= end_pos and meta["end_pos"] >= start_pos):
chunk_pages.append(meta["page_number"])
chunk = {
"id": f"{document_id}_text_{chunk_idx}",
"document_id": document_id,
"tenant_id": str(self.tenant.id),
"chunk_type": "text",
"chunk_index": chunk_idx,
"text": chunk_text,
"token_count": await self._estimate_tokens(chunk_text),
"page_numbers": list(set(chunk_pages)),
"start_position": start_pos,
"end_position": end_pos,
"metadata": {
"content_type": "text",
"chunking_strategy": "semantic_boundaries",
"created_at": datetime.utcnow().isoformat()
}
}
chunks.append(chunk)
return chunks
except Exception as e:
logger.error(f"Error chunking text content: {str(e)}")
return []
async def _chunk_table_content(
self,
document_id: str,
tables: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Chunk table content with structure preservation."""
chunks = []
try:
for table_idx, table in enumerate(tables):
table_data = table.get("data", [])
table_metadata = table.get("metadata", {})
if not table_data:
continue
# Create table description
table_description = await self._create_table_description(table)
# Create structured table chunk
table_chunk = {
"id": f"{document_id}_table_{table_idx}",
"document_id": document_id,
"tenant_id": str(self.tenant.id),
"chunk_type": "table",
"chunk_index": table_idx,
"text": table_description,
"token_count": await self._estimate_tokens(table_description),
"page_numbers": [table_metadata.get("page_number", 1)],
"table_data": table_data,
"table_metadata": table_metadata,
"metadata": {
"content_type": "table",
"chunking_strategy": "table_preservation",
"table_structure": await self._analyze_table_structure(table_data),
"created_at": datetime.utcnow().isoformat()
}
}
chunks.append(table_chunk)
# If table is large, create additional chunks for detailed analysis
if len(table_data) > 10: # Large table
detailed_chunks = await self._create_detailed_table_chunks(
document_id, table_idx, table_data, table_metadata
)
chunks.extend(detailed_chunks)
return chunks
except Exception as e:
logger.error(f"Error chunking table content: {str(e)}")
return []
async def _chunk_chart_content(
self,
document_id: str,
charts: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Chunk chart content with visual analysis."""
chunks = []
try:
for chart_idx, chart in enumerate(charts):
chart_data = chart.get("data", {})
chart_metadata = chart.get("metadata", {})
# Create chart description
chart_description = await self._create_chart_description(chart)
# Create structured chart chunk
chart_chunk = {
"id": f"{document_id}_chart_{chart_idx}",
"document_id": document_id,
"tenant_id": str(self.tenant.id),
"chunk_type": "chart",
"chunk_index": chart_idx,
"text": chart_description,
"token_count": await self._estimate_tokens(chart_description),
"page_numbers": [chart_metadata.get("page_number", 1)],
"chart_data": chart_data,
"chart_metadata": chart_metadata,
"metadata": {
"content_type": "chart",
"chunking_strategy": "chart_analysis",
"chart_type": chart_metadata.get("chart_type", "unknown"),
"created_at": datetime.utcnow().isoformat()
}
}
chunks.append(chart_chunk)
return chunks
except Exception as e:
logger.error(f"Error chunking chart content: {str(e)}")
return []
async def _split_text_into_chunks(
self,
text: str
) -> List[Tuple[str, int, int]]:
"""Split text into chunks with semantic boundaries."""
chunks = []
try:
# Simple token estimation (words + punctuation)
words = text.split()
current_chunk = []
current_pos = 0
chunk_start_pos = 0
for word in words:
current_chunk.append(word)
current_pos += len(word) + 1 # +1 for space
# Check if we've reached chunk size
if len(current_chunk) >= self.chunk_size:
chunk_text = " ".join(current_chunk)
# Try to find a good break point
break_point = await self._find_semantic_break_point(chunk_text)
if break_point > 0:
# Split at break point
first_part = chunk_text[:break_point].strip()
second_part = chunk_text[break_point:].strip()
if first_part:
chunks.append((first_part, chunk_start_pos, chunk_start_pos + len(first_part)))
# Start new chunk with remaining text
current_chunk = second_part.split() if second_part else []
chunk_start_pos = current_pos - len(second_part) if second_part else current_pos
else:
# No good break point, use current chunk
chunks.append((chunk_text, chunk_start_pos, current_pos))
current_chunk = []
chunk_start_pos = current_pos
# Add remaining text as final chunk
if current_chunk:
chunk_text = " ".join(current_chunk)
# Always add the final chunk, even if it's small
chunks.append((chunk_text, chunk_start_pos, current_pos))
# If no chunks were created and we have text, create a single chunk
if not chunks and text.strip():
chunks.append((text.strip(), 0, len(text.strip())))
return chunks
except Exception as e:
logger.error(f"Error splitting text into chunks: {str(e)}")
return [(text, 0, len(text))]
async def _find_semantic_break_point(self, text: str) -> int:
"""Find a good semantic break point in text."""
# Look for sentence endings, paragraph breaks, etc.
break_patterns = [
r'\.\s+[A-Z]', # Sentence ending followed by capital letter
r'\n\s*\n', # Paragraph break
r';\s+', # Semicolon
r',\s+and\s+', # Comma followed by "and"
r',\s+or\s+', # Comma followed by "or"
]
for pattern in break_patterns:
matches = list(re.finditer(pattern, text))
if matches:
# Use the last match in the second half of the text
for match in reversed(matches):
if match.end() > len(text) // 2:
return match.end()
return -1 # No good break point found
async def _create_table_description(self, table: Dict[str, Any]) -> str:
"""Create a textual description of table content."""
try:
table_data = table.get("data", [])
metadata = table.get("metadata", {})
if not table_data:
return "Empty table"
# Get table dimensions
rows = len(table_data)
cols = len(table_data[0]) if table_data else 0
# Create description
description = f"Table with {rows} rows and {cols} columns"
# Add column headers if available
if table_data and len(table_data) > 0:
headers = table_data[0]
if headers:
description += f". Columns: {', '.join(str(h) for h in headers[:5])}"
if len(headers) > 5:
description += f" and {len(headers) - 5} more"
# Add sample data
if len(table_data) > 1:
sample_row = table_data[1]
if sample_row:
description += f". Sample data: {', '.join(str(cell) for cell in sample_row[:3])}"
# Add metadata
if metadata.get("title"):
description += f". Title: {metadata['title']}"
return description
except Exception as e:
logger.error(f"Error creating table description: {str(e)}")
return "Table content"
async def _create_chart_description(self, chart: Dict[str, Any]) -> str:
"""Create a textual description of chart content."""
try:
chart_data = chart.get("data", {})
metadata = chart.get("metadata", {})
description = "Chart"
# Add chart type
chart_type = metadata.get("chart_type", "unknown")
description += f" ({chart_type})"
# Add title
if metadata.get("title"):
description += f": {metadata['title']}"
# Add data description
if chart_data:
if "labels" in chart_data and "values" in chart_data:
labels = chart_data["labels"][:3] # First 3 labels
values = chart_data["values"][:3] # First 3 values
description += f". Shows {', '.join(str(l) for l in labels)} with values {', '.join(str(v) for v in values)}"
if len(chart_data["labels"]) > 3:
description += f" and {len(chart_data['labels']) - 3} more data points"
return description
except Exception as e:
logger.error(f"Error creating chart description: {str(e)}")
return "Chart content"
async def _analyze_table_structure(self, table_data: List[List[str]]) -> Dict[str, Any]:
"""Analyze table structure for metadata."""
try:
if not table_data:
return {"type": "empty", "rows": 0, "columns": 0}
rows = len(table_data)
cols = len(table_data[0]) if table_data else 0
# Analyze column types
column_types = []
if table_data and len(table_data) > 1: # Has data beyond headers
for col_idx in range(cols):
col_values = [row[col_idx] for row in table_data[1:] if col_idx < len(row)]
col_type = await self._infer_column_type(col_values)
column_types.append(col_type)
return {
"type": "data_table",
"rows": rows,
"columns": cols,
"column_types": column_types,
"has_headers": rows > 0,
"has_data": rows > 1
}
except Exception as e:
logger.error(f"Error analyzing table structure: {str(e)}")
return {"type": "unknown", "rows": 0, "columns": 0}
async def _infer_column_type(self, values: List[str]) -> str:
"""Infer the data type of a column."""
if not values:
return "empty"
# Check for numeric values
numeric_count = 0
date_count = 0
for value in values:
if value:
# Check for numbers
try:
float(value.replace(',', '').replace('$', '').replace('%', ''))
numeric_count += 1
except ValueError:
pass
# Check for dates (simple pattern)
if re.match(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', value):
date_count += 1
total = len(values)
if numeric_count / total > 0.8:
return "numeric"
elif date_count / total > 0.5:
return "date"
else:
return "text"
async def _create_detailed_table_chunks(
self,
document_id: str,
table_idx: int,
table_data: List[List[str]],
metadata: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Create detailed chunks for large tables."""
chunks = []
try:
# Split large tables into sections
chunk_size = 10 # rows per chunk
for i in range(1, len(table_data), chunk_size): # Skip header row
end_idx = min(i + chunk_size, len(table_data))
section_data = table_data[i:end_idx]
# Create section description
section_description = f"Table section {i//chunk_size + 1}: Rows {i+1}-{end_idx}"
if table_data and len(table_data) > 0:
headers = table_data[0]
section_description += f". Columns: {', '.join(str(h) for h in headers[:3])}"
chunk = {
"id": f"{document_id}_table_{table_idx}_section_{i//chunk_size + 1}",
"document_id": document_id,
"tenant_id": str(self.tenant.id),
"chunk_type": "table_section",
"chunk_index": f"{table_idx}_{i//chunk_size + 1}",
"text": section_description,
"token_count": await self._estimate_tokens(section_description),
"page_numbers": [metadata.get("page_number", 1)],
"table_data": section_data,
"table_metadata": metadata,
"metadata": {
"content_type": "table_section",
"chunking_strategy": "table_sectioning",
"section_index": i//chunk_size + 1,
"row_range": f"{i+1}-{end_idx}",
"created_at": datetime.utcnow().isoformat()
}
}
chunks.append(chunk)
return chunks
except Exception as e:
logger.error(f"Error creating detailed table chunks: {str(e)}")
return []
async def _estimate_tokens(self, text: str) -> int:
"""Estimate token count for text."""
# Simple estimation: ~4 characters per token
return len(text) // 4
async def get_chunk_statistics(self, chunks: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
"""Get statistics about the chunking process."""
try:
total_chunks = sum(len(chunk_list) for chunk_list in chunks.values() if isinstance(chunk_list, list))
total_tokens = sum(
chunk.get("token_count", 0)
for chunk_list in chunks.values()
for chunk in chunk_list
if isinstance(chunk_list, list)
)
# Map chunk keys to actual chunk types
chunk_types = {}
for chunk_key, chunk_list in chunks.items():
if isinstance(chunk_list, list) and len(chunk_list) > 0:
# Extract the actual chunk type from the first chunk
actual_type = chunk_list[0].get("chunk_type", chunk_key.replace("_chunks", ""))
chunk_types[actual_type] = len(chunk_list)
return {
"total_chunks": total_chunks,
"total_tokens": total_tokens,
"average_tokens_per_chunk": total_tokens / total_chunks if total_chunks > 0 else 0,
"chunk_types": chunk_types,
"chunking_parameters": {
"chunk_size": self.chunk_size,
"chunk_overlap": self.chunk_overlap,
"chunk_min_size": self.chunk_min_size,
"chunk_max_size": self.chunk_max_size
}
}
except Exception as e:
logger.error(f"Error getting chunk statistics: {str(e)}")
return {}