Module livekit.agents.stt.fallback_adapter
Classes
class AvailabilityChangedEvent (stt: STT, available: bool)
-
Expand source code
@dataclass class AvailabilityChangedEvent: stt: STT available: bool
AvailabilityChangedEvent(stt: 'STT', available: 'bool')
Class variables
var available : bool
var stt : STT
class FallbackAdapter (stt: list[STT],
*,
attempt_timeout: float = 10.0,
max_retry_per_stt: int = 1,
retry_interval: float = 5)-
Expand source code
class FallbackAdapter( STT[Literal["stt_availability_changed"]], ): def __init__( self, stt: list[STT], *, attempt_timeout: float = 10.0, max_retry_per_stt: int = 1, retry_interval: float = 5, ) -> None: if len(stt) < 1: raise ValueError("At least one STT instance must be provided.") super().__init__( capabilities=STTCapabilities( streaming=all(t.capabilities.streaming for t in stt), interim_results=all(t.capabilities.interim_results for t in stt), ) ) self._stt_instances = stt self._attempt_timeout = attempt_timeout self._max_retry_per_stt = max_retry_per_stt self._retry_interval = retry_interval self._status: list[_STTStatus] = [ _STTStatus( available=True, recovering_synthesize_task=None, recovering_stream_task=None, ) for _ in self._stt_instances ] async def _try_recognize( self, *, stt: STT, buffer: utils.AudioBuffer, language: str | None = None, conn_options: APIConnectOptions, recovering: bool = False, ) -> SpeechEvent: try: return await stt.recognize( buffer, language=language, conn_options=dataclasses.replace( conn_options, max_retry=self._max_retry_per_stt, timeout=self._attempt_timeout, retry_interval=self._retry_interval, ), ) except asyncio.TimeoutError: if recovering: logger.warning( f"{stt.label} recovery timed out", extra={"streamed": False} ) raise logger.warning( f"{stt.label} timed out, switching to next STT", extra={"streamed": False}, ) raise except APIError as e: if recovering: logger.warning( f"{stt.label} recovery failed", exc_info=e, extra={"streamed": False}, ) raise logger.warning( f"{stt.label} failed, switching to next STT", exc_info=e, extra={"streamed": False}, ) raise except Exception: if recovering: logger.exception( f"{stt.label} recovery unexpected error", extra={"streamed": False} ) raise logger.exception( f"{stt.label} unexpected error, switching to next STT", extra={"streamed": False}, ) raise def _try_recovery( self, *, stt: STT, buffer: utils.AudioBuffer, language: str | None, conn_options: APIConnectOptions, ) -> None: stt_status = self._status[self._stt_instances.index(stt)] if ( stt_status.recovering_synthesize_task is None or stt_status.recovering_synthesize_task.done() ): async def _recover_stt_task(stt: STT) -> None: try: await self._try_recognize( stt=stt, buffer=buffer, language=language, conn_options=conn_options, recovering=True, ) stt_status.available = True logger.info(f"{stt.label} recovered") self.emit( "stt_availability_changed", AvailabilityChangedEvent(stt=stt, available=True), ) except Exception: return stt_status.recovering_synthesize_task = asyncio.create_task( _recover_stt_task(stt) ) async def _recognize_impl( self, buffer: utils.AudioBuffer, *, language: str | None, conn_options: APIConnectOptions, ): start_time = time.time() all_failed = all(not stt_status.available for stt_status in self._status) if all_failed: logger.error("all STTs are unavailable, retrying..") for i, stt in enumerate(self._stt_instances): stt_status = self._status[i] if stt_status.available or all_failed: try: return await self._try_recognize( stt=stt, buffer=buffer, language=language, conn_options=conn_options, recovering=False, ) except Exception: # exceptions already logged inside _try_recognize if stt_status.available: stt_status.available = False self.emit( "stt_availability_changed", AvailabilityChangedEvent(stt=stt, available=False), ) self._try_recovery( stt=stt, buffer=buffer, language=language, conn_options=conn_options ) raise APIConnectionError( "all STTs failed (%s) after %s seconds" % ( [stt.label for stt in self._stt_instances], time.time() - start_time, ) ) async def recognize( self, buffer: AudioBuffer, *, language: str | None = None, conn_options: APIConnectOptions = DEFAULT_FALLBACK_API_CONNECT_OPTIONS, ) -> SpeechEvent: return await super().recognize( buffer, language=language, conn_options=conn_options ) def stream( self, *, language: str | None = None, conn_options: APIConnectOptions = DEFAULT_FALLBACK_API_CONNECT_OPTIONS, ) -> RecognizeStream: return FallbackRecognizeStream( stt=self, language=language, conn_options=conn_options ) async def aclose(self) -> None: for stt_status in self._status: if stt_status.recovering_synthesize_task is not None: await aio.gracefully_cancel(stt_status.recovering_synthesize_task) if stt_status.recovering_stream_task is not None: await aio.gracefully_cancel(stt_status.recovering_stream_task)
Helper class that provides a standard way to create an ABC using inheritance.
Ancestors
- STT
- abc.ABC
- EventEmitter
- typing.Generic
Methods
async def recognize(self,
buffer: AudioBuffer,
*,
language: str | None = None,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=0, retry_interval=5.0, timeout=10.0)) ‑> SpeechEvent-
Expand source code
async def recognize( self, buffer: AudioBuffer, *, language: str | None = None, conn_options: APIConnectOptions = DEFAULT_FALLBACK_API_CONNECT_OPTIONS, ) -> SpeechEvent: return await super().recognize( buffer, language=language, conn_options=conn_options )
def stream(self,
*,
language: str | None = None,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=0, retry_interval=5.0, timeout=10.0)) ‑> RecognizeStream-
Expand source code
def stream( self, *, language: str | None = None, conn_options: APIConnectOptions = DEFAULT_FALLBACK_API_CONNECT_OPTIONS, ) -> RecognizeStream: return FallbackRecognizeStream( stt=self, language=language, conn_options=conn_options )
Inherited members
class FallbackRecognizeStream (*,
stt: FallbackAdapter,
language: str | None,
conn_options: APIConnectOptions)-
Expand source code
class FallbackRecognizeStream(RecognizeStream): def __init__( self, *, stt: FallbackAdapter, language: str | None, conn_options: APIConnectOptions, ): super().__init__(stt=stt, conn_options=conn_options, sample_rate=None) self._language = language self._fallback_adapter = stt self._recovering_streams: list[RecognizeStream] = [] async def _run(self) -> None: start_time = time.time() all_failed = all( not stt_status.available for stt_status in self._fallback_adapter._status ) if all_failed: logger.error("all STTs are unavailable, retrying..") main_stream: RecognizeStream | None = None forward_input_task: asyncio.Task | None = None async def _forward_input_task() -> None: with contextlib.suppress(RuntimeError): # stream might be closed async for data in self._input_ch: for stream in self._recovering_streams: if isinstance(data, rtc.AudioFrame): stream.push_frame(data) elif isinstance(data, self._FlushSentinel): stream.flush() if main_stream is not None: if isinstance(data, rtc.AudioFrame): main_stream.push_frame(data) elif isinstance(data, self._FlushSentinel): main_stream.flush() if main_stream is not None: main_stream.end_input() for i, stt in enumerate(self._fallback_adapter._stt_instances): stt_status = self._fallback_adapter._status[i] if stt_status.available or all_failed: try: main_stream = stt.stream( language=self._language, conn_options=dataclasses.replace( self._conn_options, max_retry=self._fallback_adapter._max_retry_per_stt, timeout=self._fallback_adapter._attempt_timeout, retry_interval=self._fallback_adapter._retry_interval, ), ) if forward_input_task is None or forward_input_task.done(): forward_input_task = asyncio.create_task(_forward_input_task()) try: async with main_stream: async for ev in main_stream: self._event_ch.send_nowait(ev) except asyncio.TimeoutError: logger.warning( f"{stt.label} timed out, switching to next STT", extra={"streamed": True}, ) raise except APIError as e: logger.warning( f"{stt.label} failed, switching to next STT", exc_info=e, extra={"streamed": True}, ) raise except Exception: logger.exception( f"{stt.label} unexpected error, switching to next STT", extra={"streamed": True}, ) raise return except Exception: if stt_status.available: stt_status.available = False self._stt.emit( "stt_availability_changed", AvailabilityChangedEvent(stt=stt, available=False), ) self._try_recovery(stt) if forward_input_task is not None: await aio.gracefully_cancel(forward_input_task) await asyncio.gather(*[stream.aclose() for stream in self._recovering_streams]) raise APIConnectionError( "all STTs failed (%s) after %s seconds" % ( [stt.label for stt in self._fallback_adapter._stt_instances], time.time() - start_time, ) ) def _try_recovery(self, stt: STT) -> None: stt_status = self._fallback_adapter._status[ self._fallback_adapter._stt_instances.index(stt) ] if ( stt_status.recovering_stream_task is None or stt_status.recovering_stream_task.done() ): stream = stt.stream( language=self._language, conn_options=dataclasses.replace( self._conn_options, max_retry=0, timeout=self._fallback_adapter._attempt_timeout, ), ) self._recovering_streams.append(stream) async def _recover_stt_task() -> None: try: nb_transcript = 0 async with stream: async for ev in stream: if ev.type in SpeechEventType.FINAL_TRANSCRIPT: if not ev.alternatives or not ev.alternatives[0].text: continue nb_transcript += 1 break if nb_transcript == 0: return stt_status.available = True logger.info(f"tts.FallbackAdapter, {stt.label} recovered") self._fallback_adapter.emit( "stt_availability_changed", AvailabilityChangedEvent(stt=stt, available=True), ) except asyncio.TimeoutError: logger.warning( f"{stream._stt.label} recovery timed out", extra={"streamed": True}, ) except APIError as e: logger.warning( f"{stream._stt.label} recovery failed", exc_info=e, extra={"streamed": True}, ) except Exception: logger.exception( f"{stream._stt.label} recovery unexpected error", extra={"streamed": True}, ) raise stt_status.recovering_stream_task = task = asyncio.create_task( _recover_stt_task() ) task.add_done_callback(lambda _: self._recovering_streams.remove(stream))
Helper class that provides a standard way to create an ABC using inheritance.
Args: sample_rate : int or None, optional The desired sample rate for the audio input. If specified, the audio input will be automatically resampled to match the given sample rate before being processed for Speech-to-Text. If not provided (None), the input will retain its original sample rate.
Ancestors
- RecognizeStream
- abc.ABC
Inherited members