226 lines
7.6 KiB
Python
226 lines
7.6 KiB
Python
"""
|
||
A2A (Agent-to-Agent) 适配器
|
||
Google 提出的 Agent 间协作协议
|
||
|
||
参考: https://github.com/google/a2a
|
||
"""
|
||
|
||
import json
|
||
import uuid
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import aiohttp
|
||
|
||
from .base import ProtocolAdapter, Connection, AgentInfo
|
||
|
||
|
||
class A2AAdapter(ProtocolAdapter):
|
||
"""A2A 协议适配器"""
|
||
|
||
protocol_name = "a2a"
|
||
|
||
def __init__(self, config_dir: Optional[Path] = None):
|
||
self.config_dir = config_dir or Path(__file__).parent.parent / "config"
|
||
self._agent_cards: Dict[str, dict] = {}
|
||
|
||
async def _fetch_agent_card(self, endpoint: str) -> dict:
|
||
"""获取 Agent Card"""
|
||
if endpoint in self._agent_cards:
|
||
return self._agent_cards[endpoint]
|
||
|
||
agent_json_url = endpoint.rstrip("/")
|
||
if not agent_json_url.endswith("agent.json"):
|
||
agent_json_url = f"{agent_json_url}/.well-known/agent.json"
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(agent_json_url, timeout=aiohttp.ClientTimeout(total=15)) as resp:
|
||
if resp.status == 200:
|
||
card = await resp.json()
|
||
self._agent_cards[endpoint] = card
|
||
return card
|
||
raise Exception(f"获取 Agent Card 失败: HTTP {resp.status}")
|
||
|
||
async def connect(self, agent_config: dict) -> Connection:
|
||
"""建立连接"""
|
||
endpoint = agent_config.get("endpoint")
|
||
if not endpoint:
|
||
raise ValueError("A2A Agent 配置必须包含 endpoint")
|
||
|
||
agent_card = await self._fetch_agent_card(endpoint)
|
||
|
||
rpc_url = None
|
||
if "url" in agent_card:
|
||
rpc_url = agent_card["url"]
|
||
elif "capabilities" in agent_card:
|
||
caps = agent_card.get("capabilities", {})
|
||
if "streaming" in caps:
|
||
rpc_url = caps.get("streaming", {}).get("streamingUrl")
|
||
|
||
if not rpc_url:
|
||
rpc_url = endpoint.rstrip("/") + "/rpc"
|
||
|
||
return Connection(
|
||
agent_id=agent_config.get("id", ""),
|
||
protocol=self.protocol_name,
|
||
endpoint=rpc_url,
|
||
session=None,
|
||
metadata={
|
||
"agent_card": agent_card,
|
||
"original_endpoint": endpoint,
|
||
"auth": agent_config.get("auth", {}),
|
||
}
|
||
)
|
||
|
||
async def call(
|
||
self,
|
||
connection: Connection,
|
||
method: str,
|
||
params: dict,
|
||
timeout: float = 30.0
|
||
) -> dict:
|
||
"""调用 A2A Agent 方法"""
|
||
rpc_url = connection.endpoint
|
||
auth_config = connection.metadata.get("auth", {})
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
if auth_config.get("type") == "api_key":
|
||
headers["Authorization"] = f"Bearer {auth_config.get('api_key', '')}"
|
||
elif auth_config.get("type") == "oauth2":
|
||
token = await self._get_oauth_token(auth_config)
|
||
headers["Authorization"] = f"Bearer {token}"
|
||
|
||
task_id = str(uuid.uuid4())
|
||
|
||
if method == "tasks/send":
|
||
payload = {
|
||
"jsonrpc": "2.0",
|
||
"id": task_id,
|
||
"method": "tasks/send",
|
||
"params": {
|
||
"id": task_id,
|
||
"message": params.get("message", {}),
|
||
}
|
||
}
|
||
elif method == "tasks/get":
|
||
payload = {
|
||
"jsonrpc": "2.0",
|
||
"id": task_id,
|
||
"method": "tasks/get",
|
||
"params": {
|
||
"id": params.get("task_id", task_id),
|
||
}
|
||
}
|
||
else:
|
||
payload = {
|
||
"jsonrpc": "2.0",
|
||
"id": task_id,
|
||
"method": method,
|
||
"params": params,
|
||
}
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
rpc_url,
|
||
json=payload,
|
||
headers=headers,
|
||
timeout=aiohttp.ClientTimeout(total=timeout)
|
||
) as resp:
|
||
if resp.status == 200:
|
||
result = await resp.json()
|
||
return {
|
||
"success": True,
|
||
"result": result.get("result", result),
|
||
"task_id": task_id,
|
||
}
|
||
else:
|
||
error_text = await resp.text()
|
||
return {
|
||
"success": False,
|
||
"error": f"HTTP {resp.status}: {error_text}",
|
||
}
|
||
|
||
async def _get_oauth_token(self, auth_config: dict) -> str:
|
||
"""获取 OAuth2 令牌"""
|
||
token_url = auth_config.get("token_url")
|
||
client_id = auth_config.get("client_id")
|
||
client_secret = auth_config.get("client_secret")
|
||
|
||
if not all([token_url, client_id, client_secret]):
|
||
raise ValueError("OAuth2 配置不完整")
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
token_url,
|
||
data={
|
||
"grant_type": "client_credentials",
|
||
"client_id": client_id,
|
||
"client_secret": client_secret,
|
||
}
|
||
) as resp:
|
||
if resp.status == 200:
|
||
result = await resp.json()
|
||
return result.get("access_token", "")
|
||
raise Exception(f"获取 OAuth2 令牌失败: HTTP {resp.status}")
|
||
|
||
async def discover(self, capability: str = "") -> List[AgentInfo]:
|
||
"""发现 Agent"""
|
||
agents_file = self.config_dir / "agents.yaml"
|
||
if not agents_file.exists():
|
||
return []
|
||
|
||
import yaml
|
||
with open(agents_file) as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
agents = []
|
||
for agent in config.get("agents", []):
|
||
if agent.get("protocol") != "a2a":
|
||
continue
|
||
|
||
if capability and capability.lower() not in agent.get("id", "").lower():
|
||
continue
|
||
|
||
agents.append(AgentInfo(
|
||
id=f"{agent['id']}@a2a",
|
||
protocol="a2a",
|
||
name=agent.get("name", agent["id"]),
|
||
endpoint=agent.get("endpoint", ""),
|
||
metadata=agent
|
||
))
|
||
|
||
return agents
|
||
|
||
async def close(self, connection: Connection):
|
||
"""关闭连接"""
|
||
pass
|
||
|
||
async def get_methods(self, connection: Connection) -> List[dict]:
|
||
"""获取 Agent 支持的方法(从 Agent Card 的 skills)"""
|
||
agent_card = connection.metadata.get("agent_card", {})
|
||
skills = agent_card.get("skills", [])
|
||
|
||
methods = []
|
||
for skill in skills:
|
||
methods.append({
|
||
"name": skill.get("id", skill.get("name", "unknown")),
|
||
"description": skill.get("description", ""),
|
||
"inputSchema": skill.get("inputSchema", {}),
|
||
"outputSchema": skill.get("outputSchema", {}),
|
||
})
|
||
|
||
methods.extend([
|
||
{"name": "tasks/send", "description": "发送任务消息"},
|
||
{"name": "tasks/get", "description": "获取任务状态"},
|
||
{"name": "tasks/cancel", "description": "取消任务"},
|
||
])
|
||
|
||
return methods
|
||
|
||
def validate_config(self, agent_config: dict) -> bool:
|
||
"""验证配置"""
|
||
return "endpoint" in agent_config
|