| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368 |
- import rclpy
- import os
- import time
- from rclpy.node import Node
- import pyaudio
- from playsound import playsound
- import wave
- import threading
- import webrtcvad
- import queue
- from std_msgs.msg import String, UInt16, Bool
- from utils.mic_serial import kws_mic
- from utils import large_model_interface
- from utils.large_model_interface import rec_wav_music_en
- from ament_index_python.packages import get_package_share_directory
- import functools
- def measure_execution_time(func):
- """
- 装饰器:测量函数执行时间并使用 ROS 日志打印结果
- """
- @functools.wraps(func)
- def wrapper(self, *args, **kwargs):
- start_time = time.time()
- result = func(self, *args, **kwargs)
- end_time = time.time()
- execution_time = end_time - start_time
-
- # 使用 ROS 日志系统记录执行时间
- if hasattr(self, 'get_logger'):
- self.get_logger().info(f"[性能统计] {func.__name__} 函数执行时间: {execution_time:.4f} 秒")
- else:
- print(f"[性能统计] {func.__name__} 函数执行时间: {execution_time:.4f} 秒")
- return result
- return wrapper
- class ASRNode(Node):
- def __init__(self):
- super().__init__("asr_node")
- # 初始化参数、变量 / Initialize parameters and variables
- self.init_param_config()
- # 初始化语音唤醒 / Initialize keyword spotting (KWS)
- self.kws_init()
- # 初始化ASR模型 / Initialize ASR model
- self.asr_mdoel_init()
- # 初始化语言设置 / Initialize language settings
- self.language_init()
- # 初始化系统声音 / Initialize system sound functionality
- self.system_sound_init()
- # 初始化ROS通信 / Initialize ROS communication
- self.init_ros_comunication()
- # 打印初始化信息 / Log initialization completion
- self.get_logger().info("asr_node Initialization completed")
- def init_ros_comunication(self):
- # 创建蜂鸣器发布者 / Create a publisher for the buzzer
- self.pub_beep = self.create_publisher(UInt16, "beep", 10)
- # 创建ASR发布者,发布转换完成的消息 / Create an ASR publisher to publish conversion results
- self.asr_pub = self.create_publisher(String, "asr", 5)
- # 创建唤醒信息发布者 / Create a publisher for wake-up signals
- self.wakeup_pub = self.create_publisher(Bool, "wakeup", 5)
- #创建发布录音状态发布者 / Create a publisher for recording status
- self.record_status_pub=self.create_publisher(Bool, "record_status", 5)
- def init_param_config(self):
- self.user_speechdir = os.path.join(
- get_package_share_directory("largemodel"),
- "resources_file",
- "user_speech.wav",
- )
- # 参数声明 / Declare parameters
- self.declare_parameter("VAD_MODE", 1)
- self.declare_parameter("sample_rate", 16000)
- self.declare_parameter("frame_duration_ms", 30)
- self.declare_parameter("language", "en")
- self.declare_parameter("use_oline_asr", False)
- self.declare_parameter("mic_serial_port", "/dev/mic")
- self.declare_parameter("mic_index", 0)
- self.declare_parameter("regional_setting", "China")
- # 获取服务器参数 / Get server parameters
- self.VAD_MODE = (
- self.get_parameter("VAD_MODE").get_parameter_value().integer_value
- )
- self.sample_rate = (
- self.get_parameter("sample_rate").get_parameter_value().integer_value
- )
- self.frame_duration_ms = (
- self.get_parameter("frame_duration_ms").get_parameter_value().integer_value
- )
- self.language = (
- self.get_parameter("language").get_parameter_value().string_value
- )
- self.use_oline_asr = (
- self.get_parameter("use_oline_asr").get_parameter_value().bool_value
- )
- self.mic_serial_port = (
- self.get_parameter("mic_serial_port").get_parameter_value().string_value
- )
- self.mic_index = (
- self.get_parameter("mic_index").get_parameter_value().integer_value
- )
- self.regional_setting = (
- self.get_parameter("regional_setting").get_parameter_value().string_value
- )
- self.frame_bytes = int(
- self.sample_rate * self.frame_duration_ms / 1000
- ) # 音频帧大小 / Audio frame size
-
- # 大模型接口实例端 / Instance of the large model interface
- # 传入 logger 用于调试日志
- self.modelinterface = large_model_interface.model_interface(logger=self.get_logger())
- # 初始化 WebRTC VAD / Initialize WebRTC VAD
- self.vad = webrtcvad.Vad()
- self.vad.set_mode(self.VAD_MODE)
- self.current_thread = None # 唤醒处理线程 / Thread for handling wake-up events
- self.stop_event = threading.Event()
- def main_loop(self):
- while rclpy.ok():
- while (
- self.audio_request_queue.qsize() > 1
- ): # 只处理最近的一次唤醒请求,防止重复唤醒 / Process only the most recent wake-up request to prevent duplicates
- self.audio_request_queue.get()
- if not self.audio_request_queue.empty():
- self.audio_request_queue.get()
- self.wakeup_pub.publish(
- Bool(data=True)
- ) # 发布唤醒信号 / Publish wake-up signal
- self.get_logger().info("I'm here")
- playsound(
- self.audio_dict[self.first_response]
- ) # 应答用户 / Respond to the user
- if (
- self.current_thread and self.current_thread.is_alive()
- ): # 打断上次的唤醒处理线程 / Interrupt the previous wake-up handling thread
- self.stop_event.set()
- self.current_thread.join() # 等待当前线程结束 / Wait for the current thread to finish
- self.stop_event.clear() # 清除事件 / Clear the event
- self.current_thread = threading.Thread(target=self.kws_handler)
- self.current_thread.daemon = True
- self.current_thread.start()
- rclpy.spin_once(self, timeout_sec=0.1)
- time.sleep(0.1)
- def kws_handler(self) -> None:
- if self.stop_event.is_set():
- return
- if self.listen_for_speech(self.mic_index):
- asr_text = self.ASR_conversion(
- self.user_speechdir
- ) # 进行 ASR 转换 / Perform ASR conversion
- if (
- asr_text == "error"
- ): # 检查 ASR 结果长度是否小于4个字符 / Check if ASR result length is less than 4 characters
- self.get_logger().warn(
- "I still don't understand what you mean. Please try again"
- )
- playsound(
- self.audio_dict[self.error_response]
- ) # 错误响应 / Error response
- else:
- self.get_logger().info(asr_text)
- self.get_logger().info("😀okay, let me think for a moment...")
- self.asr_pub_result(asr_text) # 发布 ASR结果 / Publish ASR result
- else:
- return
- def system_sound_init(
- self,
- ): # 初始化系统声音相关的功能 / Initialize system sound functionality
- pkg_path = get_package_share_directory("largemodel")
- self.audio_dict = {} # 系统声音字典 / Dictionary of system sounds
- self.audio_dict["longwan-women-1"] = os.path.join(
- pkg_path, "resources_file", "longwan-women-1.mp3"
- )
- self.audio_dict["longwan-women-2"] = os.path.join(
- pkg_path, "resources_file", "longwan-women-2.mp3"
- )
- self.audio_dict["longxiaochun-women-1"] = os.path.join(
- pkg_path, "resources_file", "longxiaochun-women-1.mp3"
- )
- self.audio_dict["longxiaochun-women-2"] = os.path.join(
- pkg_path, "resources_file", "longxiaochun-women-2.mp3"
- )
- def asr_mdoel_init(self): # 初始化asr模型 / Initialize ASR model
- if self.regional_setting == "international":
- self.get_logger().info(
- f"The online asr model :XUN-FEI ASR is loaded"
- )
- elif self.regional_setting == "China":
- if self.use_oline_asr:
-
- self.get_logger().info(
- f"The online asr model :{self.modelinterface.init_oline_asr(self.language)} is loaded"
- )
- else:
- # -------- SenseVoiceSmall 语音识别 --模型加载----- / Load SenseVoiceSmall online ASR model
- self.modelinterface.init_local_asr_model()
- self.get_logger().info("The asr model :SenseVoiceSmall is loaded")
- else:
- while True:
- self.get_logger().info('Please check the regional_setting parameter in yahboom.yaml file, it should be either "China" or "international".')
-
- time.sleep(1)
- def language_init(self):
- if self.language == "zh":
- self.first_response = "longwan-women-1"
- self.error_response = "longwan-women-2"
- elif self.language == "en":
- self.first_response = "longxiaochun-women-1"
- self.error_response = "longxiaochun-women-2"
- else:
- while True:
- self.get_logger().error(
- "language setting error,please check your language setting"
- ) # 语言设置错误,请检查语言设置 / Language setting error, please check your language setting
- time.sleep(3)
- def kws_init(
- self,
- ): # 初始化关键词唤醒相关的内容 / Initialize keyword spotting (KWS) related content
- self.port_name = self.mic_serial_port
- self.audio_request_queue = (
- queue.Queue()
- ) # 用于传递音频请求 / Queue for passing audio requests
- self.serial_port = kws_mic(
- port=self.port_name, kwsquence=self.audio_request_queue, baudrate=115200
- )
- self.serial_port.open()
- if not self.serial_port.ser or not self.serial_port.ser.is_open:
- while True:
- time.sleep(1)
- self.get_logger().error(
- "Failed to open kws serial port.Please check whether the hardware wiring or the voice module is normal?"
- ) # 未能打开kws串口 / Failed to open KWS serial port
- receive_thread = threading.Thread(target=self.serial_port.receive_data)
- receive_thread.daemon = True
- receive_thread.start()
- def asr_pub_result(self, asr_result: str) -> None:
- msg = String(data=asr_result)
- self.asr_pub.publish(msg)
- # @measure_execution_time
- def ASR_conversion(self, input_file: str) -> str:
- if self.regional_setting == "international":
- res=rec_wav_music_en()
- if res is not None:
- return res
- else:
- return "error"
- else:
-
- if self.use_oline_asr:
- result = self.modelinterface.oline_asr(input_file)
- if result[0] == "ok" and len(result[1]) > 4:
- return result[1]
- else:
- self.get_logger().error(f"ASR Error:{result[1]}") # ASR错误 / ASR error
- return "error"
- else:
- result = self.modelinterface.SenseVoiceSmall_ASR(input_file)
- if result[0] == "ok" and len(result[1]) > 4:
- return result[1]
- else:
- self.get_logger().error(f"ASR Error:{result[1]}") # ASR错误 / ASR error
- return "error"
- def listen_for_speech(self, mic_index=0):
- self.record_status_pub.publish(Bool(data=True))
- p = pyaudio.PyAudio()
- audio_buffer = []
- silence_counter = 0
- MAX_SILENCE_FRAMES = 30 # 30帧*30ms=900ms静音后停止 / Stop after 900ms of silence (30 frames * 30ms)
- speaking = False # 语音活动标志 / Flag indicating speech activity
- frame_counter = 0 # 计数器 / Frame counter
- stream_kwargs = {
- "format": pyaudio.paInt16,
- "channels": 1,
- "rate": self.sample_rate,
- "input": True,
- "frames_per_buffer": self.frame_bytes,
- }
- if mic_index != 0:
- stream_kwargs["input_device_index"] = mic_index
- # 通过蜂鸣器提示用户讲话 / Prompt the user to speak via the buzzer
- self.pub_beep.publish(UInt16(data=1))
- time.sleep(0.5)
- self.pub_beep.publish(UInt16(data=0))
- try:
- # 打开音频流 / Open audio stream
- stream = p.open(**stream_kwargs)
- while True:
- if self.stop_event.is_set():
- return False
- frame = stream.read(
- self.frame_bytes, exception_on_overflow=False
- ) # 读取音频数据 / Read audio data
- is_speech = self.vad.is_speech(
- frame, self.sample_rate
- ) # VAD检测 / VAD detection
- if is_speech:
- # 检测到语音活动 / Detected speech activity
- speaking = True
- audio_buffer.append(frame)
- silence_counter = 0
- else:
- if speaking:
- # 在语音活动后检测静音 / Detect silence after speech activity
- silence_counter += 1
- audio_buffer.append(
- frame
- ) # 持续记录缓冲 / Continue recording buffer
- # 静音持续时间达标时结束录音 / End recording when silence duration meets the threshold
- if silence_counter >= MAX_SILENCE_FRAMES:
- break
- frame_counter += 1
- if frame_counter % 2 == 0:
- self.get_logger().info("1" if is_speech else "-")
-
- finally:
- stream.stop_stream()
- stream.close()
- p.terminate()
- self.record_status_pub.publish(Bool(data=False))
- # 保存有效录音(去除尾部静音) / Save valid recording (remove trailing silence)
- if speaking and len(audio_buffer) > 0:
- # 裁剪最后静音部分 / Trim the last silent part
- clean_buffer = (
- audio_buffer[:-MAX_SILENCE_FRAMES]
- if len(audio_buffer) > MAX_SILENCE_FRAMES
- else audio_buffer
- )
- with wave.open(self.user_speechdir, "wb") as wf:
- wf.setnchannels(1)
- wf.setsampwidth(p.get_sample_size(pyaudio.paInt16))
- wf.setframerate(self.sample_rate)
- wf.writeframes(b"".join(clean_buffer))
- return True
-
-
- def main(args=None):
- rclpy.init(args=args)
- sense_voice_node = ASRNode()
- try:
- sense_voice_node.main_loop()
- except KeyboardInterrupt:
- pass
- finally:
- sense_voice_node.destroy_node()
- rclpy.shutdown()
- if __name__ == "__main__":
- main()
|