- Implement Autonomous Workflow Engine with dynamic task decomposition - Add Multi-Agent Communication Protocol with message routing - Create Enhanced Reasoning Chains (CoT, ToT, Multi-Step, Parallel, Hybrid) - Add comprehensive REST API endpoints for all Week 5 features - Include 26/26 passing tests with full coverage - Add complete documentation and API guides - Update development plan to mark Week 5 as completed Features: - Dynamic task decomposition and parallel execution - Agent registration, messaging, and coordination - 5 reasoning methods with validation and learning - Robust error handling and monitoring - Multi-tenant support and security - Production-ready architecture Files added/modified: - app/services/autonomous_workflow_engine.py - app/services/agent_communication.py - app/services/enhanced_reasoning.py - app/api/v1/endpoints/week5_features.py - tests/test_week5_features.py - docs/week5_api_documentation.md - docs/week5_readme.md - WEEK5_COMPLETION_SUMMARY.md - DEVELOPMENT_PLAN.md (updated) All tests passing: 26/26
896 lines
32 KiB
Python
896 lines
32 KiB
Python
"""
|
|
Enhanced Reasoning Chains - Week 5 Implementation
|
|
Advanced Tree of Thoughts, Chain of Thought, and Multi-Step reasoning with validation and learning.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Dict, List, Optional, Tuple, Set, Union
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
import uuid
|
|
from datetime import datetime
|
|
import json
|
|
import math
|
|
from collections import defaultdict
|
|
|
|
from app.services.llm_service import llm_service
|
|
from app.core.cache import cache_service
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ReasoningMethod(Enum):
|
|
"""Reasoning methods available."""
|
|
CHAIN_OF_THOUGHT = "chain_of_thought"
|
|
TREE_OF_THOUGHTS = "tree_of_thoughts"
|
|
MULTI_STEP = "multi_step"
|
|
PARALLEL = "parallel"
|
|
HYBRID = "hybrid"
|
|
|
|
|
|
class ThoughtType(Enum):
|
|
"""Types of thoughts in reasoning chains."""
|
|
OBSERVATION = "observation"
|
|
HYPOTHESIS = "hypothesis"
|
|
ANALYSIS = "analysis"
|
|
CONCLUSION = "conclusion"
|
|
VALIDATION = "validation"
|
|
SYNTHESIS = "synthesis"
|
|
|
|
|
|
@dataclass
|
|
class Thought:
|
|
"""Represents a single thought in reasoning."""
|
|
id: str
|
|
content: str
|
|
thought_type: ThoughtType
|
|
confidence: float = 0.0
|
|
parent_id: Optional[str] = None
|
|
children: List[str] = field(default_factory=list)
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
|
validation_status: str = "pending"
|
|
|
|
|
|
@dataclass
|
|
class ReasoningChain:
|
|
"""Represents a chain of reasoning steps."""
|
|
id: str
|
|
method: ReasoningMethod
|
|
thoughts: List[Thought] = field(default_factory=list)
|
|
confidence: float = 0.0
|
|
validation_score: float = 0.0
|
|
execution_time: float = 0.0
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
|
|
|
|
|
@dataclass
|
|
class ReasoningResult:
|
|
"""Result of reasoning process."""
|
|
chain_id: str
|
|
method: ReasoningMethod
|
|
final_answer: str
|
|
confidence: float
|
|
reasoning_steps: List[Dict[str, Any]]
|
|
validation_metrics: Dict[str, float]
|
|
execution_time: float
|
|
metadata: Dict[str, Any]
|
|
|
|
|
|
class ThoughtTree:
|
|
"""Tree structure for Tree of Thoughts reasoning."""
|
|
|
|
def __init__(self, root_thought: Thought):
|
|
self.root = root_thought
|
|
self.thoughts: Dict[str, Thought] = {root_thought.id: root_thought}
|
|
self.max_depth = 5
|
|
self.max_breadth = 10
|
|
|
|
def add_thought(self, thought: Thought, parent_id: Optional[str] = None) -> None:
|
|
"""Add a thought to the tree."""
|
|
self.thoughts[thought.id] = thought
|
|
|
|
if parent_id:
|
|
thought.parent_id = parent_id
|
|
if parent_id in self.thoughts:
|
|
self.thoughts[parent_id].children.append(thought.id)
|
|
|
|
def get_thoughts_at_depth(self, depth: int) -> List[Thought]:
|
|
"""Get all thoughts at a specific depth."""
|
|
if depth == 0:
|
|
return [self.root]
|
|
|
|
thoughts = []
|
|
for thought in self.thoughts.values():
|
|
if self._get_thought_depth(thought) == depth:
|
|
thoughts.append(thought)
|
|
|
|
return thoughts
|
|
|
|
def _get_thought_depth(self, thought: Thought) -> int:
|
|
"""Get the depth of a thought in the tree."""
|
|
if thought.id == self.root.id:
|
|
return 0
|
|
|
|
if thought.parent_id is None:
|
|
return 0
|
|
|
|
parent = self.thoughts.get(thought.parent_id)
|
|
if parent:
|
|
return self._get_thought_depth(parent) + 1
|
|
|
|
return 0
|
|
|
|
def get_best_path(self) -> List[Thought]:
|
|
"""Get the best reasoning path based on confidence scores."""
|
|
best_path = []
|
|
current_thought = self.root
|
|
|
|
while current_thought:
|
|
best_path.append(current_thought)
|
|
|
|
if not current_thought.children:
|
|
break
|
|
|
|
# Find child with highest confidence
|
|
best_child_id = max(
|
|
current_thought.children,
|
|
key=lambda child_id: self.thoughts[child_id].confidence
|
|
)
|
|
current_thought = self.thoughts[best_child_id]
|
|
|
|
return best_path
|
|
|
|
|
|
class ReasoningValidator:
|
|
"""Validates reasoning chains and thoughts."""
|
|
|
|
def __init__(self):
|
|
self.validation_rules = {
|
|
"logical_consistency": self._validate_logical_consistency,
|
|
"factual_accuracy": self._validate_factual_accuracy,
|
|
"completeness": self._validate_completeness,
|
|
"coherence": self._validate_coherence
|
|
}
|
|
|
|
async def validate_thought(self, thought: Thought, context: Dict[str, Any]) -> Dict[str, float]:
|
|
"""Validate a single thought."""
|
|
validation_scores = {}
|
|
|
|
for rule_name, rule_func in self.validation_rules.items():
|
|
try:
|
|
score = await rule_func(thought, context)
|
|
validation_scores[rule_name] = score
|
|
except Exception as e:
|
|
logger.error(f"Validation rule {rule_name} failed: {e}")
|
|
validation_scores[rule_name] = 0.0
|
|
|
|
return validation_scores
|
|
|
|
async def validate_chain(self, chain: ReasoningChain, context: Dict[str, Any]) -> Dict[str, float]:
|
|
"""Validate an entire reasoning chain."""
|
|
chain_validation = {}
|
|
|
|
# Validate individual thoughts
|
|
thought_validations = []
|
|
for thought in chain.thoughts:
|
|
validation = await self.validate_thought(thought, context)
|
|
thought_validations.append(validation)
|
|
thought.validation_status = "validated"
|
|
|
|
# Aggregate validation scores
|
|
if thought_validations:
|
|
for rule_name in self.validation_rules.keys():
|
|
scores = [v.get(rule_name, 0.0) for v in thought_validations]
|
|
chain_validation[rule_name] = sum(scores) / len(scores)
|
|
|
|
# Overall chain validation
|
|
chain_validation["overall"] = sum(chain_validation.values()) / len(chain_validation)
|
|
|
|
return chain_validation
|
|
|
|
async def _validate_logical_consistency(self, thought: Thought, context: Dict[str, Any]) -> float:
|
|
"""Validate logical consistency of a thought."""
|
|
prompt = f"""
|
|
Analyze the logical consistency of the following thought:
|
|
|
|
Thought: {thought.content}
|
|
Context: {context.get('query', '')}
|
|
|
|
Rate the logical consistency from 0.0 to 1.0, where:
|
|
0.0 = Completely illogical or contradictory
|
|
1.0 = Perfectly logical and consistent
|
|
|
|
Provide only the numerical score:
|
|
"""
|
|
|
|
try:
|
|
response = await llm_service.generate_text(
|
|
prompt=prompt,
|
|
tenant_id=context.get('tenant_id', 'default'),
|
|
task="validation",
|
|
max_tokens=10
|
|
)
|
|
|
|
# Extract score from response
|
|
score_text = response.get('text', '0.5').strip()
|
|
try:
|
|
score = float(score_text)
|
|
return max(0.0, min(1.0, score))
|
|
except ValueError:
|
|
return 0.5
|
|
|
|
except Exception as e:
|
|
logger.error(f"Logical consistency validation failed: {e}")
|
|
return 0.5
|
|
|
|
async def _validate_factual_accuracy(self, thought: Thought, context: Dict[str, Any]) -> float:
|
|
"""Validate factual accuracy of a thought."""
|
|
prompt = f"""
|
|
Assess the factual accuracy of the following thought based on the provided context:
|
|
|
|
Thought: {thought.content}
|
|
Context: {context.get('context_data', '')}
|
|
|
|
Rate the factual accuracy from 0.0 to 1.0, where:
|
|
0.0 = Completely inaccurate or false
|
|
1.0 = Completely accurate and factual
|
|
|
|
Provide only the numerical score:
|
|
"""
|
|
|
|
try:
|
|
response = await llm_service.generate_text(
|
|
prompt=prompt,
|
|
tenant_id=context.get('tenant_id', 'default'),
|
|
task="validation",
|
|
max_tokens=10
|
|
)
|
|
|
|
score_text = response.get('text', '0.5').strip()
|
|
try:
|
|
score = float(score_text)
|
|
return max(0.0, min(1.0, score))
|
|
except ValueError:
|
|
return 0.5
|
|
|
|
except Exception as e:
|
|
logger.error(f"Factual accuracy validation failed: {e}")
|
|
return 0.5
|
|
|
|
async def _validate_completeness(self, thought: Thought, context: Dict[str, Any]) -> float:
|
|
"""Validate completeness of a thought."""
|
|
# Simple heuristic-based validation
|
|
content_length = len(thought.content)
|
|
has_numbers = any(char.isdigit() for char in thought.content)
|
|
has_keywords = any(keyword in thought.content.lower() for keyword in ['because', 'therefore', 'however', 'although'])
|
|
|
|
score = 0.0
|
|
if content_length > 50:
|
|
score += 0.3
|
|
if has_numbers:
|
|
score += 0.2
|
|
if has_keywords:
|
|
score += 0.2
|
|
if thought.confidence > 0.7:
|
|
score += 0.3
|
|
|
|
return min(1.0, score)
|
|
|
|
async def _validate_coherence(self, thought: Thought, context: Dict[str, Any]) -> float:
|
|
"""Validate coherence of a thought."""
|
|
# Simple coherence check
|
|
sentences = thought.content.split('.')
|
|
if len(sentences) <= 1:
|
|
return 0.8 # Single sentence is usually coherent
|
|
|
|
# Check for logical connectors
|
|
connectors = ['and', 'but', 'or', 'because', 'therefore', 'however', 'although', 'while']
|
|
has_connectors = any(connector in thought.content.lower() for connector in connectors)
|
|
|
|
return 0.9 if has_connectors else 0.6
|
|
|
|
|
|
class EnhancedReasoningEngine:
|
|
"""Main engine for enhanced reasoning capabilities."""
|
|
|
|
def __init__(self):
|
|
self.validator = ReasoningValidator()
|
|
self.reasoning_history: List[ReasoningChain] = []
|
|
self.learning_data: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
|
|
|
|
async def reason(
|
|
self,
|
|
query: str,
|
|
context: Dict[str, Any],
|
|
method: Union[ReasoningMethod, str] = ReasoningMethod.CHAIN_OF_THOUGHT,
|
|
max_steps: int = 10
|
|
) -> ReasoningResult:
|
|
"""Perform reasoning using the specified method."""
|
|
start_time = datetime.utcnow()
|
|
|
|
# Handle string method input
|
|
if isinstance(method, str):
|
|
try:
|
|
method = ReasoningMethod(method)
|
|
except ValueError:
|
|
raise ValueError(f"Unknown reasoning method: {method}")
|
|
|
|
try:
|
|
if method == ReasoningMethod.CHAIN_OF_THOUGHT:
|
|
chain = await self._chain_of_thought_reasoning(query, context, max_steps)
|
|
elif method == ReasoningMethod.TREE_OF_THOUGHTS:
|
|
chain = await self._tree_of_thoughts_reasoning(query, context, max_steps)
|
|
elif method == ReasoningMethod.MULTI_STEP:
|
|
chain = await self._multi_step_reasoning(query, context, max_steps)
|
|
elif method == ReasoningMethod.PARALLEL:
|
|
chain = await self._parallel_reasoning(query, context, max_steps)
|
|
elif method == ReasoningMethod.HYBRID:
|
|
chain = await self._hybrid_reasoning(query, context, max_steps)
|
|
else:
|
|
raise ValueError(f"Unknown reasoning method: {method}")
|
|
|
|
# Validate the reasoning chain
|
|
validation_metrics = await self.validator.validate_chain(chain, context)
|
|
chain.validation_score = validation_metrics.get("overall", 0.0)
|
|
|
|
# Calculate execution time
|
|
execution_time = (datetime.utcnow() - start_time).total_seconds()
|
|
chain.execution_time = execution_time
|
|
|
|
# Store in history
|
|
self.reasoning_history.append(chain)
|
|
|
|
# Extract final answer
|
|
final_answer = self._extract_final_answer(chain)
|
|
|
|
# Create result
|
|
result = ReasoningResult(
|
|
chain_id=chain.id,
|
|
method=method,
|
|
final_answer=final_answer,
|
|
confidence=chain.confidence,
|
|
reasoning_steps=[self._thought_to_dict(t) for t in chain.thoughts],
|
|
validation_metrics=validation_metrics,
|
|
execution_time=execution_time,
|
|
metadata=chain.metadata
|
|
)
|
|
|
|
# Learn from this reasoning session
|
|
await self._learn_from_reasoning(chain, result, context)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Reasoning failed: {e}")
|
|
# Return fallback result
|
|
return ReasoningResult(
|
|
chain_id=str(uuid.uuid4()),
|
|
method=method,
|
|
final_answer=f"Reasoning failed: {str(e)}",
|
|
confidence=0.0,
|
|
reasoning_steps=[],
|
|
validation_metrics={},
|
|
execution_time=(datetime.utcnow() - start_time).total_seconds(),
|
|
metadata={"error": str(e)}
|
|
)
|
|
|
|
async def _chain_of_thought_reasoning(
|
|
self,
|
|
query: str,
|
|
context: Dict[str, Any],
|
|
max_steps: int
|
|
) -> ReasoningChain:
|
|
"""Perform Chain of Thought reasoning."""
|
|
chain = ReasoningChain(
|
|
id=str(uuid.uuid4()),
|
|
method=ReasoningMethod.CHAIN_OF_THOUGHT
|
|
)
|
|
|
|
current_thought = Thought(
|
|
id=str(uuid.uuid4()),
|
|
content=f"Starting analysis of: {query}",
|
|
thought_type=ThoughtType.OBSERVATION,
|
|
confidence=1.0
|
|
)
|
|
chain.thoughts.append(current_thought)
|
|
|
|
for step in range(max_steps):
|
|
# Generate next thought
|
|
next_thought_content = await self._generate_next_thought(
|
|
query, context, chain.thoughts, "chain_of_thought"
|
|
)
|
|
|
|
if not next_thought_content or "conclusion" in next_thought_content.lower():
|
|
break
|
|
|
|
# Create new thought
|
|
thought_type = self._determine_thought_type(next_thought_content, step)
|
|
confidence = await self._estimate_confidence(next_thought_content, context)
|
|
|
|
next_thought = Thought(
|
|
id=str(uuid.uuid4()),
|
|
content=next_thought_content,
|
|
thought_type=thought_type,
|
|
confidence=confidence,
|
|
parent_id=current_thought.id
|
|
)
|
|
|
|
chain.thoughts.append(next_thought)
|
|
current_thought = next_thought
|
|
|
|
# Calculate overall confidence
|
|
chain.confidence = sum(t.confidence for t in chain.thoughts) / len(chain.thoughts)
|
|
|
|
return chain
|
|
|
|
async def _tree_of_thoughts_reasoning(
|
|
self,
|
|
query: str,
|
|
context: Dict[str, Any],
|
|
max_steps: int
|
|
) -> ReasoningChain:
|
|
"""Perform Tree of Thoughts reasoning."""
|
|
# Create root thought
|
|
root_thought = Thought(
|
|
id=str(uuid.uuid4()),
|
|
content=f"Analyzing: {query}",
|
|
thought_type=ThoughtType.OBSERVATION,
|
|
confidence=1.0
|
|
)
|
|
|
|
tree = ThoughtTree(root_thought)
|
|
chain = ReasoningChain(
|
|
id=str(uuid.uuid4()),
|
|
method=ReasoningMethod.TREE_OF_THOUGHTS
|
|
)
|
|
|
|
# Expand tree
|
|
for depth in range(tree.max_depth):
|
|
current_thoughts = tree.get_thoughts_at_depth(depth)
|
|
|
|
for thought in current_thoughts:
|
|
if depth < tree.max_depth - 1:
|
|
# Generate multiple child thoughts
|
|
child_thoughts = await self._generate_child_thoughts(
|
|
query, context, thought, tree.max_breadth
|
|
)
|
|
|
|
for child_content in child_thoughts:
|
|
child_thought = Thought(
|
|
id=str(uuid.uuid4()),
|
|
content=child_content,
|
|
thought_type=self._determine_thought_type(child_content, depth + 1),
|
|
confidence=await self._estimate_confidence(child_content, context),
|
|
parent_id=thought.id
|
|
)
|
|
tree.add_thought(child_thought, thought.id)
|
|
|
|
# Evaluate and prune if needed
|
|
if depth > 0:
|
|
await self._evaluate_and_prune_tree(tree, depth)
|
|
|
|
# Get best path
|
|
best_path = tree.get_best_path()
|
|
chain.thoughts = best_path
|
|
chain.confidence = sum(t.confidence for t in best_path) / len(best_path)
|
|
|
|
return chain
|
|
|
|
async def _multi_step_reasoning(
|
|
self,
|
|
query: str,
|
|
context: Dict[str, Any],
|
|
max_steps: int
|
|
) -> ReasoningChain:
|
|
"""Perform Multi-Step reasoning with validation at each step."""
|
|
chain = ReasoningChain(
|
|
id=str(uuid.uuid4()),
|
|
method=ReasoningMethod.MULTI_STEP
|
|
)
|
|
|
|
current_thought = Thought(
|
|
id=str(uuid.uuid4()),
|
|
content=f"Starting multi-step analysis of: {query}",
|
|
thought_type=ThoughtType.OBSERVATION,
|
|
confidence=1.0
|
|
)
|
|
chain.thoughts.append(current_thought)
|
|
|
|
for step in range(max_steps):
|
|
# Generate next step
|
|
next_thought_content = await self._generate_next_thought(
|
|
query, context, chain.thoughts, "multi_step"
|
|
)
|
|
|
|
if not next_thought_content:
|
|
break
|
|
|
|
# Create thought
|
|
thought_type = self._determine_thought_type(next_thought_content, step)
|
|
confidence = await self._estimate_confidence(next_thought_content, context)
|
|
|
|
next_thought = Thought(
|
|
id=str(uuid.uuid4()),
|
|
content=next_thought_content,
|
|
thought_type=thought_type,
|
|
confidence=confidence,
|
|
parent_id=current_thought.id
|
|
)
|
|
|
|
# Validate this step
|
|
validation = await self.validator.validate_thought(next_thought, context)
|
|
if validation.get("overall", 0.0) < 0.3: # Low validation score
|
|
logger.warning(f"Step {step} failed validation, stopping")
|
|
break
|
|
|
|
chain.thoughts.append(next_thought)
|
|
current_thought = next_thought
|
|
|
|
# Calculate overall confidence
|
|
chain.confidence = sum(t.confidence for t in chain.thoughts) / len(chain.thoughts)
|
|
|
|
return chain
|
|
|
|
async def _parallel_reasoning(
|
|
self,
|
|
query: str,
|
|
context: Dict[str, Any],
|
|
max_steps: int
|
|
) -> ReasoningChain:
|
|
"""Perform parallel reasoning with multiple approaches."""
|
|
chain = ReasoningChain(
|
|
id=str(uuid.uuid4()),
|
|
method=ReasoningMethod.PARALLEL
|
|
)
|
|
|
|
# Generate multiple parallel thoughts
|
|
parallel_prompts = [
|
|
f"Analyze {query} from a logical perspective",
|
|
f"Analyze {query} from a creative perspective",
|
|
f"Analyze {query} from a critical perspective",
|
|
f"Analyze {query} from a practical perspective"
|
|
]
|
|
|
|
parallel_tasks = []
|
|
for prompt in parallel_prompts:
|
|
task = self._generate_parallel_thought(prompt, context)
|
|
parallel_tasks.append(task)
|
|
|
|
# Execute in parallel
|
|
parallel_results = await asyncio.gather(*parallel_tasks, return_exceptions=True)
|
|
|
|
# Create thoughts from results
|
|
for i, result in enumerate(parallel_results):
|
|
if isinstance(result, Exception):
|
|
logger.error(f"Parallel reasoning task {i} failed: {result}")
|
|
continue
|
|
|
|
thought = Thought(
|
|
id=str(uuid.uuid4()),
|
|
content=result,
|
|
thought_type=ThoughtType.ANALYSIS,
|
|
confidence=await self._estimate_confidence(result, context)
|
|
)
|
|
chain.thoughts.append(thought)
|
|
|
|
# Synthesize parallel results
|
|
synthesis = await self._synthesize_parallel_results(chain.thoughts, query, context)
|
|
synthesis_thought = Thought(
|
|
id=str(uuid.uuid4()),
|
|
content=synthesis,
|
|
thought_type=ThoughtType.SYNTHESIS,
|
|
confidence=await self._estimate_confidence(synthesis, context)
|
|
)
|
|
chain.thoughts.append(synthesis_thought)
|
|
|
|
# Calculate overall confidence
|
|
chain.confidence = sum(t.confidence for t in chain.thoughts) / len(chain.thoughts)
|
|
|
|
return chain
|
|
|
|
async def _hybrid_reasoning(
|
|
self,
|
|
query: str,
|
|
context: Dict[str, Any],
|
|
max_steps: int
|
|
) -> ReasoningChain:
|
|
"""Perform hybrid reasoning combining multiple methods."""
|
|
# Start with Chain of Thought
|
|
cot_chain = await self._chain_of_thought_reasoning(query, context, max_steps // 2)
|
|
|
|
# Add Tree of Thoughts exploration
|
|
tot_chain = await self._tree_of_thoughts_reasoning(query, context, max_steps // 2)
|
|
|
|
# Combine results
|
|
hybrid_chain = ReasoningChain(
|
|
id=str(uuid.uuid4()),
|
|
method=ReasoningMethod.HYBRID
|
|
)
|
|
|
|
# Add CoT thoughts
|
|
hybrid_chain.thoughts.extend(cot_chain.thoughts)
|
|
|
|
# Add ToT thoughts (avoiding duplicates)
|
|
for tot_thought in tot_chain.thoughts:
|
|
if not any(t.content == tot_thought.content for t in hybrid_chain.thoughts):
|
|
hybrid_chain.thoughts.append(tot_thought)
|
|
|
|
# Synthesize hybrid results
|
|
synthesis = await self._synthesize_hybrid_results(hybrid_chain.thoughts, query, context)
|
|
synthesis_thought = Thought(
|
|
id=str(uuid.uuid4()),
|
|
content=synthesis,
|
|
thought_type=ThoughtType.SYNTHESIS,
|
|
confidence=await self._estimate_confidence(synthesis, context)
|
|
)
|
|
hybrid_chain.thoughts.append(synthesis_thought)
|
|
|
|
# Calculate overall confidence
|
|
hybrid_chain.confidence = sum(t.confidence for t in hybrid_chain.thoughts) / len(hybrid_chain.thoughts)
|
|
|
|
return hybrid_chain
|
|
|
|
async def _generate_next_thought(
|
|
self,
|
|
query: str,
|
|
context: Dict[str, Any],
|
|
previous_thoughts: List[Thought],
|
|
method: str
|
|
) -> str:
|
|
"""Generate the next thought in the reasoning chain."""
|
|
thought_history = "\n".join([f"Step {i+1}: {t.content}" for i, t in enumerate(previous_thoughts)])
|
|
|
|
prompt = f"""
|
|
Continue the reasoning process for the following query:
|
|
|
|
Query: {query}
|
|
Context: {context.get('context_data', '')}
|
|
|
|
Previous thoughts:
|
|
{thought_history}
|
|
|
|
Generate the next logical step in the reasoning process. Be specific and analytical.
|
|
If you reach a conclusion, indicate it clearly.
|
|
|
|
Next step:
|
|
"""
|
|
|
|
try:
|
|
response = await llm_service.generate_text(
|
|
prompt=prompt,
|
|
tenant_id=context.get('tenant_id', 'default'),
|
|
task="reasoning",
|
|
max_tokens=200
|
|
)
|
|
|
|
return response.get('text', '').strip()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate next thought: {e}")
|
|
return ""
|
|
|
|
async def _generate_child_thoughts(
|
|
self,
|
|
query: str,
|
|
context: Dict[str, Any],
|
|
parent_thought: Thought,
|
|
max_children: int
|
|
) -> List[str]:
|
|
"""Generate child thoughts for Tree of Thoughts."""
|
|
prompt = f"""
|
|
For the following query and parent thought, generate {max_children} different approaches or perspectives:
|
|
|
|
Query: {query}
|
|
Parent thought: {parent_thought.content}
|
|
|
|
Generate {max_children} different reasoning paths or perspectives. Each should be distinct and valuable.
|
|
|
|
Responses:
|
|
"""
|
|
|
|
try:
|
|
response = await llm_service.generate_text(
|
|
prompt=prompt,
|
|
tenant_id=context.get('tenant_id', 'default'),
|
|
task="reasoning",
|
|
max_tokens=400
|
|
)
|
|
|
|
# Parse multiple thoughts from response
|
|
content = response.get('text', '')
|
|
thoughts = [t.strip() for t in content.split('\n') if t.strip()]
|
|
|
|
return thoughts[:max_children]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate child thoughts: {e}")
|
|
return []
|
|
|
|
async def _estimate_confidence(self, content: str, context: Dict[str, Any]) -> float:
|
|
"""Estimate confidence in a thought."""
|
|
# Simple heuristic-based confidence estimation
|
|
confidence = 0.5 # Base confidence
|
|
|
|
# Factors that increase confidence
|
|
if len(content) > 100:
|
|
confidence += 0.1
|
|
if any(word in content.lower() for word in ['because', 'therefore', 'evidence', 'data']):
|
|
confidence += 0.1
|
|
if any(char.isdigit() for char in content):
|
|
confidence += 0.1
|
|
if content.endswith('.') or content.endswith('!'):
|
|
confidence += 0.05
|
|
|
|
return min(1.0, confidence)
|
|
|
|
def _determine_thought_type(self, content: str, step: int) -> ThoughtType:
|
|
"""Determine the type of a thought based on content and step."""
|
|
content_lower = content.lower()
|
|
|
|
if step == 0:
|
|
return ThoughtType.OBSERVATION
|
|
elif any(word in content_lower for word in ['conclude', 'therefore', 'thus', 'result']):
|
|
return ThoughtType.CONCLUSION
|
|
elif any(word in content_lower for word in ['because', 'since', 'as', 'due to']):
|
|
return ThoughtType.ANALYSIS
|
|
elif any(word in content_lower for word in ['if', 'suppose', 'assume', 'hypothesis']):
|
|
return ThoughtType.HYPOTHESIS
|
|
elif any(word in content_lower for word in ['validate', 'check', 'verify', 'confirm']):
|
|
return ThoughtType.VALIDATION
|
|
else:
|
|
return ThoughtType.ANALYSIS
|
|
|
|
async def _evaluate_and_prune_tree(self, tree: ThoughtTree, depth: int) -> None:
|
|
"""Evaluate and prune the tree at a given depth."""
|
|
thoughts_at_depth = tree.get_thoughts_at_depth(depth)
|
|
|
|
# Sort by confidence and keep top thoughts
|
|
thoughts_at_depth.sort(key=lambda t: t.confidence, reverse=True)
|
|
|
|
# Keep only top thoughts (simple pruning)
|
|
for thought in thoughts_at_depth[tree.max_breadth:]:
|
|
if thought.id in tree.thoughts:
|
|
del tree.thoughts[thought.id]
|
|
|
|
async def _generate_parallel_thought(self, prompt: str, context: Dict[str, Any]) -> str:
|
|
"""Generate a thought for parallel reasoning."""
|
|
try:
|
|
response = await llm_service.generate_text(
|
|
prompt=prompt,
|
|
tenant_id=context.get('tenant_id', 'default'),
|
|
task="reasoning",
|
|
max_tokens=150
|
|
)
|
|
|
|
return response.get('text', '').strip()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Parallel thought generation failed: {e}")
|
|
return ""
|
|
|
|
async def _synthesize_parallel_results(
|
|
self,
|
|
thoughts: List[Thought],
|
|
query: str,
|
|
context: Dict[str, Any]
|
|
) -> str:
|
|
"""Synthesize results from parallel reasoning."""
|
|
thought_contents = "\n".join([f"- {t.content}" for t in thoughts])
|
|
|
|
prompt = f"""
|
|
Synthesize the following parallel analyses into a coherent conclusion:
|
|
|
|
Query: {query}
|
|
|
|
Parallel analyses:
|
|
{thought_contents}
|
|
|
|
Provide a synthesized conclusion that combines the best insights from all perspectives:
|
|
"""
|
|
|
|
try:
|
|
response = await llm_service.generate_text(
|
|
prompt=prompt,
|
|
tenant_id=context.get('tenant_id', 'default'),
|
|
task="synthesis",
|
|
max_tokens=200
|
|
)
|
|
|
|
return response.get('text', '').strip()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Parallel synthesis failed: {e}")
|
|
return "Synthesis failed due to error."
|
|
|
|
async def _synthesize_hybrid_results(
|
|
self,
|
|
thoughts: List[Thought],
|
|
query: str,
|
|
context: Dict[str, Any]
|
|
) -> str:
|
|
"""Synthesize results from hybrid reasoning."""
|
|
return await self._synthesize_parallel_results(thoughts, query, context)
|
|
|
|
def _extract_final_answer(self, chain: ReasoningChain) -> str:
|
|
"""Extract the final answer from a reasoning chain."""
|
|
if not chain.thoughts:
|
|
return "No reasoning steps completed."
|
|
|
|
# Look for conclusion thoughts
|
|
conclusions = [t for t in chain.thoughts if t.thought_type == ThoughtType.CONCLUSION]
|
|
|
|
if conclusions:
|
|
# Return the highest confidence conclusion
|
|
best_conclusion = max(conclusions, key=lambda t: t.confidence)
|
|
return best_conclusion.content
|
|
|
|
# If no conclusions, return the last thought
|
|
return chain.thoughts[-1].content
|
|
|
|
def _thought_to_dict(self, thought: Thought) -> Dict[str, Any]:
|
|
"""Convert a thought to dictionary format."""
|
|
return {
|
|
"id": thought.id,
|
|
"content": thought.content,
|
|
"type": thought.thought_type.value,
|
|
"confidence": thought.confidence,
|
|
"parent_id": thought.parent_id,
|
|
"validation_status": thought.validation_status,
|
|
"created_at": thought.created_at.isoformat()
|
|
}
|
|
|
|
async def _learn_from_reasoning(
|
|
self,
|
|
chain: ReasoningChain,
|
|
result: ReasoningResult,
|
|
context: Dict[str, Any]
|
|
) -> None:
|
|
"""Learn from the reasoning process to improve future reasoning."""
|
|
learning_data = {
|
|
"chain_id": chain.id,
|
|
"method": chain.method.value,
|
|
"query": context.get('query', ''),
|
|
"confidence": result.confidence,
|
|
"validation_score": result.validation_metrics.get("overall", 0.0),
|
|
"execution_time": result.execution_time,
|
|
"thought_count": len(chain.thoughts),
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
method_key = chain.method.value
|
|
self.learning_data[method_key].append(learning_data)
|
|
|
|
# Keep only recent learning data
|
|
if len(self.learning_data[method_key]) > 1000:
|
|
self.learning_data[method_key] = self.learning_data[method_key][-500:]
|
|
|
|
async def get_reasoning_stats(self) -> Dict[str, Any]:
|
|
"""Get statistics about reasoning performance."""
|
|
stats = {}
|
|
|
|
for method in ReasoningMethod:
|
|
method_data = self.learning_data[method.value]
|
|
if method_data:
|
|
avg_confidence = sum(d['confidence'] for d in method_data) / len(method_data)
|
|
avg_validation = sum(d['validation_score'] for d in method_data) / len(method_data)
|
|
avg_time = sum(d['execution_time'] for d in method_data) / len(method_data)
|
|
|
|
stats[method.value] = {
|
|
"total_uses": len(method_data),
|
|
"avg_confidence": avg_confidence,
|
|
"avg_validation_score": avg_validation,
|
|
"avg_execution_time": avg_time
|
|
}
|
|
|
|
return stats
|
|
|
|
|
|
# Global enhanced reasoning engine instance
|
|
enhanced_reasoning_engine = EnhancedReasoningEngine()
|