| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260 |
- #!/usr/bin/env python3
- """
- world_node.py - ROS2 World Model 节点
- 功能:
- - 订阅配置的 Topic,将数据写入 WorldModel
- - 定时发布世界状态 Snapshot
- - YAML 配置驱动
- 使用:
- ros2 run world world_node --ros-args --params-file config/config.yaml
- """
- import json
- import yaml
- import os
- from typing import Optional
- import rclpy
- from rclpy.node import Node
- from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
- from std_msgs.msg import String
- from ament_index_python.packages import get_package_share_directory
- from world.world_model import WorldModel
- class TopicSubscription:
- """Topic 订阅配置封装"""
-
- def __init__(
- self,
- topic: str,
- update_type: str,
- entity_type: Optional[str] = None,
- ttl: Optional[float] = None,
- mapping: Optional[dict] = None,
- ):
- self.topic = topic
- self.update_type = update_type
- self.entity_type = entity_type
- self.ttl = ttl
- self.mapping = mapping or {}
-
- def extract_entity_id(self, data: dict) -> Optional[str]:
- """从数据中提取 entity_id"""
- field = self.mapping.get("entity_id")
- if not field:
- return None
- return self._get_nested(data, field)
-
- def extract_state(self, data: dict) -> dict:
- """提取 state 字段"""
- state_mapping = self.mapping.get("state", {})
- result = {}
- for target_key, source_key in state_mapping.items():
- value = self._get_nested(data, source_key)
- if value is not None:
- result[target_key] = value
- return result
-
- def extract_metadata(self, data: dict) -> Optional[dict]:
- """提取 metadata"""
- meta_mapping = self.mapping.get("metadata", {})
- if not meta_mapping:
- return None
- result = {}
- for target_key, source_key in meta_mapping.items():
- value = self._get_nested(data, source_key)
- if value is not None:
- result[target_key] = value
- return result if result else None
-
- @staticmethod
- def _get_nested(data: dict, path: str):
- """从嵌套字典中按路径获取值"""
- keys = path.split(".")
- current = data
- for key in keys:
- if isinstance(current, dict) and key in current:
- current = current[key]
- else:
- return None
- return current
- class WorldNode(Node):
- """ROS2 World Model 节点"""
-
- def __init__(self) -> None:
- super().__init__("world_node")
-
- self.declare_parameter("publish_snapshot_topic", "/world/snapshot")
- self.declare_parameter("snapshot_rate", 1.0)
- self.declare_parameter("snapshot_mode", "minimal")
-
- self.publish_snapshot_topic = self.get_parameter("publish_snapshot_topic").value
- self.snapshot_rate = self.get_parameter("snapshot_rate").value
- self.snapshot_mode = self.get_parameter("snapshot_mode").value
-
- self._world_model = WorldModel()
- self.get_logger().info("WorldModel initialized")
-
- self._subscription_configs: list[TopicSubscription] = []
- self._load_subscriptions()
- self._create_subscriptions()
-
- self._snapshot_publisher = self.create_publisher(
- String,
- self.publish_snapshot_topic,
- qos_profile=10,
- )
-
- timer_period = 1.0 / self.snapshot_rate if self.snapshot_rate > 0 else 1.0
- self._timer = self.create_timer(timer_period, self._publish_snapshot)
-
- self.get_logger().info(
- f"WorldNode started, snapshot rate: {self.snapshot_rate} Hz"
- )
-
- def _load_subscriptions(self) -> None:
- """从 YAML 加载订阅配置"""
- config_path = os.environ.get(
- 'WORLD_CONFIG_PATH',
- os.path.join(get_package_share_directory('world'), 'config', 'config.yaml'),
- )
-
- try:
- config_path = os.path.normpath(config_path)
- with open(config_path, 'r') as f:
- config_data = yaml.safe_load(f)
-
- params = config_data.get('/**', {}).get('ros__parameters', {})
- raw_subs = params.get('subscriptions', [])
-
- for item in raw_subs:
- if isinstance(item, str):
- self._subscription_configs.append(item)
- elif isinstance(item, bytes):
- self._subscription_configs.append(item.decode('utf-8'))
-
- self.get_logger().info(f"Loaded {len(self._subscription_configs)} subscription configs")
- except Exception as e:
- self.get_logger().warn(f"Failed to load config: {e}")
-
- def _create_subscriptions(self) -> None:
- """根据配置创建订阅"""
- callback_group = MutuallyExclusiveCallbackGroup()
-
- for config_str in self._subscription_configs:
- try:
- config = json.loads(config_str)
- except json.JSONDecodeError as e:
- self.get_logger().error(f"Failed to parse config: {e}")
- continue
-
- topic = config.get("topic")
- update_type = config.get("update_type", "entity")
-
- if not topic:
- continue
-
- sub_config = TopicSubscription(
- topic=topic,
- update_type=update_type,
- entity_type=config.get("entity_type"),
- ttl=config.get("ttl"),
- mapping=config.get("mapping", {}),
- )
-
- self.create_subscription(
- String,
- topic,
- lambda msg, sc=sub_config: self._topic_callback(msg, sc),
- qos_profile=10,
- callback_group=callback_group,
- )
-
- self.get_logger().info(f"Subscribed: {topic}")
-
- self.get_logger().info(f"Created {len(self._subscription_configs)} subscriptions")
-
- def _topic_callback(self, msg: String, sub_config: TopicSubscription) -> None:
- """处理接收到的消息"""
- try:
- data = json.loads(msg.data)
- self._process_message(data, sub_config)
- except json.JSONDecodeError as e:
- self.get_logger().warn(f"JSON parse failed [{sub_config.topic}]: {e}")
- except Exception as e:
- self.get_logger().error(f"Process message failed [{sub_config.topic}]: {e}")
-
- def _process_message(self, data: dict, sub_config: TopicSubscription) -> None:
- """根据 update_type 处理消息"""
- update_type = sub_config.update_type
-
- if update_type == "entity":
- self._update_entity(data, sub_config)
- elif update_type == "environment":
- self._update_environment(data, sub_config)
- elif update_type == "system":
- self._update_system(data, sub_config)
-
- def _update_entity(self, data: dict, sub_config: TopicSubscription) -> None:
- """更新 Entity"""
- entity_id = sub_config.extract_entity_id(data)
- if not entity_id:
- return
-
- entity_type = sub_config.entity_type or "unknown"
- state = sub_config.extract_state(data)
- metadata = sub_config.extract_metadata(data)
-
- self._world_model.update_entity(
- entity_id=entity_id,
- entity_type=entity_type,
- state=state,
- metadata=metadata,
- ttl=sub_config.ttl,
- )
-
- def _update_environment(self, data: dict, sub_config: TopicSubscription) -> None:
- """更新环境状态"""
- state = sub_config.extract_state(data)
- if state:
- self._world_model.update_environment_batch(state)
-
- def _update_system(self, data: dict, sub_config: TopicSubscription) -> None:
- """更新系统状态"""
- state = sub_config.extract_state(data)
- if state:
- self._world_model.update_system_batch(state)
-
- def _publish_snapshot(self) -> None:
- """定时发布 Snapshot"""
- try:
- snapshot = self._world_model.snapshot(mode=self.snapshot_mode)
- msg = String()
- msg.data = json.dumps(snapshot, ensure_ascii=False)
- self._snapshot_publisher.publish(msg)
- except Exception as e:
- self.get_logger().error(f"Publish snapshot failed: {e}")
-
- def destroy_node(self) -> None:
- self.get_logger().info("WorldNode shutting down")
- super().destroy_node()
- def main(args=None):
- rclpy.init(args=args)
- try:
- node = WorldNode()
- rclpy.spin(node)
- finally:
- if rclpy.ok():
- rclpy.shutdown()
- if __name__ == '__main__':
- main()
|