world_node.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. #!/usr/bin/env python3
  2. """
  3. world_node.py - ROS2 World Model 节点
  4. 功能:
  5. - 订阅配置的 Topic,将数据写入 WorldModel
  6. - 定时发布世界状态 Snapshot
  7. - YAML 配置驱动
  8. 使用:
  9. ros2 run world world_node --ros-args --params-file config/config.yaml
  10. """
  11. import json
  12. import yaml
  13. import os
  14. from typing import Optional
  15. import rclpy
  16. from rclpy.node import Node
  17. from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
  18. from std_msgs.msg import String
  19. from ament_index_python.packages import get_package_share_directory
  20. from world.world_model import WorldModel
  21. class TopicSubscription:
  22. """Topic 订阅配置封装"""
  23. def __init__(
  24. self,
  25. topic: str,
  26. update_type: str,
  27. entity_type: Optional[str] = None,
  28. ttl: Optional[float] = None,
  29. mapping: Optional[dict] = None,
  30. ):
  31. self.topic = topic
  32. self.update_type = update_type
  33. self.entity_type = entity_type
  34. self.ttl = ttl
  35. self.mapping = mapping or {}
  36. def extract_entity_id(self, data: dict) -> Optional[str]:
  37. """从数据中提取 entity_id"""
  38. field = self.mapping.get("entity_id")
  39. if not field:
  40. return None
  41. return self._get_nested(data, field)
  42. def extract_state(self, data: dict) -> dict:
  43. """提取 state 字段"""
  44. state_mapping = self.mapping.get("state", {})
  45. result = {}
  46. for target_key, source_key in state_mapping.items():
  47. value = self._get_nested(data, source_key)
  48. if value is not None:
  49. result[target_key] = value
  50. return result
  51. def extract_metadata(self, data: dict) -> Optional[dict]:
  52. """提取 metadata"""
  53. meta_mapping = self.mapping.get("metadata", {})
  54. if not meta_mapping:
  55. return None
  56. result = {}
  57. for target_key, source_key in meta_mapping.items():
  58. value = self._get_nested(data, source_key)
  59. if value is not None:
  60. result[target_key] = value
  61. return result if result else None
  62. @staticmethod
  63. def _get_nested(data: dict, path: str):
  64. """从嵌套字典中按路径获取值"""
  65. keys = path.split(".")
  66. current = data
  67. for key in keys:
  68. if isinstance(current, dict) and key in current:
  69. current = current[key]
  70. else:
  71. return None
  72. return current
  73. class WorldNode(Node):
  74. """ROS2 World Model 节点"""
  75. def __init__(self) -> None:
  76. super().__init__("world_node")
  77. self.declare_parameter("publish_snapshot_topic", "/world/snapshot")
  78. self.declare_parameter("snapshot_rate", 1.0)
  79. self.declare_parameter("snapshot_mode", "minimal")
  80. self.publish_snapshot_topic = self.get_parameter("publish_snapshot_topic").value
  81. self.snapshot_rate = self.get_parameter("snapshot_rate").value
  82. self.snapshot_mode = self.get_parameter("snapshot_mode").value
  83. self._world_model = WorldModel()
  84. self.get_logger().info("WorldModel initialized")
  85. self._subscription_configs: list[TopicSubscription] = []
  86. self._load_subscriptions()
  87. self._create_subscriptions()
  88. self._snapshot_publisher = self.create_publisher(
  89. String,
  90. self.publish_snapshot_topic,
  91. qos_profile=10,
  92. )
  93. timer_period = 1.0 / self.snapshot_rate if self.snapshot_rate > 0 else 1.0
  94. self._timer = self.create_timer(timer_period, self._publish_snapshot)
  95. self.get_logger().info(
  96. f"WorldNode started, snapshot rate: {self.snapshot_rate} Hz"
  97. )
  98. def _load_subscriptions(self) -> None:
  99. """从 YAML 加载订阅配置"""
  100. config_path = os.environ.get(
  101. 'WORLD_CONFIG_PATH',
  102. os.path.join(get_package_share_directory('world'), 'config', 'config.yaml'),
  103. )
  104. try:
  105. config_path = os.path.normpath(config_path)
  106. with open(config_path, 'r') as f:
  107. config_data = yaml.safe_load(f)
  108. params = config_data.get('/**', {}).get('ros__parameters', {})
  109. raw_subs = params.get('subscriptions', [])
  110. for item in raw_subs:
  111. if isinstance(item, str):
  112. self._subscription_configs.append(item)
  113. elif isinstance(item, bytes):
  114. self._subscription_configs.append(item.decode('utf-8'))
  115. self.get_logger().info(f"Loaded {len(self._subscription_configs)} subscription configs")
  116. except Exception as e:
  117. self.get_logger().warn(f"Failed to load config: {e}")
  118. def _create_subscriptions(self) -> None:
  119. """根据配置创建订阅"""
  120. callback_group = MutuallyExclusiveCallbackGroup()
  121. for config_str in self._subscription_configs:
  122. try:
  123. config = json.loads(config_str)
  124. except json.JSONDecodeError as e:
  125. self.get_logger().error(f"Failed to parse config: {e}")
  126. continue
  127. topic = config.get("topic")
  128. update_type = config.get("update_type", "entity")
  129. if not topic:
  130. continue
  131. sub_config = TopicSubscription(
  132. topic=topic,
  133. update_type=update_type,
  134. entity_type=config.get("entity_type"),
  135. ttl=config.get("ttl"),
  136. mapping=config.get("mapping", {}),
  137. )
  138. self.create_subscription(
  139. String,
  140. topic,
  141. lambda msg, sc=sub_config: self._topic_callback(msg, sc),
  142. qos_profile=10,
  143. callback_group=callback_group,
  144. )
  145. self.get_logger().info(f"Subscribed: {topic}")
  146. self.get_logger().info(f"Created {len(self._subscription_configs)} subscriptions")
  147. def _topic_callback(self, msg: String, sub_config: TopicSubscription) -> None:
  148. """处理接收到的消息"""
  149. try:
  150. data = json.loads(msg.data)
  151. self._process_message(data, sub_config)
  152. except json.JSONDecodeError as e:
  153. self.get_logger().warn(f"JSON parse failed [{sub_config.topic}]: {e}")
  154. except Exception as e:
  155. self.get_logger().error(f"Process message failed [{sub_config.topic}]: {e}")
  156. def _process_message(self, data: dict, sub_config: TopicSubscription) -> None:
  157. """根据 update_type 处理消息"""
  158. update_type = sub_config.update_type
  159. if update_type == "entity":
  160. self._update_entity(data, sub_config)
  161. elif update_type == "environment":
  162. self._update_environment(data, sub_config)
  163. elif update_type == "system":
  164. self._update_system(data, sub_config)
  165. def _update_entity(self, data: dict, sub_config: TopicSubscription) -> None:
  166. """更新 Entity"""
  167. entity_id = sub_config.extract_entity_id(data)
  168. if not entity_id:
  169. return
  170. entity_type = sub_config.entity_type or "unknown"
  171. state = sub_config.extract_state(data)
  172. metadata = sub_config.extract_metadata(data)
  173. self._world_model.update_entity(
  174. entity_id=entity_id,
  175. entity_type=entity_type,
  176. state=state,
  177. metadata=metadata,
  178. ttl=sub_config.ttl,
  179. )
  180. def _update_environment(self, data: dict, sub_config: TopicSubscription) -> None:
  181. """更新环境状态"""
  182. state = sub_config.extract_state(data)
  183. if state:
  184. self._world_model.update_environment_batch(state)
  185. def _update_system(self, data: dict, sub_config: TopicSubscription) -> None:
  186. """更新系统状态"""
  187. state = sub_config.extract_state(data)
  188. if state:
  189. self._world_model.update_system_batch(state)
  190. def _publish_snapshot(self) -> None:
  191. """定时发布 Snapshot"""
  192. try:
  193. snapshot = self._world_model.snapshot(mode=self.snapshot_mode)
  194. msg = String()
  195. msg.data = json.dumps(snapshot, ensure_ascii=False)
  196. self._snapshot_publisher.publish(msg)
  197. except Exception as e:
  198. self.get_logger().error(f"Publish snapshot failed: {e}")
  199. def destroy_node(self) -> None:
  200. self.get_logger().info("WorldNode shutting down")
  201. super().destroy_node()
  202. def main(args=None):
  203. rclpy.init(args=args)
  204. try:
  205. node = WorldNode()
  206. rclpy.spin(node)
  207. finally:
  208. if rclpy.ok():
  209. rclpy.shutdown()
  210. if __name__ == '__main__':
  211. main()