world_model.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. """
  2. WorldModel - 主管理器
  3. 核心功能:
  4. - 管理所有 Entity (实体)
  5. - 管理 EnvironmentState (环境状态)
  6. - 管理 SystemState (系统状态)
  7. - 提供 Snapshot 导出
  8. - 支持 Hook 回调
  9. """
  10. from __future__ import annotations
  11. from typing import Callable, Any
  12. import time
  13. import threading
  14. from .entity import Entity
  15. from .environment import EnvironmentState
  16. from .system_state import SystemState
  17. ChangeHook = Callable[[str, dict, Any], None]
  18. """
  19. Hook 回调签名:
  20. def hook(change_type: str, data: dict, extra: Any) -> None
  21. Args:
  22. change_type: 变化类型 (entity_update, entity_remove, env_update, sys_update)
  23. data: 变化的数据
  24. extra: 额外参数
  25. """
  26. class WorldModel:
  27. """
  28. 世界模型主管理器
  29. 线程安全,支持并发访问
  30. """
  31. def __init__(self):
  32. self._entities: dict[str, Entity] = {}
  33. self._environment = EnvironmentState()
  34. self._system_state = SystemState()
  35. self._hooks: list[ChangeHook] = []
  36. self._lock = threading.RLock()
  37. # ========== Entity 管理 ==========
  38. def update_entity(
  39. self,
  40. entity_id: str,
  41. entity_type: str,
  42. state: dict | None = None,
  43. metadata: dict | None = None,
  44. ttl: float | None = None,
  45. ) -> dict:
  46. """
  47. 更新或创建实体
  48. Args:
  49. entity_id: 实体 ID
  50. entity_type: 实体类型
  51. state: 状态字典
  52. metadata: 元数据
  53. ttl: 生存时间
  54. Returns:
  55. 变化字典
  56. """
  57. with self._lock:
  58. current_time = time.time()
  59. if entity_id in self._entities:
  60. entity = self._entities[entity_id]
  61. diff = entity.update_state(state or {}, current_time)
  62. if metadata:
  63. entity.update_metadata(metadata)
  64. if ttl is not None:
  65. entity.ttl = ttl
  66. if diff:
  67. self._trigger_hook('entity_update', {
  68. 'entity_id': entity_id,
  69. 'type': entity_type,
  70. 'diff': diff,
  71. 'state': entity.state,
  72. })
  73. return diff
  74. else:
  75. entity = Entity(
  76. entity_id=entity_id,
  77. entity_type=entity_type,
  78. state=state or {},
  79. metadata=metadata,
  80. ttl=ttl,
  81. last_update=current_time,
  82. )
  83. self._entities[entity_id] = entity
  84. self._trigger_hook('entity_create', {
  85. 'entity_id': entity_id,
  86. 'type': entity_type,
  87. 'state': state,
  88. })
  89. return {'_created': True}
  90. def remove_entity(self, entity_id: str) -> bool:
  91. """移除实体"""
  92. with self._lock:
  93. if entity_id in self._entities:
  94. del self._entities[entity_id]
  95. self._trigger_hook('entity_remove', {'entity_id': entity_id})
  96. return True
  97. return False
  98. def get_entity(self, entity_id: str) -> Entity | None:
  99. """获取实体"""
  100. with self._lock:
  101. return self._entities.get(entity_id)
  102. def get_entities_by_type(self, entity_type: str) -> list[Entity]:
  103. """获取指定类型的所有实体"""
  104. with self._lock:
  105. return [e for e in self._entities.values() if e.entity_type == entity_type]
  106. def get_all_entities(self) -> dict[str, Entity]:
  107. """获取所有实体 (返回副本)"""
  108. with self._lock:
  109. return dict(self._entities)
  110. def cleanup_expired(self) -> int:
  111. """清理过期实体,返回清理数量"""
  112. with self._lock:
  113. current_time = time.time()
  114. expired_ids = [
  115. eid for eid, entity in self._entities.items()
  116. if entity.is_expired(current_time)
  117. ]
  118. for eid in expired_ids:
  119. del self._entities[eid]
  120. self._trigger_hook('entity_expire', {'entity_id': eid})
  121. return len(expired_ids)
  122. # ========== Environment 管理 ==========
  123. def update_environment(self, key: str, value: Any) -> None:
  124. """更新单个环境数据"""
  125. with self._lock:
  126. old_value = self._environment.get(key)
  127. self._environment.set(key, value)
  128. self._trigger_hook('env_update', {
  129. 'key': key,
  130. 'old': old_value,
  131. 'new': value,
  132. })
  133. def update_environment_batch(self, data: dict) -> dict:
  134. """批量更新环境数据"""
  135. with self._lock:
  136. diff = self._environment.update(data)
  137. if diff:
  138. self._trigger_hook('env_batch_update', diff)
  139. return diff
  140. def get_environment(self) -> EnvironmentState:
  141. """获取环境状态 (返回副本)"""
  142. with self._lock:
  143. return self._environment
  144. # ========== SystemState 管理 ==========
  145. def update_system(self, key: str, value: Any) -> None:
  146. """更新单个系统数据"""
  147. with self._lock:
  148. self._system_state.update({key: value})
  149. self._trigger_hook('sys_update', {'key': key, 'value': value})
  150. def update_system_batch(self, data: dict) -> dict:
  151. """批量更新系统数据"""
  152. with self._lock:
  153. diff = self._system_state.update(data)
  154. if diff:
  155. self._trigger_hook('sys_batch_update', diff)
  156. return diff
  157. def get_system_state(self) -> SystemState:
  158. """获取系统状态 (返回副本)"""
  159. with self._lock:
  160. return self._system_state
  161. # ========== Hook 管理 ==========
  162. def add_hook(self, hook: ChangeHook) -> None:
  163. """添加变化回调"""
  164. with self._lock:
  165. self._hooks.append(hook)
  166. def remove_hook(self, hook: ChangeHook) -> None:
  167. """移除回调"""
  168. with self._lock:
  169. if hook in self._hooks:
  170. self._hooks.remove(hook)
  171. def _trigger_hook(self, change_type: str, data: dict) -> None:
  172. """触发所有 Hook"""
  173. for hook in self._hooks:
  174. try:
  175. hook(change_type, data, self)
  176. except Exception:
  177. pass
  178. # ========== Snapshot 导出 ==========
  179. def snapshot(self, mode: str = "minimal") -> dict:
  180. """
  181. 生成世界状态快照
  182. Args:
  183. mode: 模式 "minimal" (仅关键信息) | "full" (完整信息)
  184. Returns:
  185. 快照字典
  186. """
  187. with self._lock:
  188. current_time = time.time()
  189. if mode == "minimal":
  190. return self._snapshot_minimal(current_time)
  191. else:
  192. return self._snapshot_full(current_time)
  193. def _snapshot_minimal(self, current_time: float) -> dict:
  194. """Minimal 快照 - 仅关键信息"""
  195. entities = {}
  196. entity_types = {}
  197. for eid, entity in self._entities.items():
  198. if entity.is_expired(current_time):
  199. continue
  200. entities[eid] = {
  201. 'id': eid,
  202. 'type': entity.entity_type,
  203. 'state': entity.state,
  204. }
  205. etype = entity.entity_type
  206. entity_types[etype] = entity_types.get(etype, 0) + 1
  207. return {
  208. 'timestamp': current_time,
  209. 'mode': 'minimal',
  210. 'entities': entities,
  211. 'environment': {},
  212. 'system': self._system_state.to_dict(),
  213. 'stats': {
  214. 'total_entities': len(entities),
  215. 'entity_types': entity_types,
  216. },
  217. }
  218. def _snapshot_full(self, current_time: float) -> dict:
  219. """Full 快照 - 完整信息"""
  220. entities = {}
  221. entity_types = {}
  222. for eid, entity in self._entities.items():
  223. if entity.is_expired(current_time):
  224. continue
  225. entities[eid] = entity.to_dict()
  226. etype = entity.entity_type
  227. entity_types[etype] = entity_types.get(etype, 0) + 1
  228. return {
  229. 'timestamp': current_time,
  230. 'mode': 'full',
  231. 'entities': entities,
  232. 'environment': self._environment.to_dict(),
  233. 'system': self._system_state.to_dict(),
  234. 'stats': {
  235. 'total_entities': len(entities),
  236. 'entity_types': entity_types,
  237. 'total_entities_raw': len(self._entities),
  238. 'expired_count': sum(1 for e in self._entities.values() if e.is_expired(current_time)),
  239. },
  240. }
  241. def reset(self) -> None:
  242. """重置所有状态"""
  243. with self._lock:
  244. self._entities.clear()
  245. self._environment = EnvironmentState()
  246. self._system_state = SystemState()
  247. self._trigger_hook('reset', {})