4
0

planner_config_loader.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. """
  2. planner_config_loader.py - Planner 配置加载模块
  3. 该模块负责从统一 config.yaml 中读取 Planner 相关配置,
  4. 并将其转换为 planner.py 可直接使用的配置对象和运行参数。
  5. 职责边界:
  6. - 读取、校验、转换配置
  7. - 不做规划逻辑
  8. - 不做 ROS 通信
  9. - 不包含场景硬编码
  10. 使用方式:
  11. from planner_config_loader import load_planner_runtime_config
  12. runtime_config = load_planner_runtime_config("planner_config.yaml")
  13. planner = Planner(runtime_config["planner_config"])
  14. plan = planner.generate_plan(
  15. world_snapshot={...},
  16. user_intent="降温",
  17. available_tools=runtime_config["available_tools"],
  18. domain_rules=runtime_config["domain_rules"],
  19. planner_mode=runtime_config["planner_mode"],
  20. )
  21. """
  22. from __future__ import annotations
  23. import os
  24. from dataclasses import dataclass, field
  25. from typing import Any
  26. try:
  27. import yaml
  28. YAML_AVAILABLE = True
  29. except ImportError:
  30. YAML_AVAILABLE = False
  31. from planner import PlannerConfig
  32. from tool_protocol import RiskLevel
  33. # =============================================================================
  34. # 配置校验异常
  35. # =============================================================================
  36. class ConfigValidationError(ValueError):
  37. """配置校验异常"""
  38. pass
  39. # =============================================================================
  40. # 运行时配置数据类
  41. # =============================================================================
  42. @dataclass
  43. class PlannerRuntimeConfig:
  44. """Planner 运行时配置
  45. 聚合所有 Planner 运行所需的配置项。
  46. 属性:
  47. planner_mode: 规划模式
  48. planner_config: PlannerConfig 对象
  49. available_tools: 可用能力列表
  50. tool_descriptions: 工具语义描述列表
  51. intent_to_action: 意图到 action 的映射
  52. domain_rules: 领域规则
  53. 示例:
  54. >>> config = PlannerRuntimeConfig(
  55. ... planner_mode="hybrid",
  56. ... planner_config=PlannerConfig(),
  57. ... available_tools=["feed", "speak"],
  58. ... tool_descriptions=[{"name": "feed", "description": "..."}],
  59. ... intent_to_action={"降温": "adjust_fan"},
  60. ... domain_rules={"high_risk_actions": []},
  61. ... )
  62. """
  63. planner_mode: str = "hybrid"
  64. planner_config: PlannerConfig = field(default_factory=PlannerConfig)
  65. available_tools: list[str] = field(default_factory=list)
  66. tool_descriptions: list[dict[str, Any]] = field(default_factory=list)
  67. intent_to_action: dict[str, str] = field(default_factory=dict)
  68. domain_rules: dict[str, Any] = field(default_factory=dict)
  69. def to_dict(self) -> dict[str, Any]:
  70. """转换为字典
  71. Returns:
  72. 包含所有配置项的字典
  73. """
  74. return {
  75. "planner_mode": self.planner_mode,
  76. "planner_config": {
  77. "default_risk_level": self.planner_config.default_risk_level.value,
  78. "require_confirmation_on_medium_risk": self.planner_config.require_confirmation_on_medium_risk,
  79. "require_confirmation_on_high_risk": self.planner_config.require_confirmation_on_high_risk,
  80. "max_plan_steps": self.planner_config.max_plan_steps,
  81. "default_source": self.planner_config.default_source,
  82. },
  83. "available_tools": self.available_tools,
  84. "tool_descriptions": self.tool_descriptions,
  85. "intent_to_action": self.intent_to_action,
  86. "domain_rules": self.domain_rules,
  87. }
  88. # =============================================================================
  89. # 配置加载函数
  90. # =============================================================================
  91. def load_yaml_config(config_path: str) -> dict[str, Any]:
  92. """加载 YAML 配置文件
  93. Args:
  94. config_path: 配置文件路径
  95. Returns:
  96. 解析后的配置字典
  97. Raises:
  98. FileNotFoundError: 配置文件不存在
  99. ValueError: YAML 格式错误
  100. ImportError: PyYAML 未安装
  101. """
  102. if not os.path.exists(config_path):
  103. raise FileNotFoundError(f"配置文件不存在: {config_path}")
  104. if not YAML_AVAILABLE:
  105. raise ImportError(
  106. "PyYAML 未安装,请运行: pip install pyyaml"
  107. )
  108. try:
  109. with open(config_path, "r", encoding="utf-8") as f:
  110. config = yaml.safe_load(f)
  111. except yaml.YAMLError as e:
  112. raise ValueError(f"YAML 格式错误 in {config_path}: {e}")
  113. if config is None:
  114. return {}
  115. return config
  116. def load_planner_section(config: dict[str, Any]) -> dict[str, Any]:
  117. """提取 planner 配置段
  118. Args:
  119. config: 完整配置字典
  120. Returns:
  121. planner 配置段,缺省时返回空字典
  122. """
  123. return config.get("planner", {})
  124. def build_planner_config(planner_section: dict[str, Any]) -> PlannerConfig:
  125. """构造 PlannerConfig 对象
  126. Args:
  127. planner_section: planner 配置段
  128. Returns:
  129. PlannerConfig 对象
  130. Raises:
  131. ConfigValidationError: 配置值非法
  132. """
  133. # 获取各字段,缺省时使用默认值
  134. risk_level_str = planner_section.get("default_risk_level", "low")
  135. require_medium = planner_section.get("require_confirmation_on_medium_risk", True)
  136. require_high = planner_section.get("require_confirmation_on_high_risk", True)
  137. max_steps = planner_section.get("max_plan_steps", 10)
  138. default_source = planner_section.get("default_source", "hybrid")
  139. # 校验并转换 risk_level
  140. risk_level_str = str(risk_level_str).lower()
  141. if risk_level_str not in ("low", "medium", "high"):
  142. raise ConfigValidationError(
  143. f"无效的 default_risk_level: '{risk_level_str}',有效值为: low/medium/high"
  144. )
  145. risk_level = RiskLevel(risk_level_str)
  146. # 校验 max_plan_steps
  147. if not isinstance(max_steps, int) or max_steps <= 0:
  148. raise ConfigValidationError(
  149. f"无效的 max_plan_steps: {max_steps},必须为大于 0 的整数"
  150. )
  151. # 校验 default_source
  152. valid_sources = ("rule_engine", "llm", "hybrid")
  153. if default_source not in valid_sources:
  154. raise ConfigValidationError(
  155. f"无效的 default_source: '{default_source}',有效值为: {valid_sources}"
  156. )
  157. return PlannerConfig(
  158. default_risk_level=risk_level,
  159. require_confirmation_on_medium_risk=require_medium,
  160. require_confirmation_on_high_risk=require_high,
  161. max_plan_steps=max_steps,
  162. default_source=default_source,
  163. )
  164. def load_available_tools(planner_section: dict[str, Any]) -> list[str]:
  165. """加载可用能力列表
  166. Args:
  167. planner_section: planner 配置段
  168. Returns:
  169. 去重后的工具列表,缺省时返回空列表
  170. """
  171. raw_tools = planner_section.get("available_tools", [])
  172. if not isinstance(raw_tools, list):
  173. return []
  174. # 去重,保持顺序
  175. seen = set()
  176. tools = []
  177. for tool in raw_tools:
  178. if isinstance(tool, str) and tool and tool not in seen:
  179. seen.add(tool)
  180. tools.append(tool)
  181. return tools
  182. def load_tool_descriptions(planner_section: dict[str, Any]) -> list[dict[str, Any]]:
  183. """加载工具语义描述列表
  184. Args:
  185. planner_section: planner 配置段
  186. Returns:
  187. 工具描述列表,每项包含 name, description, tool_call_type, category
  188. 缺失时返回空列表
  189. 配置非法时过滤无效项并打印 warning
  190. """
  191. raw_descriptions = planner_section.get("tool_descriptions", [])
  192. if not isinstance(raw_descriptions, list):
  193. print(f"[PlannerConfigLoader] Warning: tool_descriptions 应为列表类型,忽略")
  194. return []
  195. valid_descriptions = []
  196. for i, item in enumerate(raw_descriptions):
  197. if not isinstance(item, dict):
  198. print(f"[PlannerConfigLoader] Warning: tool_descriptions[{i}] 应为字典,忽略")
  199. continue
  200. # 必须包含 name 和 description
  201. name = item.get("name")
  202. description = item.get("description")
  203. if not name or not isinstance(name, str):
  204. print(f"[PlannerConfigLoader] Warning: tool_descriptions[{i}] 缺少有效 name 字段,忽略")
  205. continue
  206. if not description or not isinstance(description, str):
  207. print(f"[PlannerConfigLoader] Warning: tool_descriptions[{i}] 缺少有效 description 字段,忽略")
  208. continue
  209. # 构建有效描述,补充默认值
  210. valid_item = {
  211. "name": str(name).strip(),
  212. "description": str(description).strip(),
  213. "tool_call_type": str(item.get("tool_call_type", "execute")).strip(),
  214. "category": str(item.get("category", "action")).strip(),
  215. }
  216. valid_descriptions.append(valid_item)
  217. if len(valid_descriptions) < len(raw_descriptions):
  218. print(f"[PlannerConfigLoader] Warning: 过滤了 {len(raw_descriptions) - len(valid_descriptions)} 个无效 tool_descriptions 项")
  219. return valid_descriptions
  220. def load_intent_to_action(planner_section: dict[str, Any]) -> dict[str, str]:
  221. """加载意图到 action 的映射
  222. Args:
  223. planner_section: planner 配置段
  224. Returns:
  225. 关键词到 action 的映射字典
  226. Raises:
  227. ConfigValidationError: key 或 value 为空
  228. """
  229. raw_mapping = planner_section.get("intent_to_action", {})
  230. if not isinstance(raw_mapping, dict):
  231. return {}
  232. result = {}
  233. for key, value in raw_mapping.items():
  234. key_str = str(key).strip()
  235. value_str = str(value).strip()
  236. if not key_str:
  237. raise ConfigValidationError("intent_to_action 的 key 不能为空字符串")
  238. if not value_str:
  239. raise ConfigValidationError(
  240. f"intent_to_action 的 value 不能为空字符串,key='{key_str}'"
  241. )
  242. result[key_str] = value_str
  243. return result
  244. def load_high_risk_actions(planner_section: dict[str, Any]) -> list[str]:
  245. """加载高风险动作列表
  246. Args:
  247. planner_section: planner 配置段
  248. Returns:
  249. 高风险动作列表
  250. """
  251. raw_list = planner_section.get("high_risk_actions", [])
  252. if not isinstance(raw_list, list):
  253. return []
  254. return [str(item) for item in raw_list if item]
  255. def load_medium_risk_actions(planner_section: dict[str, Any]) -> list[str]:
  256. """加载中风险动作列表
  257. Args:
  258. planner_section: planner 配置段
  259. Returns:
  260. 中风险动作列表
  261. """
  262. raw_list = planner_section.get("medium_risk_actions", [])
  263. if not isinstance(raw_list, list):
  264. return []
  265. return [str(item) for item in raw_list if item]
  266. def load_confirmation_rules(planner_section: dict[str, Any]) -> dict[str, bool]:
  267. """加载确认规则配置
  268. Args:
  269. planner_section: planner 配置段
  270. Returns:
  271. 确认规则字典
  272. """
  273. raw_rules = planner_section.get("confirmation_rules", {})
  274. if not isinstance(raw_rules, dict):
  275. raw_rules = {}
  276. return {
  277. "repeated_action_requires_confirmation": bool(
  278. raw_rules.get("repeated_action_requires_confirmation", True)
  279. ),
  280. "unavailable_actuator_requires_confirmation": bool(
  281. raw_rules.get("unavailable_actuator_requires_confirmation", True)
  282. ),
  283. "high_risk_requires_confirmation": bool(
  284. raw_rules.get("high_risk_requires_confirmation", True)
  285. ),
  286. "medium_risk_requires_confirmation": bool(
  287. raw_rules.get("medium_risk_requires_confirmation", True)
  288. ),
  289. }
  290. def load_domain_rules(planner_section: dict[str, Any]) -> dict[str, Any]:
  291. """加载领域规则
  292. Args:
  293. planner_section: planner 配置段
  294. Returns:
  295. 领域规则字典
  296. """
  297. return {
  298. "high_risk_actions": load_high_risk_actions(planner_section),
  299. "medium_risk_actions": load_medium_risk_actions(planner_section),
  300. "confirmation_rules": load_confirmation_rules(planner_section),
  301. }
  302. def validate_planner_mode(mode: str) -> str:
  303. """校验并规范化 planner 模式
  304. Args:
  305. mode: 原始模式字符串
  306. Returns:
  307. 规范化后的模式字符串
  308. Raises:
  309. ConfigValidationError: 模式非法
  310. """
  311. mode_lower = str(mode).lower().strip()
  312. valid_modes = ("rule", "llm", "hybrid")
  313. if mode_lower not in valid_modes:
  314. raise ConfigValidationError(
  315. f"无效的 planner mode: '{mode}',有效值为: {valid_modes}"
  316. )
  317. return mode_lower
  318. def load_planner_runtime_config(config_path: str) -> PlannerRuntimeConfig:
  319. """加载 Planner 运行时配置(统一入口)
  320. 读取 YAML 文件,提取 planner 配置段,
  321. 进行校验和转换,返回可直接使用的配置对象。
  322. Args:
  323. config_path: 配置文件路径
  324. Returns:
  325. PlannerRuntimeConfig 对象
  326. Raises:
  327. FileNotFoundError: 配置文件不存在
  328. ValueError: YAML 格式错误
  329. ConfigValidationError: 配置值非法
  330. ImportError: PyYAML 未安装
  331. """
  332. # 加载并解析 YAML
  333. full_config = load_yaml_config(config_path)
  334. # 提取 planner 配置段
  335. planner_section = load_planner_section(full_config)
  336. # 校验并加载 planner_mode
  337. planner_mode = planner_section.get("mode", "hybrid")
  338. planner_mode = validate_planner_mode(planner_mode)
  339. # 构造 PlannerConfig
  340. planner_config = build_planner_config(planner_section)
  341. # 加载其他配置项
  342. available_tools = load_available_tools(planner_section)
  343. tool_descriptions = load_tool_descriptions(planner_section)
  344. intent_to_action = load_intent_to_action(planner_section)
  345. domain_rules = load_domain_rules(planner_section)
  346. # 构建并返回运行时配置
  347. return PlannerRuntimeConfig(
  348. planner_mode=planner_mode,
  349. planner_config=planner_config,
  350. available_tools=available_tools,
  351. tool_descriptions=tool_descriptions,
  352. intent_to_action=intent_to_action,
  353. domain_rules=domain_rules,
  354. )
  355. def load_planner_runtime_config_as_dict(config_path: str) -> dict[str, Any]:
  356. """加载 Planner 运行时配置(返回字典形式)
  357. 与 load_planner_runtime_config 功能相同,但返回字典格式。
  358. 兼容不需要 dataclass 的场景。
  359. Args:
  360. config_path: 配置文件路径
  361. Returns:
  362. 包含所有配置项的字典
  363. """
  364. runtime_config = load_planner_runtime_config(config_path)
  365. return {
  366. "planner_mode": runtime_config.planner_mode,
  367. "planner_config": runtime_config.planner_config,
  368. "available_tools": runtime_config.available_tools,
  369. "intent_to_action": runtime_config.intent_to_action,
  370. "domain_rules": runtime_config.domain_rules,
  371. }
  372. # =============================================================================
  373. # 主程序入口(测试示例)
  374. # =============================================================================
  375. if __name__ == "__main__":
  376. import json
  377. import os
  378. print("=" * 70)
  379. print("Planner 配置加载测试")
  380. print("=" * 70)
  381. # 获取当前脚本所在目录
  382. script_dir = os.path.dirname(os.path.abspath(__file__))
  383. config_path = os.path.join(script_dir, "planner_config.yaml")
  384. # 检查配置文件是否存在
  385. if not os.path.exists(config_path):
  386. print(f"\n配置文件不存在: {config_path}")
  387. print("创建示例配置文件...")
  388. # 创建一个临时配置用于测试
  389. from planner import PlannerConfig
  390. from tool_protocol import RiskLevel
  391. runtime_config = PlannerRuntimeConfig(
  392. planner_mode="hybrid",
  393. planner_config=PlannerConfig(),
  394. available_tools=["feed", "adjust_fan", "speak"],
  395. intent_to_action={"降温": "adjust_fan"},
  396. domain_rules={
  397. "high_risk_actions": ["turn_off"],
  398. "medium_risk_actions": ["adjust_fan"],
  399. "confirmation_rules": {},
  400. },
  401. )
  402. else:
  403. # 加载配置
  404. print(f"\n加载配置文件: {config_path}")
  405. runtime_config = load_planner_runtime_config(config_path)
  406. # -------------------------------------------------------------------------
  407. # 打印配置信息
  408. # -------------------------------------------------------------------------
  409. print("\n[1] planner_mode")
  410. print("-" * 40)
  411. print(f" {runtime_config.planner_mode}")
  412. print("\n[2] PlannerConfig")
  413. print("-" * 40)
  414. pc = runtime_config.planner_config
  415. print(f" default_risk_level: {pc.default_risk_level.value}")
  416. print(f" require_confirmation_on_medium_risk: {pc.require_confirmation_on_medium_risk}")
  417. print(f" require_confirmation_on_high_risk: {pc.require_confirmation_on_high_risk}")
  418. print(f" max_plan_steps: {pc.max_plan_steps}")
  419. print(f" default_source: {pc.default_source}")
  420. print("\n[3] available_tools")
  421. print("-" * 40)
  422. tools = runtime_config.available_tools
  423. print(f" 数量: {len(tools)}")
  424. for i, tool in enumerate(tools, 1):
  425. print(f" {i}. {tool}")
  426. print("\n[4] tool_descriptions")
  427. print("-" * 40)
  428. tool_descs = runtime_config.tool_descriptions
  429. print(f" 数量: {len(tool_descs)}")
  430. for i, desc in enumerate(tool_descs, 1):
  431. print(f" {i}. name={desc.get('name')}, type={desc.get('tool_call_type')}, category={desc.get('category')}")
  432. print(f" description: {desc.get('description', '')[:50]}...")
  433. print("\n[5] intent_to_action")
  434. print("-" * 40)
  435. mapping = runtime_config.intent_to_action
  436. print(f" 数量: {len(mapping)}")
  437. for keyword, action in mapping.items():
  438. print(f" '{keyword}' → '{action}'")
  439. print("\n[6] domain_rules")
  440. print("-" * 40)
  441. rules = runtime_config.domain_rules
  442. print(f" high_risk_actions: {rules.get('high_risk_actions', [])}")
  443. print(f" medium_risk_actions: {rules.get('medium_risk_actions', [])}")
  444. print(f" confirmation_rules:")
  445. for key, value in rules.get("confirmation_rules", {}).items():
  446. print(f" {key}: {value}")
  447. # -------------------------------------------------------------------------
  448. # 测试校验功能
  449. # -------------------------------------------------------------------------
  450. print("\n[7] 配置校验测试")
  451. print("-" * 40)
  452. # 测试无效 risk_level
  453. try:
  454. bad_section = {"default_risk_level": "invalid"}
  455. build_planner_config(bad_section)
  456. print(" ❌ 无效 risk_level 未被捕获")
  457. except ConfigValidationError as e:
  458. print(f" ✓ 无效 risk_level 正确捕获: {e}")
  459. # 测试无效 max_plan_steps
  460. try:
  461. bad_section = {"max_plan_steps": 0}
  462. build_planner_config(bad_section)
  463. print(" ❌ 无效 max_plan_steps 未被捕获")
  464. except ConfigValidationError as e:
  465. print(f" ✓ 无效 max_plan_steps 正确捕获: {e}")
  466. # 测试空的 intent key
  467. try:
  468. bad_section = {"intent_to_action": {"": "feed"}}
  469. load_intent_to_action(bad_section)
  470. print(" ❌ 空的 intent key 未被捕获")
  471. except ConfigValidationError as e:
  472. print(f" ✓ 空的 intent key 正确捕获")
  473. # -------------------------------------------------------------------------
  474. # 完整 JSON 输出
  475. # -------------------------------------------------------------------------
  476. print("\n[8] 完整配置 JSON")
  477. print("-" * 40)
  478. print(json.dumps(runtime_config.to_dict(), indent=2, ensure_ascii=False))
  479. print("\n" + "=" * 70)
  480. print("Planner 配置加载测试完成")
  481. print("=" * 70)