Ver Fonte

在大模型逻辑增加了sessionid 防止串线

hwt há 10 horas atrás
pai
commit
bcc043c920
1 ficheiros alterados com 78 adições e 17 exclusões
  1. 78 17
      brain/PlannerNode2/largemodel/largemodel/model_service.py

+ 78 - 17
brain/PlannerNode2/largemodel/largemodel/model_service.py

@@ -3,7 +3,7 @@ import json
 import rclpy
 import rclpy
 from rclpy.node import Node
 from rclpy.node import Node
 from interfaces.action import Rot
 from interfaces.action import Rot
-from std_msgs.msg import String
+from std_msgs.msg import String, Bool
 from utils import large_model_interface
 from utils import large_model_interface
 from rclpy.action import ActionClient
 from rclpy.action import ActionClient
 from ament_index_python.packages import get_package_share_directory
 from ament_index_python.packages import get_package_share_directory
@@ -83,6 +83,7 @@ class LargeModelService(Node):
         )
         )
 
 
         self.conversation_id = None  # 会话id
         self.conversation_id = None  # 会话id
+        self.current_session_id = 0  # 当前会话ID,用于防串线
         self.map_mapping = ""  # 地图映射(从订阅获取)
         self.map_mapping = ""  # 地图映射(从订阅获取)
         self.config_data = {}  # 配置数据(从订阅获取)
         self.config_data = {}  # 配置数据(从订阅获取)
         self.map_points = []  # 地图导航点列表(从订阅获取)
         self.map_points = []  # 地图导航点列表(从订阅获取)
@@ -136,6 +137,10 @@ class LargeModelService(Node):
         self.config_sub = self.create_subscription(
         self.config_sub = self.create_subscription(
             String, "/ai/config", self.config_callback, 10
             String, "/ai/config", self.config_callback, 10
         )
         )
+        # 订阅 wakeup 信号 / Subscribe wakeup signal
+        self.wakeup_sub = self.create_subscription(
+            Bool, "wakeup", self.wakeup_callback, 5
+        )
         self.environment_sub = None  # 环境数据订阅者(需要动态更新)
         self.environment_sub = None  # 环境数据订阅者(需要动态更新)
         self.environment_topic = "/ai/env"  # 默认值(与 environment_node publish_topic 一致)
         self.environment_topic = "/ai/env"  # 默认值(与 environment_node publish_topic 一致)
 
 
@@ -144,22 +149,36 @@ class LargeModelService(Node):
 
 
     def seewhat_callback(self, msg):
     def seewhat_callback(self, msg):
         if msg.data == "seewhat":
         if msg.data == "seewhat":
+            session_id = self.current_session_id
             if (
             if (
                 self.regional_setting == "China"
                 self.regional_setting == "China"
             ):  # 在线模型推理方式:决策层推理+执行层监督 / Online model inference method: Decision layer reasoning + Execution layer supervision
             ):  # 在线模型推理方式:决策层推理+执行层监督 / Online model inference method: Decision layer reasoning + Execution layer supervision
-                self.dual_large_model_mode(type="image")
+                self.dual_large_model_mode(type="image", session_id=session_id)
             else:
             else:
-                self.dual_large_model_international_model(type="image")
+                self.dual_large_model_international_model(type="image", session_id=session_id)
+
+    def wakeup_callback(self, msg):
+        """
+        订阅 wakeup 信号回调
+        每次收到唤醒信号,session_id + 1,确保旧 LLM 结果被丢弃
+        """
+        if msg.data:
+            self.current_session_id += 1
+            self.get_logger().warn(
+                f"[Session] 用户重新唤醒,切换 session_id={self.current_session_id}"
+            )
 
 
     def asr_callback(self, msg):
     def asr_callback(self, msg):
+        session_id = self.current_session_id
         if (
         if (
             self.regional_setting == "China"
             self.regional_setting == "China"
         ):  # 在线模型推理方式:决策层推理+执行层监督 / Online model inference method: Decision layer reasoning + Execution layer supervision
         ):  # 在线模型推理方式:决策层推理+执行层监督 / Online model inference method: Decision layer reasoning + Execution layer supervision
-            self.dual_large_model_mode(type="text", prompt=msg.data)
+            self.dual_large_model_mode(type="text", prompt=msg.data, session_id=session_id)
         else:
         else:
-            self.dual_large_model_international_model(type="text", prompt=msg.data)
+            self.dual_large_model_international_model(type="text", prompt=msg.data, session_id=session_id)
 
 
     def actionstatus_callback(self, msg):
     def actionstatus_callback(self, msg):
+        session_id = self.current_session_id
         if (
         if (
             msg.data == "finish"
             msg.data == "finish"
         ):  # 如果收到的是finish则表示当前指令执行完成,开启新的指令执行周期 / If "finish" is received, it means the current instruction has been executed and a new instruction cycle begins
         ):  # 如果收到的是finish则表示当前指令执行完成,开启新的指令执行周期 / If "finish" is received, it means the current instruction has been executed and a new instruction cycle begins
@@ -171,14 +190,14 @@ class LargeModelService(Node):
             # ask_user 超时,触发空推理
             # ask_user 超时,触发空推理
             self.get_logger().warn("[多轮对话] ask_user 超时,触发空推理")
             self.get_logger().warn("[多轮对话] ask_user 超时,触发空推理")
             if self.regional_setting == "China":
             if self.regional_setting == "China":
-                self.dual_large_model_mode(type="text", prompt="ask_user_timeout")
+                self.dual_large_model_mode(type="text", prompt="ask_user_timeout", session_id=session_id)
             else:
             else:
-                self.dual_large_model_international_model(type="text", prompt="ask_user_timeout")
+                self.dual_large_model_international_model(type="text", prompt="ask_user_timeout", session_id=session_id)
         else:  # 向指令执行层大模型反馈动作执行结果 / Feedback action execution results to the large model in the command execution layer
         else:  # 向指令执行层大模型反馈动作执行结果 / Feedback action execution results to the large model in the command execution layer
             if self.regional_setting == "China":
             if self.regional_setting == "China":
-                self.dual_large_model_mode(type="text", prompt=msg.data)
+                self.dual_large_model_mode(type="text", prompt=msg.data, session_id=session_id)
             else:
             else:
-                self.dual_large_model_international_model(type="text", prompt=msg.data)
+                self.dual_large_model_international_model(type="text", prompt=msg.data, session_id=session_id)
 
 
     def config_callback(self, msg):
     def config_callback(self, msg):
         """
         """
@@ -270,7 +289,7 @@ class LargeModelService(Node):
             self.get_logger().warn(f'解析环境数据失败: {e}')
             self.get_logger().warn(f'解析环境数据失败: {e}')
 
 
     # @measure_execution_time
     # @measure_execution_time
-    def dual_large_model_mode(self, type, prompt=""):
+    def dual_large_model_mode(self, type, prompt="", session_id=None):
         """
         """
         此函数实现了双模型推理模式,即先由文本生成模型进行任务规划,然后由多模态大模型生成动作列表
         此函数实现了双模型推理模式,即先由文本生成模型进行任务规划,然后由多模态大模型生成动作列表
         This function implements the dual model inference mode, where the text generation model first plans the task, and then the multimodal large model generates the action list.
         This function implements the dual model inference mode, where the text generation model first plans the task, and then the multimodal large model generates the action list.
@@ -294,6 +313,13 @@ class LargeModelService(Node):
                 prompt
                 prompt
             )  # 调用决策层大模型进行任务规划 / Call the decision layer large model for task planning
             )  # 调用决策层大模型进行任务规划 / Call the decision layer large model for task planning
 
 
+            # 检查 session 是否过期
+            if session_id is not None and session_id != self.current_session_id:
+                self.get_logger().warn(
+                    f"[Session] 丢弃过期决策层结果: result_session={session_id}, current_session={self.current_session_id}"
+                )
+                return
+
             if not execute_instructions == "error":
             if not execute_instructions == "error":
 
 
                 prompt_desidon = (
                 prompt_desidon = (
@@ -316,6 +342,7 @@ class LargeModelService(Node):
                 self.instruction_process(
                 self.instruction_process(
                     type="text",
                     type="text",
                     prompt=prompt_desidon,
                     prompt=prompt_desidon,
+                    session_id=session_id,
                 )  # 传递决策层模型规划好的执行步骤给执行层模型 / Pass the planned execution steps from the decision layer model to the execution layer model
                 )  # 传递决策层模型规划好的执行步骤给执行层模型 / Pass the planned execution steps from the decision layer model to the execution layer model
 
 
                 self.new_order_cycle = (
                 self.new_order_cycle = (
@@ -327,10 +354,11 @@ class LargeModelService(Node):
                 )  # 模型推理失败,请检查模型配额和账户是否正常!!!
                 )  # 模型推理失败,请检查模型配额和账户是否正常!!!
         else:
         else:
             self.instruction_process(
             self.instruction_process(
-                prompt, type
+                prompt, type,
+                session_id=session_id,
             )  # 调用执行层大模型生成成动作列表并执行 / Call the execution layer large model to generate an action list and execute
             )  # 调用执行层大模型生成成动作列表并执行 / Call the execution layer large model to generate an action list and execute
 
 
-    def instruction_process(self, prompt, type, conversation_id=None):
+    def instruction_process(self, prompt, type, conversation_id=None, session_id=None):
         """
         """
         根据输入信息的类型(文字/图片),构建不同的请求体进行推理,并返回结果)
         根据输入信息的类型(文字/图片),构建不同的请求体进行推理,并返回结果)
         Based on the type of input information (text/image), construct different request bodies for inference and return the result.
         Based on the type of input information (text/image), construct different request bodies for inference and return the result.
@@ -343,6 +371,14 @@ class LargeModelService(Node):
                 raw_content = self.model_client.multimodalinfer(
                 raw_content = self.model_client.multimodalinfer(
                     prompt_seewhat, image_path=self.image_save_path
                     prompt_seewhat, image_path=self.image_save_path
                 )
                 )
+
+            # 检查 session 是否过期
+            if session_id is not None and session_id != self.current_session_id:
+                self.get_logger().warn(
+                    f"[Session] 丢弃过期执行层结果(国内): result_session={session_id}, current_session={self.current_session_id}"
+                )
+                return
+
             json_str = self.extract_json_content(raw_content)
             json_str = self.extract_json_content(raw_content)
 
 
         elif self.regional_setting == "international":  # 国际版
         elif self.regional_setting == "international":  # 国际版
@@ -358,6 +394,7 @@ class LargeModelService(Node):
                     self.conversation_id = result[2]
                     self.conversation_id = result[2]
                 else:
                 else:
                     self.get_logger().info(f"ERROR:{result[1]}")
                     self.get_logger().info(f"ERROR:{result[1]}")
+                    json_str = None
             elif type == "image":
             elif type == "image":
                 result = self.model_client.TaskExecution(
                 result = self.model_client.TaskExecution(
                     input=prompt_seewhat,
                     input=prompt_seewhat,
@@ -371,6 +408,14 @@ class LargeModelService(Node):
                     self.conversation_id = result[2]
                     self.conversation_id = result[2]
                 else:
                 else:
                     self.get_logger().info(f"ERROR:{result[1]}")
                     self.get_logger().info(f"ERROR:{result[1]}")
+                    json_str = None
+
+            # 检查 session 是否过期
+            if session_id is not None and session_id != self.current_session_id:
+                self.get_logger().warn(
+                    f"[Session] 丢弃过期执行层结果(国际): result_session={session_id}, current_session={self.current_session_id}"
+                )
+                return
 
 
         if json_str is not None:
         if json_str is not None:
             # 解析JSON字符串,分离"action"、"response"字段 / Parse JSON string, separate "action" and "response" fields
             # 解析JSON字符串,分离"action"、"response"字段 / Parse JSON string, separate "action" and "response" fields
@@ -392,10 +437,11 @@ class LargeModelService(Node):
             )
             )
 
 
         self.send_action_service(
         self.send_action_service(
-            action_list, llm_response
+            action_list, llm_response,
+            session_id=session_id,
         )  # 异步发送动作列表、回复内容给ActionServer / Asynchronously send action list and response content to ActionServer
         )  # 异步发送动作列表、回复内容给ActionServer / Asynchronously send action list and response content to ActionServer
 
 
-    def dual_large_model_international_model(self, type, prompt=""):
+    def dual_large_model_international_model(self, type, prompt="", session_id=None):
         """
         """
         此函数适用于国际版双模型推理模式,使用dify作为中间件
         此函数适用于国际版双模型推理模式,使用dify作为中间件
         /this function is suitable for international model inference mode, using dify as the middleware
         /this function is suitable for international model inference mode, using dify as the middleware
@@ -406,6 +452,13 @@ class LargeModelService(Node):
             self.conversation_id = None
             self.conversation_id = None
             result = self.model_client.TaskDecision(prompt)
             result = self.model_client.TaskDecision(prompt)
 
 
+            # 检查 session 是否过期
+            if session_id is not None and session_id != self.current_session_id:
+                self.get_logger().warn(
+                    f"[Session] 丢弃过期决策层结果(国际): result_session={session_id}, current_session={self.current_session_id}"
+                )
+                return
+
             if result[0]:
             if result[0]:
                 prompt_desidon = (
                 prompt_desidon = (
                     self.prompt_dict[self.language]
                     self.prompt_dict[self.language]
@@ -422,7 +475,7 @@ class LargeModelService(Node):
                     .get("prompt_1")
                     .get("prompt_1")
                     .format(prompt=prompt, execute_instructions=result[1])
                     .format(prompt=prompt, execute_instructions=result[1])
                 )
                 )
-                self.instruction_process(type="text", prompt=prompt_desion)
+                self.instruction_process(type="text", prompt=prompt_desion, session_id=session_id)
 
 
                 self.new_order_cycle = (
                 self.new_order_cycle = (
                     False  # 重置指令周期标志位 / Reset the instruction cycle flag
                     False  # 重置指令周期标志位 / Reset the instruction cycle flag
@@ -434,10 +487,18 @@ class LargeModelService(Node):
 
 
         else:
         else:
             self.instruction_process(
             self.instruction_process(
-                prompt, type, conversation_id=self.conversation_id
+                prompt, type, conversation_id=self.conversation_id,
+                session_id=session_id,
             )  # 调用执行层大模型生成成动作列表并执行 / Call the execution layer large model to generate an action list and execute
             )  # 调用执行层大模型生成成动作列表并执行 / Call the execution layer large model to generate an action list and execute
 
 
-    def send_action_service(self, actions, text):
+    def send_action_service(self, actions, text, session_id=None):
+        # 检查 session 是否过期
+        if session_id is not None and session_id != self.current_session_id:
+            self.get_logger().warn(
+                f"[Session] 拒绝发送过期 action: result_session={session_id}, current_session={self.current_session_id}"
+            )
+            return
+
         goal_msg = Rot.Goal()  # 创建目标消息对象 / Create goal message object
         goal_msg = Rot.Goal()  # 创建目标消息对象 / Create goal message object
         goal_msg.actions = actions  # 设置目标消息中的动作列表 / Set the action list in the goal message
         goal_msg.actions = actions  # 设置目标消息中的动作列表 / Set the action list in the goal message
         goal_msg.llm_response = text
         goal_msg.llm_response = text