- Implement Autonomous Workflow Engine with dynamic task decomposition - Add Multi-Agent Communication Protocol with message routing - Create Enhanced Reasoning Chains (CoT, ToT, Multi-Step, Parallel, Hybrid) - Add comprehensive REST API endpoints for all Week 5 features - Include 26/26 passing tests with full coverage - Add complete documentation and API guides - Update development plan to mark Week 5 as completed Features: - Dynamic task decomposition and parallel execution - Agent registration, messaging, and coordination - 5 reasoning methods with validation and learning - Robust error handling and monitoring - Multi-tenant support and security - Production-ready architecture Files added/modified: - app/services/autonomous_workflow_engine.py - app/services/agent_communication.py - app/services/enhanced_reasoning.py - app/api/v1/endpoints/week5_features.py - tests/test_week5_features.py - docs/week5_api_documentation.md - docs/week5_readme.md - WEEK5_COMPLETION_SUMMARY.md - DEVELOPMENT_PLAN.md (updated) All tests passing: 26/26
430 lines
16 KiB
Python
430 lines
16 KiB
Python
"""
|
|
Multi-Agent Communication Protocol - Week 5 Implementation
|
|
Handles inter-agent messaging, coordination, and message queuing.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Dict, List, Optional, Callable, Coroutine
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
import uuid
|
|
from datetime import datetime
|
|
import json
|
|
from collections import defaultdict, deque
|
|
|
|
from app.services.agentic_rag_service import AgentType
|
|
from app.core.cache import cache_service
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MessageType(Enum):
|
|
"""Types of messages between agents."""
|
|
TASK_REQUEST = "task_request"
|
|
TASK_RESPONSE = "task_response"
|
|
DATA_SHARE = "data_share"
|
|
COORDINATION = "coordination"
|
|
STATUS_UPDATE = "status_update"
|
|
ERROR = "error"
|
|
HEARTBEAT = "heartbeat"
|
|
|
|
|
|
class MessagePriority(Enum):
|
|
"""Message priority levels."""
|
|
LOW = 1
|
|
NORMAL = 2
|
|
HIGH = 3
|
|
CRITICAL = 4
|
|
|
|
|
|
@dataclass
|
|
class AgentMessage:
|
|
"""Message structure for inter-agent communication."""
|
|
id: str
|
|
sender: str
|
|
recipient: str
|
|
message_type: MessageType
|
|
payload: Dict[str, Any]
|
|
priority: MessagePriority = MessagePriority.NORMAL
|
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
|
correlation_id: Optional[str] = None
|
|
reply_to: Optional[str] = None
|
|
ttl: int = 300 # Time to live in seconds
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class MessageQueue:
|
|
"""Message queue for an agent."""
|
|
agent_id: str
|
|
messages: deque = field(default_factory=deque)
|
|
max_size: int = 1000
|
|
processing: bool = False
|
|
|
|
|
|
class MessageBroker:
|
|
"""Central message broker for agent communication."""
|
|
|
|
def __init__(self):
|
|
self.queues: Dict[str, MessageQueue] = {}
|
|
self.subscribers: Dict[str, List[Callable]] = defaultdict(list)
|
|
self.message_history: List[AgentMessage] = []
|
|
self.max_history: int = 10000
|
|
self.processing_tasks: Dict[str, asyncio.Task] = {}
|
|
|
|
async def register_agent(self, agent_id: str) -> None:
|
|
"""Register an agent with the message broker."""
|
|
if agent_id not in self.queues:
|
|
self.queues[agent_id] = MessageQueue(agent_id=agent_id)
|
|
logger.info(f"Agent {agent_id} registered with message broker")
|
|
|
|
async def unregister_agent(self, agent_id: str) -> None:
|
|
"""Unregister an agent from the message broker."""
|
|
if agent_id in self.queues:
|
|
# Cancel any processing tasks
|
|
if agent_id in self.processing_tasks:
|
|
self.processing_tasks[agent_id].cancel()
|
|
del self.processing_tasks[agent_id]
|
|
|
|
del self.queues[agent_id]
|
|
logger.info(f"Agent {agent_id} unregistered from message broker")
|
|
|
|
async def send_message(self, message: AgentMessage) -> bool:
|
|
"""Send a message to a recipient agent."""
|
|
try:
|
|
# Validate message
|
|
if not message.recipient or not message.sender:
|
|
logger.error("Invalid message: missing sender or recipient")
|
|
return False
|
|
|
|
# Check if recipient exists
|
|
if message.recipient not in self.queues:
|
|
logger.warning(f"Recipient {message.recipient} not found, message dropped")
|
|
return False
|
|
|
|
# Add to recipient's queue
|
|
queue = self.queues[message.recipient]
|
|
if len(queue.messages) >= queue.max_size:
|
|
# Remove oldest low-priority message
|
|
self._remove_oldest_low_priority_message(queue)
|
|
|
|
queue.messages.append(message)
|
|
|
|
# Store in history
|
|
self.message_history.append(message)
|
|
if len(self.message_history) > self.max_history:
|
|
self.message_history.pop(0)
|
|
|
|
# Notify subscribers
|
|
await self._notify_subscribers(message)
|
|
|
|
logger.debug(f"Message {message.id} sent from {message.sender} to {message.recipient}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to send message: {e}")
|
|
return False
|
|
|
|
def _remove_oldest_low_priority_message(self, queue: MessageQueue) -> None:
|
|
"""Remove the oldest low-priority message from the queue."""
|
|
for i, msg in enumerate(queue.messages):
|
|
if msg.priority == MessagePriority.LOW:
|
|
queue.messages.remove(msg)
|
|
break
|
|
|
|
async def receive_message(self, agent_id: str, timeout: float = 1.0) -> Optional[AgentMessage]:
|
|
"""Receive a message for an agent."""
|
|
if agent_id not in self.queues:
|
|
return None
|
|
|
|
queue = self.queues[agent_id]
|
|
|
|
# Wait for message with timeout
|
|
start_time = datetime.utcnow()
|
|
while datetime.utcnow().timestamp() - start_time.timestamp() < timeout:
|
|
if queue.messages:
|
|
message = queue.messages.popleft()
|
|
|
|
# Check TTL
|
|
if (datetime.utcnow() - message.timestamp).total_seconds() > message.ttl:
|
|
logger.warning(f"Message {message.id} expired, skipping")
|
|
continue
|
|
|
|
return message
|
|
|
|
await asyncio.sleep(0.01)
|
|
|
|
return None
|
|
|
|
async def broadcast_message(self, sender: str, message_type: MessageType, payload: Dict[str, Any]) -> None:
|
|
"""Broadcast a message to all registered agents."""
|
|
for agent_id in self.queues.keys():
|
|
if agent_id != sender:
|
|
message = AgentMessage(
|
|
id=str(uuid.uuid4()),
|
|
sender=sender,
|
|
recipient=agent_id,
|
|
message_type=message_type,
|
|
payload=payload,
|
|
timestamp=datetime.utcnow()
|
|
)
|
|
await self.send_message(message)
|
|
|
|
async def subscribe(self, agent_id: str, callback: Callable[[AgentMessage], Coroutine[Any, Any, None]]) -> None:
|
|
"""Subscribe to messages for an agent."""
|
|
self.subscribers[agent_id].append(callback)
|
|
|
|
async def unsubscribe(self, agent_id: str, callback: Callable[[AgentMessage], Coroutine[Any, Any, None]]) -> None:
|
|
"""Unsubscribe from messages for an agent."""
|
|
if agent_id in self.subscribers and callback in self.subscribers[agent_id]:
|
|
self.subscribers[agent_id].remove(callback)
|
|
|
|
async def _notify_subscribers(self, message: AgentMessage) -> None:
|
|
"""Notify subscribers about a new message."""
|
|
callbacks = self.subscribers.get(message.recipient, [])
|
|
for callback in callbacks:
|
|
try:
|
|
await callback(message)
|
|
except Exception as e:
|
|
logger.error(f"Error in message subscriber callback: {e}")
|
|
|
|
async def get_queue_status(self, agent_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get status of an agent's message queue."""
|
|
if agent_id not in self.queues:
|
|
return None
|
|
|
|
queue = self.queues[agent_id]
|
|
return {
|
|
"agent_id": agent_id,
|
|
"queue_size": len(queue.messages),
|
|
"max_size": queue.max_size,
|
|
"processing": queue.processing
|
|
}
|
|
|
|
async def get_broker_status(self) -> Dict[str, Any]:
|
|
"""Get overall broker status."""
|
|
return {
|
|
"total_agents": len(self.queues),
|
|
"total_messages": sum(len(q.messages) for q in self.queues.values()),
|
|
"message_history_size": len(self.message_history),
|
|
"active_subscribers": sum(len(callbacks) for callbacks in self.subscribers.values())
|
|
}
|
|
|
|
|
|
class AgentCoordinator:
|
|
"""Coordinates agent activities and manages workflows."""
|
|
|
|
def __init__(self, message_broker: MessageBroker):
|
|
self.message_broker = message_broker
|
|
self.agent_registry: Dict[str, Dict[str, Any]] = {}
|
|
self.workflow_sessions: Dict[str, Dict[str, Any]] = {}
|
|
self.coordination_rules: Dict[str, Callable] = {}
|
|
|
|
async def register_agent(self, agent_id: str, agent_type: AgentType, capabilities: List[str]) -> None:
|
|
"""Register an agent with the coordinator."""
|
|
await self.message_broker.register_agent(agent_id)
|
|
|
|
self.agent_registry[agent_id] = {
|
|
"agent_type": agent_type,
|
|
"capabilities": capabilities,
|
|
"status": "active",
|
|
"last_heartbeat": datetime.utcnow(),
|
|
"workload": 0
|
|
}
|
|
|
|
logger.info(f"Agent {agent_id} registered with coordinator")
|
|
|
|
async def unregister_agent(self, agent_id: str) -> None:
|
|
"""Unregister an agent from the coordinator."""
|
|
await self.message_broker.unregister_agent(agent_id)
|
|
|
|
if agent_id in self.agent_registry:
|
|
del self.agent_registry[agent_id]
|
|
|
|
logger.info(f"Agent {agent_id} unregistered from coordinator")
|
|
|
|
async def coordinate_task(self, task_id: str, task_type: AgentType, requirements: Dict[str, Any]) -> str:
|
|
"""Coordinate task assignment to appropriate agents."""
|
|
# Find suitable agents
|
|
suitable_agents = []
|
|
for agent_id, agent_info in self.agent_registry.items():
|
|
if (agent_info["agent_type"] == task_type and
|
|
agent_info["status"] == "active" and
|
|
agent_info["workload"] < 10): # Max workload threshold
|
|
suitable_agents.append((agent_id, agent_info))
|
|
|
|
if not suitable_agents:
|
|
raise ValueError(f"No suitable agents found for task type {task_type}")
|
|
|
|
# Select agent with lowest workload
|
|
selected_agent_id = min(suitable_agents, key=lambda x: x[1]["workload"])[0]
|
|
|
|
# Update workload
|
|
self.agent_registry[selected_agent_id]["workload"] += 1
|
|
|
|
# Send task request
|
|
message = AgentMessage(
|
|
id=str(uuid.uuid4()),
|
|
sender="coordinator",
|
|
recipient=selected_agent_id,
|
|
message_type=MessageType.TASK_REQUEST,
|
|
payload={
|
|
"task_id": task_id,
|
|
"task_type": task_type.value,
|
|
"requirements": requirements
|
|
},
|
|
priority=MessagePriority.HIGH,
|
|
correlation_id=task_id
|
|
)
|
|
|
|
await self.message_broker.send_message(message)
|
|
return selected_agent_id
|
|
|
|
async def handle_task_response(self, message: AgentMessage) -> None:
|
|
"""Handle task response from an agent."""
|
|
task_id = message.payload.get("task_id")
|
|
agent_id = message.sender
|
|
|
|
if task_id and agent_id in self.agent_registry:
|
|
# Decrease workload
|
|
self.agent_registry[agent_id]["workload"] = max(0, self.agent_registry[agent_id]["workload"] - 1)
|
|
|
|
# Update last activity
|
|
self.agent_registry[agent_id]["last_heartbeat"] = datetime.utcnow()
|
|
|
|
logger.info(f"Task {task_id} completed by agent {agent_id}")
|
|
|
|
async def handle_heartbeat(self, message: AgentMessage) -> None:
|
|
"""Handle heartbeat from an agent."""
|
|
agent_id = message.sender
|
|
if agent_id in self.agent_registry:
|
|
self.agent_registry[agent_id]["last_heartbeat"] = datetime.utcnow()
|
|
self.agent_registry[agent_id]["status"] = "active"
|
|
|
|
async def check_agent_health(self) -> Dict[str, Any]:
|
|
"""Check health of all registered agents."""
|
|
health_status = {}
|
|
current_time = datetime.utcnow()
|
|
|
|
for agent_id, agent_info in self.agent_registry.items():
|
|
time_since_heartbeat = (current_time - agent_info["last_heartbeat"]).total_seconds()
|
|
|
|
if time_since_heartbeat > 60: # 60 seconds timeout
|
|
agent_info["status"] = "inactive"
|
|
health_status[agent_id] = {
|
|
"status": "inactive",
|
|
"last_heartbeat": agent_info["last_heartbeat"],
|
|
"time_since_heartbeat": time_since_heartbeat
|
|
}
|
|
else:
|
|
health_status[agent_id] = {
|
|
"status": "active",
|
|
"workload": agent_info["workload"],
|
|
"last_heartbeat": agent_info["last_heartbeat"]
|
|
}
|
|
|
|
return health_status
|
|
|
|
async def get_coordinator_status(self) -> Dict[str, Any]:
|
|
"""Get coordinator status."""
|
|
return {
|
|
"total_agents": len(self.agent_registry),
|
|
"active_agents": sum(1 for info in self.agent_registry.values() if info["status"] == "active"),
|
|
"total_workload": sum(info["workload"] for info in self.agent_registry.values()),
|
|
"agent_types": list(set(info["agent_type"].value for info in self.agent_registry.values()))
|
|
}
|
|
|
|
|
|
class AgentCommunicationManager:
|
|
"""Main manager for agent communication."""
|
|
|
|
def __init__(self):
|
|
self.message_broker = MessageBroker()
|
|
self.coordinator = AgentCoordinator(self.message_broker)
|
|
self.running = False
|
|
self.health_check_task: Optional[asyncio.Task] = None
|
|
|
|
async def start(self) -> None:
|
|
"""Start the communication manager."""
|
|
self.running = True
|
|
self.health_check_task = asyncio.create_task(self._health_check_loop())
|
|
logger.info("Agent communication manager started")
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the communication manager."""
|
|
self.running = False
|
|
if self.health_check_task:
|
|
self.health_check_task.cancel()
|
|
logger.info("Agent communication manager stopped")
|
|
|
|
async def clear_state(self) -> None:
|
|
"""Clear all state for testing."""
|
|
# Clear agent registry
|
|
self.coordinator.agent_registry.clear()
|
|
# Clear message broker queues
|
|
self.message_broker.queues.clear()
|
|
# Clear message history
|
|
self.message_broker.message_history.clear()
|
|
# Clear subscribers
|
|
self.message_broker.subscribers.clear()
|
|
# Clear processing tasks
|
|
self.message_broker.processing_tasks.clear()
|
|
logger.info("Agent communication manager state cleared")
|
|
|
|
async def _health_check_loop(self) -> None:
|
|
"""Periodic health check loop."""
|
|
while self.running:
|
|
try:
|
|
health_status = await self.coordinator.check_agent_health()
|
|
|
|
# Log inactive agents
|
|
inactive_agents = [agent_id for agent_id, status in health_status.items()
|
|
if status["status"] == "inactive"]
|
|
if inactive_agents:
|
|
logger.warning(f"Inactive agents detected: {inactive_agents}")
|
|
|
|
await asyncio.sleep(30) # Check every 30 seconds
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in health check loop: {e}")
|
|
await asyncio.sleep(30)
|
|
|
|
async def register_agent(self, agent_id: str, agent_type: AgentType, capabilities: List[str]) -> None:
|
|
"""Register an agent."""
|
|
await self.coordinator.register_agent(agent_id, agent_type, capabilities)
|
|
|
|
async def unregister_agent(self, agent_id: str) -> None:
|
|
"""Unregister an agent."""
|
|
await self.coordinator.unregister_agent(agent_id)
|
|
|
|
async def send_message(self, message: AgentMessage) -> bool:
|
|
"""Send a message."""
|
|
return await self.message_broker.send_message(message)
|
|
|
|
async def receive_message(self, agent_id: str, timeout: float = 1.0) -> Optional[AgentMessage]:
|
|
"""Receive a message for an agent."""
|
|
return await self.message_broker.receive_message(agent_id, timeout)
|
|
|
|
async def coordinate_task(self, task_id: str, task_type: AgentType, requirements: Dict[str, Any]) -> str:
|
|
"""Coordinate task assignment."""
|
|
return await self.coordinator.coordinate_task(task_id, task_type, requirements)
|
|
|
|
async def get_status(self) -> Dict[str, Any]:
|
|
"""Get communication manager status."""
|
|
broker_status = await self.message_broker.get_broker_status()
|
|
coordinator_status = await self.coordinator.get_coordinator_status()
|
|
|
|
return {
|
|
"broker": broker_status,
|
|
"coordinator": coordinator_status,
|
|
"running": self.running
|
|
}
|
|
|
|
|
|
# Global communication manager instance
|
|
agent_communication_manager = AgentCommunicationManager()
|