4
0

model_service.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. import os
  2. import json
  3. import rclpy
  4. from rclpy.node import Node
  5. from interfaces.action import Rot
  6. from std_msgs.msg import String
  7. from utils import large_model_interface
  8. from rclpy.action import ActionClient
  9. from ament_index_python.packages import get_package_share_directory
  10. from utils.promot import get_prompt, get_map_mapping, set_map_mapping, set_large_model_config, set_model_paths, set_system_config, set_environment_data
  11. import time
  12. import re
  13. import functools
  14. def measure_execution_time(func):
  15. """
  16. 装饰器:测量函数执行时间并使用 ROS 日志打印结果
  17. """
  18. @functools.wraps(func)
  19. def wrapper(self, *args, **kwargs):
  20. start_time = time.time()
  21. result = func(self, *args, **kwargs)
  22. end_time = time.time()
  23. execution_time = end_time - start_time
  24. # 使用 ROS 日志系统记录执行时间
  25. if hasattr(self, 'get_logger'):
  26. self.get_logger().info(f"[性能统计] {func.__name__} 函数执行时间: {execution_time:.4f} 秒")
  27. else:
  28. print(f"[性能统计] {func.__name__} 函数执行时间: {execution_time:.4f} 秒")
  29. return result
  30. return wrapper
  31. class LargeModelService(Node):
  32. def __init__(self):
  33. super().__init__("LargeModelService")
  34. self.init_param_config() # 初始化参数配置 / Initialize parameter configuration
  35. self.init_largemodel() # 初始化大模型 / Initialize large model
  36. self.init_ros_comunication() # 初始化ROS通信 / Initialize ROS communication
  37. self.init_language() # 初始化语言/Initialize language
  38. self.get_logger().info(
  39. "LargeModelService node Initialization completed..."
  40. ) # 打印日志 / Print log
  41. def init_largemodel(self):
  42. # 创建模型接口客户端 / Create model interface client
  43. # 传入 logger 用于调试日志
  44. self.model_client = large_model_interface.model_interface(logger=self.get_logger())
  45. self.new_order_cycle = True # 新指令周期标志 / New order cycle flag
  46. if self.regional_setting == "China": # 如果是中国地区
  47. self.model_client.init_Multimodal() # 初始化执行层模型,决策层模型无需初始化 / Initialize execution layer model, decision layer model does not need initialization
  48. elif self.regional_setting == "international": # 如果是国际地区
  49. self.model_client.init_dify_client()
  50. else:
  51. while True:
  52. self.get_logger().info()(
  53. 'Please check the regional_setting parameter in yahboom.yaml file, it should be either "China" or "international".'
  54. )
  55. time.sleep(1)
  56. def init_param_config(self):
  57. self.pkg_path = get_package_share_directory("largemodel")
  58. self.image_save_path = os.path.join(
  59. self.pkg_path, "resources_file", "image.png"
  60. )
  61. # 参数声明 / Parameter declaration
  62. self.declare_parameter("language", "zh")
  63. self.declare_parameter("regional_setting", "China")
  64. self.declare_parameter("text_chat_mode", False)
  65. # 获取参数服务器参数 / Get parameters from the parameter server
  66. self.language = (
  67. self.get_parameter("language").get_parameter_value().string_value
  68. )
  69. self.regional_setting = (
  70. self.get_parameter("regional_setting").get_parameter_value().string_value
  71. )
  72. self.text_chat_mode = (
  73. self.get_parameter("text_chat_mode").get_parameter_value().bool_value
  74. )
  75. self.conversation_id = None # 会话id
  76. self.map_mapping = "" # 地图映射(从订阅获取)
  77. self.config_data = {} # 配置数据(从订阅获取)
  78. self.map_points = [] # 地图导航点列表(从订阅获取)
  79. self.environment_data = {} # 环境数据(从订阅获取)
  80. def init_language(self):
  81. self.language_dict = {
  82. "zh": "中文",
  83. "en": "English",
  84. }
  85. language_list = ["zh", "en"]
  86. if self.language not in language_list:
  87. while True:
  88. self.get_logger().info(
  89. "The language setting is incorrect. Please check the action_service'' language setting in the yahboom.yaml file"
  90. )
  91. self.get_logger().info(self.language)
  92. time.sleep(1)
  93. self.prompt_dict = { #
  94. "zh": { # 中文 / Chinese
  95. "prompt_1": "用户:{prompt},决策层AI规划:{execute_instructions}",
  96. "prompt_2": "机器人反馈:执行seewhat()完成",
  97. "prompt_3": "决策层AI规划:{execute_instructions}",
  98. },
  99. "en": { # 英文 / English
  100. "prompt_1": "user:{prompt},Decision making AI planning:{execute_instructions}",
  101. "prompt_2": "Robot feedback: Execute seewhat() completed",
  102. "prompt_3": "Decision making AI planning:{execute_instructions}",
  103. },
  104. }
  105. def init_ros_comunication(self):
  106. # 创建执行动作状态订阅者 / Create action status subscriber
  107. self.actionstatus_sub = self.create_subscription(
  108. String, "actionstatus", self.actionstatus_callback, 1
  109. )
  110. # 创建动作客户端,连接到 'action_service' / Create action client, connect to 'action_service'
  111. self._action_client = ActionClient(self, Rot, "action_service")
  112. # asr话题订阅者 / ASR topic subscriber
  113. self.asrsub = self.create_subscription(String, "asr", self.asr_callback, 1)
  114. # 创建seehat订阅者 / Create seewhat subscriber
  115. self.seewhat_sub = self.create_subscription(
  116. String, "seewhat_handle", self.seewhat_callback, 1
  117. )
  118. # 创建执行动作状态发布者 / Create action status publisher
  119. self.actionstatus_pub = self.create_publisher(String, "actionstatus", 1)
  120. # 创建文字交互发布者 / Create text interaction publisher
  121. self.text_pub = self.create_publisher(String, "text_response", 1)
  122. # 订阅配置节点数据 / Subscribe config node data
  123. self.config_sub = self.create_subscription(
  124. String, "/ai/config", self.config_callback, 10
  125. )
  126. self.environment_sub = None # 环境数据订阅者(需要动态更新)
  127. self.environment_topic = "/ai/env" # 默认值(与 environment_node publish_topic 一致)
  128. # 初始化时就创建环境数据订阅(使用默认 topic)
  129. self.update_environment_subscription()
  130. def seewhat_callback(self, msg):
  131. if msg.data == "seewhat":
  132. if (
  133. self.regional_setting == "China"
  134. ): # 在线模型推理方式:决策层推理+执行层监督 / Online model inference method: Decision layer reasoning + Execution layer supervision
  135. self.dual_large_model_mode(type="image")
  136. else:
  137. self.dual_large_model_international_model(type="image")
  138. def asr_callback(self, msg):
  139. if (
  140. self.regional_setting == "China"
  141. ): # 在线模型推理方式:决策层推理+执行层监督 / Online model inference method: Decision layer reasoning + Execution layer supervision
  142. self.dual_large_model_mode(type="text", prompt=msg.data)
  143. else:
  144. self.dual_large_model_international_model(type="text", prompt=msg.data)
  145. def actionstatus_callback(self, msg):
  146. if (
  147. msg.data == "finish"
  148. ): # 如果收到的是finish则表示当前指令执行完成,开启新的指令执行周期 / If "finish" is received, it means the current instruction has been executed and a new instruction cycle begins
  149. self.new_order_cycle = True
  150. self.get_logger().info(
  151. "The current instruction cycle has ended"
  152. ) # 当前指令周期已结束...
  153. elif msg.data == "ask_user_timeout":
  154. # ask_user 超时,触发空推理
  155. self.get_logger().warn("[多轮对话] ask_user 超时,触发空推理")
  156. if self.regional_setting == "China":
  157. self.dual_large_model_mode(type="text", prompt="ask_user_timeout")
  158. else:
  159. self.dual_large_model_international_model(type="text", prompt="ask_user_timeout")
  160. else: # 向指令执行层大模型反馈动作执行结果 / Feedback action execution results to the large model in the command execution layer
  161. if self.regional_setting == "China":
  162. self.dual_large_model_mode(type="text", prompt=msg.data)
  163. else:
  164. self.dual_large_model_international_model(type="text", prompt=msg.data)
  165. def config_callback(self, msg):
  166. """
  167. 订阅配置数据回调函数
  168. 从 /ai/config topic 接收配置数据
  169. """
  170. try:
  171. config_json = json.loads(msg.data)
  172. self.config_data = config_json.get('config', {})
  173. # 更新大模型配置
  174. large_model_config = self.config_data.get('large_model', {})
  175. if large_model_config:
  176. set_large_model_config(large_model_config)
  177. # 如果模型接口支持动态更新,则调用更新接口
  178. if hasattr(self.model_client, 'update_config'):
  179. self.model_client.update_config(large_model_config)
  180. # 更新模型路径
  181. model_paths = self.config_data.get('model_paths', {})
  182. if model_paths:
  183. set_model_paths(model_paths)
  184. # 更新系统配置
  185. system_config = self.config_data.get('system', {})
  186. if system_config:
  187. set_system_config(system_config)
  188. # 更新 topics 配置
  189. topics_config = self.config_data.get('topics', {})
  190. if topics_config:
  191. environment_node_config = topics_config.get('environment_node', {})
  192. if environment_node_config:
  193. new_topic = environment_node_config.get('environment_topic', '/ai/env')
  194. if new_topic != self.environment_topic or self.environment_sub is None:
  195. self.environment_topic = new_topic
  196. self.update_environment_subscription()
  197. self.get_logger().info(f'[配置] 环境数据订阅 Topic 已更新: {self.environment_topic}')
  198. except Exception as e:
  199. self.get_logger().warn(f'解析配置数据失败: {e}')
  200. def update_environment_subscription(self):
  201. """动态更新环境数据订阅"""
  202. try:
  203. if self.environment_sub:
  204. self.destroy_subscription(self.environment_sub)
  205. self.environment_sub = self.create_subscription(
  206. String, self.environment_topic, self.environment_callback, 10
  207. )
  208. self.get_logger().info(f'[配置] 已订阅环境数据 Topic: {self.environment_topic}')
  209. except Exception as e:
  210. self.get_logger().warn(f'更新环境订阅失败: {e}')
  211. def environment_callback(self, msg):
  212. """
  213. 订阅环境数据回调函数
  214. 从 /ai/environment topic 接收环境数据
  215. """
  216. try:
  217. env_json = json.loads(msg.data)
  218. self.environment_data = env_json
  219. # 更新环境数据缓存(供提示词使用)
  220. set_environment_data(env_json)
  221. # 更新地图映射数据
  222. map_data = env_json.get('map', {})
  223. if map_data:
  224. points = map_data.get('points', [])
  225. if points:
  226. # 将 points 字典转换为地图映射格式
  227. # 使用更清晰的格式,让大模型知道用 id 调用
  228. map_str = "#地图映射\n\n"
  229. for point in points:
  230. point_id = point.get('id', '')
  231. name = point.get('name', '')
  232. if point_id and name:
  233. map_str += f"{point_id} -> {name}\n"
  234. self.map_mapping = map_str
  235. set_map_mapping(map_str)
  236. self.map_points = points
  237. else:
  238. self.get_logger().warn("[环境回调] points 为空")
  239. else:
  240. self.get_logger().warn("[环境回调] map_data 为空")
  241. except Exception as e:
  242. self.get_logger().warn(f'解析环境数据失败: {e}')
  243. # @measure_execution_time
  244. def dual_large_model_mode(self, type, prompt=""):
  245. """
  246. 此函数实现了双模型推理模式,即先由文本生成模型进行任务规划,然后由多模态大模型生成动作列表
  247. This function implements the dual model inference mode, where the text generation model first plans the task, and then the multimodal large model generates the action list.
  248. """
  249. if (
  250. self.new_order_cycle
  251. ): # 判断是否是新任务周期 / Determine if it is a new task cycle
  252. # 获取完整的 prompt 并打印
  253. full_prompt = get_prompt()
  254. self.get_logger().info("=" * 80)
  255. self.get_logger().info("[调试] 发送给决策层大模型的完整 Prompt:")
  256. self.get_logger().info("=" * 80)
  257. self.get_logger().info(full_prompt)
  258. self.get_logger().info("=" * 80)
  259. self.get_logger().info(f"[调试] 用户输入: {prompt}")
  260. self.get_logger().info("=" * 80)
  261. # 判断上一轮对话指令是否完成如果完成就清空历史上下文,开启新的上下文 / Determine if the previous round of dialogue instructions are completed. If completed, clear the historical context and start a new context
  262. self.model_client.init_Multimodal_history(full_prompt) # 初始化执行层上下文历史 / Initialize execution layer context history
  263. execute_instructions = self.model_client.TaskDecision(
  264. prompt
  265. ) # 调用决策层大模型进行任务规划 / Call the decision layer large model for task planning
  266. if not execute_instructions == "error":
  267. prompt_desidon = (
  268. self.prompt_dict[self.language]
  269. .get("prompt_3")
  270. .format(execute_instructions=execute_instructions[1])
  271. ) # 翻译成对应语言的prompt /translate into the corresponding language prompt
  272. if self.text_chat_mode:
  273. msg = String(data=prompt_desidon)
  274. self.text_pub.publish(msg)
  275. else:
  276. self.get_logger().info(prompt_desidon) # 即将执行的任务:...
  277. prompt_desidon = (
  278. self.prompt_dict[self.language]
  279. .get("prompt_1")
  280. .format(prompt=prompt, execute_instructions=execute_instructions[1])
  281. ) # 翻译成对应语言的prompt /translate into the corresponding language prompt
  282. self.instruction_process(
  283. type="text",
  284. prompt=prompt_desidon,
  285. ) # 传递决策层模型规划好的执行步骤给执行层模型 / Pass the planned execution steps from the decision layer model to the execution layer model
  286. self.new_order_cycle = (
  287. False # 重置指令周期标志位 / Reset the instruction cycle flag
  288. )
  289. else:
  290. self.get_logger().info(
  291. "The model service is abnormal. Check the large model account or configuration options"
  292. ) # 模型推理失败,请检查模型配额和账户是否正常!!!
  293. else:
  294. self.instruction_process(
  295. prompt, type
  296. ) # 调用执行层大模型生成成动作列表并执行 / Call the execution layer large model to generate an action list and execute
  297. def instruction_process(self, prompt, type, conversation_id=None):
  298. """
  299. 根据输入信息的类型(文字/图片),构建不同的请求体进行推理,并返回结果)
  300. Based on the type of input information (text/image), construct different request bodies for inference and return the result.
  301. """
  302. prompt_seewhat = self.prompt_dict[self.language].get("prompt_2")
  303. if self.regional_setting == "China": # 国内版
  304. if type == "text":
  305. raw_content = self.model_client.multimodalinfer(prompt)
  306. elif type == "image":
  307. raw_content = self.model_client.multimodalinfer(
  308. prompt_seewhat, image_path=self.image_save_path
  309. )
  310. json_str = self.extract_json_content(raw_content)
  311. elif self.regional_setting == "international": # 国际版
  312. if type == "text":
  313. result = self.model_client.TaskExecution(
  314. input=prompt,
  315. map_mapping=self.map_mapping,
  316. language=self.language_dict[self.language],
  317. conversation_id=conversation_id,
  318. )
  319. if result[0]:
  320. json_str = self.extract_json_content(result[1])
  321. self.conversation_id = result[2]
  322. else:
  323. self.get_logger().info(f"ERROR:{result[1]}")
  324. elif type == "image":
  325. result = self.model_client.TaskExecution(
  326. input=prompt_seewhat,
  327. map_mapping=self.map_mapping,
  328. language=self.language_dict[self.language],
  329. image_path=self.image_save_path,
  330. conversation_id=conversation_id,
  331. )
  332. if result[0]:
  333. json_str = self.extract_json_content(result[1])
  334. self.conversation_id = result[2]
  335. else:
  336. self.get_logger().info(f"ERROR:{result[1]}")
  337. if json_str is not None:
  338. # 解析JSON字符串,分离"action"、"response"字段 / Parse JSON string, separate "action" and "response" fields
  339. action_plan_json = json.loads(json_str)
  340. action_list = action_plan_json.get("action", [])
  341. llm_response = action_plan_json.get("response", "")
  342. else:
  343. self.get_logger().info(
  344. f"LargeScaleModel return: {json_str},The format was unexpected. The output format of the AI model at the execution layer did not meet the requirements"
  345. )
  346. return
  347. if self.text_chat_mode:
  348. msg = String(data=f'"action": {action_list}, "response": {llm_response}')
  349. self.text_pub.publish(msg)
  350. else:
  351. self.get_logger().info(
  352. f'"action": {action_list}, "response": {llm_response}'
  353. )
  354. self.send_action_service(
  355. action_list, llm_response
  356. ) # 异步发送动作列表、回复内容给ActionServer / Asynchronously send action list and response content to ActionServer
  357. def dual_large_model_international_model(self, type, prompt=""):
  358. """
  359. 此函数适用于国际版双模型推理模式,使用dify作为中间件
  360. /this function is suitable for international model inference mode, using dify as the middleware
  361. """
  362. if (
  363. self.new_order_cycle
  364. ): # 判断是否是新任务周期 / Determine if it is a new task cycle
  365. self.conversation_id = None
  366. result = self.model_client.TaskDecision(prompt)
  367. if result[0]:
  368. prompt_desidon = (
  369. self.prompt_dict[self.language]
  370. .get("prompt_3")
  371. .format(execute_instructions=result[1])
  372. ) # 翻译成对应语言的prompt /translate into the corresponding language prompt
  373. if self.text_chat_mode: # 文字交互模式 / Text interaction mode
  374. msg = String(data=prompt_desidon)
  375. self.text_pub.publish(msg)
  376. else: # 语音交互模式 / Voice interaction mode
  377. self.get_logger().info(prompt_desidon)
  378. prompt_desion = (
  379. self.prompt_dict[self.language]
  380. .get("prompt_1")
  381. .format(prompt=prompt, execute_instructions=result[1])
  382. )
  383. self.instruction_process(type="text", prompt=prompt_desion)
  384. self.new_order_cycle = (
  385. False # 重置指令周期标志位 / Reset the instruction cycle flag
  386. )
  387. else:
  388. self.get_logger().info(
  389. "The model service is abnormal. Check the large model account or configuration options"
  390. ) # 模型推理失败,请检查模型配额和账户是否正常!!!
  391. else:
  392. self.instruction_process(
  393. prompt, type, conversation_id=self.conversation_id
  394. ) # 调用执行层大模型生成成动作列表并执行 / Call the execution layer large model to generate an action list and execute
  395. def send_action_service(self, actions, text):
  396. goal_msg = Rot.Goal() # 创建目标消息对象 / Create goal message object
  397. goal_msg.actions = actions # 设置目标消息中的动作列表 / Set the action list in the goal message
  398. goal_msg.llm_response = text
  399. self._send_goal_future = self._action_client.send_goal_async(goal_msg)
  400. # 添加目标发送后的响应回调函数 / Add response callback function after sending the goal
  401. self._send_goal_future.add_done_callback(self.goal_response_callback)
  402. def goal_response_callback(self, future):
  403. goal_handle = future.result() # 获取目标句柄 / Get goal handle
  404. if not goal_handle.accepted:
  405. self.get_logger().info(
  406. "action_client message: action service rejected action list"
  407. ) # 目标被拒绝...
  408. @staticmethod
  409. def extract_json_content(
  410. raw_content,
  411. ): # 解析变量提取json / Extract JSON by parsing variables
  412. try:
  413. # 方法一:分割代码块 / Method 1: Split code blocks
  414. if "```json" in raw_content:
  415. # 分割代码块并取中间部分 / Split code blocks and take the middle part
  416. json_str = raw_content.split("```json")[1].split("```")[0].strip()
  417. elif "```" in raw_content:
  418. # 处理没有指定类型的代码块 / Handle code blocks without specified types
  419. json_str = raw_content.split("```")[1].strip()
  420. else:
  421. # 直接尝试解析 / Try parsing directly
  422. json_str = raw_content
  423. # 方法二:正则表达式提取(备用方案) / Method 2: Regular expression extraction (backup plan)
  424. if not json_str:
  425. match = re.search(r"\{.*\}", raw_content, re.DOTALL)
  426. if match:
  427. json_str = match.group()
  428. return json_str
  429. except Exception as e:
  430. return None
  431. def main(args=None):
  432. rclpy.init(args=args)
  433. model_service = LargeModelService()
  434. rclpy.spin(model_service)
  435. rclpy.shutdown()
  436. if __name__ == "__main__":
  437. main()