llm_client.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578
  1. """
  2. llm_client.py - LLM Planner 客户端适配层
  3. 该模块负责封装 LLM 调用,为 Planner 提供统一的 LLM 接口。
  4. 职责边界:
  5. - 接收 prompt,调用 LLM,返回 Plan JSON dict
  6. - 不执行 capability
  7. - 不依赖 ROS2
  8. - 不做 Prompt 构造(由 prompt_manager 负责)
  9. 架构说明:
  10. 本模块是纯 Python 接口层,不直接连接 OmniNode ROS2 节点。
  11. 真实的 OmniNode 通信应由 planner_node.py 完成:
  12. - planner_node.py: 负责与 OmniNode ROS2 节点通信,获取 world_snapshot
  13. - prompt_manager.py: 负责构造 prompt
  14. - llm_client.py: 仅负责 LLM 结果处理和抽象接口层(占位)
  15. 设计原则:
  16. - 可替换接口:当前为占位实现,未来可替换为 Omni HTTP/SDK
  17. - 异常安全:LLM 失败时返回安全的 fallback plan
  18. - JSON 输出:输出必须是能被 Plan.from_dict() 解析的 dict
  19. 使用方式:
  20. from llm_client import LLMPlannerClient
  21. client = LLMPlannerClient()
  22. plan_json = client.generate_plan_json(
  23. user_intent="降温",
  24. world_snapshot={"temperature": 32},
  25. available_tools=["adjust_fan", "speak"],
  26. domain_rules={"high_risk_actions": ["turn_off"]},
  27. )
  28. """
  29. from __future__ import annotations
  30. import json
  31. import re
  32. import time
  33. import uuid
  34. from typing import Any
  35. try:
  36. import yaml
  37. YAML_AVAILABLE = True
  38. except ImportError:
  39. YAML_AVAILABLE = False
  40. # =============================================================================
  41. # 异常定义
  42. # =============================================================================
  43. class LLMClientError(Exception):
  44. """LLM 客户端基础异常"""
  45. pass
  46. class LLMConnectionError(LLMClientError):
  47. """LLM 连接异常"""
  48. pass
  49. class LLMResponseParseError(LLMClientError):
  50. """LLM 响应解析异常"""
  51. pass
  52. class LLMTimeoutError(LLMClientError):
  53. """LLM 调用超时异常"""
  54. pass
  55. # =============================================================================
  56. # LLM Planner 客户端
  57. # =============================================================================
  58. class LLMPlannerClient:
  59. """LLM Planner 客户端
  60. 封装 LLM 调用,为 Planner 提供统一的 LLM 接口。
  61. 当前为占位实现,未来可替换为 Omni HTTP/SDK。
  62. 属性:
  63. api_endpoint: LLM API 端点
  64. model_name: 模型名称
  65. timeout: 超时时间(秒)
  66. max_retries: 最大重试次数
  67. 示例:
  68. >>> client = LLMPlannerClient()
  69. >>> plan_json = client.generate_plan_json(
  70. ... user_intent="降温",
  71. ... world_snapshot={"temperature": 32},
  72. ... available_tools=["adjust_fan"],
  73. ... )
  74. """
  75. # LLM 端点配置(占位,未来由配置注入)
  76. DEFAULT_API_ENDPOINT = "http://localhost:8000/omni/generate"
  77. def __init__(
  78. self,
  79. api_endpoint: str | None = None,
  80. model_name: str = "omni-planner-v1",
  81. timeout: float = 30.0,
  82. max_retries: int = 3,
  83. ):
  84. """初始化 LLM Planner 客户端
  85. Args:
  86. api_endpoint: API 端点,None 则使用默认端点
  87. model_name: 模型名称
  88. timeout: 超时时间(秒)
  89. max_retries: 最大重试次数
  90. """
  91. self.api_endpoint = api_endpoint or self.DEFAULT_API_ENDPOINT
  92. self.model_name = model_name
  93. self.timeout = timeout
  94. self.max_retries = max_retries
  95. def generate_plan_json(
  96. self,
  97. user_intent: str,
  98. world_snapshot: dict[str, Any],
  99. available_tools: list[str],
  100. domain_rules: dict[str, Any] | None = None,
  101. planner_config: dict[str, Any] | None = None,
  102. ) -> dict[str, Any]:
  103. """生成 Plan JSON
  104. 主要入口方法。根据用户意图和世界状态生成标准 Plan JSON。
  105. Args:
  106. user_intent: 用户意图描述
  107. world_snapshot: 世界状态快照
  108. available_tools: 可用能力列表
  109. domain_rules: 领域规则,None 则使用默认值
  110. planner_config: Planner 配置,None 则忽略
  111. Returns:
  112. Plan JSON dict,可被 Plan.from_dict() 解析
  113. Raises:
  114. LLMClientError: LLM 调用或解析失败
  115. """
  116. domain_rules = domain_rules or {}
  117. planner_config = planner_config or {}
  118. try:
  119. # 调用模型
  120. response_text = self._call_model(user_intent, world_snapshot, available_tools, domain_rules)
  121. # 提取 JSON
  122. plan_json = self._extract_json_from_response(response_text)
  123. # 基本校验
  124. self._validate_plan_json(plan_json)
  125. return plan_json
  126. except (LLMConnectionError, LLMResponseParseError, LLMTimeoutError):
  127. # LLM 相关异常,返回 fallback
  128. return self._build_safe_fallback_plan_json(
  129. user_intent=user_intent,
  130. reason="llm_error",
  131. )
  132. except Exception as e:
  133. # 其他异常,记录并返回 fallback
  134. return self._build_safe_fallback_plan_json(
  135. user_intent=user_intent,
  136. reason=f"unexpected_error: {e}",
  137. )
  138. def _call_model(
  139. self,
  140. user_intent: str,
  141. world_snapshot: dict[str, Any],
  142. available_tools: list[str],
  143. domain_rules: dict[str, Any],
  144. ) -> str:
  145. """调用 LLM 模型
  146. 统一接口封装。当前为占位实现,未来替换为真实 Omni 调用。
  147. Args:
  148. user_intent: 用户意图
  149. world_snapshot: 世界状态
  150. available_tools: 可用能力
  151. domain_rules: 领域规则
  152. Returns:
  153. 模型返回的文本
  154. """
  155. # 占位实现:使用 mock 响应
  156. # 未来替换为:
  157. # - Omni HTTP 调用
  158. # - Omni SDK 调用
  159. # - 其他 LLM API 调用
  160. # 导入 prompt_manager 以获取 prompt
  161. try:
  162. from prompt_manager import PlannerPromptManager
  163. prompt_manager = PlannerPromptManager()
  164. prompt = prompt_manager.build_plan_prompt(
  165. user_intent=user_intent,
  166. world_snapshot=world_snapshot,
  167. available_tools=available_tools,
  168. domain_rules=domain_rules,
  169. )
  170. except ImportError:
  171. # 如果 prompt_manager 不可用,生成简单 prompt
  172. prompt = self._generate_simple_prompt(
  173. user_intent, world_snapshot, available_tools, domain_rules
  174. )
  175. # 调用实际模型(占位)
  176. return self._mock_llm_call(prompt)
  177. def _mock_llm_call(self, prompt: str) -> str:
  178. """Mock LLM 调用
  179. 占位实现,用于测试和开发。
  180. 未来替换为真实 LLM 调用。
  181. 设计原则:
  182. - 不包含具体业务逻辑(农业/降温/喂食等)
  183. - 只根据测试标记返回不同类型的响应
  184. - 保持接口占位符的角色
  185. Args:
  186. prompt: 输入 prompt
  187. Returns:
  188. Mock 响应文本(JSON 格式)
  189. """
  190. # 通过测试标记区分返回类型
  191. if "__TEST_INVALID_JSON__" in prompt:
  192. return self._generate_invalid_json_response()
  193. elif "__TEST_ASK_USER__" in prompt:
  194. return self._generate_ask_user_response()
  195. else:
  196. return self._generate_valid_generic_response()
  197. def _generate_valid_generic_response(self) -> str:
  198. """生成通用合法 Plan JSON 响应"""
  199. plan = {
  200. "plan_id": f"plan_{uuid.uuid4().hex[:8]}",
  201. "goal": "处理用户请求",
  202. "reasoning": "根据用户意图生成执行计划",
  203. "risk_level": "low",
  204. "requires_confirmation": False,
  205. "confirmation_message": None,
  206. "steps": [
  207. {
  208. "step_id": 1,
  209. "action": "query",
  210. "tool_call_type": "query_world",
  211. "parameters": {},
  212. "preconditions": {},
  213. "fallback": None,
  214. "status": "pending",
  215. "description": "查询当前状态",
  216. "requires_confirmation": False,
  217. "confirmation_message": None,
  218. "metadata": {},
  219. }
  220. ],
  221. "status": "created",
  222. "source": "llm",
  223. "metadata": {},
  224. }
  225. return json.dumps(plan, ensure_ascii=False)
  226. def _generate_ask_user_response(self) -> str:
  227. """生成需要询问用户的 Plan JSON 响应"""
  228. plan = {
  229. "plan_id": f"plan_{uuid.uuid4().hex[:8]}",
  230. "goal": "需要用户确认",
  231. "reasoning": "该请求需要用户进一步确认",
  232. "risk_level": "medium",
  233. "requires_confirmation": False,
  234. "confirmation_message": "请确认您的意图",
  235. "steps": [
  236. {
  237. "step_id": 1,
  238. "action": "ask_user",
  239. "tool_call_type": "ask_user",
  240. "parameters": {"question": "请确认您的具体需求"},
  241. "preconditions": {},
  242. "fallback": None,
  243. "status": "pending",
  244. "description": "询问用户确认",
  245. "requires_confirmation": False,
  246. "confirmation_message": None,
  247. "metadata": {},
  248. }
  249. ],
  250. "status": "created",
  251. "source": "llm",
  252. "metadata": {},
  253. }
  254. return json.dumps(plan, ensure_ascii=False)
  255. def _generate_invalid_json_response(self) -> str:
  256. """生成无效的 JSON 响应(用于测试 fallback)"""
  257. return '{"plan_id": "broken", "goal": "incomplete json"'
  258. def _generate_simple_prompt(
  259. self,
  260. user_intent: str,
  261. world_snapshot: dict[str, Any],
  262. available_tools: list[str],
  263. domain_rules: dict[str, Any],
  264. ) -> str:
  265. """生成简单的 prompt(当 prompt_manager 不可用时)"""
  266. return f"""用户意图: {user_intent}
  267. 世界状态: {json.dumps(world_snapshot, ensure_ascii=False)}
  268. 可用能力: {', '.join(available_tools)}
  269. 领域规则: {json.dumps(domain_rules, ensure_ascii=False)}
  270. 请生成标准 Plan JSON。"""
  271. def _extract_json_from_response(self, response_text: str) -> dict[str, Any]:
  272. """从模型响应中提取 JSON
  273. 支持以下格式:
  274. - 纯 JSON
  275. - JSON 前后有解释文字
  276. - JSON 在 markdown 代码块中
  277. Args:
  278. response_text: 模型响应文本
  279. Returns:
  280. 提取的 JSON dict
  281. Raises:
  282. LLMResponseParseError: 解析失败
  283. """
  284. response_text = response_text.strip()
  285. # 尝试直接解析
  286. try:
  287. return json.loads(response_text)
  288. except json.JSONDecodeError:
  289. pass
  290. # 尝试从 markdown 代码块中提取
  291. json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
  292. matches = re.findall(json_pattern, response_text)
  293. for match in matches:
  294. try:
  295. return json.loads(match.strip())
  296. except json.JSONDecodeError:
  297. continue
  298. # 尝试提取任何 {...} 块
  299. brace_pattern = r'\{[\s\S]*\}'
  300. matches = re.findall(brace_pattern, response_text)
  301. if matches:
  302. # 尝试从后向前找第一个有效 JSON
  303. for i in range(len(matches) - 1, -1, -1):
  304. try:
  305. return json.loads(matches[i])
  306. except json.JSONDecodeError:
  307. continue
  308. raise LLMResponseParseError(
  309. f"无法从响应中提取 JSON:\n{response_text[:500]}..."
  310. )
  311. def _validate_plan_json(self, plan_json: dict[str, Any]) -> None:
  312. """校验 Plan JSON 基本结构
  313. Args:
  314. plan_json: Plan JSON dict
  315. Raises:
  316. LLMResponseParseError: 校验失败
  317. """
  318. required_fields = ["goal", "steps"]
  319. for field in required_fields:
  320. if field not in plan_json:
  321. raise LLMResponseParseError(f"缺少必需字段: {field}")
  322. if not isinstance(plan_json["steps"], list):
  323. raise LLMResponseParseError("steps 必须是列表")
  324. # 校验每个步骤
  325. for i, step in enumerate(plan_json["steps"]):
  326. if "action" not in step:
  327. raise LLMResponseParseError(f"步骤 {i} 缺少 action 字段")
  328. def _build_safe_fallback_plan_json(
  329. self,
  330. user_intent: str,
  331. reason: str = "",
  332. ) -> dict[str, Any]:
  333. """构建安全的 fallback Plan JSON
  334. 当 LLM 调用失败时,返回一个安全的 fallback plan。
  335. 优先使用 ASK_USER 或 SPEAK,不直接执行风险动作。
  336. Args:
  337. user_intent: 用户原始意图
  338. reason: 回退原因
  339. Returns:
  340. 安全的 Plan JSON
  341. """
  342. return {
  343. "plan_id": f"plan_fallback_{uuid.uuid4().hex[:8]}",
  344. "goal": f"无法处理的请求: {user_intent}",
  345. "reasoning": f"LLM 处理失败或响应无效 ({reason}),返回安全 fallback",
  346. "risk_level": "low",
  347. "requires_confirmation": False,
  348. "confirmation_message": None,
  349. "steps": [
  350. {
  351. "step_id": 1,
  352. "action": "ask_user",
  353. "tool_call_type": "ask_user",
  354. "parameters": {
  355. "question": f"抱歉,我无法理解或处理您的请求: {user_intent}。请重新描述。"
  356. },
  357. "preconditions": {},
  358. "fallback": None,
  359. "status": "pending",
  360. "description": "询问用户澄清意图",
  361. "requires_confirmation": False,
  362. "confirmation_message": None,
  363. "metadata": {"fallback_reason": reason},
  364. }
  365. ],
  366. "status": "created",
  367. "source": "llm_fallback",
  368. "metadata": {"fallback": True, "reason": reason},
  369. }
  370. # =============================================================================
  371. # 同步调用接口(供 Planner 使用)
  372. # =============================================================================
  373. def call_llm_for_plan(
  374. user_intent: str,
  375. world_snapshot: dict[str, Any],
  376. available_tools: list[str],
  377. domain_rules: dict[str, Any] | None = None,
  378. planner_config: dict[str, Any] | None = None,
  379. **kwargs,
  380. ) -> dict[str, Any]:
  381. """便捷函数:调用 LLM 生成 Plan JSON
  382. 创建一个临时客户端并调用 generate_plan_json。
  383. Args:
  384. user_intent: 用户意图
  385. world_snapshot: 世界状态
  386. available_tools: 可用能力
  387. domain_rules: 领域规则
  388. planner_config: Planner 配置
  389. **kwargs: 传递给 LLMPlannerClient 的其他参数
  390. Returns:
  391. Plan JSON dict
  392. """
  393. client = LLMPlannerClient(**kwargs)
  394. return client.generate_plan_json(
  395. user_intent=user_intent,
  396. world_snapshot=world_snapshot,
  397. available_tools=available_tools,
  398. domain_rules=domain_rules,
  399. planner_config=planner_config,
  400. )
  401. # =============================================================================
  402. # 主程序入口(测试示例)
  403. # =============================================================================
  404. if __name__ == "__main__":
  405. print("=" * 70)
  406. print("LLM Planner Client 测试")
  407. print("=" * 70)
  408. client = LLMPlannerClient()
  409. # 场景 1:通用请求(默认返回合法 Plan)
  410. print("\n[场景 1] 通用请求")
  411. print("-" * 40)
  412. plan_json = client.generate_plan_json(
  413. user_intent="打开风扇降温",
  414. world_snapshot={"temperature": 32, "humidity": 70},
  415. available_tools=["adjust_fan", "speak", "query"],
  416. domain_rules={"high_risk_actions": ["turn_off"]},
  417. )
  418. print(f"Goal: {plan_json.get('goal')}")
  419. print(f"Steps: {len(plan_json.get('steps', []))}")
  420. print(f"Action: {plan_json['steps'][0]['action'] if plan_json.get('steps') else 'N/A'}")
  421. # 场景 2:测试 ASK_USER 响应
  422. print("\n[场景 2] 测试 ASK_USER 响应")
  423. print("-" * 40)
  424. plan_json = client.generate_plan_json(
  425. user_intent="确认操作 __TEST_ASK_USER__",
  426. world_snapshot={"temperature": 28},
  427. available_tools=["feed", "speak", "query"],
  428. )
  429. print(f"Goal: {plan_json.get('goal')}")
  430. print(f"Tool Type: {plan_json['steps'][0]['tool_call_type'] if plan_json.get('steps') else 'N/A'}")
  431. # 场景 3:测试 Invalid JSON fallback
  432. print("\n[场景 3] 测试 Invalid JSON fallback")
  433. print("-" * 40)
  434. plan_json = client.generate_plan_json(
  435. user_intent="测试 __TEST_INVALID_JSON__",
  436. world_snapshot={"temperature": 30},
  437. available_tools=["adjust_fan", "speak"],
  438. )
  439. print(f"Goal: {plan_json.get('goal')}")
  440. print(f"Source: {plan_json.get('source')}")
  441. print(f"Action: {plan_json['steps'][0]['action'] if plan_json.get('steps') else 'N/A'}")
  442. # 场景 4:测试 JSON 提取
  443. print("\n[场景 4] JSON 提取测试")
  444. print("-" * 40)
  445. test_response = '根据您的请求,我生成以下计划:{"plan_id": "plan_test123", "goal": "测试计划", "reasoning": "这是一个测试", "risk_level": "low", "requires_confirmation": false, "confirmation_message": null, "steps": [{"step_id": 1, "action": "query", "tool_call_type": "query_world", "parameters": {}, "preconditions": {}, "fallback": null, "status": "pending", "description": "查询状态", "requires_confirmation": false, "confirmation_message": null, "metadata": {}}], "status": "created", "source": "llm", "metadata": {}}这是标准格式的 Plan。'
  446. try:
  447. extracted = client._extract_json_from_response(test_response)
  448. print(f"成功提取 JSON: plan_id={extracted.get('plan_id')}")
  449. except LLMResponseParseError as e:
  450. print(f"提取失败: {e}")
  451. # 场景 5:测试 markdown 代码块提取
  452. print("\n[场景 5] Markdown 代码块提取测试")
  453. print("-" * 40)
  454. test_response_markdown = '''这是响应:
  455. ```json
  456. {
  457. "plan_id": "plan_markdown",
  458. "goal": "测试 markdown",
  459. "reasoning": "测试",
  460. "risk_level": "low",
  461. "requires_confirmation": false,
  462. "confirmation_message": null,
  463. "steps": [
  464. {
  465. "step_id": 1,
  466. "action": "query",
  467. "tool_call_type": "query_world",
  468. "parameters": {},
  469. "preconditions": {},
  470. "fallback": null,
  471. "status": "pending",
  472. "description": "查询",
  473. "requires_confirmation": false,
  474. "confirmation_message": null,
  475. "metadata": {}
  476. }
  477. ],
  478. "status": "created",
  479. "source": "llm",
  480. "metadata": {}
  481. }
  482. ```'''
  483. try:
  484. extracted = client._extract_json_from_response(test_response_markdown)
  485. print(f"成功提取 JSON: plan_id={extracted.get('plan_id')}")
  486. except LLMResponseParseError as e:
  487. print(f"提取失败: {e}")
  488. print("\n" + "=" * 70)
  489. print("LLM Planner Client 测试完成")
  490. print("=" * 70)