Module livekit.agents.inference.interruption
Classes
class AdaptiveInterruptionDetector (*,
threshold: float = 0.5,
min_interruption_duration: float = 0.05,
max_audio_duration: float = 3,
audio_prefix_duration: float = 1.0,
detection_interval: float = 0.1,
inference_timeout: float = 0.7,
base_url: str | None = None,
api_key: str | None = None,
api_secret: str | None = None,
http_session: aiohttp.ClientSession | None = None)-
Expand source code
class AdaptiveInterruptionDetector( rtc.EventEmitter[ Literal[ "overlapping_speech", "error", "metrics_collected", ] ], ): def __init__( self, *, threshold: float = THRESHOLD, min_interruption_duration: float = MIN_INTERRUPTION_DURATION, max_audio_duration: float = MAX_AUDIO_DURATION, audio_prefix_duration: float = AUDIO_PREFIX_DURATION, detection_interval: float = DETECTION_INTERVAL, inference_timeout: float = REMOTE_INFERENCE_TIMEOUT, base_url: str | None = None, api_key: str | None = None, api_secret: str | None = None, http_session: aiohttp.ClientSession | None = None, ) -> None: """ Initialize a AdaptiveInterruptionDetector instance. Args: threshold (float, optional): The threshold for the interruption detection, defaults to 0.5. min_interruption_duration (float, optional): The minimum duration, in seconds, of the interruption event, defaults to 50ms. max_audio_duration (float, optional): The maximum audio duration, including the audio prefix, in seconds, for the interruption detection, defaults to 3s. audio_prefix_duration (float, optional): The audio prefix duration, in seconds, for the interruption detection, defaults to 0.5s. detection_interval (float, optional): The interval between detections, in seconds, for the interruption detection, defaults to 0.1s. inference_timeout (float, optional): The timeout for the interruption detection, defaults to 1 second. base_url (str, optional): The base URL for the interruption detection, defaults to the shared LIVEKIT_REMOTE_EOT_URL environment variable. api_key (str, optional): The API key for the interruption detection, defaults to the LIVEKIT_INFERENCE_API_KEY environment variable. api_secret (str, optional): The API secret for the interruption detection, defaults to the LIVEKIT_INFERENCE_API_SECRET environment variable. http_session (aiohttp.ClientSession, optional): The HTTP session to use for the interruption detection. """ super().__init__() if max_audio_duration > 3.0: raise ValueError("max_audio_duration must be less than or equal to 3.0 seconds") lk_base_url = ( base_url if base_url else os.getenv("LIVEKIT_REMOTE_EOT_URL", get_default_inference_url()) ) lk_api_key: str = api_key if api_key else "" lk_api_secret: str = api_secret if api_secret else "" # use LiveKit credentials if using the inference service (production or staging) is_inference_url = lk_base_url in (DEFAULT_INFERENCE_URL, STAGING_INFERENCE_URL) if is_inference_url: lk_api_key = ( api_key if api_key else os.getenv("LIVEKIT_INFERENCE_API_KEY", os.getenv("LIVEKIT_API_KEY", "")) ) if not lk_api_key: raise ValueError( "api_key is required, either as argument or set LIVEKIT_API_KEY environmental variable" ) lk_api_secret = ( api_secret if api_secret else os.getenv("LIVEKIT_INFERENCE_API_SECRET", os.getenv("LIVEKIT_API_SECRET", "")) ) if not lk_api_secret: raise ValueError( "api_secret is required, either as argument or set LIVEKIT_API_SECRET environmental variable" ) use_proxy = True else: use_proxy = False self._opts = InterruptionOptions( sample_rate=SAMPLE_RATE, threshold=threshold, min_frames=math.ceil(min_interruption_duration * _FRAMES_PER_SECOND), max_audio_duration=max_audio_duration, audio_prefix_duration=audio_prefix_duration, detection_interval=detection_interval, inference_timeout=inference_timeout, base_url=lk_base_url, api_key=lk_api_key, api_secret=lk_api_secret, use_proxy=use_proxy, ) self._label = f"{type(self).__module__}.{type(self).__name__}" self._sample_rate = SAMPLE_RATE self._session = http_session self._streams = weakref.WeakSet[InterruptionHttpStream | InterruptionWebSocketStream]() logger.info( "adaptive interruption detector initialized", extra={ "base_url": self._opts.base_url, "detection_interval": self._opts.detection_interval, "audio_prefix_duration": self._opts.audio_prefix_duration, "max_audio_duration": self._opts.max_audio_duration, "min_frames": self._opts.min_frames, "threshold": self._opts.threshold, "inference_timeout": self._opts.inference_timeout, "use_proxy": self._opts.use_proxy, }, ) @property def model(self) -> str: return "adaptive interruption" @property def provider(self) -> str: return "livekit" @property def label(self) -> str: return self._label @property def sample_rate(self) -> int: return self._sample_rate def _emit_error(self, api_error: Exception, recoverable: bool) -> None: self.emit( "error", InterruptionDetectionError( label=self._label, error=api_error, recoverable=recoverable, ), ) def _ensure_session(self) -> aiohttp.ClientSession: if not self._session: self._session = http_context.http_session() return self._session def stream( self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS ) -> InterruptionHttpStream | InterruptionWebSocketStream: try: stream: InterruptionHttpStream | InterruptionWebSocketStream if self._opts.use_proxy: stream = InterruptionWebSocketStream(model=self, conn_options=conn_options) else: stream = InterruptionHttpStream(model=self, conn_options=conn_options) except Exception as e: self._emit_error(e, recoverable=False) raise self._streams.add(stream) return stream def update_options( self, *, threshold: NotGivenOr[float] = NOT_GIVEN, min_interruption_duration: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(threshold): self._opts.threshold = threshold if is_given(min_interruption_duration): self._opts.min_frames = math.ceil(min_interruption_duration * _FRAMES_PER_SECOND) for stream in self._streams: stream.update_options( threshold=threshold, min_interruption_duration=min_interruption_duration )Abstract base class for generic types.
On Python 3.12 and newer, generic classes implicitly inherit from Generic when they declare a parameter list after the class's name::
class Mapping[KT, VT]: def __getitem__(self, key: KT) -> VT: ... # Etc.On older versions of Python, however, generic classes have to explicitly inherit from Generic.
After a class has been declared to be generic, it can then be used as follows::
def lookup_name[KT, VT](mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return defaultInitialize a AdaptiveInterruptionDetector instance.
Args
threshold:float, optional- The threshold for the interruption detection, defaults to 0.5.
min_interruption_duration:float, optional- The minimum duration, in seconds, of the interruption event, defaults to 50ms.
max_audio_duration:float, optional- The maximum audio duration, including the audio prefix, in seconds, for the interruption detection, defaults to 3s.
audio_prefix_duration:float, optional- The audio prefix duration, in seconds, for the interruption detection, defaults to 0.5s.
detection_interval:float, optional- The interval between detections, in seconds, for the interruption detection, defaults to 0.1s.
inference_timeout:float, optional- The timeout for the interruption detection, defaults to 1 second.
base_url:str, optional- The base URL for the interruption detection, defaults to the shared LIVEKIT_REMOTE_EOT_URL environment variable.
api_key:str, optional- The API key for the interruption detection, defaults to the LIVEKIT_INFERENCE_API_KEY environment variable.
api_secret:str, optional- The API secret for the interruption detection, defaults to the LIVEKIT_INFERENCE_API_SECRET environment variable.
http_session:aiohttp.ClientSession, optional- The HTTP session to use for the interruption detection.
Ancestors
- EventEmitter
- typing.Generic
Instance variables
prop label : str-
Expand source code
@property def label(self) -> str: return self._label prop model : str-
Expand source code
@property def model(self) -> str: return "adaptive interruption" prop provider : str-
Expand source code
@property def provider(self) -> str: return "livekit" prop sample_rate : int-
Expand source code
@property def sample_rate(self) -> int: return self._sample_rate
Methods
def stream(self,
*,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> InterruptionHttpStream | InterruptionWebSocketStream-
Expand source code
def stream( self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS ) -> InterruptionHttpStream | InterruptionWebSocketStream: try: stream: InterruptionHttpStream | InterruptionWebSocketStream if self._opts.use_proxy: stream = InterruptionWebSocketStream(model=self, conn_options=conn_options) else: stream = InterruptionHttpStream(model=self, conn_options=conn_options) except Exception as e: self._emit_error(e, recoverable=False) raise self._streams.add(stream) return stream def update_options(self,
*,
threshold: NotGivenOr[float] = NOT_GIVEN,
min_interruption_duration: NotGivenOr[float] = NOT_GIVEN) ‑> None-
Expand source code
def update_options( self, *, threshold: NotGivenOr[float] = NOT_GIVEN, min_interruption_duration: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(threshold): self._opts.threshold = threshold if is_given(min_interruption_duration): self._opts.min_frames = math.ceil(min_interruption_duration * _FRAMES_PER_SECOND) for stream in self._streams: stream.update_options( threshold=threshold, min_interruption_duration=min_interruption_duration )
Inherited members
class InterruptionCacheEntry (*,
created_at: int = <factory>,
speech_input: npt.NDArray[np.int16] | None = None,
total_duration: float | None = None,
prediction_duration: float | None = None,
detection_delay: float | None = None,
probabilities: npt.NDArray[np.float32] | None = None,
is_interruption: bool | None = None)-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionCacheEntry: """Typed cache entry for interruption inference results.""" created_at: int = field(default_factory=time.perf_counter_ns) """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation.""" speech_input: npt.NDArray[np.int16] | None = None total_duration: float | None = None prediction_duration: float | None = None detection_delay: float | None = None probabilities: npt.NDArray[np.float32] | None = None is_interruption: bool | None = None def get_total_duration(self, default: float = 0.0) -> float: """RTT (Round Trip Time) time taken to perform the inference, in seconds.""" return self.total_duration if self.total_duration is not None else default def get_prediction_duration(self, default: float = 0.0) -> float: """Time taken to perform the inference from the model side, in seconds.""" return self.prediction_duration if self.prediction_duration is not None else default def get_detection_delay(self, default: float = 0.0) -> float: """Total time from the onset of the speech to the final prediction, in seconds.""" return self.detection_delay if self.detection_delay is not None else default def get_probability(self, default: float = 0.0) -> float: """The conservative estimated probability of the interruption event.""" return ( _estimate_probability(self.probabilities) if self.probabilities is not None else default )Typed cache entry for interruption inference results.
Instance variables
var created_at : int-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionCacheEntry: """Typed cache entry for interruption inference results.""" created_at: int = field(default_factory=time.perf_counter_ns) """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation.""" speech_input: npt.NDArray[np.int16] | None = None total_duration: float | None = None prediction_duration: float | None = None detection_delay: float | None = None probabilities: npt.NDArray[np.float32] | None = None is_interruption: bool | None = None def get_total_duration(self, default: float = 0.0) -> float: """RTT (Round Trip Time) time taken to perform the inference, in seconds.""" return self.total_duration if self.total_duration is not None else default def get_prediction_duration(self, default: float = 0.0) -> float: """Time taken to perform the inference from the model side, in seconds.""" return self.prediction_duration if self.prediction_duration is not None else default def get_detection_delay(self, default: float = 0.0) -> float: """Total time from the onset of the speech to the final prediction, in seconds.""" return self.detection_delay if self.detection_delay is not None else default def get_probability(self, default: float = 0.0) -> float: """The conservative estimated probability of the interruption event.""" return ( _estimate_probability(self.probabilities) if self.probabilities is not None else default )The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation.
var detection_delay : float | None-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionCacheEntry: """Typed cache entry for interruption inference results.""" created_at: int = field(default_factory=time.perf_counter_ns) """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation.""" speech_input: npt.NDArray[np.int16] | None = None total_duration: float | None = None prediction_duration: float | None = None detection_delay: float | None = None probabilities: npt.NDArray[np.float32] | None = None is_interruption: bool | None = None def get_total_duration(self, default: float = 0.0) -> float: """RTT (Round Trip Time) time taken to perform the inference, in seconds.""" return self.total_duration if self.total_duration is not None else default def get_prediction_duration(self, default: float = 0.0) -> float: """Time taken to perform the inference from the model side, in seconds.""" return self.prediction_duration if self.prediction_duration is not None else default def get_detection_delay(self, default: float = 0.0) -> float: """Total time from the onset of the speech to the final prediction, in seconds.""" return self.detection_delay if self.detection_delay is not None else default def get_probability(self, default: float = 0.0) -> float: """The conservative estimated probability of the interruption event.""" return ( _estimate_probability(self.probabilities) if self.probabilities is not None else default ) var is_interruption : bool | None-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionCacheEntry: """Typed cache entry for interruption inference results.""" created_at: int = field(default_factory=time.perf_counter_ns) """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation.""" speech_input: npt.NDArray[np.int16] | None = None total_duration: float | None = None prediction_duration: float | None = None detection_delay: float | None = None probabilities: npt.NDArray[np.float32] | None = None is_interruption: bool | None = None def get_total_duration(self, default: float = 0.0) -> float: """RTT (Round Trip Time) time taken to perform the inference, in seconds.""" return self.total_duration if self.total_duration is not None else default def get_prediction_duration(self, default: float = 0.0) -> float: """Time taken to perform the inference from the model side, in seconds.""" return self.prediction_duration if self.prediction_duration is not None else default def get_detection_delay(self, default: float = 0.0) -> float: """Total time from the onset of the speech to the final prediction, in seconds.""" return self.detection_delay if self.detection_delay is not None else default def get_probability(self, default: float = 0.0) -> float: """The conservative estimated probability of the interruption event.""" return ( _estimate_probability(self.probabilities) if self.probabilities is not None else default ) var prediction_duration : float | None-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionCacheEntry: """Typed cache entry for interruption inference results.""" created_at: int = field(default_factory=time.perf_counter_ns) """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation.""" speech_input: npt.NDArray[np.int16] | None = None total_duration: float | None = None prediction_duration: float | None = None detection_delay: float | None = None probabilities: npt.NDArray[np.float32] | None = None is_interruption: bool | None = None def get_total_duration(self, default: float = 0.0) -> float: """RTT (Round Trip Time) time taken to perform the inference, in seconds.""" return self.total_duration if self.total_duration is not None else default def get_prediction_duration(self, default: float = 0.0) -> float: """Time taken to perform the inference from the model side, in seconds.""" return self.prediction_duration if self.prediction_duration is not None else default def get_detection_delay(self, default: float = 0.0) -> float: """Total time from the onset of the speech to the final prediction, in seconds.""" return self.detection_delay if self.detection_delay is not None else default def get_probability(self, default: float = 0.0) -> float: """The conservative estimated probability of the interruption event.""" return ( _estimate_probability(self.probabilities) if self.probabilities is not None else default ) var probabilities : numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.float32]] | None-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionCacheEntry: """Typed cache entry for interruption inference results.""" created_at: int = field(default_factory=time.perf_counter_ns) """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation.""" speech_input: npt.NDArray[np.int16] | None = None total_duration: float | None = None prediction_duration: float | None = None detection_delay: float | None = None probabilities: npt.NDArray[np.float32] | None = None is_interruption: bool | None = None def get_total_duration(self, default: float = 0.0) -> float: """RTT (Round Trip Time) time taken to perform the inference, in seconds.""" return self.total_duration if self.total_duration is not None else default def get_prediction_duration(self, default: float = 0.0) -> float: """Time taken to perform the inference from the model side, in seconds.""" return self.prediction_duration if self.prediction_duration is not None else default def get_detection_delay(self, default: float = 0.0) -> float: """Total time from the onset of the speech to the final prediction, in seconds.""" return self.detection_delay if self.detection_delay is not None else default def get_probability(self, default: float = 0.0) -> float: """The conservative estimated probability of the interruption event.""" return ( _estimate_probability(self.probabilities) if self.probabilities is not None else default ) var speech_input : numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.int16]] | None-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionCacheEntry: """Typed cache entry for interruption inference results.""" created_at: int = field(default_factory=time.perf_counter_ns) """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation.""" speech_input: npt.NDArray[np.int16] | None = None total_duration: float | None = None prediction_duration: float | None = None detection_delay: float | None = None probabilities: npt.NDArray[np.float32] | None = None is_interruption: bool | None = None def get_total_duration(self, default: float = 0.0) -> float: """RTT (Round Trip Time) time taken to perform the inference, in seconds.""" return self.total_duration if self.total_duration is not None else default def get_prediction_duration(self, default: float = 0.0) -> float: """Time taken to perform the inference from the model side, in seconds.""" return self.prediction_duration if self.prediction_duration is not None else default def get_detection_delay(self, default: float = 0.0) -> float: """Total time from the onset of the speech to the final prediction, in seconds.""" return self.detection_delay if self.detection_delay is not None else default def get_probability(self, default: float = 0.0) -> float: """The conservative estimated probability of the interruption event.""" return ( _estimate_probability(self.probabilities) if self.probabilities is not None else default ) var total_duration : float | None-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionCacheEntry: """Typed cache entry for interruption inference results.""" created_at: int = field(default_factory=time.perf_counter_ns) """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation.""" speech_input: npt.NDArray[np.int16] | None = None total_duration: float | None = None prediction_duration: float | None = None detection_delay: float | None = None probabilities: npt.NDArray[np.float32] | None = None is_interruption: bool | None = None def get_total_duration(self, default: float = 0.0) -> float: """RTT (Round Trip Time) time taken to perform the inference, in seconds.""" return self.total_duration if self.total_duration is not None else default def get_prediction_duration(self, default: float = 0.0) -> float: """Time taken to perform the inference from the model side, in seconds.""" return self.prediction_duration if self.prediction_duration is not None else default def get_detection_delay(self, default: float = 0.0) -> float: """Total time from the onset of the speech to the final prediction, in seconds.""" return self.detection_delay if self.detection_delay is not None else default def get_probability(self, default: float = 0.0) -> float: """The conservative estimated probability of the interruption event.""" return ( _estimate_probability(self.probabilities) if self.probabilities is not None else default )
Methods
def get_detection_delay(self, default: float = 0.0) ‑> float-
Expand source code
def get_detection_delay(self, default: float = 0.0) -> float: """Total time from the onset of the speech to the final prediction, in seconds.""" return self.detection_delay if self.detection_delay is not None else defaultTotal time from the onset of the speech to the final prediction, in seconds.
def get_prediction_duration(self, default: float = 0.0) ‑> float-
Expand source code
def get_prediction_duration(self, default: float = 0.0) -> float: """Time taken to perform the inference from the model side, in seconds.""" return self.prediction_duration if self.prediction_duration is not None else defaultTime taken to perform the inference from the model side, in seconds.
def get_probability(self, default: float = 0.0) ‑> float-
Expand source code
def get_probability(self, default: float = 0.0) -> float: """The conservative estimated probability of the interruption event.""" return ( _estimate_probability(self.probabilities) if self.probabilities is not None else default )The conservative estimated probability of the interruption event.
def get_total_duration(self, default: float = 0.0) ‑> float-
Expand source code
def get_total_duration(self, default: float = 0.0) -> float: """RTT (Round Trip Time) time taken to perform the inference, in seconds.""" return self.total_duration if self.total_duration is not None else defaultRTT (Round Trip Time) time taken to perform the inference, in seconds.
class InterruptionDetectionError (**data: Any)-
Expand source code
class InterruptionDetectionError(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) type: Literal["interruption_detection_error"] = "interruption_detection_error" timestamp: float = Field(default_factory=time.time) label: str error: Exception = Field(..., exclude=True) recoverable: boolUsage Documentation
A base class for creating Pydantic models.
Attributes
__class_vars__- The names of the class variables defined on the model.
__private_attributes__- Metadata about the private attributes of the model.
__signature__- The synthesized
__init__[Signature][inspect.Signature] of the model. __pydantic_complete__- Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__- The core schema of the model.
__pydantic_custom_init__- Whether the model has a custom
__init__function. __pydantic_decorators__- Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. __pydantic_generic_metadata__- Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__- Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__- The name of the post-init method for the model, if defined.
__pydantic_root_model__- Whether the model is a [
RootModel][pydantic.root_model.RootModel]. __pydantic_serializer__- The
pydantic-coreSchemaSerializerused to dump instances of the model. __pydantic_validator__- The
pydantic-coreSchemaValidatorused to validate instances of the model. __pydantic_fields__- A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. __pydantic_computed_fields__- A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. __pydantic_extra__- A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. __pydantic_fields_set__- The names of fields explicitly set during instantiation.
__pydantic_private__- Values of private attributes set on the model instance.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.Ancestors
- pydantic.main.BaseModel
Class variables
var error : Exceptionvar label : strvar model_configvar recoverable : boolvar timestamp : floatvar type : Literal['interruption_detection_error']
class InterruptionHttpStream (*,
model: AdaptiveInterruptionDetector,
conn_options: APIConnectOptions)-
Expand source code
class InterruptionHttpStream(InterruptionStreamBase): def __init__( self, *, model: AdaptiveInterruptionDetector, conn_options: APIConnectOptions ) -> None: super().__init__(model=model, conn_options=conn_options) def update_options( self, *, threshold: NotGivenOr[float] = NOT_GIVEN, min_interruption_duration: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(threshold): self._opts.threshold = threshold if is_given(min_interruption_duration): self._opts.min_frames = math.ceil(min_interruption_duration * _FRAMES_PER_SECOND) async def _run(self) -> None: async def _send_task(input_ch: aio.Chan[npt.NDArray[np.int16]]) -> None: async for data in input_ch: if ( overlap_started_at := self._overlap_started_at ) is None or not self._overlap_started: continue # we don't increment the request counter for hosted agents resp: InterruptionResponse = await self.predict(data) created_at = resp.created_at self._cache[created_at] = entry = InterruptionCacheEntry( created_at=created_at, speech_input=data, prediction_duration=resp.prediction_duration, total_duration=(time.perf_counter_ns() - created_at) / 1e9, detection_delay=time.time() - overlap_started_at, probabilities=resp.probabilities, is_interruption=resp.is_bargein, ) if entry.is_interruption and self._overlap_started: logger.debug("user interruption detected") if self._user_speech_span: self._update_user_speech_span(self._user_speech_span, entry) self._user_speech_span = None ev = OverlappingSpeechEvent.from_cache_entry( entry=entry, is_interruption=True, started_at=overlap_started_at, ended_at=time.time(), ) self.send(ev) self._overlap_started = False data_ch = aio.Chan[npt.NDArray[np.int16]]() tasks = [ asyncio.create_task(self._forward_data(data_ch)), asyncio.create_task(_send_task(data_ch)), ] try: await asyncio.gather(*tasks) finally: await aio.cancel_and_wait(*tasks) @log_exceptions(logger=logger) async def predict(self, waveform: np.ndarray) -> InterruptionResponse: created_at = perf_counter_ns() try: async with self._session.post( url=f"{self._opts.base_url}/bargein?threshold={self._opts.threshold}&min_frames={int(self._opts.min_frames)}&created_at={int(created_at)}", headers={ "Content-Type": "application/octet-stream", "Authorization": f"Bearer {create_access_token(self._opts.api_key, self._opts.api_secret)}", }, data=waveform.tobytes(), timeout=aiohttp.ClientTimeout(total=self._opts.inference_timeout), ) as resp: try: resp.raise_for_status() data: dict[str, Any] = await resp.json() result = InterruptionResponse.model_validate( data | { "prediction_duration": (time.perf_counter_ns() - created_at) / 1e9, "probabilities": np.array( data.get("probabilities", []), dtype=np.float32 ), } ) logger.trace( "interruption inference done", extra={ "created_at": created_at, "is_interruption": result.is_bargein, "prediction_duration": result.prediction_duration, }, ) return result except Exception as e: msg = await resp.text() status_code = ( e.status if isinstance(e, aiohttp.ClientResponseError) else resp.status ) raise APIStatusError( f"error during interruption prediction: {e}", body=msg, status_code=status_code, retryable=False if status_code == 429 else None, ) from e except asyncio.TimeoutError as e: raise APIStatusError( f"interruption inference timeout: {e}", status_code=408, retryable=False, ) from e except aiohttp.ClientError as e: raise APIConnectionError(f"interruption inference connection error: {e}") from e except APIError as e: raise e except Exception as e: raise APIError(f"error during interruption prediction: {e}") from eHelper class that provides a standard way to create an ABC using inheritance.
Ancestors
- InterruptionStreamBase
- abc.ABC
Methods
async def predict(self, waveform: np.ndarray) ‑> InterruptionResponse-
Expand source code
@log_exceptions(logger=logger) async def predict(self, waveform: np.ndarray) -> InterruptionResponse: created_at = perf_counter_ns() try: async with self._session.post( url=f"{self._opts.base_url}/bargein?threshold={self._opts.threshold}&min_frames={int(self._opts.min_frames)}&created_at={int(created_at)}", headers={ "Content-Type": "application/octet-stream", "Authorization": f"Bearer {create_access_token(self._opts.api_key, self._opts.api_secret)}", }, data=waveform.tobytes(), timeout=aiohttp.ClientTimeout(total=self._opts.inference_timeout), ) as resp: try: resp.raise_for_status() data: dict[str, Any] = await resp.json() result = InterruptionResponse.model_validate( data | { "prediction_duration": (time.perf_counter_ns() - created_at) / 1e9, "probabilities": np.array( data.get("probabilities", []), dtype=np.float32 ), } ) logger.trace( "interruption inference done", extra={ "created_at": created_at, "is_interruption": result.is_bargein, "prediction_duration": result.prediction_duration, }, ) return result except Exception as e: msg = await resp.text() status_code = ( e.status if isinstance(e, aiohttp.ClientResponseError) else resp.status ) raise APIStatusError( f"error during interruption prediction: {e}", body=msg, status_code=status_code, retryable=False if status_code == 429 else None, ) from e except asyncio.TimeoutError as e: raise APIStatusError( f"interruption inference timeout: {e}", status_code=408, retryable=False, ) from e except aiohttp.ClientError as e: raise APIConnectionError(f"interruption inference connection error: {e}") from e except APIError as e: raise e except Exception as e: raise APIError(f"error during interruption prediction: {e}") from e def update_options(self,
*,
threshold: NotGivenOr[float] = NOT_GIVEN,
min_interruption_duration: NotGivenOr[float] = NOT_GIVEN) ‑> None-
Expand source code
def update_options( self, *, threshold: NotGivenOr[float] = NOT_GIVEN, min_interruption_duration: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(threshold): self._opts.threshold = threshold if is_given(min_interruption_duration): self._opts.min_frames = math.ceil(min_interruption_duration * _FRAMES_PER_SECOND)
Inherited members
class InterruptionOptions (*,
sample_rate: int,
threshold: float,
min_frames: int,
max_audio_duration: float,
audio_prefix_duration: float,
detection_interval: float,
inference_timeout: float,
base_url: str,
api_key: str,
api_secret: str,
use_proxy: bool)-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionOptions: sample_rate: int """The sample rate of the audio frames, defaults to 16000Hz""" threshold: float """The threshold for the interruption detection, defaults to 0.5""" min_frames: int """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames""" max_audio_duration: float """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds""" audio_prefix_duration: float """The audio prefix duration for the interruption detection, defaults to 1.0 seconds""" detection_interval: float """The interval between detections, defaults to 0.1 seconds""" inference_timeout: float """The timeout for the interruption detection, defaults to 1 second""" base_url: str api_key: str api_secret: str use_proxy: bool """Whether to use the inference instead of the hosted API"""InterruptionOptions(*, sample_rate: 'int', threshold: 'float', min_frames: 'int', max_audio_duration: 'float', audio_prefix_duration: 'float', detection_interval: 'float', inference_timeout: 'float', base_url: 'str', api_key: 'str', api_secret: 'str', use_proxy: 'bool')
Instance variables
var api_key : str-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionOptions: sample_rate: int """The sample rate of the audio frames, defaults to 16000Hz""" threshold: float """The threshold for the interruption detection, defaults to 0.5""" min_frames: int """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames""" max_audio_duration: float """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds""" audio_prefix_duration: float """The audio prefix duration for the interruption detection, defaults to 1.0 seconds""" detection_interval: float """The interval between detections, defaults to 0.1 seconds""" inference_timeout: float """The timeout for the interruption detection, defaults to 1 second""" base_url: str api_key: str api_secret: str use_proxy: bool """Whether to use the inference instead of the hosted API""" var api_secret : str-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionOptions: sample_rate: int """The sample rate of the audio frames, defaults to 16000Hz""" threshold: float """The threshold for the interruption detection, defaults to 0.5""" min_frames: int """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames""" max_audio_duration: float """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds""" audio_prefix_duration: float """The audio prefix duration for the interruption detection, defaults to 1.0 seconds""" detection_interval: float """The interval between detections, defaults to 0.1 seconds""" inference_timeout: float """The timeout for the interruption detection, defaults to 1 second""" base_url: str api_key: str api_secret: str use_proxy: bool """Whether to use the inference instead of the hosted API""" var audio_prefix_duration : float-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionOptions: sample_rate: int """The sample rate of the audio frames, defaults to 16000Hz""" threshold: float """The threshold for the interruption detection, defaults to 0.5""" min_frames: int """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames""" max_audio_duration: float """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds""" audio_prefix_duration: float """The audio prefix duration for the interruption detection, defaults to 1.0 seconds""" detection_interval: float """The interval between detections, defaults to 0.1 seconds""" inference_timeout: float """The timeout for the interruption detection, defaults to 1 second""" base_url: str api_key: str api_secret: str use_proxy: bool """Whether to use the inference instead of the hosted API"""The audio prefix duration for the interruption detection, defaults to 1.0 seconds
var base_url : str-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionOptions: sample_rate: int """The sample rate of the audio frames, defaults to 16000Hz""" threshold: float """The threshold for the interruption detection, defaults to 0.5""" min_frames: int """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames""" max_audio_duration: float """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds""" audio_prefix_duration: float """The audio prefix duration for the interruption detection, defaults to 1.0 seconds""" detection_interval: float """The interval between detections, defaults to 0.1 seconds""" inference_timeout: float """The timeout for the interruption detection, defaults to 1 second""" base_url: str api_key: str api_secret: str use_proxy: bool """Whether to use the inference instead of the hosted API""" var detection_interval : float-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionOptions: sample_rate: int """The sample rate of the audio frames, defaults to 16000Hz""" threshold: float """The threshold for the interruption detection, defaults to 0.5""" min_frames: int """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames""" max_audio_duration: float """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds""" audio_prefix_duration: float """The audio prefix duration for the interruption detection, defaults to 1.0 seconds""" detection_interval: float """The interval between detections, defaults to 0.1 seconds""" inference_timeout: float """The timeout for the interruption detection, defaults to 1 second""" base_url: str api_key: str api_secret: str use_proxy: bool """Whether to use the inference instead of the hosted API"""The interval between detections, defaults to 0.1 seconds
var inference_timeout : float-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionOptions: sample_rate: int """The sample rate of the audio frames, defaults to 16000Hz""" threshold: float """The threshold for the interruption detection, defaults to 0.5""" min_frames: int """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames""" max_audio_duration: float """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds""" audio_prefix_duration: float """The audio prefix duration for the interruption detection, defaults to 1.0 seconds""" detection_interval: float """The interval between detections, defaults to 0.1 seconds""" inference_timeout: float """The timeout for the interruption detection, defaults to 1 second""" base_url: str api_key: str api_secret: str use_proxy: bool """Whether to use the inference instead of the hosted API"""The timeout for the interruption detection, defaults to 1 second
var max_audio_duration : float-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionOptions: sample_rate: int """The sample rate of the audio frames, defaults to 16000Hz""" threshold: float """The threshold for the interruption detection, defaults to 0.5""" min_frames: int """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames""" max_audio_duration: float """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds""" audio_prefix_duration: float """The audio prefix duration for the interruption detection, defaults to 1.0 seconds""" detection_interval: float """The interval between detections, defaults to 0.1 seconds""" inference_timeout: float """The timeout for the interruption detection, defaults to 1 second""" base_url: str api_key: str api_secret: str use_proxy: bool """Whether to use the inference instead of the hosted API"""The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds
var min_frames : int-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionOptions: sample_rate: int """The sample rate of the audio frames, defaults to 16000Hz""" threshold: float """The threshold for the interruption detection, defaults to 0.5""" min_frames: int """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames""" max_audio_duration: float """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds""" audio_prefix_duration: float """The audio prefix duration for the interruption detection, defaults to 1.0 seconds""" detection_interval: float """The interval between detections, defaults to 0.1 seconds""" inference_timeout: float """The timeout for the interruption detection, defaults to 1 second""" base_url: str api_key: str api_secret: str use_proxy: bool """Whether to use the inference instead of the hosted API"""The minimum number of frames to detect a interruption, defaults to 50ms/2 frames
var sample_rate : int-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionOptions: sample_rate: int """The sample rate of the audio frames, defaults to 16000Hz""" threshold: float """The threshold for the interruption detection, defaults to 0.5""" min_frames: int """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames""" max_audio_duration: float """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds""" audio_prefix_duration: float """The audio prefix duration for the interruption detection, defaults to 1.0 seconds""" detection_interval: float """The interval between detections, defaults to 0.1 seconds""" inference_timeout: float """The timeout for the interruption detection, defaults to 1 second""" base_url: str api_key: str api_secret: str use_proxy: bool """Whether to use the inference instead of the hosted API"""The sample rate of the audio frames, defaults to 16000Hz
var threshold : float-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionOptions: sample_rate: int """The sample rate of the audio frames, defaults to 16000Hz""" threshold: float """The threshold for the interruption detection, defaults to 0.5""" min_frames: int """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames""" max_audio_duration: float """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds""" audio_prefix_duration: float """The audio prefix duration for the interruption detection, defaults to 1.0 seconds""" detection_interval: float """The interval between detections, defaults to 0.1 seconds""" inference_timeout: float """The timeout for the interruption detection, defaults to 1 second""" base_url: str api_key: str api_secret: str use_proxy: bool """Whether to use the inference instead of the hosted API"""The threshold for the interruption detection, defaults to 0.5
var use_proxy : bool-
Expand source code
@dataclass(slots=True, kw_only=True) class InterruptionOptions: sample_rate: int """The sample rate of the audio frames, defaults to 16000Hz""" threshold: float """The threshold for the interruption detection, defaults to 0.5""" min_frames: int """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames""" max_audio_duration: float """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds""" audio_prefix_duration: float """The audio prefix duration for the interruption detection, defaults to 1.0 seconds""" detection_interval: float """The interval between detections, defaults to 0.1 seconds""" inference_timeout: float """The timeout for the interruption detection, defaults to 1 second""" base_url: str api_key: str api_secret: str use_proxy: bool """Whether to use the inference instead of the hosted API"""Whether to use the inference instead of the hosted API
class InterruptionResponse (**data: Any)-
Expand source code
class InterruptionResponse(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) created_at: int is_bargein: bool prediction_duration: float probabilities: npt.NDArray[np.float32] = Field(..., exclude=True)Usage Documentation
A base class for creating Pydantic models.
Attributes
__class_vars__- The names of the class variables defined on the model.
__private_attributes__- Metadata about the private attributes of the model.
__signature__- The synthesized
__init__[Signature][inspect.Signature] of the model. __pydantic_complete__- Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__- The core schema of the model.
__pydantic_custom_init__- Whether the model has a custom
__init__function. __pydantic_decorators__- Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. __pydantic_generic_metadata__- Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__- Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__- The name of the post-init method for the model, if defined.
__pydantic_root_model__- Whether the model is a [
RootModel][pydantic.root_model.RootModel]. __pydantic_serializer__- The
pydantic-coreSchemaSerializerused to dump instances of the model. __pydantic_validator__- The
pydantic-coreSchemaValidatorused to validate instances of the model. __pydantic_fields__- A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. __pydantic_computed_fields__- A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. __pydantic_extra__- A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. __pydantic_fields_set__- The names of fields explicitly set during instantiation.
__pydantic_private__- Values of private attributes set on the model instance.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.Ancestors
- pydantic.main.BaseModel
Class variables
var created_at : intvar is_bargein : boolvar model_configvar prediction_duration : floatvar probabilities : numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.float32]]
class InterruptionStreamBase (*,
model: AdaptiveInterruptionDetector,
conn_options: APIConnectOptions)-
Expand source code
class InterruptionStreamBase(ABC): def __init__( self, *, model: AdaptiveInterruptionDetector, conn_options: APIConnectOptions ) -> None: self._model = model self._opts = model._opts self._session = model._ensure_session() self._input_ch = aio.Chan[InterruptionDataFrameType]() self._event_ch = aio.Chan[OverlappingSpeechEvent]() self._audio_buffer = AudioArrayBuffer( buffer_size=int(self._opts.max_audio_duration * self._opts.sample_rate), dtype=np.int16, sample_rate=self._opts.sample_rate, ) self._cache = BoundedDict[int, InterruptionCacheEntry](maxsize=10) self._tee_aiter = aio.itertools.tee(self._event_ch, 2) self._event_aiter, monitor_aiter = self._tee_aiter self._metrics_task = asyncio.create_task( self._metrics_monitor_task(monitor_aiter), name="InterruptionStreamBase._metrics_task" ) self._task = asyncio.create_task(self._main_task()) self._task.add_done_callback(lambda _: self._event_ch.close()) self._num_retries = 0 self._conn_options = conn_options self._sample_rate = self._opts.sample_rate self._overlap_started_at: float | None = None self._user_speech_span: trace.Span | None = None self._agent_speech_started: bool = False self._overlap_started: bool = False self._overlap_count: int = 0 self._accumulated_samples: int = 0 self._num_requests = aio.AsyncAtomicCounter(initial=0) self._batch_size: int = int(self._opts.detection_interval * self._opts.sample_rate) self._prefix_size: int = int(self._opts.audio_prefix_duration * self._opts.sample_rate) @abstractmethod async def _run(self) -> None: ... @log_exceptions(logger=logger) async def _main_task(self) -> None: max_retries = self._conn_options.max_retry while self._num_retries <= max_retries: try: return await self._run() except APIError as e: if max_retries == 0 or not e.retryable: self._emit_error(e, recoverable=False) raise elif self._num_retries == max_retries: self._emit_error(e, recoverable=False) raise APIConnectionError( f"failed to detect interruption after {self._num_retries} attempts", ) from e else: self._emit_error(e, recoverable=True) retry_interval = self._conn_options._interval_for_retry(self._num_retries) logger.warning( f"failed to detect interruption, retrying in {retry_interval}s", exc_info=e, extra={ "model": self._model._label, "attempt": self._num_retries, }, ) await asyncio.sleep(retry_interval) self._num_retries += 1 except Exception as e: self._emit_error(e, recoverable=False) raise def _emit_error(self, api_error: Exception, recoverable: bool) -> None: self._model._emit_error(api_error, recoverable) def push_frame(self, frame: InterruptionDataFrameType) -> None: """Push some audio frame to be analyzed""" self._check_input_not_ended() self._check_not_closed() self._input_ch.send_nowait(frame) def flush(self) -> None: """Mark the end of the current segment""" self._check_input_not_ended() self._check_not_closed() self._input_ch.send_nowait(_FlushSentinel()) def end_input(self) -> None: """Mark the end of input, no more audio will be pushed""" self.flush() self._input_ch.close() async def aclose(self) -> None: """Close the stream immediately""" self._input_ch.close() await aio.cancel_and_wait(self._task) self._event_ch.close() try: await self._metrics_task finally: await self._tee_aiter.aclose() async def __anext__(self) -> OverlappingSpeechEvent: try: val = await self._event_aiter.__anext__() except StopAsyncIteration: if not self._task.cancelled() and (exc := self._task.exception()): raise exc # noqa: B904 raise StopAsyncIteration from None return val def __aiter__(self) -> AsyncIterator[OverlappingSpeechEvent]: return self def _check_not_closed(self) -> None: if self._event_ch.closed: cls = type(self) raise RuntimeError(f"{cls.__module__}.{cls.__name__} is closed") def _check_input_not_ended(self) -> None: if self._input_ch.closed: cls = type(self) raise RuntimeError(f"{cls.__module__}.{cls.__name__} input ended") @staticmethod def _update_user_speech_span( user_speech_span: trace.Span, entry: InterruptionCacheEntry ) -> None: user_speech_span.set_attribute( trace_types.ATTR_IS_INTERRUPTION, str(entry.is_interruption).lower() ) user_speech_span.set_attribute( trace_types.ATTR_INTERRUPTION_PROBABILITY, entry.get_probability() ) user_speech_span.set_attribute( trace_types.ATTR_INTERRUPTION_TOTAL_DURATION, entry.get_total_duration() ) user_speech_span.set_attribute( trace_types.ATTR_INTERRUPTION_PREDICTION_DURATION, entry.get_prediction_duration() ) user_speech_span.set_attribute( trace_types.ATTR_INTERRUPTION_DETECTION_DELAY, entry.get_detection_delay() ) async def _forward_data(self, output_ch: aio.Chan[npt.NDArray[np.int16]]) -> None: """Preprocess the audio data and forward it to the output channel for inference.""" async def _reset_state() -> None: self._agent_speech_started = False self._overlap_started = False self._overlap_count = 0 self._accumulated_samples = 0 await self._num_requests.set(0) self._audio_buffer.reset() self._cache.clear() self._user_speech_span = None async for input_frame in self._input_ch: match input_frame: case _FlushSentinel(): continue case _AgentSpeechStartedSentinel() | _AgentSpeechEndedSentinel(): await _reset_state() self._agent_speech_started = isinstance( input_frame, _AgentSpeechStartedSentinel ) continue case _OverlapSpeechStartedSentinel() if self._agent_speech_started: self._overlap_started_at = input_frame._started_at self._user_speech_span = input_frame._user_speaking_span self._overlap_started = True self._accumulated_samples = 0 self._overlap_count += 1 # include the audio prefix in the window and # only shift (remove leading silence) when the first overlap speech started # otherwise, keep the existing data if self._overlap_count == 1: shift_size = max( 0, len(self._audio_buffer) - ( int(input_frame._speech_duration * self._sample_rate) + self._prefix_size ), ) self._audio_buffer.shift(shift_size) logger.trace( "overlap speech started, starting interruption inference", extra={ "overlap_count": self._overlap_count, }, ) self._cache.clear() continue case _OverlapSpeechEndedSentinel(): if self._overlap_started and self._overlap_started_at is not None: logger.trace("overlap speech ended, stopping interruption inference") self._user_speech_span = None _, last_entry = self._cache.pop_if( lambda entry: ( entry.total_duration is not None and entry.total_duration > 0 ) ) if last_entry is None: logger.trace("no request made for overlap speech") ev = OverlappingSpeechEvent.from_cache_entry( entry=last_entry or _EMPTY_CACHE_ENTRY, is_interruption=False, started_at=self._overlap_started_at, ended_at=input_frame._ended_at, ) ev.num_requests = await self._num_requests.get_and_reset() self.send(ev) self._overlap_started = False self._accumulated_samples = 0 self._overlap_started_at = None # we don't clear the cache here since responses might be in flight case rtc.AudioFrame() if self._agent_speech_started: samples_written = self._audio_buffer.push_frame(input_frame) self._accumulated_samples += samples_written if self._accumulated_samples >= self._batch_size and self._overlap_started: output_ch.send_nowait(self._audio_buffer.read()) self._accumulated_samples = 0 output_ch.close() def send(self, event: OverlappingSpeechEvent) -> None: self._event_ch.send_nowait(event) self._model.emit(event.type, event) @utils.log_exceptions(logger=logger) async def _metrics_monitor_task( self, event_aiter: AsyncIterable[OverlappingSpeechEvent] ) -> None: async for ev in event_aiter: metrics = InterruptionMetrics( timestamp=time.time(), total_duration=ev.total_duration, prediction_duration=ev.prediction_duration, detection_delay=ev.detection_delay, num_interruptions=1 if ev.is_interruption else 0, num_backchannels=1 if not ev.is_interruption else 0, num_requests=ev.num_requests, metadata=Metadata( model_name=self._model.model, model_provider=self._model.provider ), ) self._model.emit("metrics_collected", metrics)Helper class that provides a standard way to create an ABC using inheritance.
Ancestors
- abc.ABC
Subclasses
Methods
async def aclose(self) ‑> None-
Expand source code
async def aclose(self) -> None: """Close the stream immediately""" self._input_ch.close() await aio.cancel_and_wait(self._task) self._event_ch.close() try: await self._metrics_task finally: await self._tee_aiter.aclose()Close the stream immediately
def end_input(self) ‑> None-
Expand source code
def end_input(self) -> None: """Mark the end of input, no more audio will be pushed""" self.flush() self._input_ch.close()Mark the end of input, no more audio will be pushed
def flush(self) ‑> None-
Expand source code
def flush(self) -> None: """Mark the end of the current segment""" self._check_input_not_ended() self._check_not_closed() self._input_ch.send_nowait(_FlushSentinel())Mark the end of the current segment
def push_frame(self, frame: InterruptionDataFrameType) ‑> None-
Expand source code
def push_frame(self, frame: InterruptionDataFrameType) -> None: """Push some audio frame to be analyzed""" self._check_input_not_ended() self._check_not_closed() self._input_ch.send_nowait(frame)Push some audio frame to be analyzed
def send(self,
event: OverlappingSpeechEvent) ‑> None-
Expand source code
def send(self, event: OverlappingSpeechEvent) -> None: self._event_ch.send_nowait(event) self._model.emit(event.type, event)
class InterruptionWSDetectedMessage (**data: Any)-
Expand source code
class InterruptionWSDetectedMessage(BaseModel): type: Literal[InterruptionWSMessageType.INTERRUPTION_DETECTED] = ( InterruptionWSMessageType.INTERRUPTION_DETECTED ) created_at: int prediction_duration: float = Field(default=0.0) probabilities: list[float] = Field(default_factory=list)Usage Documentation
A base class for creating Pydantic models.
Attributes
__class_vars__- The names of the class variables defined on the model.
__private_attributes__- Metadata about the private attributes of the model.
__signature__- The synthesized
__init__[Signature][inspect.Signature] of the model. __pydantic_complete__- Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__- The core schema of the model.
__pydantic_custom_init__- Whether the model has a custom
__init__function. __pydantic_decorators__- Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. __pydantic_generic_metadata__- Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__- Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__- The name of the post-init method for the model, if defined.
__pydantic_root_model__- Whether the model is a [
RootModel][pydantic.root_model.RootModel]. __pydantic_serializer__- The
pydantic-coreSchemaSerializerused to dump instances of the model. __pydantic_validator__- The
pydantic-coreSchemaValidatorused to validate instances of the model. __pydantic_fields__- A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. __pydantic_computed_fields__- A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. __pydantic_extra__- A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. __pydantic_fields_set__- The names of fields explicitly set during instantiation.
__pydantic_private__- Values of private attributes set on the model instance.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.Ancestors
- pydantic.main.BaseModel
Class variables
var created_at : intvar model_configvar prediction_duration : floatvar probabilities : list[float]var type : Literal[<InterruptionWSMessageType.INTERRUPTION_DETECTED: 'bargein_detected'>]
class InterruptionWSErrorMessage (**data: Any)-
Expand source code
class InterruptionWSErrorMessage(BaseModel): type: Literal[InterruptionWSMessageType.ERROR] = InterruptionWSMessageType.ERROR message: str code: int session_id: strUsage Documentation
A base class for creating Pydantic models.
Attributes
__class_vars__- The names of the class variables defined on the model.
__private_attributes__- Metadata about the private attributes of the model.
__signature__- The synthesized
__init__[Signature][inspect.Signature] of the model. __pydantic_complete__- Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__- The core schema of the model.
__pydantic_custom_init__- Whether the model has a custom
__init__function. __pydantic_decorators__- Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. __pydantic_generic_metadata__- Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__- Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__- The name of the post-init method for the model, if defined.
__pydantic_root_model__- Whether the model is a [
RootModel][pydantic.root_model.RootModel]. __pydantic_serializer__- The
pydantic-coreSchemaSerializerused to dump instances of the model. __pydantic_validator__- The
pydantic-coreSchemaValidatorused to validate instances of the model. __pydantic_fields__- A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. __pydantic_computed_fields__- A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. __pydantic_extra__- A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. __pydantic_fields_set__- The names of fields explicitly set during instantiation.
__pydantic_private__- Values of private attributes set on the model instance.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.Ancestors
- pydantic.main.BaseModel
Class variables
var code : intvar message : strvar model_configvar session_id : strvar type : Literal[<InterruptionWSMessageType.ERROR: 'error'>]
class InterruptionWSInferenceDoneMessage (**data: Any)-
Expand source code
class InterruptionWSInferenceDoneMessage(BaseModel): type: Literal[InterruptionWSMessageType.INFERENCE_DONE] = ( InterruptionWSMessageType.INFERENCE_DONE ) created_at: int prediction_duration: float = Field(default=0.0) probabilities: list[float] = Field(default_factory=list)Usage Documentation
A base class for creating Pydantic models.
Attributes
__class_vars__- The names of the class variables defined on the model.
__private_attributes__- Metadata about the private attributes of the model.
__signature__- The synthesized
__init__[Signature][inspect.Signature] of the model. __pydantic_complete__- Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__- The core schema of the model.
__pydantic_custom_init__- Whether the model has a custom
__init__function. __pydantic_decorators__- Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. __pydantic_generic_metadata__- Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__- Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__- The name of the post-init method for the model, if defined.
__pydantic_root_model__- Whether the model is a [
RootModel][pydantic.root_model.RootModel]. __pydantic_serializer__- The
pydantic-coreSchemaSerializerused to dump instances of the model. __pydantic_validator__- The
pydantic-coreSchemaValidatorused to validate instances of the model. __pydantic_fields__- A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. __pydantic_computed_fields__- A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. __pydantic_extra__- A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. __pydantic_fields_set__- The names of fields explicitly set during instantiation.
__pydantic_private__- Values of private attributes set on the model instance.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.Ancestors
- pydantic.main.BaseModel
Class variables
var created_at : intvar model_configvar prediction_duration : floatvar probabilities : list[float]var type : Literal[<InterruptionWSMessageType.INFERENCE_DONE: 'inference_done'>]
class InterruptionWSMessageType (*args, **kwds)-
Expand source code
class InterruptionWSMessageType(str, Enum): SESSION_CREATE = "session.create" SESSION_CLOSE = "session.close" SESSION_CREATED = "session.created" SESSION_CLOSED = "session.closed" INTERRUPTION_DETECTED = "bargein_detected" INFERENCE_DONE = "inference_done" ERROR = "error"str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Ancestors
- builtins.str
- enum.Enum
Class variables
var ERRORvar INFERENCE_DONEvar INTERRUPTION_DETECTEDvar SESSION_CLOSEvar SESSION_CLOSEDvar SESSION_CREATEvar SESSION_CREATED
class InterruptionWSSessionCloseMessage (**data: Any)-
Expand source code
class InterruptionWSSessionCloseMessage(BaseModel): type: Literal[InterruptionWSMessageType.SESSION_CLOSE] = InterruptionWSMessageType.SESSION_CLOSEUsage Documentation
A base class for creating Pydantic models.
Attributes
__class_vars__- The names of the class variables defined on the model.
__private_attributes__- Metadata about the private attributes of the model.
__signature__- The synthesized
__init__[Signature][inspect.Signature] of the model. __pydantic_complete__- Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__- The core schema of the model.
__pydantic_custom_init__- Whether the model has a custom
__init__function. __pydantic_decorators__- Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. __pydantic_generic_metadata__- Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__- Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__- The name of the post-init method for the model, if defined.
__pydantic_root_model__- Whether the model is a [
RootModel][pydantic.root_model.RootModel]. __pydantic_serializer__- The
pydantic-coreSchemaSerializerused to dump instances of the model. __pydantic_validator__- The
pydantic-coreSchemaValidatorused to validate instances of the model. __pydantic_fields__- A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. __pydantic_computed_fields__- A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. __pydantic_extra__- A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. __pydantic_fields_set__- The names of fields explicitly set during instantiation.
__pydantic_private__- Values of private attributes set on the model instance.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.Ancestors
- pydantic.main.BaseModel
Class variables
var model_configvar type : Literal[<InterruptionWSMessageType.SESSION_CLOSE: 'session.close'>]
class InterruptionWSSessionClosedMessage (**data: Any)-
Expand source code
class InterruptionWSSessionClosedMessage(BaseModel): type: Literal[InterruptionWSMessageType.SESSION_CLOSED] = ( InterruptionWSMessageType.SESSION_CLOSED )Usage Documentation
A base class for creating Pydantic models.
Attributes
__class_vars__- The names of the class variables defined on the model.
__private_attributes__- Metadata about the private attributes of the model.
__signature__- The synthesized
__init__[Signature][inspect.Signature] of the model. __pydantic_complete__- Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__- The core schema of the model.
__pydantic_custom_init__- Whether the model has a custom
__init__function. __pydantic_decorators__- Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. __pydantic_generic_metadata__- Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__- Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__- The name of the post-init method for the model, if defined.
__pydantic_root_model__- Whether the model is a [
RootModel][pydantic.root_model.RootModel]. __pydantic_serializer__- The
pydantic-coreSchemaSerializerused to dump instances of the model. __pydantic_validator__- The
pydantic-coreSchemaValidatorused to validate instances of the model. __pydantic_fields__- A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. __pydantic_computed_fields__- A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. __pydantic_extra__- A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. __pydantic_fields_set__- The names of fields explicitly set during instantiation.
__pydantic_private__- Values of private attributes set on the model instance.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.Ancestors
- pydantic.main.BaseModel
Class variables
var model_configvar type : Literal[<InterruptionWSMessageType.SESSION_CLOSED: 'session.closed'>]
class InterruptionWSSessionCreateMessage (**data: Any)-
Expand source code
class InterruptionWSSessionCreateMessage(BaseModel): type: Literal[InterruptionWSMessageType.SESSION_CREATE] = ( InterruptionWSMessageType.SESSION_CREATE ) settings: InterruptionWSSessionCreateSettingsUsage Documentation
A base class for creating Pydantic models.
Attributes
__class_vars__- The names of the class variables defined on the model.
__private_attributes__- Metadata about the private attributes of the model.
__signature__- The synthesized
__init__[Signature][inspect.Signature] of the model. __pydantic_complete__- Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__- The core schema of the model.
__pydantic_custom_init__- Whether the model has a custom
__init__function. __pydantic_decorators__- Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. __pydantic_generic_metadata__- Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__- Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__- The name of the post-init method for the model, if defined.
__pydantic_root_model__- Whether the model is a [
RootModel][pydantic.root_model.RootModel]. __pydantic_serializer__- The
pydantic-coreSchemaSerializerused to dump instances of the model. __pydantic_validator__- The
pydantic-coreSchemaValidatorused to validate instances of the model. __pydantic_fields__- A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. __pydantic_computed_fields__- A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. __pydantic_extra__- A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. __pydantic_fields_set__- The names of fields explicitly set during instantiation.
__pydantic_private__- Values of private attributes set on the model instance.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.Ancestors
- pydantic.main.BaseModel
Class variables
var model_configvar settings : InterruptionWSSessionCreateSettingsvar type : Literal[<InterruptionWSMessageType.SESSION_CREATE: 'session.create'>]
class InterruptionWSSessionCreateSettings (**data: Any)-
Expand source code
class InterruptionWSSessionCreateSettings(BaseModel): sample_rate: int num_channels: int threshold: float min_frames: int encoding: Literal["s16le"]Usage Documentation
A base class for creating Pydantic models.
Attributes
__class_vars__- The names of the class variables defined on the model.
__private_attributes__- Metadata about the private attributes of the model.
__signature__- The synthesized
__init__[Signature][inspect.Signature] of the model. __pydantic_complete__- Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__- The core schema of the model.
__pydantic_custom_init__- Whether the model has a custom
__init__function. __pydantic_decorators__- Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. __pydantic_generic_metadata__- Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__- Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__- The name of the post-init method for the model, if defined.
__pydantic_root_model__- Whether the model is a [
RootModel][pydantic.root_model.RootModel]. __pydantic_serializer__- The
pydantic-coreSchemaSerializerused to dump instances of the model. __pydantic_validator__- The
pydantic-coreSchemaValidatorused to validate instances of the model. __pydantic_fields__- A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. __pydantic_computed_fields__- A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. __pydantic_extra__- A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. __pydantic_fields_set__- The names of fields explicitly set during instantiation.
__pydantic_private__- Values of private attributes set on the model instance.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.Ancestors
- pydantic.main.BaseModel
Class variables
var encoding : Literal['s16le']var min_frames : intvar model_configvar num_channels : intvar sample_rate : intvar threshold : float
class InterruptionWSSessionCreatedMessage (**data: Any)-
Expand source code
class InterruptionWSSessionCreatedMessage(BaseModel): type: Literal[InterruptionWSMessageType.SESSION_CREATED] = ( InterruptionWSMessageType.SESSION_CREATED )Usage Documentation
A base class for creating Pydantic models.
Attributes
__class_vars__- The names of the class variables defined on the model.
__private_attributes__- Metadata about the private attributes of the model.
__signature__- The synthesized
__init__[Signature][inspect.Signature] of the model. __pydantic_complete__- Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__- The core schema of the model.
__pydantic_custom_init__- Whether the model has a custom
__init__function. __pydantic_decorators__- Metadata containing the decorators defined on the model.
This replaces
Model.__validators__andModel.__root_validators__from Pydantic V1. __pydantic_generic_metadata__- Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__- Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__- The name of the post-init method for the model, if defined.
__pydantic_root_model__- Whether the model is a [
RootModel][pydantic.root_model.RootModel]. __pydantic_serializer__- The
pydantic-coreSchemaSerializerused to dump instances of the model. __pydantic_validator__- The
pydantic-coreSchemaValidatorused to validate instances of the model. __pydantic_fields__- A dictionary of field names and their corresponding [
FieldInfo][pydantic.fields.FieldInfo] objects. __pydantic_computed_fields__- A dictionary of computed field names and their corresponding [
ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects. __pydantic_extra__- A dictionary containing extra values, if [
extra][pydantic.config.ConfigDict.extra] is set to'allow'. __pydantic_fields_set__- The names of fields explicitly set during instantiation.
__pydantic_private__- Values of private attributes set on the model instance.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.Ancestors
- pydantic.main.BaseModel
Class variables
var model_configvar type : Literal[<InterruptionWSMessageType.SESSION_CREATED: 'session.created'>]
class InterruptionWebSocketStream (*,
model: AdaptiveInterruptionDetector,
conn_options: APIConnectOptions)-
Expand source code
class InterruptionWebSocketStream(InterruptionStreamBase): def __init__( self, *, model: AdaptiveInterruptionDetector, conn_options: APIConnectOptions ) -> None: super().__init__(model=model, conn_options=conn_options) self._request_id = str(shortuuid("interruption_request_")) self._reconnect_event = asyncio.Event() def update_options( self, *, threshold: NotGivenOr[float] = NOT_GIVEN, min_interruption_duration: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(threshold): self._opts.threshold = threshold if is_given(min_interruption_duration): self._opts.min_frames = math.ceil(min_interruption_duration * _FRAMES_PER_SECOND) self._reconnect_event.set() async def _run(self) -> None: closing_ws = False async def send_task( ws: aiohttp.ClientWebSocketResponse, input_ch: aio.Chan[npt.NDArray[np.int16]] ) -> None: nonlocal closing_ws timeout_ns = int(self._opts.inference_timeout * 1e9) async for audio_data in input_ch: now = perf_counter_ns() for _key, entry in self._cache.items(): if entry.total_duration is not None: continue if now - entry.created_at > timeout_ns: raise APIStatusError( f"interruption inference timed out after " f"{(now - entry.created_at) / 1e9:.1f}s (ws)", status_code=408, retryable=False, ) break # oldest unanswered entry is still within timeout await self._num_requests.increment() created_at = perf_counter_ns() header = struct.pack("<Q", created_at) # 8 bytes await ws.send_bytes(header + audio_data.tobytes()) self._cache[created_at] = InterruptionCacheEntry( created_at=created_at, speech_input=audio_data, ) closing_ws = True msg = InterruptionWSSessionCloseMessage( type=InterruptionWSMessageType.SESSION_CLOSE, ) await ws.send_str(msg.model_dump_json()) async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None: nonlocal closing_ws while True: ws_msg = await ws.receive() if ws_msg.type in ( aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, ): if closing_ws or self._session.closed: return raise APIStatusError( message=f"LiveKit Adaptive Interruption connection closed unexpectedly: {ws_msg.data}", status_code=ws.close_code or -1, body=f"{ws_msg.data=} {ws_msg.extra=}", ) if ws_msg.type != aiohttp.WSMsgType.TEXT: logger.warning( "unexpected LiveKit Adaptive Interruption message type %s", ws_msg.type ) continue data = json.loads(ws_msg.data) msg: AnyInterruptionWSMessage = InterruptionWSMessage.validate_python(data) match msg: case ( InterruptionWSSessionCreatedMessage() | InterruptionWSSessionClosedMessage() ): pass case InterruptionWSDetectedMessage(): created_at = msg.created_at if ( overlap_started_at := self._overlap_started_at ) is not None and self._overlap_started: entry = self._cache.set_or_update( created_at, lambda c=created_at: InterruptionCacheEntry(created_at=c), # type: ignore[misc] total_duration=(perf_counter_ns() - created_at) / 1e9, probabilities=np.array(msg.probabilities, dtype=np.float32), is_interruption=True, prediction_duration=msg.prediction_duration, detection_delay=time.time() - overlap_started_at, ) if self._user_speech_span: self._update_user_speech_span(self._user_speech_span, entry) self._user_speech_span = None logger.debug( "interruption detected", extra={ "total_duration": entry.get_total_duration(), "prediction_duration": entry.get_prediction_duration(), "detection_delay": entry.get_detection_delay(), "probability": entry.get_probability(), }, ) ev = OverlappingSpeechEvent.from_cache_entry( entry=entry, is_interruption=True, started_at=overlap_started_at, ended_at=time.time(), ) ev.num_requests = await self._num_requests.get_and_reset() self.send(ev) self._overlap_started = False case InterruptionWSInferenceDoneMessage(): created_at = msg.created_at if ( overlap_started_at := self._overlap_started_at ) is not None and self._overlap_started: entry = self._cache.set_or_update( created_at, lambda c=created_at: InterruptionCacheEntry(created_at=c), # type: ignore[misc] total_duration=(perf_counter_ns() - created_at) / 1e9, prediction_duration=msg.prediction_duration, probabilities=np.array(msg.probabilities, dtype=np.float32), is_interruption=False, detection_delay=time.time() - overlap_started_at, ) logger.trace( "interruption inference done", extra={ "total_duration": entry.get_total_duration(), "prediction_duration": entry.get_prediction_duration(), "probability": entry.get_probability(), }, ) case InterruptionWSErrorMessage(): raise APIStatusError( f"LiveKit Adaptive Interruption returned error: {msg.code}", body=msg.message, status_code=msg.code, ) case _: logger.warning( "received unexpected message from LiveKit Adaptive Interruption: %s", data, ) ws: aiohttp.ClientWebSocketResponse | None = None while True: data_ch = aio.Chan[npt.NDArray[np.int16]]() try: closing_ws = False ws = await self._connect_ws() tasks = [ asyncio.create_task(self._forward_data(data_ch)), asyncio.create_task(send_task(ws, data_ch)), asyncio.create_task(recv_task(ws)), ] tasks_group = asyncio.gather(*tasks) wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) try: done, _ = await asyncio.wait( (tasks_group, wait_reconnect_task), return_when=asyncio.FIRST_COMPLETED, ) for task in done: if task != wait_reconnect_task: task.result() if wait_reconnect_task not in done: break self._reconnect_event.clear() finally: closing_ws = True if ws is not None and not ws.closed: await ws.close() ws = None await aio.gracefully_cancel(*tasks, wait_reconnect_task) tasks_group.cancel() try: tasks_group.exception() except asyncio.CancelledError: pass finally: closing_ws = True if ws is not None and not ws.closed: await ws.close() async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: """Connect to the LiveKit Adaptive Interruption WebSocket.""" settings = InterruptionWSSessionCreateSettings( sample_rate=self._opts.sample_rate, num_channels=1, threshold=self._opts.threshold, min_frames=self._model._opts.min_frames, encoding="s16le", ) base_url = self._opts.base_url if base_url.startswith(("http://", "https://")): base_url = base_url.replace("http", "ws", 1) headers = { **get_inference_headers(), "Authorization": f"Bearer {create_access_token(self._opts.api_key, self._opts.api_secret)}", } try: ws = await asyncio.wait_for( self._session.ws_connect(f"{base_url}/bargein", headers=headers), self._conn_options.timeout, ) except ( aiohttp.ClientConnectorError, asyncio.TimeoutError, aiohttp.ClientResponseError, ) as e: if isinstance(e, aiohttp.ClientResponseError) and e.status == 429: raise APIStatusError( "LiveKit Adaptive Interruption quota exceeded", status_code=e.status, retryable=False, ) from e elif isinstance(e, asyncio.TimeoutError): raise APIConnectionError( "failed to connect to LiveKit Adaptive Interruption: timeout", retryable=False, ) from e raise APIConnectionError("failed to connect to LiveKit Adaptive Interruption") from e try: msg = InterruptionWSSessionCreateMessage( type=InterruptionWSMessageType.SESSION_CREATE, settings=settings, ) await ws.send_str(msg.model_dump_json()) except Exception as e: await ws.close() raise APIConnectionError( "failed to send session.create message to LiveKit Adaptive Interruption" ) from e return wsHelper class that provides a standard way to create an ABC using inheritance.
Ancestors
- InterruptionStreamBase
- abc.ABC
Methods
def update_options(self,
*,
threshold: NotGivenOr[float] = NOT_GIVEN,
min_interruption_duration: NotGivenOr[float] = NOT_GIVEN) ‑> None-
Expand source code
def update_options( self, *, threshold: NotGivenOr[float] = NOT_GIVEN, min_interruption_duration: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(threshold): self._opts.threshold = threshold if is_given(min_interruption_duration): self._opts.min_frames = math.ceil(min_interruption_duration * _FRAMES_PER_SECOND) self._reconnect_event.set()
Inherited members
class OverlappingSpeechEvent (**data: Any)-
Expand source code
class OverlappingSpeechEvent(BaseModel): """Represents an overlapping speech event detected during agent speech.""" model_config = ConfigDict(arbitrary_types_allowed=True) type: Literal["overlapping_speech"] = "overlapping_speech" created_at: float = Field(default_factory=time.time) """Timestamp (in seconds) when the event was emitted.""" detected_at: float = Field(default_factory=time.time) """Timestamp (in seconds) when the overlap was detected.""" is_interruption: bool = False """Whether interruption is detected.""" total_duration: float = 0.0 """RTT (Round Trip Time) time taken to perform the inference, in seconds.""" prediction_duration: float = 0.0 """Time taken to perform the inference from the model side, in seconds.""" detection_delay: float = 0.0 """Total time from the onset of the speech to the final prediction, in seconds.""" overlap_started_at: float | None = None """Timestamp (in seconds) when the overlap speech started. Useful for emitting held transcripts.""" speech_input: npt.NDArray[np.int16] | None = None """The audio input that was used for the inference.""" probabilities: npt.NDArray[np.float32] | None = None """The raw probabilities for the interruption detection.""" probability: float = 0.0 """The conservative estimated probability of the interruption event.""" num_requests: int = 0 """Number of requests sent for this event.""" @model_serializer(mode="wrap") def serialize_model(self, handler: SerializerFunctionWrapHandler) -> Any: # remove numpy arrays from the model dump copy = self.model_copy(deep=True) data = copy.speech_input, copy.probabilities copy.speech_input, copy.probabilities = None, None try: serialized = handler(copy) finally: copy.speech_input, copy.probabilities = data return serialized @classmethod def from_cache_entry( cls, *, entry: InterruptionCacheEntry, is_interruption: bool, started_at: float | None = None, ended_at: float | None = None, ) -> OverlappingSpeechEvent: """Initialize the event from a cache entry. Args: entry: The cache entry to initialize the event from. is_interruption: Whether the interruption is detected. started_at: The timestamp when the overlap speech started. ended_at: The timestamp when the overlap speech ended. Returns: The initialized event. """ return cls( type="overlapping_speech", detected_at=ended_at or time.time(), is_interruption=is_interruption, overlap_started_at=started_at, speech_input=entry.speech_input, probabilities=entry.probabilities, total_duration=entry.get_total_duration(), detection_delay=entry.get_detection_delay(), prediction_duration=entry.get_prediction_duration(), probability=entry.get_probability(), )Represents an overlapping speech event detected during agent speech.
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.Ancestors
- pydantic.main.BaseModel
Class variables
var created_at : float-
Timestamp (in seconds) when the event was emitted.
var detected_at : float-
Timestamp (in seconds) when the overlap was detected.
var detection_delay : float-
Total time from the onset of the speech to the final prediction, in seconds.
var is_interruption : bool-
Whether interruption is detected.
var model_configvar num_requests : int-
Number of requests sent for this event.
var overlap_started_at : float | None-
Timestamp (in seconds) when the overlap speech started. Useful for emitting held transcripts.
var prediction_duration : float-
Time taken to perform the inference from the model side, in seconds.
var probabilities : numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.float32]] | None-
The raw probabilities for the interruption detection.
var probability : float-
The conservative estimated probability of the interruption event.
var speech_input : numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.int16]] | None-
The audio input that was used for the inference.
var total_duration : float-
RTT (Round Trip Time) time taken to perform the inference, in seconds.
var type : Literal['overlapping_speech']
Static methods
def from_cache_entry(*,
entry: InterruptionCacheEntry,
is_interruption: bool,
started_at: float | None = None,
ended_at: float | None = None) ‑> OverlappingSpeechEvent-
Initialize the event from a cache entry.
Args
entry- The cache entry to initialize the event from.
is_interruption- Whether the interruption is detected.
started_at- The timestamp when the overlap speech started.
ended_at- The timestamp when the overlap speech ended.
Returns
The initialized event.
Methods
def serialize_model(self, handler: SerializerFunctionWrapHandler) ‑> Any-
Expand source code
@model_serializer(mode="wrap") def serialize_model(self, handler: SerializerFunctionWrapHandler) -> Any: # remove numpy arrays from the model dump copy = self.model_copy(deep=True) data = copy.speech_input, copy.probabilities copy.speech_input, copy.probabilities = None, None try: serialized = handler(copy) finally: copy.speech_input, copy.probabilities = data return serialized