asr.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. import rclpy
  2. import os
  3. import time
  4. from rclpy.node import Node
  5. import pyaudio
  6. from playsound import playsound
  7. import wave
  8. import threading
  9. import webrtcvad
  10. import queue
  11. from std_msgs.msg import String, UInt16, Bool
  12. from utils.mic_serial import kws_mic
  13. from utils import large_model_interface
  14. from utils.large_model_interface import rec_wav_music_en
  15. from ament_index_python.packages import get_package_share_directory
  16. import functools
  17. def measure_execution_time(func):
  18. """
  19. 装饰器:测量函数执行时间并使用 ROS 日志打印结果
  20. """
  21. @functools.wraps(func)
  22. def wrapper(self, *args, **kwargs):
  23. start_time = time.time()
  24. result = func(self, *args, **kwargs)
  25. end_time = time.time()
  26. execution_time = end_time - start_time
  27. # 使用 ROS 日志系统记录执行时间
  28. if hasattr(self, 'get_logger'):
  29. self.get_logger().info(f"[性能统计] {func.__name__} 函数执行时间: {execution_time:.4f} 秒")
  30. else:
  31. print(f"[性能统计] {func.__name__} 函数执行时间: {execution_time:.4f} 秒")
  32. return result
  33. return wrapper
  34. class ASRNode(Node):
  35. def __init__(self):
  36. super().__init__("asr_node")
  37. # 初始化参数、变量 / Initialize parameters and variables
  38. self.init_param_config()
  39. # 初始化语音唤醒 / Initialize keyword spotting (KWS)
  40. self.kws_init()
  41. # 初始化ASR模型 / Initialize ASR model
  42. self.asr_mdoel_init()
  43. # 初始化语言设置 / Initialize language settings
  44. self.language_init()
  45. # 初始化系统声音 / Initialize system sound functionality
  46. self.system_sound_init()
  47. # 初始化ROS通信 / Initialize ROS communication
  48. self.init_ros_comunication()
  49. # 打印初始化信息 / Log initialization completion
  50. self.get_logger().info("asr_node Initialization completed")
  51. def init_ros_comunication(self):
  52. # 创建蜂鸣器发布者 / Create a publisher for the buzzer
  53. self.pub_beep = self.create_publisher(UInt16, "beep", 10)
  54. # 创建ASR发布者,发布转换完成的消息 / Create an ASR publisher to publish conversion results
  55. self.asr_pub = self.create_publisher(String, "asr", 5)
  56. # 创建唤醒信息发布者 / Create a publisher for wake-up signals
  57. self.wakeup_pub = self.create_publisher(Bool, "wakeup", 5)
  58. #创建发布录音状态发布者 / Create a publisher for recording status
  59. self.record_status_pub=self.create_publisher(Bool, "record_status", 5)
  60. # 创建 ASR 控制话题订阅者 / Create ASR control topic subscriber
  61. self.asr_control_sub = self.create_subscription(
  62. String, "/asr/control", self.asr_control_callback, 10
  63. )
  64. def init_param_config(self):
  65. self.user_speechdir = os.path.join(
  66. get_package_share_directory("largemodel"),
  67. "resources_file",
  68. "user_speech.wav",
  69. )
  70. # 参数声明 / Declare parameters
  71. self.declare_parameter("VAD_MODE", 1)
  72. self.declare_parameter("sample_rate", 16000)
  73. self.declare_parameter("frame_duration_ms", 30)
  74. self.declare_parameter("language", "en")
  75. self.declare_parameter("use_oline_asr", False)
  76. self.declare_parameter("mic_serial_port", "/dev/mic")
  77. self.declare_parameter("mic_index", 0)
  78. self.declare_parameter("regional_setting", "China")
  79. # 获取服务器参数 / Get server parameters
  80. self.VAD_MODE = (
  81. self.get_parameter("VAD_MODE").get_parameter_value().integer_value
  82. )
  83. self.sample_rate = (
  84. self.get_parameter("sample_rate").get_parameter_value().integer_value
  85. )
  86. self.frame_duration_ms = (
  87. self.get_parameter("frame_duration_ms").get_parameter_value().integer_value
  88. )
  89. self.language = (
  90. self.get_parameter("language").get_parameter_value().string_value
  91. )
  92. self.use_oline_asr = (
  93. self.get_parameter("use_oline_asr").get_parameter_value().bool_value
  94. )
  95. self.mic_serial_port = (
  96. self.get_parameter("mic_serial_port").get_parameter_value().string_value
  97. )
  98. self.mic_index = (
  99. self.get_parameter("mic_index").get_parameter_value().integer_value
  100. )
  101. self.regional_setting = (
  102. self.get_parameter("regional_setting").get_parameter_value().string_value
  103. )
  104. self.frame_bytes = int(
  105. self.sample_rate * self.frame_duration_ms / 1000
  106. ) # 音频帧大小 / Audio frame size
  107. # 大模型接口实例端 / Instance of the large model interface
  108. # 传入 logger 用于调试日志
  109. self.modelinterface = large_model_interface.model_interface(logger=self.get_logger())
  110. # 初始化 WebRTC VAD / Initialize WebRTC VAD
  111. self.vad = webrtcvad.Vad()
  112. self.vad.set_mode(self.VAD_MODE)
  113. self.current_thread = None # 唤醒处理线程 / Thread for handling wake-up events
  114. self.stop_event = threading.Event()
  115. def main_loop(self):
  116. while rclpy.ok():
  117. while (
  118. self.audio_request_queue.qsize() > 1
  119. ): # 只处理最近的一次唤醒请求,防止重复唤醒 / Process only the most recent wake-up request to prevent duplicates
  120. self.audio_request_queue.get()
  121. if not self.audio_request_queue.empty():
  122. self.audio_request_queue.get()
  123. self.wakeup_pub.publish(
  124. Bool(data=True)
  125. ) # 发布唤醒信号 / Publish wake-up signal
  126. self.get_logger().info("I'm here")
  127. playsound(
  128. self.audio_dict[self.first_response]
  129. ) # 应答用户(阻塞,等提示音播完)/ Respond to the user (blocking)
  130. if (
  131. self.current_thread and self.current_thread.is_alive()
  132. ): # 打断上次的唤醒处理线程 / Interrupt the previous wake-up handling thread
  133. self.stop_event.set() # 通知旧线程停止(不等待)
  134. self.stop_event = threading.Event() # 创建新的 stop_event
  135. self.current_thread = threading.Thread(target=self.kws_handler)
  136. self.current_thread.daemon = True
  137. self.current_thread.start()
  138. rclpy.spin_once(self, timeout_sec=0.1)
  139. def kws_handler(self, play_error_response=True) -> None:
  140. if self.stop_event.is_set():
  141. return
  142. # 清空 buffer 中已有的旧帧,确保从"当前时刻"开始录音
  143. while not self.audio_buffer.empty():
  144. try:
  145. self.audio_buffer.get_nowait()
  146. except queue.Empty:
  147. break
  148. if self.listen_for_speech(self.mic_index):
  149. asr_text = self.ASR_conversion(
  150. self.user_speechdir
  151. ) # 进行 ASR 转换 / Perform ASR conversion
  152. if (
  153. asr_text == "error"
  154. ): # 检查 ASR 结果长度是否小于4个字符 / Check if ASR result length is less than 4 characters
  155. self.get_logger().warn(
  156. "I still don't understand what you mean. Please try again"
  157. )
  158. if play_error_response:
  159. playsound(
  160. self.audio_dict[self.error_response]
  161. ) # 错误响应 / Error response
  162. else:
  163. self.get_logger().info(asr_text)
  164. self.get_logger().info("😀okay, let me think for a moment...")
  165. self.asr_pub_result(asr_text) # 发布 ASR结果 / Publish ASR result
  166. else:
  167. return
  168. def asr_control_callback(self, msg):
  169. """
  170. 处理 /asr/control 控制指令
  171. """
  172. command = msg.data.strip()
  173. if command in ["continue_listen", "start_listen", "listen_once"]:
  174. self.get_logger().info(f"[多轮对话] 收到 ASR 控制指令: {command}")
  175. # 停止旧线程
  176. if self.current_thread and self.current_thread.is_alive():
  177. self.stop_event.set()
  178. time.sleep(0.05)
  179. # 清空 buffer
  180. while not self.audio_buffer.empty():
  181. try:
  182. self.audio_buffer.get_nowait()
  183. except queue.Empty:
  184. break
  185. # 创建新线程直接开始录音(不播放唤醒音)
  186. self.stop_event = threading.Event()
  187. self.current_thread = threading.Thread(
  188. target=self.kws_handler,
  189. kwargs={"play_error_response": False}
  190. )
  191. self.current_thread.daemon = True
  192. self.current_thread.start()
  193. return
  194. if command == "wake_listen":
  195. self.get_logger().info("[ASR控制] 收到带唤醒音的监听指令")
  196. # 停止旧线程
  197. if self.current_thread and self.current_thread.is_alive():
  198. self.stop_event.set()
  199. time.sleep(0.05)
  200. # 清空 buffer
  201. while not self.audio_buffer.empty():
  202. try:
  203. self.audio_buffer.get_nowait()
  204. except queue.Empty:
  205. break
  206. # 发布唤醒信号
  207. self.wakeup_pub.publish(Bool(data=True))
  208. self.get_logger().info("I'm here")
  209. playsound(self.audio_dict[self.first_response])
  210. # 创建新线程
  211. self.stop_event = threading.Event()
  212. self.current_thread = threading.Thread(
  213. target=self.kws_handler,
  214. kwargs={"play_error_response": True}
  215. )
  216. self.current_thread.daemon = True
  217. self.current_thread.start()
  218. return
  219. if command in ["stop_listen", "cancel"]:
  220. self.get_logger().info(f"[ASR控制] 收到停止监听指令: {command}")
  221. self.stop_event.set()
  222. return
  223. self.get_logger().warn(f"[ASR控制] 未知指令: {command}")
  224. def system_sound_init(
  225. self,
  226. ): # 初始化系统声音相关的功能 / Initialize system sound functionality
  227. pkg_path = get_package_share_directory("largemodel")
  228. self.audio_dict = {} # 系统声音字典 / Dictionary of system sounds
  229. self.audio_dict["longwan-women-1"] = os.path.join(
  230. pkg_path, "resources_file", "longwan-women-1.mp3"
  231. )
  232. self.audio_dict["longwan-women-2"] = os.path.join(
  233. pkg_path, "resources_file", "longwan-women-2.mp3"
  234. )
  235. self.audio_dict["longxiaochun-women-1"] = os.path.join(
  236. pkg_path, "resources_file", "longxiaochun-women-1.mp3"
  237. )
  238. self.audio_dict["longxiaochun-women-2"] = os.path.join(
  239. pkg_path, "resources_file", "longxiaochun-women-2.mp3"
  240. )
  241. def asr_mdoel_init(self): # 初始化asr模型 / Initialize ASR model
  242. if self.regional_setting == "international":
  243. self.get_logger().info(
  244. f"The online asr model :XUN-FEI ASR is loaded"
  245. )
  246. elif self.regional_setting == "China":
  247. if self.use_oline_asr:
  248. self.get_logger().info(
  249. f"The online asr model :{self.modelinterface.init_oline_asr(self.language)} is loaded"
  250. )
  251. else:
  252. # -------- SenseVoiceSmall 语音识别 --模型加载----- / Load SenseVoiceSmall online ASR model
  253. self.modelinterface.init_local_asr_model()
  254. self.get_logger().info("The asr model :SenseVoiceSmall is loaded")
  255. else:
  256. while True:
  257. self.get_logger().info('Please check the regional_setting parameter in yahboom.yaml file, it should be either "China" or "international".')
  258. time.sleep(1)
  259. def language_init(self):
  260. if self.language == "zh":
  261. self.first_response = "longwan-women-1"
  262. self.error_response = "longwan-women-2"
  263. elif self.language == "en":
  264. self.first_response = "longxiaochun-women-1"
  265. self.error_response = "longxiaochun-women-2"
  266. else:
  267. while True:
  268. self.get_logger().error(
  269. "language setting error,please check your language setting"
  270. ) # 语言设置错误,请检查语言设置 / Language setting error, please check your language setting
  271. time.sleep(3)
  272. def kws_init(
  273. self,
  274. ): # 初始化关键词唤醒相关的内容 / Initialize keyword spotting (KWS) related content
  275. self.port_name = self.mic_serial_port
  276. self.audio_request_queue = (
  277. queue.Queue()
  278. ) # 用于传递音频请求 / Queue for passing audio requests
  279. self.serial_port = kws_mic(
  280. port=self.port_name, kwsquence=self.audio_request_queue, baudrate=115200
  281. )
  282. self.serial_port.open()
  283. if not self.serial_port.ser or not self.serial_port.ser.is_open:
  284. while True:
  285. time.sleep(1)
  286. self.get_logger().error(
  287. "Failed to open kws serial port.Please check whether the hardware wiring or the voice module is normal?"
  288. ) # 未能打开kws串口 / Failed to open KWS serial port
  289. receive_thread = threading.Thread(target=self.serial_port.receive_data)
  290. receive_thread.daemon = True
  291. receive_thread.start()
  292. # 初始化常驻音频读取线程 / Initialize persistent audio capture thread
  293. self.audio_buffer = queue.Queue(
  294. maxsize=100
  295. ) # 环形 buffer,保留约 100 帧(~3秒)
  296. self.audio_capture_running = True
  297. self.audio_capture_thread = threading.Thread(target=self._audio_capture_loop)
  298. self.audio_capture_thread.daemon = True
  299. self.audio_capture_thread.start()
  300. self.get_logger().info("Audio capture thread started (persistent mode)")
  301. def _audio_capture_loop(self):
  302. """常驻音频读取线程,持续从麦克风读取音频帧到 buffer"""
  303. p = pyaudio.PyAudio()
  304. stream_kwargs = {
  305. "format": pyaudio.paInt16,
  306. "channels": 1,
  307. "rate": self.sample_rate,
  308. "input": True,
  309. "frames_per_buffer": self.frame_bytes,
  310. }
  311. if self.mic_index >= 0:
  312. stream_kwargs["input_device_index"] = self.mic_index
  313. stream = p.open(**stream_kwargs)
  314. self.get_logger().info("Audio stream opened (persistent mode)")
  315. try:
  316. while self.audio_capture_running:
  317. frame = stream.read(self.frame_bytes, exception_on_overflow=False)
  318. if self.audio_buffer.full():
  319. try:
  320. self.audio_buffer.get_nowait() # 丢弃最旧的帧
  321. except queue.Empty:
  322. pass
  323. self.audio_buffer.put(frame)
  324. finally:
  325. stream.stop_stream()
  326. stream.close()
  327. p.terminate()
  328. self.get_logger().info("Audio capture thread stopped")
  329. def asr_pub_result(self, asr_result: str) -> None:
  330. msg = String(data=asr_result)
  331. self.asr_pub.publish(msg)
  332. # @measure_execution_time
  333. def ASR_conversion(self, input_file: str) -> str:
  334. if self.regional_setting == "international":
  335. res=rec_wav_music_en()
  336. if res is not None:
  337. return res
  338. else:
  339. return "error"
  340. else:
  341. if self.use_oline_asr:
  342. result = self.modelinterface.oline_asr(input_file)
  343. if result[0] == "ok" and len(result[1]) > 4:
  344. return result[1]
  345. else:
  346. self.get_logger().error(f"ASR Error:{result[1]}") # ASR错误 / ASR error
  347. return "error"
  348. else:
  349. result = self.modelinterface.SenseVoiceSmall_ASR(input_file)
  350. if result[0] == "ok" and len(result[1]) > 4:
  351. return result[1]
  352. else:
  353. self.get_logger().error(f"ASR Error:{result[1]}") # ASR错误 / ASR error
  354. return "error"
  355. def listen_for_speech(self, mic_index=0):
  356. self.record_status_pub.publish(Bool(data=True))
  357. audio_buffer = []
  358. silence_counter = 0
  359. MAX_SILENCE_FRAMES = 30 # 30帧*30ms=900ms静音后停止 / Stop after 900ms of silence (30 frames * 30ms)
  360. speaking = False # 语音活动标志 / Flag indicating speech activity
  361. frame_counter = 0 # 计数器 / Frame counter
  362. empty_count = 0 # 连续空帧计数 / Consecutive empty frame count
  363. MAX_EMPTY_FRAMES = 200 # 约6秒无音频则退出 / Exit after ~6s of no audio
  364. WAIT_SPEECH_TIMEOUT_SEC = 8.0 # 等待用户开口超时 / Wait for user to start speaking timeout
  365. wait_start_time = time.time()
  366. # 通过蜂鸣器提示用户讲话 / Prompt the user to speak via the buzzer
  367. self.pub_beep.publish(UInt16(data=1))
  368. self.pub_beep.publish(UInt16(data=0))
  369. while True:
  370. if self.stop_event.is_set():
  371. self.record_status_pub.publish(Bool(data=False))
  372. return False
  373. try:
  374. frame = self.audio_buffer.get(timeout=0.5)
  375. empty_count = 0
  376. except queue.Empty:
  377. empty_count += 1
  378. if empty_count >= MAX_EMPTY_FRAMES:
  379. self.get_logger().warn("No audio input, exiting recording")
  380. self.record_status_pub.publish(Bool(data=False))
  381. return False
  382. continue
  383. is_speech = self.vad.is_speech(frame, self.sample_rate)
  384. # 等待用户开口超时检测
  385. if not speaking and time.time() - wait_start_time > WAIT_SPEECH_TIMEOUT_SEC:
  386. self.get_logger().warn("No speech detected within timeout, exiting recording")
  387. self.record_status_pub.publish(Bool(data=False))
  388. return False
  389. if is_speech:
  390. speaking = True
  391. audio_buffer.append(frame)
  392. silence_counter = 0
  393. else:
  394. if speaking:
  395. silence_counter += 1
  396. audio_buffer.append(frame)
  397. if silence_counter >= MAX_SILENCE_FRAMES:
  398. break
  399. frame_counter += 1
  400. if frame_counter % 2 == 0:
  401. self.get_logger().info("1" if is_speech else "-")
  402. self.record_status_pub.publish(Bool(data=False))
  403. if speaking and len(audio_buffer) > 0:
  404. clean_buffer = (
  405. audio_buffer[:-MAX_SILENCE_FRAMES]
  406. if len(audio_buffer) > MAX_SILENCE_FRAMES
  407. else audio_buffer
  408. )
  409. with wave.open(self.user_speechdir, "wb") as wf:
  410. wf.setnchannels(1)
  411. wf.setsampwidth(2) # 16-bit = 2 bytes
  412. wf.setframerate(self.sample_rate)
  413. wf.writeframes(b"".join(clean_buffer))
  414. return True
  415. return False
  416. def main(args=None):
  417. rclpy.init(args=args)
  418. sense_voice_node = ASRNode()
  419. try:
  420. sense_voice_node.main_loop()
  421. except KeyboardInterrupt:
  422. pass
  423. finally:
  424. # 停止常驻音频读取线程 / Stop the persistent audio capture thread
  425. sense_voice_node.audio_capture_running = False
  426. if sense_voice_node.audio_capture_thread.is_alive():
  427. sense_voice_node.audio_capture_thread.join(timeout=2)
  428. sense_voice_node.destroy_node()
  429. rclpy.shutdown()
  430. if __name__ == "__main__":
  431. main()