asr.py.back 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. import rclpy
  2. import os
  3. import time
  4. from rclpy.node import Node
  5. import pyaudio
  6. from playsound import playsound
  7. import wave
  8. import threading
  9. import webrtcvad
  10. import queue
  11. from std_msgs.msg import String, UInt16, Bool
  12. from utils.mic_serial import kws_mic
  13. from utils import large_model_interface
  14. from utils.large_model_interface import rec_wav_music_en
  15. from ament_index_python.packages import get_package_share_directory
  16. import functools
  17. def measure_execution_time(func):
  18. """
  19. 装饰器:测量函数执行时间并使用 ROS 日志打印结果
  20. """
  21. @functools.wraps(func)
  22. def wrapper(self, *args, **kwargs):
  23. start_time = time.time()
  24. result = func(self, *args, **kwargs)
  25. end_time = time.time()
  26. execution_time = end_time - start_time
  27. # 使用 ROS 日志系统记录执行时间
  28. if hasattr(self, 'get_logger'):
  29. self.get_logger().info(f"[性能统计] {func.__name__} 函数执行时间: {execution_time:.4f} 秒")
  30. else:
  31. print(f"[性能统计] {func.__name__} 函数执行时间: {execution_time:.4f} 秒")
  32. return result
  33. return wrapper
  34. class ASRNode(Node):
  35. def __init__(self):
  36. super().__init__("asr_node")
  37. # 初始化参数、变量 / Initialize parameters and variables
  38. self.init_param_config()
  39. # 初始化语音唤醒 / Initialize keyword spotting (KWS)
  40. self.kws_init()
  41. # 初始化ASR模型 / Initialize ASR model
  42. self.asr_mdoel_init()
  43. # 初始化语言设置 / Initialize language settings
  44. self.language_init()
  45. # 初始化系统声音 / Initialize system sound functionality
  46. self.system_sound_init()
  47. # 初始化ROS通信 / Initialize ROS communication
  48. self.init_ros_comunication()
  49. # 打印初始化信息 / Log initialization completion
  50. self.get_logger().info("asr_node Initialization completed")
  51. def init_ros_comunication(self):
  52. # 创建蜂鸣器发布者 / Create a publisher for the buzzer
  53. self.pub_beep = self.create_publisher(UInt16, "beep", 10)
  54. # 创建ASR发布者,发布转换完成的消息 / Create an ASR publisher to publish conversion results
  55. self.asr_pub = self.create_publisher(String, "asr", 5)
  56. # 创建唤醒信息发布者 / Create a publisher for wake-up signals
  57. self.wakeup_pub = self.create_publisher(Bool, "wakeup", 5)
  58. #创建发布录音状态发布者 / Create a publisher for recording status
  59. self.record_status_pub=self.create_publisher(Bool, "record_status", 5)
  60. def init_param_config(self):
  61. self.user_speechdir = os.path.join(
  62. get_package_share_directory("largemodel"),
  63. "resources_file",
  64. "user_speech.wav",
  65. )
  66. # 参数声明 / Declare parameters
  67. self.declare_parameter("VAD_MODE", 1)
  68. self.declare_parameter("sample_rate", 16000)
  69. self.declare_parameter("frame_duration_ms", 30)
  70. self.declare_parameter("language", "en")
  71. self.declare_parameter("use_oline_asr", False)
  72. self.declare_parameter("mic_serial_port", "/dev/mic")
  73. self.declare_parameter("mic_index", 0)
  74. self.declare_parameter("regional_setting", "China")
  75. # 获取服务器参数 / Get server parameters
  76. self.VAD_MODE = (
  77. self.get_parameter("VAD_MODE").get_parameter_value().integer_value
  78. )
  79. self.sample_rate = (
  80. self.get_parameter("sample_rate").get_parameter_value().integer_value
  81. )
  82. self.frame_duration_ms = (
  83. self.get_parameter("frame_duration_ms").get_parameter_value().integer_value
  84. )
  85. self.language = (
  86. self.get_parameter("language").get_parameter_value().string_value
  87. )
  88. self.use_oline_asr = (
  89. self.get_parameter("use_oline_asr").get_parameter_value().bool_value
  90. )
  91. self.mic_serial_port = (
  92. self.get_parameter("mic_serial_port").get_parameter_value().string_value
  93. )
  94. self.mic_index = (
  95. self.get_parameter("mic_index").get_parameter_value().integer_value
  96. )
  97. self.regional_setting = (
  98. self.get_parameter("regional_setting").get_parameter_value().string_value
  99. )
  100. self.frame_bytes = int(
  101. self.sample_rate * self.frame_duration_ms / 1000
  102. ) # 音频帧大小 / Audio frame size
  103. # 大模型接口实例端 / Instance of the large model interface
  104. # 传入 logger 用于调试日志
  105. self.modelinterface = large_model_interface.model_interface(logger=self.get_logger())
  106. # 初始化 WebRTC VAD / Initialize WebRTC VAD
  107. self.vad = webrtcvad.Vad()
  108. self.vad.set_mode(self.VAD_MODE)
  109. self.current_thread = None # 唤醒处理线程 / Thread for handling wake-up events
  110. self.stop_event = threading.Event()
  111. def main_loop(self):
  112. while rclpy.ok():
  113. while (
  114. self.audio_request_queue.qsize() > 1
  115. ): # 只处理最近的一次唤醒请求,防止重复唤醒 / Process only the most recent wake-up request to prevent duplicates
  116. self.audio_request_queue.get()
  117. if not self.audio_request_queue.empty():
  118. self.audio_request_queue.get()
  119. self.wakeup_pub.publish(
  120. Bool(data=True)
  121. ) # 发布唤醒信号 / Publish wake-up signal
  122. self.get_logger().info("I'm here")
  123. playsound(
  124. self.audio_dict[self.first_response]
  125. ) # 应答用户 / Respond to the user
  126. if (
  127. self.current_thread and self.current_thread.is_alive()
  128. ): # 打断上次的唤醒处理线程 / Interrupt the previous wake-up handling thread
  129. self.stop_event.set()
  130. self.current_thread.join() # 等待当前线程结束 / Wait for the current thread to finish
  131. self.stop_event.clear() # 清除事件 / Clear the event
  132. self.current_thread = threading.Thread(target=self.kws_handler)
  133. self.current_thread.daemon = True
  134. self.current_thread.start()
  135. rclpy.spin_once(self, timeout_sec=0.1)
  136. time.sleep(0.1)
  137. def kws_handler(self) -> None:
  138. if self.stop_event.is_set():
  139. return
  140. if self.listen_for_speech(self.mic_index):
  141. asr_text = self.ASR_conversion(
  142. self.user_speechdir
  143. ) # 进行 ASR 转换 / Perform ASR conversion
  144. if (
  145. asr_text == "error"
  146. ): # 检查 ASR 结果长度是否小于4个字符 / Check if ASR result length is less than 4 characters
  147. self.get_logger().warn(
  148. "I still don't understand what you mean. Please try again"
  149. )
  150. playsound(
  151. self.audio_dict[self.error_response]
  152. ) # 错误响应 / Error response
  153. else:
  154. self.get_logger().info(asr_text)
  155. self.get_logger().info("😀okay, let me think for a moment...")
  156. self.asr_pub_result(asr_text) # 发布 ASR结果 / Publish ASR result
  157. else:
  158. return
  159. def system_sound_init(
  160. self,
  161. ): # 初始化系统声音相关的功能 / Initialize system sound functionality
  162. pkg_path = get_package_share_directory("largemodel")
  163. self.audio_dict = {} # 系统声音字典 / Dictionary of system sounds
  164. self.audio_dict["longwan-women-1"] = os.path.join(
  165. pkg_path, "resources_file", "longwan-women-1.mp3"
  166. )
  167. self.audio_dict["longwan-women-2"] = os.path.join(
  168. pkg_path, "resources_file", "longwan-women-2.mp3"
  169. )
  170. self.audio_dict["longxiaochun-women-1"] = os.path.join(
  171. pkg_path, "resources_file", "longxiaochun-women-1.mp3"
  172. )
  173. self.audio_dict["longxiaochun-women-2"] = os.path.join(
  174. pkg_path, "resources_file", "longxiaochun-women-2.mp3"
  175. )
  176. def asr_mdoel_init(self): # 初始化asr模型 / Initialize ASR model
  177. if self.regional_setting == "international":
  178. self.get_logger().info(
  179. f"The online asr model :XUN-FEI ASR is loaded"
  180. )
  181. elif self.regional_setting == "China":
  182. if self.use_oline_asr:
  183. self.get_logger().info(
  184. f"The online asr model :{self.modelinterface.init_oline_asr(self.language)} is loaded"
  185. )
  186. else:
  187. # -------- SenseVoiceSmall 语音识别 --模型加载----- / Load SenseVoiceSmall online ASR model
  188. self.modelinterface.init_local_asr_model()
  189. self.get_logger().info("The asr model :SenseVoiceSmall is loaded")
  190. else:
  191. while True:
  192. self.get_logger().info('Please check the regional_setting parameter in yahboom.yaml file, it should be either "China" or "international".')
  193. time.sleep(1)
  194. def language_init(self):
  195. if self.language == "zh":
  196. self.first_response = "longwan-women-1"
  197. self.error_response = "longwan-women-2"
  198. elif self.language == "en":
  199. self.first_response = "longxiaochun-women-1"
  200. self.error_response = "longxiaochun-women-2"
  201. else:
  202. while True:
  203. self.get_logger().error(
  204. "language setting error,please check your language setting"
  205. ) # 语言设置错误,请检查语言设置 / Language setting error, please check your language setting
  206. time.sleep(3)
  207. def kws_init(
  208. self,
  209. ): # 初始化关键词唤醒相关的内容 / Initialize keyword spotting (KWS) related content
  210. self.port_name = self.mic_serial_port
  211. self.audio_request_queue = (
  212. queue.Queue()
  213. ) # 用于传递音频请求 / Queue for passing audio requests
  214. self.serial_port = kws_mic(
  215. port=self.port_name, kwsquence=self.audio_request_queue, baudrate=115200
  216. )
  217. self.serial_port.open()
  218. if not self.serial_port.ser or not self.serial_port.ser.is_open:
  219. while True:
  220. time.sleep(1)
  221. self.get_logger().error(
  222. "Failed to open kws serial port.Please check whether the hardware wiring or the voice module is normal?"
  223. ) # 未能打开kws串口 / Failed to open KWS serial port
  224. receive_thread = threading.Thread(target=self.serial_port.receive_data)
  225. receive_thread.daemon = True
  226. receive_thread.start()
  227. def asr_pub_result(self, asr_result: str) -> None:
  228. msg = String(data=asr_result)
  229. self.asr_pub.publish(msg)
  230. # @measure_execution_time
  231. def ASR_conversion(self, input_file: str) -> str:
  232. if self.regional_setting == "international":
  233. res=rec_wav_music_en()
  234. if res is not None:
  235. return res
  236. else:
  237. return "error"
  238. else:
  239. if self.use_oline_asr:
  240. result = self.modelinterface.oline_asr(input_file)
  241. if result[0] == "ok" and len(result[1]) > 4:
  242. return result[1]
  243. else:
  244. self.get_logger().error(f"ASR Error:{result[1]}") # ASR错误 / ASR error
  245. return "error"
  246. else:
  247. result = self.modelinterface.SenseVoiceSmall_ASR(input_file)
  248. if result[0] == "ok" and len(result[1]) > 4:
  249. return result[1]
  250. else:
  251. self.get_logger().error(f"ASR Error:{result[1]}") # ASR错误 / ASR error
  252. return "error"
  253. def listen_for_speech(self, mic_index=0):
  254. self.record_status_pub.publish(Bool(data=True))
  255. p = pyaudio.PyAudio()
  256. audio_buffer = []
  257. silence_counter = 0
  258. MAX_SILENCE_FRAMES = 30 # 30帧*30ms=900ms静音后停止 / Stop after 900ms of silence (30 frames * 30ms)
  259. speaking = False # 语音活动标志 / Flag indicating speech activity
  260. frame_counter = 0 # 计数器 / Frame counter
  261. stream_kwargs = {
  262. "format": pyaudio.paInt16,
  263. "channels": 1,
  264. "rate": self.sample_rate,
  265. "input": True,
  266. "frames_per_buffer": self.frame_bytes,
  267. }
  268. if mic_index != 0:
  269. stream_kwargs["input_device_index"] = mic_index
  270. # 通过蜂鸣器提示用户讲话 / Prompt the user to speak via the buzzer
  271. self.pub_beep.publish(UInt16(data=1))
  272. time.sleep(0.5)
  273. self.pub_beep.publish(UInt16(data=0))
  274. try:
  275. # 打开音频流 / Open audio stream
  276. stream = p.open(**stream_kwargs)
  277. while True:
  278. if self.stop_event.is_set():
  279. return False
  280. frame = stream.read(
  281. self.frame_bytes, exception_on_overflow=False
  282. ) # 读取音频数据 / Read audio data
  283. is_speech = self.vad.is_speech(
  284. frame, self.sample_rate
  285. ) # VAD检测 / VAD detection
  286. if is_speech:
  287. # 检测到语音活动 / Detected speech activity
  288. speaking = True
  289. audio_buffer.append(frame)
  290. silence_counter = 0
  291. else:
  292. if speaking:
  293. # 在语音活动后检测静音 / Detect silence after speech activity
  294. silence_counter += 1
  295. audio_buffer.append(
  296. frame
  297. ) # 持续记录缓冲 / Continue recording buffer
  298. # 静音持续时间达标时结束录音 / End recording when silence duration meets the threshold
  299. if silence_counter >= MAX_SILENCE_FRAMES:
  300. break
  301. frame_counter += 1
  302. if frame_counter % 2 == 0:
  303. self.get_logger().info("1" if is_speech else "-")
  304. finally:
  305. stream.stop_stream()
  306. stream.close()
  307. p.terminate()
  308. self.record_status_pub.publish(Bool(data=False))
  309. # 保存有效录音(去除尾部静音) / Save valid recording (remove trailing silence)
  310. if speaking and len(audio_buffer) > 0:
  311. # 裁剪最后静音部分 / Trim the last silent part
  312. clean_buffer = (
  313. audio_buffer[:-MAX_SILENCE_FRAMES]
  314. if len(audio_buffer) > MAX_SILENCE_FRAMES
  315. else audio_buffer
  316. )
  317. with wave.open(self.user_speechdir, "wb") as wf:
  318. wf.setnchannels(1)
  319. wf.setsampwidth(p.get_sample_size(pyaudio.paInt16))
  320. wf.setframerate(self.sample_rate)
  321. wf.writeframes(b"".join(clean_buffer))
  322. return True
  323. def main(args=None):
  324. rclpy.init(args=args)
  325. sense_voice_node = ASRNode()
  326. try:
  327. sense_voice_node.main_loop()
  328. except KeyboardInterrupt:
  329. pass
  330. finally:
  331. sense_voice_node.destroy_node()
  332. rclpy.shutdown()
  333. if __name__ == "__main__":
  334. main()