Initial commit to git.yoin
This commit is contained in:
225
uni-agent/adapters/a2a.py
Normal file
225
uni-agent/adapters/a2a.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
A2A (Agent-to-Agent) 适配器
|
||||
Google 提出的 Agent 间协作协议
|
||||
|
||||
参考: https://github.com/google/a2a
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .base import ProtocolAdapter, Connection, AgentInfo
|
||||
|
||||
|
||||
class A2AAdapter(ProtocolAdapter):
|
||||
"""A2A 协议适配器"""
|
||||
|
||||
protocol_name = "a2a"
|
||||
|
||||
def __init__(self, config_dir: Optional[Path] = None):
|
||||
self.config_dir = config_dir or Path(__file__).parent.parent / "config"
|
||||
self._agent_cards: Dict[str, dict] = {}
|
||||
|
||||
async def _fetch_agent_card(self, endpoint: str) -> dict:
|
||||
"""获取 Agent Card"""
|
||||
if endpoint in self._agent_cards:
|
||||
return self._agent_cards[endpoint]
|
||||
|
||||
agent_json_url = endpoint.rstrip("/")
|
||||
if not agent_json_url.endswith("agent.json"):
|
||||
agent_json_url = f"{agent_json_url}/.well-known/agent.json"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(agent_json_url, timeout=aiohttp.ClientTimeout(total=15)) as resp:
|
||||
if resp.status == 200:
|
||||
card = await resp.json()
|
||||
self._agent_cards[endpoint] = card
|
||||
return card
|
||||
raise Exception(f"获取 Agent Card 失败: HTTP {resp.status}")
|
||||
|
||||
async def connect(self, agent_config: dict) -> Connection:
|
||||
"""建立连接"""
|
||||
endpoint = agent_config.get("endpoint")
|
||||
if not endpoint:
|
||||
raise ValueError("A2A Agent 配置必须包含 endpoint")
|
||||
|
||||
agent_card = await self._fetch_agent_card(endpoint)
|
||||
|
||||
rpc_url = None
|
||||
if "url" in agent_card:
|
||||
rpc_url = agent_card["url"]
|
||||
elif "capabilities" in agent_card:
|
||||
caps = agent_card.get("capabilities", {})
|
||||
if "streaming" in caps:
|
||||
rpc_url = caps.get("streaming", {}).get("streamingUrl")
|
||||
|
||||
if not rpc_url:
|
||||
rpc_url = endpoint.rstrip("/") + "/rpc"
|
||||
|
||||
return Connection(
|
||||
agent_id=agent_config.get("id", ""),
|
||||
protocol=self.protocol_name,
|
||||
endpoint=rpc_url,
|
||||
session=None,
|
||||
metadata={
|
||||
"agent_card": agent_card,
|
||||
"original_endpoint": endpoint,
|
||||
"auth": agent_config.get("auth", {}),
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
connection: Connection,
|
||||
method: str,
|
||||
params: dict,
|
||||
timeout: float = 30.0
|
||||
) -> dict:
|
||||
"""调用 A2A Agent 方法"""
|
||||
rpc_url = connection.endpoint
|
||||
auth_config = connection.metadata.get("auth", {})
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if auth_config.get("type") == "api_key":
|
||||
headers["Authorization"] = f"Bearer {auth_config.get('api_key', '')}"
|
||||
elif auth_config.get("type") == "oauth2":
|
||||
token = await self._get_oauth_token(auth_config)
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
if method == "tasks/send":
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": task_id,
|
||||
"method": "tasks/send",
|
||||
"params": {
|
||||
"id": task_id,
|
||||
"message": params.get("message", {}),
|
||||
}
|
||||
}
|
||||
elif method == "tasks/get":
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": task_id,
|
||||
"method": "tasks/get",
|
||||
"params": {
|
||||
"id": params.get("task_id", task_id),
|
||||
}
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": task_id,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
rpc_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
result = await resp.json()
|
||||
return {
|
||||
"success": True,
|
||||
"result": result.get("result", result),
|
||||
"task_id": task_id,
|
||||
}
|
||||
else:
|
||||
error_text = await resp.text()
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"HTTP {resp.status}: {error_text}",
|
||||
}
|
||||
|
||||
async def _get_oauth_token(self, auth_config: dict) -> str:
|
||||
"""获取 OAuth2 令牌"""
|
||||
token_url = auth_config.get("token_url")
|
||||
client_id = auth_config.get("client_id")
|
||||
client_secret = auth_config.get("client_secret")
|
||||
|
||||
if not all([token_url, client_id, client_secret]):
|
||||
raise ValueError("OAuth2 配置不完整")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
token_url,
|
||||
data={
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
}
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
result = await resp.json()
|
||||
return result.get("access_token", "")
|
||||
raise Exception(f"获取 OAuth2 令牌失败: HTTP {resp.status}")
|
||||
|
||||
async def discover(self, capability: str = "") -> List[AgentInfo]:
|
||||
"""发现 Agent"""
|
||||
agents_file = self.config_dir / "agents.yaml"
|
||||
if not agents_file.exists():
|
||||
return []
|
||||
|
||||
import yaml
|
||||
with open(agents_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
agents = []
|
||||
for agent in config.get("agents", []):
|
||||
if agent.get("protocol") != "a2a":
|
||||
continue
|
||||
|
||||
if capability and capability.lower() not in agent.get("id", "").lower():
|
||||
continue
|
||||
|
||||
agents.append(AgentInfo(
|
||||
id=f"{agent['id']}@a2a",
|
||||
protocol="a2a",
|
||||
name=agent.get("name", agent["id"]),
|
||||
endpoint=agent.get("endpoint", ""),
|
||||
metadata=agent
|
||||
))
|
||||
|
||||
return agents
|
||||
|
||||
async def close(self, connection: Connection):
|
||||
"""关闭连接"""
|
||||
pass
|
||||
|
||||
async def get_methods(self, connection: Connection) -> List[dict]:
|
||||
"""获取 Agent 支持的方法(从 Agent Card 的 skills)"""
|
||||
agent_card = connection.metadata.get("agent_card", {})
|
||||
skills = agent_card.get("skills", [])
|
||||
|
||||
methods = []
|
||||
for skill in skills:
|
||||
methods.append({
|
||||
"name": skill.get("id", skill.get("name", "unknown")),
|
||||
"description": skill.get("description", ""),
|
||||
"inputSchema": skill.get("inputSchema", {}),
|
||||
"outputSchema": skill.get("outputSchema", {}),
|
||||
})
|
||||
|
||||
methods.extend([
|
||||
{"name": "tasks/send", "description": "发送任务消息"},
|
||||
{"name": "tasks/get", "description": "获取任务状态"},
|
||||
{"name": "tasks/cancel", "description": "取消任务"},
|
||||
])
|
||||
|
||||
return methods
|
||||
|
||||
def validate_config(self, agent_config: dict) -> bool:
|
||||
"""验证配置"""
|
||||
return "endpoint" in agent_config
|
||||
Reference in New Issue
Block a user