| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578 |
- """
- llm_client.py - LLM Planner 客户端适配层
- 该模块负责封装 LLM 调用,为 Planner 提供统一的 LLM 接口。
- 职责边界:
- - 接收 prompt,调用 LLM,返回 Plan JSON dict
- - 不执行 capability
- - 不依赖 ROS2
- - 不做 Prompt 构造(由 prompt_manager 负责)
- 架构说明:
- 本模块是纯 Python 接口层,不直接连接 OmniNode ROS2 节点。
- 真实的 OmniNode 通信应由 planner_node.py 完成:
- - planner_node.py: 负责与 OmniNode ROS2 节点通信,获取 world_snapshot
- - prompt_manager.py: 负责构造 prompt
- - llm_client.py: 仅负责 LLM 结果处理和抽象接口层(占位)
- 设计原则:
- - 可替换接口:当前为占位实现,未来可替换为 Omni HTTP/SDK
- - 异常安全:LLM 失败时返回安全的 fallback plan
- - JSON 输出:输出必须是能被 Plan.from_dict() 解析的 dict
- 使用方式:
- from llm_client import LLMPlannerClient
- client = LLMPlannerClient()
- plan_json = client.generate_plan_json(
- user_intent="降温",
- world_snapshot={"temperature": 32},
- available_tools=["adjust_fan", "speak"],
- domain_rules={"high_risk_actions": ["turn_off"]},
- )
- """
- from __future__ import annotations
- import json
- import re
- import time
- import uuid
- from typing import Any
- try:
- import yaml
- YAML_AVAILABLE = True
- except ImportError:
- YAML_AVAILABLE = False
- # =============================================================================
- # 异常定义
- # =============================================================================
- class LLMClientError(Exception):
- """LLM 客户端基础异常"""
- pass
- class LLMConnectionError(LLMClientError):
- """LLM 连接异常"""
- pass
- class LLMResponseParseError(LLMClientError):
- """LLM 响应解析异常"""
- pass
- class LLMTimeoutError(LLMClientError):
- """LLM 调用超时异常"""
- pass
- # =============================================================================
- # LLM Planner 客户端
- # =============================================================================
- class LLMPlannerClient:
- """LLM Planner 客户端
-
- 封装 LLM 调用,为 Planner 提供统一的 LLM 接口。
- 当前为占位实现,未来可替换为 Omni HTTP/SDK。
-
- 属性:
- api_endpoint: LLM API 端点
- model_name: 模型名称
- timeout: 超时时间(秒)
- max_retries: 最大重试次数
-
- 示例:
- >>> client = LLMPlannerClient()
- >>> plan_json = client.generate_plan_json(
- ... user_intent="降温",
- ... world_snapshot={"temperature": 32},
- ... available_tools=["adjust_fan"],
- ... )
- """
-
- # LLM 端点配置(占位,未来由配置注入)
- DEFAULT_API_ENDPOINT = "http://localhost:8000/omni/generate"
-
- def __init__(
- self,
- api_endpoint: str | None = None,
- model_name: str = "omni-planner-v1",
- timeout: float = 30.0,
- max_retries: int = 3,
- ):
- """初始化 LLM Planner 客户端
-
- Args:
- api_endpoint: API 端点,None 则使用默认端点
- model_name: 模型名称
- timeout: 超时时间(秒)
- max_retries: 最大重试次数
- """
- self.api_endpoint = api_endpoint or self.DEFAULT_API_ENDPOINT
- self.model_name = model_name
- self.timeout = timeout
- self.max_retries = max_retries
-
- def generate_plan_json(
- self,
- user_intent: str,
- world_snapshot: dict[str, Any],
- available_tools: list[str],
- domain_rules: dict[str, Any] | None = None,
- planner_config: dict[str, Any] | None = None,
- ) -> dict[str, Any]:
- """生成 Plan JSON
-
- 主要入口方法。根据用户意图和世界状态生成标准 Plan JSON。
-
- Args:
- user_intent: 用户意图描述
- world_snapshot: 世界状态快照
- available_tools: 可用能力列表
- domain_rules: 领域规则,None 则使用默认值
- planner_config: Planner 配置,None 则忽略
-
- Returns:
- Plan JSON dict,可被 Plan.from_dict() 解析
-
- Raises:
- LLMClientError: LLM 调用或解析失败
- """
- domain_rules = domain_rules or {}
- planner_config = planner_config or {}
-
- try:
- # 调用模型
- response_text = self._call_model(user_intent, world_snapshot, available_tools, domain_rules)
-
- # 提取 JSON
- plan_json = self._extract_json_from_response(response_text)
-
- # 基本校验
- self._validate_plan_json(plan_json)
-
- return plan_json
-
- except (LLMConnectionError, LLMResponseParseError, LLMTimeoutError):
- # LLM 相关异常,返回 fallback
- return self._build_safe_fallback_plan_json(
- user_intent=user_intent,
- reason="llm_error",
- )
- except Exception as e:
- # 其他异常,记录并返回 fallback
- return self._build_safe_fallback_plan_json(
- user_intent=user_intent,
- reason=f"unexpected_error: {e}",
- )
-
- def _call_model(
- self,
- user_intent: str,
- world_snapshot: dict[str, Any],
- available_tools: list[str],
- domain_rules: dict[str, Any],
- ) -> str:
- """调用 LLM 模型
-
- 统一接口封装。当前为占位实现,未来替换为真实 Omni 调用。
-
- Args:
- user_intent: 用户意图
- world_snapshot: 世界状态
- available_tools: 可用能力
- domain_rules: 领域规则
-
- Returns:
- 模型返回的文本
- """
- # 占位实现:使用 mock 响应
- # 未来替换为:
- # - Omni HTTP 调用
- # - Omni SDK 调用
- # - 其他 LLM API 调用
-
- # 导入 prompt_manager 以获取 prompt
- try:
- from prompt_manager import PlannerPromptManager
- prompt_manager = PlannerPromptManager()
- prompt = prompt_manager.build_plan_prompt(
- user_intent=user_intent,
- world_snapshot=world_snapshot,
- available_tools=available_tools,
- domain_rules=domain_rules,
- )
- except ImportError:
- # 如果 prompt_manager 不可用,生成简单 prompt
- prompt = self._generate_simple_prompt(
- user_intent, world_snapshot, available_tools, domain_rules
- )
-
- # 调用实际模型(占位)
- return self._mock_llm_call(prompt)
-
- def _mock_llm_call(self, prompt: str) -> str:
- """Mock LLM 调用
-
- 占位实现,用于测试和开发。
- 未来替换为真实 LLM 调用。
- 设计原则:
- - 不包含具体业务逻辑(农业/降温/喂食等)
- - 只根据测试标记返回不同类型的响应
- - 保持接口占位符的角色
- Args:
- prompt: 输入 prompt
-
- Returns:
- Mock 响应文本(JSON 格式)
- """
- # 通过测试标记区分返回类型
- if "__TEST_INVALID_JSON__" in prompt:
- return self._generate_invalid_json_response()
- elif "__TEST_ASK_USER__" in prompt:
- return self._generate_ask_user_response()
- else:
- return self._generate_valid_generic_response()
- def _generate_valid_generic_response(self) -> str:
- """生成通用合法 Plan JSON 响应"""
- plan = {
- "plan_id": f"plan_{uuid.uuid4().hex[:8]}",
- "goal": "处理用户请求",
- "reasoning": "根据用户意图生成执行计划",
- "risk_level": "low",
- "requires_confirmation": False,
- "confirmation_message": None,
- "steps": [
- {
- "step_id": 1,
- "action": "query",
- "tool_call_type": "query_world",
- "parameters": {},
- "preconditions": {},
- "fallback": None,
- "status": "pending",
- "description": "查询当前状态",
- "requires_confirmation": False,
- "confirmation_message": None,
- "metadata": {},
- }
- ],
- "status": "created",
- "source": "llm",
- "metadata": {},
- }
- return json.dumps(plan, ensure_ascii=False)
- def _generate_ask_user_response(self) -> str:
- """生成需要询问用户的 Plan JSON 响应"""
- plan = {
- "plan_id": f"plan_{uuid.uuid4().hex[:8]}",
- "goal": "需要用户确认",
- "reasoning": "该请求需要用户进一步确认",
- "risk_level": "medium",
- "requires_confirmation": False,
- "confirmation_message": "请确认您的意图",
- "steps": [
- {
- "step_id": 1,
- "action": "ask_user",
- "tool_call_type": "ask_user",
- "parameters": {"question": "请确认您的具体需求"},
- "preconditions": {},
- "fallback": None,
- "status": "pending",
- "description": "询问用户确认",
- "requires_confirmation": False,
- "confirmation_message": None,
- "metadata": {},
- }
- ],
- "status": "created",
- "source": "llm",
- "metadata": {},
- }
- return json.dumps(plan, ensure_ascii=False)
- def _generate_invalid_json_response(self) -> str:
- """生成无效的 JSON 响应(用于测试 fallback)"""
- return '{"plan_id": "broken", "goal": "incomplete json"'
-
- def _generate_simple_prompt(
- self,
- user_intent: str,
- world_snapshot: dict[str, Any],
- available_tools: list[str],
- domain_rules: dict[str, Any],
- ) -> str:
- """生成简单的 prompt(当 prompt_manager 不可用时)"""
- return f"""用户意图: {user_intent}
- 世界状态: {json.dumps(world_snapshot, ensure_ascii=False)}
- 可用能力: {', '.join(available_tools)}
- 领域规则: {json.dumps(domain_rules, ensure_ascii=False)}
- 请生成标准 Plan JSON。"""
-
- def _extract_json_from_response(self, response_text: str) -> dict[str, Any]:
- """从模型响应中提取 JSON
-
- 支持以下格式:
- - 纯 JSON
- - JSON 前后有解释文字
- - JSON 在 markdown 代码块中
-
- Args:
- response_text: 模型响应文本
-
- Returns:
- 提取的 JSON dict
-
- Raises:
- LLMResponseParseError: 解析失败
- """
- response_text = response_text.strip()
-
- # 尝试直接解析
- try:
- return json.loads(response_text)
- except json.JSONDecodeError:
- pass
-
- # 尝试从 markdown 代码块中提取
- json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
- matches = re.findall(json_pattern, response_text)
- for match in matches:
- try:
- return json.loads(match.strip())
- except json.JSONDecodeError:
- continue
-
- # 尝试提取任何 {...} 块
- brace_pattern = r'\{[\s\S]*\}'
- matches = re.findall(brace_pattern, response_text)
- if matches:
- # 尝试从后向前找第一个有效 JSON
- for i in range(len(matches) - 1, -1, -1):
- try:
- return json.loads(matches[i])
- except json.JSONDecodeError:
- continue
-
- raise LLMResponseParseError(
- f"无法从响应中提取 JSON:\n{response_text[:500]}..."
- )
-
- def _validate_plan_json(self, plan_json: dict[str, Any]) -> None:
- """校验 Plan JSON 基本结构
-
- Args:
- plan_json: Plan JSON dict
-
- Raises:
- LLMResponseParseError: 校验失败
- """
- required_fields = ["goal", "steps"]
- for field in required_fields:
- if field not in plan_json:
- raise LLMResponseParseError(f"缺少必需字段: {field}")
-
- if not isinstance(plan_json["steps"], list):
- raise LLMResponseParseError("steps 必须是列表")
-
- # 校验每个步骤
- for i, step in enumerate(plan_json["steps"]):
- if "action" not in step:
- raise LLMResponseParseError(f"步骤 {i} 缺少 action 字段")
-
- def _build_safe_fallback_plan_json(
- self,
- user_intent: str,
- reason: str = "",
- ) -> dict[str, Any]:
- """构建安全的 fallback Plan JSON
-
- 当 LLM 调用失败时,返回一个安全的 fallback plan。
- 优先使用 ASK_USER 或 SPEAK,不直接执行风险动作。
-
- Args:
- user_intent: 用户原始意图
- reason: 回退原因
-
- Returns:
- 安全的 Plan JSON
- """
- return {
- "plan_id": f"plan_fallback_{uuid.uuid4().hex[:8]}",
- "goal": f"无法处理的请求: {user_intent}",
- "reasoning": f"LLM 处理失败或响应无效 ({reason}),返回安全 fallback",
- "risk_level": "low",
- "requires_confirmation": False,
- "confirmation_message": None,
- "steps": [
- {
- "step_id": 1,
- "action": "ask_user",
- "tool_call_type": "ask_user",
- "parameters": {
- "question": f"抱歉,我无法理解或处理您的请求: {user_intent}。请重新描述。"
- },
- "preconditions": {},
- "fallback": None,
- "status": "pending",
- "description": "询问用户澄清意图",
- "requires_confirmation": False,
- "confirmation_message": None,
- "metadata": {"fallback_reason": reason},
- }
- ],
- "status": "created",
- "source": "llm_fallback",
- "metadata": {"fallback": True, "reason": reason},
- }
- # =============================================================================
- # 同步调用接口(供 Planner 使用)
- # =============================================================================
- def call_llm_for_plan(
- user_intent: str,
- world_snapshot: dict[str, Any],
- available_tools: list[str],
- domain_rules: dict[str, Any] | None = None,
- planner_config: dict[str, Any] | None = None,
- **kwargs,
- ) -> dict[str, Any]:
- """便捷函数:调用 LLM 生成 Plan JSON
-
- 创建一个临时客户端并调用 generate_plan_json。
-
- Args:
- user_intent: 用户意图
- world_snapshot: 世界状态
- available_tools: 可用能力
- domain_rules: 领域规则
- planner_config: Planner 配置
- **kwargs: 传递给 LLMPlannerClient 的其他参数
-
- Returns:
- Plan JSON dict
- """
- client = LLMPlannerClient(**kwargs)
- return client.generate_plan_json(
- user_intent=user_intent,
- world_snapshot=world_snapshot,
- available_tools=available_tools,
- domain_rules=domain_rules,
- planner_config=planner_config,
- )
- # =============================================================================
- # 主程序入口(测试示例)
- # =============================================================================
- if __name__ == "__main__":
- print("=" * 70)
- print("LLM Planner Client 测试")
- print("=" * 70)
- client = LLMPlannerClient()
- # 场景 1:通用请求(默认返回合法 Plan)
- print("\n[场景 1] 通用请求")
- print("-" * 40)
- plan_json = client.generate_plan_json(
- user_intent="打开风扇降温",
- world_snapshot={"temperature": 32, "humidity": 70},
- available_tools=["adjust_fan", "speak", "query"],
- domain_rules={"high_risk_actions": ["turn_off"]},
- )
- print(f"Goal: {plan_json.get('goal')}")
- print(f"Steps: {len(plan_json.get('steps', []))}")
- print(f"Action: {plan_json['steps'][0]['action'] if plan_json.get('steps') else 'N/A'}")
- # 场景 2:测试 ASK_USER 响应
- print("\n[场景 2] 测试 ASK_USER 响应")
- print("-" * 40)
- plan_json = client.generate_plan_json(
- user_intent="确认操作 __TEST_ASK_USER__",
- world_snapshot={"temperature": 28},
- available_tools=["feed", "speak", "query"],
- )
- print(f"Goal: {plan_json.get('goal')}")
- print(f"Tool Type: {plan_json['steps'][0]['tool_call_type'] if plan_json.get('steps') else 'N/A'}")
- # 场景 3:测试 Invalid JSON fallback
- print("\n[场景 3] 测试 Invalid JSON fallback")
- print("-" * 40)
- plan_json = client.generate_plan_json(
- user_intent="测试 __TEST_INVALID_JSON__",
- world_snapshot={"temperature": 30},
- available_tools=["adjust_fan", "speak"],
- )
- print(f"Goal: {plan_json.get('goal')}")
- print(f"Source: {plan_json.get('source')}")
- print(f"Action: {plan_json['steps'][0]['action'] if plan_json.get('steps') else 'N/A'}")
- # 场景 4:测试 JSON 提取
- print("\n[场景 4] JSON 提取测试")
- print("-" * 40)
- test_response = '根据您的请求,我生成以下计划:{"plan_id": "plan_test123", "goal": "测试计划", "reasoning": "这是一个测试", "risk_level": "low", "requires_confirmation": false, "confirmation_message": null, "steps": [{"step_id": 1, "action": "query", "tool_call_type": "query_world", "parameters": {}, "preconditions": {}, "fallback": null, "status": "pending", "description": "查询状态", "requires_confirmation": false, "confirmation_message": null, "metadata": {}}], "status": "created", "source": "llm", "metadata": {}}这是标准格式的 Plan。'
- try:
- extracted = client._extract_json_from_response(test_response)
- print(f"成功提取 JSON: plan_id={extracted.get('plan_id')}")
- except LLMResponseParseError as e:
- print(f"提取失败: {e}")
- # 场景 5:测试 markdown 代码块提取
- print("\n[场景 5] Markdown 代码块提取测试")
- print("-" * 40)
- test_response_markdown = '''这是响应:
- ```json
- {
- "plan_id": "plan_markdown",
- "goal": "测试 markdown",
- "reasoning": "测试",
- "risk_level": "low",
- "requires_confirmation": false,
- "confirmation_message": null,
- "steps": [
- {
- "step_id": 1,
- "action": "query",
- "tool_call_type": "query_world",
- "parameters": {},
- "preconditions": {},
- "fallback": null,
- "status": "pending",
- "description": "查询",
- "requires_confirmation": false,
- "confirmation_message": null,
- "metadata": {}
- }
- ],
- "status": "created",
- "source": "llm",
- "metadata": {}
- }
- ```'''
- try:
- extracted = client._extract_json_from_response(test_response_markdown)
- print(f"成功提取 JSON: plan_id={extracted.get('plan_id')}")
- except LLMResponseParseError as e:
- print(f"提取失败: {e}")
- print("\n" + "=" * 70)
- print("LLM Planner Client 测试完成")
- print("=" * 70)
|