#!/usr/bin/env python3 """ world_node.py - ROS2 World Model 节点 功能: - 订阅配置的 Topic,将数据写入 WorldModel - 定时发布世界状态 Snapshot - YAML 配置驱动 使用: ros2 run world world_node --ros-args --params-file config/config.yaml """ import json import yaml import os from typing import Optional import rclpy from rclpy.node import Node from rclpy.callback_groups import MutuallyExclusiveCallbackGroup from std_msgs.msg import String from ament_index_python.packages import get_package_share_directory from world.world_model import WorldModel class TopicSubscription: """Topic 订阅配置封装""" def __init__( self, topic: str, update_type: str, entity_type: Optional[str] = None, ttl: Optional[float] = None, mapping: Optional[dict] = None, ): self.topic = topic self.update_type = update_type self.entity_type = entity_type self.ttl = ttl self.mapping = mapping or {} def extract_entity_id(self, data: dict) -> Optional[str]: """从数据中提取 entity_id""" field = self.mapping.get("entity_id") if not field: return None return self._get_nested(data, field) def extract_state(self, data: dict) -> dict: """提取 state 字段""" state_mapping = self.mapping.get("state", {}) result = {} for target_key, source_key in state_mapping.items(): value = self._get_nested(data, source_key) if value is not None: result[target_key] = value return result def extract_metadata(self, data: dict) -> Optional[dict]: """提取 metadata""" meta_mapping = self.mapping.get("metadata", {}) if not meta_mapping: return None result = {} for target_key, source_key in meta_mapping.items(): value = self._get_nested(data, source_key) if value is not None: result[target_key] = value return result if result else None @staticmethod def _get_nested(data: dict, path: str): """从嵌套字典中按路径获取值""" keys = path.split(".") current = data for key in keys: if isinstance(current, dict) and key in current: current = current[key] else: return None return current class WorldNode(Node): """ROS2 World Model 节点""" def __init__(self) -> None: super().__init__("world_node") self.declare_parameter("publish_snapshot_topic", "/world/snapshot") self.declare_parameter("snapshot_rate", 1.0) self.declare_parameter("snapshot_mode", "minimal") self.publish_snapshot_topic = self.get_parameter("publish_snapshot_topic").value self.snapshot_rate = self.get_parameter("snapshot_rate").value self.snapshot_mode = self.get_parameter("snapshot_mode").value self._world_model = WorldModel() self.get_logger().info("WorldModel initialized") self._subscription_configs: list[TopicSubscription] = [] self._load_subscriptions() self._create_subscriptions() self._snapshot_publisher = self.create_publisher( String, self.publish_snapshot_topic, qos_profile=10, ) timer_period = 1.0 / self.snapshot_rate if self.snapshot_rate > 0 else 1.0 self._timer = self.create_timer(timer_period, self._publish_snapshot) self.get_logger().info( f"WorldNode started, snapshot rate: {self.snapshot_rate} Hz" ) def _load_subscriptions(self) -> None: """从 YAML 加载订阅配置""" config_path = os.environ.get( 'WORLD_CONFIG_PATH', os.path.join(get_package_share_directory('world'), 'config', 'config.yaml'), ) try: config_path = os.path.normpath(config_path) with open(config_path, 'r') as f: config_data = yaml.safe_load(f) params = config_data.get('/**', {}).get('ros__parameters', {}) raw_subs = params.get('subscriptions', []) for item in raw_subs: if isinstance(item, str): self._subscription_configs.append(item) elif isinstance(item, bytes): self._subscription_configs.append(item.decode('utf-8')) self.get_logger().info(f"Loaded {len(self._subscription_configs)} subscription configs") except Exception as e: self.get_logger().warn(f"Failed to load config: {e}") def _create_subscriptions(self) -> None: """根据配置创建订阅""" callback_group = MutuallyExclusiveCallbackGroup() for config_str in self._subscription_configs: try: config = json.loads(config_str) except json.JSONDecodeError as e: self.get_logger().error(f"Failed to parse config: {e}") continue topic = config.get("topic") update_type = config.get("update_type", "entity") if not topic: continue sub_config = TopicSubscription( topic=topic, update_type=update_type, entity_type=config.get("entity_type"), ttl=config.get("ttl"), mapping=config.get("mapping", {}), ) self.create_subscription( String, topic, lambda msg, sc=sub_config: self._topic_callback(msg, sc), qos_profile=10, callback_group=callback_group, ) self.get_logger().info(f"Subscribed: {topic}") self.get_logger().info(f"Created {len(self._subscription_configs)} subscriptions") def _topic_callback(self, msg: String, sub_config: TopicSubscription) -> None: """处理接收到的消息""" try: data = json.loads(msg.data) self._process_message(data, sub_config) except json.JSONDecodeError as e: self.get_logger().warn(f"JSON parse failed [{sub_config.topic}]: {e}") except Exception as e: self.get_logger().error(f"Process message failed [{sub_config.topic}]: {e}") def _process_message(self, data: dict, sub_config: TopicSubscription) -> None: """根据 update_type 处理消息""" update_type = sub_config.update_type if update_type == "entity": self._update_entity(data, sub_config) elif update_type == "environment": self._update_environment(data, sub_config) elif update_type == "system": self._update_system(data, sub_config) def _update_entity(self, data: dict, sub_config: TopicSubscription) -> None: """更新 Entity""" entity_id = sub_config.extract_entity_id(data) if not entity_id: return entity_type = sub_config.entity_type or "unknown" state = sub_config.extract_state(data) metadata = sub_config.extract_metadata(data) self._world_model.update_entity( entity_id=entity_id, entity_type=entity_type, state=state, metadata=metadata, ttl=sub_config.ttl, ) def _update_environment(self, data: dict, sub_config: TopicSubscription) -> None: """更新环境状态""" state = sub_config.extract_state(data) if state: self._world_model.update_environment_batch(state) def _update_system(self, data: dict, sub_config: TopicSubscription) -> None: """更新系统状态""" state = sub_config.extract_state(data) if state: self._world_model.update_system_batch(state) def _publish_snapshot(self) -> None: """定时发布 Snapshot""" try: snapshot = self._world_model.snapshot(mode=self.snapshot_mode) msg = String() msg.data = json.dumps(snapshot, ensure_ascii=False) self._snapshot_publisher.publish(msg) except Exception as e: self.get_logger().error(f"Publish snapshot failed: {e}") def destroy_node(self) -> None: self.get_logger().info("WorldNode shutting down") super().destroy_node() def main(args=None): rclpy.init(args=args) try: node = WorldNode() rclpy.spin(node) finally: if rclpy.ok(): rclpy.shutdown() if __name__ == '__main__': main()