Bläddra i källkod

在大模型逻辑增加了sessionid 防止串线

hwt 10 timmar sedan
förälder
incheckning
bcc043c920
1 ändrade filer med 78 tillägg och 17 borttagningar
  1. 78 17
      brain/PlannerNode2/largemodel/largemodel/model_service.py

+ 78 - 17
brain/PlannerNode2/largemodel/largemodel/model_service.py

@@ -3,7 +3,7 @@ import json
 import rclpy
 from rclpy.node import Node
 from interfaces.action import Rot
-from std_msgs.msg import String
+from std_msgs.msg import String, Bool
 from utils import large_model_interface
 from rclpy.action import ActionClient
 from ament_index_python.packages import get_package_share_directory
@@ -83,6 +83,7 @@ class LargeModelService(Node):
         )
 
         self.conversation_id = None  # 会话id
+        self.current_session_id = 0  # 当前会话ID,用于防串线
         self.map_mapping = ""  # 地图映射(从订阅获取)
         self.config_data = {}  # 配置数据(从订阅获取)
         self.map_points = []  # 地图导航点列表(从订阅获取)
@@ -136,6 +137,10 @@ class LargeModelService(Node):
         self.config_sub = self.create_subscription(
             String, "/ai/config", self.config_callback, 10
         )
+        # 订阅 wakeup 信号 / Subscribe wakeup signal
+        self.wakeup_sub = self.create_subscription(
+            Bool, "wakeup", self.wakeup_callback, 5
+        )
         self.environment_sub = None  # 环境数据订阅者(需要动态更新)
         self.environment_topic = "/ai/env"  # 默认值(与 environment_node publish_topic 一致)
 
@@ -144,22 +149,36 @@ class LargeModelService(Node):
 
     def seewhat_callback(self, msg):
         if msg.data == "seewhat":
+            session_id = self.current_session_id
             if (
                 self.regional_setting == "China"
             ):  # 在线模型推理方式:决策层推理+执行层监督 / Online model inference method: Decision layer reasoning + Execution layer supervision
-                self.dual_large_model_mode(type="image")
+                self.dual_large_model_mode(type="image", session_id=session_id)
             else:
-                self.dual_large_model_international_model(type="image")
+                self.dual_large_model_international_model(type="image", session_id=session_id)
+
+    def wakeup_callback(self, msg):
+        """
+        订阅 wakeup 信号回调
+        每次收到唤醒信号,session_id + 1,确保旧 LLM 结果被丢弃
+        """
+        if msg.data:
+            self.current_session_id += 1
+            self.get_logger().warn(
+                f"[Session] 用户重新唤醒,切换 session_id={self.current_session_id}"
+            )
 
     def asr_callback(self, msg):
+        session_id = self.current_session_id
         if (
             self.regional_setting == "China"
         ):  # 在线模型推理方式:决策层推理+执行层监督 / Online model inference method: Decision layer reasoning + Execution layer supervision
-            self.dual_large_model_mode(type="text", prompt=msg.data)
+            self.dual_large_model_mode(type="text", prompt=msg.data, session_id=session_id)
         else:
-            self.dual_large_model_international_model(type="text", prompt=msg.data)
+            self.dual_large_model_international_model(type="text", prompt=msg.data, session_id=session_id)
 
     def actionstatus_callback(self, msg):
+        session_id = self.current_session_id
         if (
             msg.data == "finish"
         ):  # 如果收到的是finish则表示当前指令执行完成,开启新的指令执行周期 / If "finish" is received, it means the current instruction has been executed and a new instruction cycle begins
@@ -171,14 +190,14 @@ class LargeModelService(Node):
             # ask_user 超时,触发空推理
             self.get_logger().warn("[多轮对话] ask_user 超时,触发空推理")
             if self.regional_setting == "China":
-                self.dual_large_model_mode(type="text", prompt="ask_user_timeout")
+                self.dual_large_model_mode(type="text", prompt="ask_user_timeout", session_id=session_id)
             else:
-                self.dual_large_model_international_model(type="text", prompt="ask_user_timeout")
+                self.dual_large_model_international_model(type="text", prompt="ask_user_timeout", session_id=session_id)
         else:  # 向指令执行层大模型反馈动作执行结果 / Feedback action execution results to the large model in the command execution layer
             if self.regional_setting == "China":
-                self.dual_large_model_mode(type="text", prompt=msg.data)
+                self.dual_large_model_mode(type="text", prompt=msg.data, session_id=session_id)
             else:
-                self.dual_large_model_international_model(type="text", prompt=msg.data)
+                self.dual_large_model_international_model(type="text", prompt=msg.data, session_id=session_id)
 
     def config_callback(self, msg):
         """
@@ -270,7 +289,7 @@ class LargeModelService(Node):
             self.get_logger().warn(f'解析环境数据失败: {e}')
 
     # @measure_execution_time
-    def dual_large_model_mode(self, type, prompt=""):
+    def dual_large_model_mode(self, type, prompt="", session_id=None):
         """
         此函数实现了双模型推理模式,即先由文本生成模型进行任务规划,然后由多模态大模型生成动作列表
         This function implements the dual model inference mode, where the text generation model first plans the task, and then the multimodal large model generates the action list.
@@ -294,6 +313,13 @@ class LargeModelService(Node):
                 prompt
             )  # 调用决策层大模型进行任务规划 / Call the decision layer large model for task planning
 
+            # 检查 session 是否过期
+            if session_id is not None and session_id != self.current_session_id:
+                self.get_logger().warn(
+                    f"[Session] 丢弃过期决策层结果: result_session={session_id}, current_session={self.current_session_id}"
+                )
+                return
+
             if not execute_instructions == "error":
 
                 prompt_desidon = (
@@ -316,6 +342,7 @@ class LargeModelService(Node):
                 self.instruction_process(
                     type="text",
                     prompt=prompt_desidon,
+                    session_id=session_id,
                 )  # 传递决策层模型规划好的执行步骤给执行层模型 / Pass the planned execution steps from the decision layer model to the execution layer model
 
                 self.new_order_cycle = (
@@ -327,10 +354,11 @@ class LargeModelService(Node):
                 )  # 模型推理失败,请检查模型配额和账户是否正常!!!
         else:
             self.instruction_process(
-                prompt, type
+                prompt, type,
+                session_id=session_id,
             )  # 调用执行层大模型生成成动作列表并执行 / Call the execution layer large model to generate an action list and execute
 
-    def instruction_process(self, prompt, type, conversation_id=None):
+    def instruction_process(self, prompt, type, conversation_id=None, session_id=None):
         """
         根据输入信息的类型(文字/图片),构建不同的请求体进行推理,并返回结果)
         Based on the type of input information (text/image), construct different request bodies for inference and return the result.
@@ -343,6 +371,14 @@ class LargeModelService(Node):
                 raw_content = self.model_client.multimodalinfer(
                     prompt_seewhat, image_path=self.image_save_path
                 )
+
+            # 检查 session 是否过期
+            if session_id is not None and session_id != self.current_session_id:
+                self.get_logger().warn(
+                    f"[Session] 丢弃过期执行层结果(国内): result_session={session_id}, current_session={self.current_session_id}"
+                )
+                return
+
             json_str = self.extract_json_content(raw_content)
 
         elif self.regional_setting == "international":  # 国际版
@@ -358,6 +394,7 @@ class LargeModelService(Node):
                     self.conversation_id = result[2]
                 else:
                     self.get_logger().info(f"ERROR:{result[1]}")
+                    json_str = None
             elif type == "image":
                 result = self.model_client.TaskExecution(
                     input=prompt_seewhat,
@@ -371,6 +408,14 @@ class LargeModelService(Node):
                     self.conversation_id = result[2]
                 else:
                     self.get_logger().info(f"ERROR:{result[1]}")
+                    json_str = None
+
+            # 检查 session 是否过期
+            if session_id is not None and session_id != self.current_session_id:
+                self.get_logger().warn(
+                    f"[Session] 丢弃过期执行层结果(国际): result_session={session_id}, current_session={self.current_session_id}"
+                )
+                return
 
         if json_str is not None:
             # 解析JSON字符串,分离"action"、"response"字段 / Parse JSON string, separate "action" and "response" fields
@@ -392,10 +437,11 @@ class LargeModelService(Node):
             )
 
         self.send_action_service(
-            action_list, llm_response
+            action_list, llm_response,
+            session_id=session_id,
         )  # 异步发送动作列表、回复内容给ActionServer / Asynchronously send action list and response content to ActionServer
 
-    def dual_large_model_international_model(self, type, prompt=""):
+    def dual_large_model_international_model(self, type, prompt="", session_id=None):
         """
         此函数适用于国际版双模型推理模式,使用dify作为中间件
         /this function is suitable for international model inference mode, using dify as the middleware
@@ -406,6 +452,13 @@ class LargeModelService(Node):
             self.conversation_id = None
             result = self.model_client.TaskDecision(prompt)
 
+            # 检查 session 是否过期
+            if session_id is not None and session_id != self.current_session_id:
+                self.get_logger().warn(
+                    f"[Session] 丢弃过期决策层结果(国际): result_session={session_id}, current_session={self.current_session_id}"
+                )
+                return
+
             if result[0]:
                 prompt_desidon = (
                     self.prompt_dict[self.language]
@@ -422,7 +475,7 @@ class LargeModelService(Node):
                     .get("prompt_1")
                     .format(prompt=prompt, execute_instructions=result[1])
                 )
-                self.instruction_process(type="text", prompt=prompt_desion)
+                self.instruction_process(type="text", prompt=prompt_desion, session_id=session_id)
 
                 self.new_order_cycle = (
                     False  # 重置指令周期标志位 / Reset the instruction cycle flag
@@ -434,10 +487,18 @@ class LargeModelService(Node):
 
         else:
             self.instruction_process(
-                prompt, type, conversation_id=self.conversation_id
+                prompt, type, conversation_id=self.conversation_id,
+                session_id=session_id,
             )  # 调用执行层大模型生成成动作列表并执行 / Call the execution layer large model to generate an action list and execute
 
-    def send_action_service(self, actions, text):
+    def send_action_service(self, actions, text, session_id=None):
+        # 检查 session 是否过期
+        if session_id is not None and session_id != self.current_session_id:
+            self.get_logger().warn(
+                f"[Session] 拒绝发送过期 action: result_session={session_id}, current_session={self.current_session_id}"
+            )
+            return
+
         goal_msg = Rot.Goal()  # 创建目标消息对象 / Create goal message object
         goal_msg.actions = actions  # 设置目标消息中的动作列表 / Set the action list in the goal message
         goal_msg.llm_response = text