557 lines
22 KiB
Python
557 lines
22 KiB
Python
"""
|
|
Document chunking service for the Virtual Board Member AI System.
|
|
Implements intelligent chunking strategy with support for structured data indexing.
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
from datetime import datetime
|
|
import uuid
|
|
import json
|
|
|
|
from app.core.config import settings
|
|
from app.models.tenant import Tenant
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DocumentChunkingService:
|
|
"""Service for intelligent document chunking with structured data support."""
|
|
|
|
def __init__(self, tenant: Tenant):
|
|
self.tenant = tenant
|
|
self.chunk_size = settings.CHUNK_SIZE
|
|
self.chunk_overlap = settings.CHUNK_OVERLAP
|
|
self.chunk_min_size = settings.CHUNK_MIN_SIZE
|
|
self.chunk_max_size = settings.CHUNK_MAX_SIZE
|
|
|
|
async def chunk_document_content(
|
|
self,
|
|
document_id: str,
|
|
content: Dict[str, Any]
|
|
) -> Dict[str, List[Dict[str, Any]]]:
|
|
"""
|
|
Chunk document content into multiple types of chunks for vector indexing.
|
|
|
|
Args:
|
|
document_id: The document ID
|
|
content: Document content with text, tables, charts, etc.
|
|
|
|
Returns:
|
|
Dictionary with different types of chunks (text, tables, charts)
|
|
"""
|
|
try:
|
|
chunks = {
|
|
"text_chunks": [],
|
|
"table_chunks": [],
|
|
"chart_chunks": [],
|
|
"metadata": {
|
|
"document_id": document_id,
|
|
"tenant_id": str(self.tenant.id),
|
|
"chunking_timestamp": datetime.utcnow().isoformat(),
|
|
"chunk_size": self.chunk_size,
|
|
"chunk_overlap": self.chunk_overlap
|
|
}
|
|
}
|
|
|
|
# Process text content
|
|
if content.get("text_content"):
|
|
text_chunks = await self._chunk_text_content(
|
|
document_id, content["text_content"]
|
|
)
|
|
chunks["text_chunks"] = text_chunks
|
|
|
|
# Process table content
|
|
if content.get("tables"):
|
|
table_chunks = await self._chunk_table_content(
|
|
document_id, content["tables"]
|
|
)
|
|
chunks["table_chunks"] = table_chunks
|
|
|
|
# Process chart content
|
|
if content.get("charts"):
|
|
chart_chunks = await self._chunk_chart_content(
|
|
document_id, content["charts"]
|
|
)
|
|
chunks["chart_chunks"] = chart_chunks
|
|
|
|
# Add metadata about chunking results
|
|
chunks["metadata"]["total_chunks"] = (
|
|
len(chunks["text_chunks"]) +
|
|
len(chunks["table_chunks"]) +
|
|
len(chunks["chart_chunks"])
|
|
)
|
|
chunks["metadata"]["text_chunks"] = len(chunks["text_chunks"])
|
|
chunks["metadata"]["table_chunks"] = len(chunks["table_chunks"])
|
|
chunks["metadata"]["chart_chunks"] = len(chunks["chart_chunks"])
|
|
|
|
logger.info(f"Chunked document {document_id} into {chunks['metadata']['total_chunks']} chunks")
|
|
return chunks
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error chunking document {document_id}: {str(e)}")
|
|
raise
|
|
|
|
async def _chunk_text_content(
|
|
self,
|
|
document_id: str,
|
|
text_content: List[Dict[str, Any]]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Chunk text content with intelligent boundaries."""
|
|
chunks = []
|
|
|
|
try:
|
|
# Combine all text content
|
|
full_text = ""
|
|
text_metadata = []
|
|
|
|
for i, text_item in enumerate(text_content):
|
|
text = text_item.get("text", "")
|
|
page_num = text_item.get("page_number", i + 1)
|
|
|
|
# Add page separator
|
|
if full_text:
|
|
full_text += f"\n\n--- Page {page_num} ---\n\n"
|
|
|
|
full_text += text
|
|
text_metadata.append({
|
|
"start_pos": len(full_text) - len(text),
|
|
"end_pos": len(full_text),
|
|
"page_number": page_num,
|
|
"original_index": i
|
|
})
|
|
|
|
# Split into chunks
|
|
text_chunks = await self._split_text_into_chunks(full_text)
|
|
|
|
# Create chunk objects with metadata
|
|
for chunk_idx, (chunk_text, start_pos, end_pos) in enumerate(text_chunks):
|
|
# Find which pages this chunk covers
|
|
chunk_pages = []
|
|
for meta in text_metadata:
|
|
if (meta["start_pos"] <= end_pos and meta["end_pos"] >= start_pos):
|
|
chunk_pages.append(meta["page_number"])
|
|
|
|
chunk = {
|
|
"id": f"{document_id}_text_{chunk_idx}",
|
|
"document_id": document_id,
|
|
"tenant_id": str(self.tenant.id),
|
|
"chunk_type": "text",
|
|
"chunk_index": chunk_idx,
|
|
"text": chunk_text,
|
|
"token_count": await self._estimate_tokens(chunk_text),
|
|
"page_numbers": list(set(chunk_pages)),
|
|
"start_position": start_pos,
|
|
"end_position": end_pos,
|
|
"metadata": {
|
|
"content_type": "text",
|
|
"chunking_strategy": "semantic_boundaries",
|
|
"created_at": datetime.utcnow().isoformat()
|
|
}
|
|
}
|
|
chunks.append(chunk)
|
|
|
|
return chunks
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error chunking text content: {str(e)}")
|
|
return []
|
|
|
|
async def _chunk_table_content(
|
|
self,
|
|
document_id: str,
|
|
tables: List[Dict[str, Any]]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Chunk table content with structure preservation."""
|
|
chunks = []
|
|
|
|
try:
|
|
for table_idx, table in enumerate(tables):
|
|
table_data = table.get("data", [])
|
|
table_metadata = table.get("metadata", {})
|
|
|
|
if not table_data:
|
|
continue
|
|
|
|
# Create table description
|
|
table_description = await self._create_table_description(table)
|
|
|
|
# Create structured table chunk
|
|
table_chunk = {
|
|
"id": f"{document_id}_table_{table_idx}",
|
|
"document_id": document_id,
|
|
"tenant_id": str(self.tenant.id),
|
|
"chunk_type": "table",
|
|
"chunk_index": table_idx,
|
|
"text": table_description,
|
|
"token_count": await self._estimate_tokens(table_description),
|
|
"page_numbers": [table_metadata.get("page_number", 1)],
|
|
"table_data": table_data,
|
|
"table_metadata": table_metadata,
|
|
"metadata": {
|
|
"content_type": "table",
|
|
"chunking_strategy": "table_preservation",
|
|
"table_structure": await self._analyze_table_structure(table_data),
|
|
"created_at": datetime.utcnow().isoformat()
|
|
}
|
|
}
|
|
chunks.append(table_chunk)
|
|
|
|
# If table is large, create additional chunks for detailed analysis
|
|
if len(table_data) > 10: # Large table
|
|
detailed_chunks = await self._create_detailed_table_chunks(
|
|
document_id, table_idx, table_data, table_metadata
|
|
)
|
|
chunks.extend(detailed_chunks)
|
|
|
|
return chunks
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error chunking table content: {str(e)}")
|
|
return []
|
|
|
|
async def _chunk_chart_content(
|
|
self,
|
|
document_id: str,
|
|
charts: List[Dict[str, Any]]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Chunk chart content with visual analysis."""
|
|
chunks = []
|
|
|
|
try:
|
|
for chart_idx, chart in enumerate(charts):
|
|
chart_data = chart.get("data", {})
|
|
chart_metadata = chart.get("metadata", {})
|
|
|
|
# Create chart description
|
|
chart_description = await self._create_chart_description(chart)
|
|
|
|
# Create structured chart chunk
|
|
chart_chunk = {
|
|
"id": f"{document_id}_chart_{chart_idx}",
|
|
"document_id": document_id,
|
|
"tenant_id": str(self.tenant.id),
|
|
"chunk_type": "chart",
|
|
"chunk_index": chart_idx,
|
|
"text": chart_description,
|
|
"token_count": await self._estimate_tokens(chart_description),
|
|
"page_numbers": [chart_metadata.get("page_number", 1)],
|
|
"chart_data": chart_data,
|
|
"chart_metadata": chart_metadata,
|
|
"metadata": {
|
|
"content_type": "chart",
|
|
"chunking_strategy": "chart_analysis",
|
|
"chart_type": chart_metadata.get("chart_type", "unknown"),
|
|
"created_at": datetime.utcnow().isoformat()
|
|
}
|
|
}
|
|
chunks.append(chart_chunk)
|
|
|
|
return chunks
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error chunking chart content: {str(e)}")
|
|
return []
|
|
|
|
async def _split_text_into_chunks(
|
|
self,
|
|
text: str
|
|
) -> List[Tuple[str, int, int]]:
|
|
"""Split text into chunks with semantic boundaries."""
|
|
chunks = []
|
|
|
|
try:
|
|
# Simple token estimation (words + punctuation)
|
|
words = text.split()
|
|
current_chunk = []
|
|
current_pos = 0
|
|
chunk_start_pos = 0
|
|
|
|
for word in words:
|
|
current_chunk.append(word)
|
|
current_pos += len(word) + 1 # +1 for space
|
|
|
|
# Check if we've reached chunk size
|
|
if len(current_chunk) >= self.chunk_size:
|
|
chunk_text = " ".join(current_chunk)
|
|
|
|
# Try to find a good break point
|
|
break_point = await self._find_semantic_break_point(chunk_text)
|
|
if break_point > 0:
|
|
# Split at break point
|
|
first_part = chunk_text[:break_point].strip()
|
|
second_part = chunk_text[break_point:].strip()
|
|
|
|
if first_part:
|
|
chunks.append((first_part, chunk_start_pos, chunk_start_pos + len(first_part)))
|
|
|
|
# Start new chunk with remaining text
|
|
current_chunk = second_part.split() if second_part else []
|
|
chunk_start_pos = current_pos - len(second_part) if second_part else current_pos
|
|
else:
|
|
# No good break point, use current chunk
|
|
chunks.append((chunk_text, chunk_start_pos, current_pos))
|
|
current_chunk = []
|
|
chunk_start_pos = current_pos
|
|
|
|
# Add remaining text as final chunk
|
|
if current_chunk:
|
|
chunk_text = " ".join(current_chunk)
|
|
# Always add the final chunk, even if it's small
|
|
chunks.append((chunk_text, chunk_start_pos, current_pos))
|
|
|
|
# If no chunks were created and we have text, create a single chunk
|
|
if not chunks and text.strip():
|
|
chunks.append((text.strip(), 0, len(text.strip())))
|
|
|
|
return chunks
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error splitting text into chunks: {str(e)}")
|
|
return [(text, 0, len(text))]
|
|
|
|
async def _find_semantic_break_point(self, text: str) -> int:
|
|
"""Find a good semantic break point in text."""
|
|
# Look for sentence endings, paragraph breaks, etc.
|
|
break_patterns = [
|
|
r'\.\s+[A-Z]', # Sentence ending followed by capital letter
|
|
r'\n\s*\n', # Paragraph break
|
|
r';\s+', # Semicolon
|
|
r',\s+and\s+', # Comma followed by "and"
|
|
r',\s+or\s+', # Comma followed by "or"
|
|
]
|
|
|
|
for pattern in break_patterns:
|
|
matches = list(re.finditer(pattern, text))
|
|
if matches:
|
|
# Use the last match in the second half of the text
|
|
for match in reversed(matches):
|
|
if match.end() > len(text) // 2:
|
|
return match.end()
|
|
|
|
return -1 # No good break point found
|
|
|
|
async def _create_table_description(self, table: Dict[str, Any]) -> str:
|
|
"""Create a textual description of table content."""
|
|
try:
|
|
table_data = table.get("data", [])
|
|
metadata = table.get("metadata", {})
|
|
|
|
if not table_data:
|
|
return "Empty table"
|
|
|
|
# Get table dimensions
|
|
rows = len(table_data)
|
|
cols = len(table_data[0]) if table_data else 0
|
|
|
|
# Create description
|
|
description = f"Table with {rows} rows and {cols} columns"
|
|
|
|
# Add column headers if available
|
|
if table_data and len(table_data) > 0:
|
|
headers = table_data[0]
|
|
if headers:
|
|
description += f". Columns: {', '.join(str(h) for h in headers[:5])}"
|
|
if len(headers) > 5:
|
|
description += f" and {len(headers) - 5} more"
|
|
|
|
# Add sample data
|
|
if len(table_data) > 1:
|
|
sample_row = table_data[1]
|
|
if sample_row:
|
|
description += f". Sample data: {', '.join(str(cell) for cell in sample_row[:3])}"
|
|
|
|
# Add metadata
|
|
if metadata.get("title"):
|
|
description += f". Title: {metadata['title']}"
|
|
|
|
return description
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating table description: {str(e)}")
|
|
return "Table content"
|
|
|
|
async def _create_chart_description(self, chart: Dict[str, Any]) -> str:
|
|
"""Create a textual description of chart content."""
|
|
try:
|
|
chart_data = chart.get("data", {})
|
|
metadata = chart.get("metadata", {})
|
|
|
|
description = "Chart"
|
|
|
|
# Add chart type
|
|
chart_type = metadata.get("chart_type", "unknown")
|
|
description += f" ({chart_type})"
|
|
|
|
# Add title
|
|
if metadata.get("title"):
|
|
description += f": {metadata['title']}"
|
|
|
|
# Add data description
|
|
if chart_data:
|
|
if "labels" in chart_data and "values" in chart_data:
|
|
labels = chart_data["labels"][:3] # First 3 labels
|
|
values = chart_data["values"][:3] # First 3 values
|
|
description += f". Shows {', '.join(str(l) for l in labels)} with values {', '.join(str(v) for v in values)}"
|
|
|
|
if len(chart_data["labels"]) > 3:
|
|
description += f" and {len(chart_data['labels']) - 3} more data points"
|
|
|
|
return description
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating chart description: {str(e)}")
|
|
return "Chart content"
|
|
|
|
async def _analyze_table_structure(self, table_data: List[List[str]]) -> Dict[str, Any]:
|
|
"""Analyze table structure for metadata."""
|
|
try:
|
|
if not table_data:
|
|
return {"type": "empty", "rows": 0, "columns": 0}
|
|
|
|
rows = len(table_data)
|
|
cols = len(table_data[0]) if table_data else 0
|
|
|
|
# Analyze column types
|
|
column_types = []
|
|
if table_data and len(table_data) > 1: # Has data beyond headers
|
|
for col_idx in range(cols):
|
|
col_values = [row[col_idx] for row in table_data[1:] if col_idx < len(row)]
|
|
col_type = await self._infer_column_type(col_values)
|
|
column_types.append(col_type)
|
|
|
|
return {
|
|
"type": "data_table",
|
|
"rows": rows,
|
|
"columns": cols,
|
|
"column_types": column_types,
|
|
"has_headers": rows > 0,
|
|
"has_data": rows > 1
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error analyzing table structure: {str(e)}")
|
|
return {"type": "unknown", "rows": 0, "columns": 0}
|
|
|
|
async def _infer_column_type(self, values: List[str]) -> str:
|
|
"""Infer the data type of a column."""
|
|
if not values:
|
|
return "empty"
|
|
|
|
# Check for numeric values
|
|
numeric_count = 0
|
|
date_count = 0
|
|
|
|
for value in values:
|
|
if value:
|
|
# Check for numbers
|
|
try:
|
|
float(value.replace(',', '').replace('$', '').replace('%', ''))
|
|
numeric_count += 1
|
|
except ValueError:
|
|
pass
|
|
|
|
# Check for dates (simple pattern)
|
|
if re.match(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', value):
|
|
date_count += 1
|
|
|
|
total = len(values)
|
|
if numeric_count / total > 0.8:
|
|
return "numeric"
|
|
elif date_count / total > 0.5:
|
|
return "date"
|
|
else:
|
|
return "text"
|
|
|
|
async def _create_detailed_table_chunks(
|
|
self,
|
|
document_id: str,
|
|
table_idx: int,
|
|
table_data: List[List[str]],
|
|
metadata: Dict[str, Any]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Create detailed chunks for large tables."""
|
|
chunks = []
|
|
|
|
try:
|
|
# Split large tables into sections
|
|
chunk_size = 10 # rows per chunk
|
|
for i in range(1, len(table_data), chunk_size): # Skip header row
|
|
end_idx = min(i + chunk_size, len(table_data))
|
|
section_data = table_data[i:end_idx]
|
|
|
|
# Create section description
|
|
section_description = f"Table section {i//chunk_size + 1}: Rows {i+1}-{end_idx}"
|
|
if table_data and len(table_data) > 0:
|
|
headers = table_data[0]
|
|
section_description += f". Columns: {', '.join(str(h) for h in headers[:3])}"
|
|
|
|
chunk = {
|
|
"id": f"{document_id}_table_{table_idx}_section_{i//chunk_size + 1}",
|
|
"document_id": document_id,
|
|
"tenant_id": str(self.tenant.id),
|
|
"chunk_type": "table_section",
|
|
"chunk_index": f"{table_idx}_{i//chunk_size + 1}",
|
|
"text": section_description,
|
|
"token_count": await self._estimate_tokens(section_description),
|
|
"page_numbers": [metadata.get("page_number", 1)],
|
|
"table_data": section_data,
|
|
"table_metadata": metadata,
|
|
"metadata": {
|
|
"content_type": "table_section",
|
|
"chunking_strategy": "table_sectioning",
|
|
"section_index": i//chunk_size + 1,
|
|
"row_range": f"{i+1}-{end_idx}",
|
|
"created_at": datetime.utcnow().isoformat()
|
|
}
|
|
}
|
|
chunks.append(chunk)
|
|
|
|
return chunks
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating detailed table chunks: {str(e)}")
|
|
return []
|
|
|
|
async def _estimate_tokens(self, text: str) -> int:
|
|
"""Estimate token count for text."""
|
|
# Simple estimation: ~4 characters per token
|
|
return len(text) // 4
|
|
|
|
async def get_chunk_statistics(self, chunks: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
|
|
"""Get statistics about the chunking process."""
|
|
try:
|
|
total_chunks = sum(len(chunk_list) for chunk_list in chunks.values() if isinstance(chunk_list, list))
|
|
total_tokens = sum(
|
|
chunk.get("token_count", 0)
|
|
for chunk_list in chunks.values()
|
|
for chunk in chunk_list
|
|
if isinstance(chunk_list, list)
|
|
)
|
|
|
|
# Map chunk keys to actual chunk types
|
|
chunk_types = {}
|
|
for chunk_key, chunk_list in chunks.items():
|
|
if isinstance(chunk_list, list) and len(chunk_list) > 0:
|
|
# Extract the actual chunk type from the first chunk
|
|
actual_type = chunk_list[0].get("chunk_type", chunk_key.replace("_chunks", ""))
|
|
chunk_types[actual_type] = len(chunk_list)
|
|
|
|
return {
|
|
"total_chunks": total_chunks,
|
|
"total_tokens": total_tokens,
|
|
"average_tokens_per_chunk": total_tokens / total_chunks if total_chunks > 0 else 0,
|
|
"chunk_types": chunk_types,
|
|
"chunking_parameters": {
|
|
"chunk_size": self.chunk_size,
|
|
"chunk_overlap": self.chunk_overlap,
|
|
"chunk_min_size": self.chunk_min_size,
|
|
"chunk_max_size": self.chunk_max_size
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting chunk statistics: {str(e)}")
|
|
return {}
|