Module livekit.agents.inference.interruption

Classes

class AdaptiveInterruptionDetector (*,
threshold: float = 0.5,
min_interruption_duration: float = 0.05,
max_audio_duration: float = 3,
audio_prefix_duration: float = 1.0,
detection_interval: float = 0.1,
inference_timeout: float = 0.7,
base_url: str | None = None,
api_key: str | None = None,
api_secret: str | None = None,
http_session: aiohttp.ClientSession | None = None)
Expand source code
class AdaptiveInterruptionDetector(
    rtc.EventEmitter[
        Literal[
            "overlapping_speech",
            "error",
            "metrics_collected",
        ]
    ],
):
    def __init__(
        self,
        *,
        threshold: float = THRESHOLD,
        min_interruption_duration: float = MIN_INTERRUPTION_DURATION,
        max_audio_duration: float = MAX_AUDIO_DURATION,
        audio_prefix_duration: float = AUDIO_PREFIX_DURATION,
        detection_interval: float = DETECTION_INTERVAL,
        inference_timeout: float = REMOTE_INFERENCE_TIMEOUT,
        base_url: str | None = None,
        api_key: str | None = None,
        api_secret: str | None = None,
        http_session: aiohttp.ClientSession | None = None,
    ) -> None:
        """
        Initialize a AdaptiveInterruptionDetector instance.

        Args:
            threshold (float, optional): The threshold for the interruption detection, defaults to 0.5.
            min_interruption_duration (float, optional): The minimum duration, in seconds, of the interruption event, defaults to 50ms.
            max_audio_duration (float, optional): The maximum audio duration, including the audio prefix, in seconds, for the interruption detection, defaults to 3s.
            audio_prefix_duration (float, optional): The audio prefix duration, in seconds, for the interruption detection, defaults to 0.5s.
            detection_interval (float, optional): The interval between detections, in seconds, for the interruption detection, defaults to 0.1s.
            inference_timeout (float, optional): The timeout for the interruption detection, defaults to 1 second.
            base_url (str, optional): The base URL for the interruption detection, defaults to the shared LIVEKIT_REMOTE_EOT_URL environment variable.
            api_key (str, optional): The API key for the interruption detection, defaults to the LIVEKIT_INFERENCE_API_KEY environment variable.
            api_secret (str, optional): The API secret for the interruption detection, defaults to the LIVEKIT_INFERENCE_API_SECRET environment variable.
            http_session (aiohttp.ClientSession, optional): The HTTP session to use for the interruption detection.
        """
        super().__init__()
        if max_audio_duration > 3.0:
            raise ValueError("max_audio_duration must be less than or equal to 3.0 seconds")

        lk_base_url = (
            base_url
            if base_url
            else os.getenv("LIVEKIT_REMOTE_EOT_URL", get_default_inference_url())
        )
        lk_api_key: str = api_key if api_key else ""
        lk_api_secret: str = api_secret if api_secret else ""
        # use LiveKit credentials if using the inference service (production or staging)
        is_inference_url = lk_base_url in (DEFAULT_INFERENCE_URL, STAGING_INFERENCE_URL)
        if is_inference_url:
            lk_api_key = (
                api_key
                if api_key
                else os.getenv("LIVEKIT_INFERENCE_API_KEY", os.getenv("LIVEKIT_API_KEY", ""))
            )
            if not lk_api_key:
                raise ValueError(
                    "api_key is required, either as argument or set LIVEKIT_API_KEY environmental variable"
                )

            lk_api_secret = (
                api_secret
                if api_secret
                else os.getenv("LIVEKIT_INFERENCE_API_SECRET", os.getenv("LIVEKIT_API_SECRET", ""))
            )
            if not lk_api_secret:
                raise ValueError(
                    "api_secret is required, either as argument or set LIVEKIT_API_SECRET environmental variable"
                )

            use_proxy = True
        else:
            use_proxy = False

        self._opts = InterruptionOptions(
            sample_rate=SAMPLE_RATE,
            threshold=threshold,
            min_frames=math.ceil(min_interruption_duration * _FRAMES_PER_SECOND),
            max_audio_duration=max_audio_duration,
            audio_prefix_duration=audio_prefix_duration,
            detection_interval=detection_interval,
            inference_timeout=inference_timeout,
            base_url=lk_base_url,
            api_key=lk_api_key,
            api_secret=lk_api_secret,
            use_proxy=use_proxy,
        )
        self._label = f"{type(self).__module__}.{type(self).__name__}"
        self._sample_rate = SAMPLE_RATE
        self._session = http_session
        self._streams = weakref.WeakSet[InterruptionHttpStream | InterruptionWebSocketStream]()

        logger.info(
            "adaptive interruption detector initialized",
            extra={
                "base_url": self._opts.base_url,
                "detection_interval": self._opts.detection_interval,
                "audio_prefix_duration": self._opts.audio_prefix_duration,
                "max_audio_duration": self._opts.max_audio_duration,
                "min_frames": self._opts.min_frames,
                "threshold": self._opts.threshold,
                "inference_timeout": self._opts.inference_timeout,
                "use_proxy": self._opts.use_proxy,
            },
        )

    @property
    def model(self) -> str:
        return "adaptive interruption"

    @property
    def provider(self) -> str:
        return "livekit"

    @property
    def label(self) -> str:
        return self._label

    @property
    def sample_rate(self) -> int:
        return self._sample_rate

    def _emit_error(self, api_error: Exception, recoverable: bool) -> None:
        self.emit(
            "error",
            InterruptionDetectionError(
                label=self._label,
                error=api_error,
                recoverable=recoverable,
            ),
        )

    def _ensure_session(self) -> aiohttp.ClientSession:
        if not self._session:
            self._session = http_context.http_session()
        return self._session

    def stream(
        self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
    ) -> InterruptionHttpStream | InterruptionWebSocketStream:
        try:
            stream: InterruptionHttpStream | InterruptionWebSocketStream
            if self._opts.use_proxy:
                stream = InterruptionWebSocketStream(model=self, conn_options=conn_options)
            else:
                stream = InterruptionHttpStream(model=self, conn_options=conn_options)
        except Exception as e:
            self._emit_error(e, recoverable=False)
            raise
        self._streams.add(stream)
        return stream

    def update_options(
        self,
        *,
        threshold: NotGivenOr[float] = NOT_GIVEN,
        min_interruption_duration: NotGivenOr[float] = NOT_GIVEN,
    ) -> None:
        if is_given(threshold):
            self._opts.threshold = threshold
        if is_given(min_interruption_duration):
            self._opts.min_frames = math.ceil(min_interruption_duration * _FRAMES_PER_SECOND)

        for stream in self._streams:
            stream.update_options(
                threshold=threshold, min_interruption_duration=min_interruption_duration
            )

Abstract base class for generic types.

On Python 3.12 and newer, generic classes implicitly inherit from Generic when they declare a parameter list after the class's name::

class Mapping[KT, VT]:
    def __getitem__(self, key: KT) -> VT:
        ...
    # Etc.

On older versions of Python, however, generic classes have to explicitly inherit from Generic.

After a class has been declared to be generic, it can then be used as follows::

def lookup_name[KT, VT](mapping: Mapping[KT, VT], key: KT, default: VT) -> VT:
    try:
        return mapping[key]
    except KeyError:
        return default

Initialize a AdaptiveInterruptionDetector instance.

Args

threshold : float, optional
The threshold for the interruption detection, defaults to 0.5.
min_interruption_duration : float, optional
The minimum duration, in seconds, of the interruption event, defaults to 50ms.
max_audio_duration : float, optional
The maximum audio duration, including the audio prefix, in seconds, for the interruption detection, defaults to 3s.
audio_prefix_duration : float, optional
The audio prefix duration, in seconds, for the interruption detection, defaults to 0.5s.
detection_interval : float, optional
The interval between detections, in seconds, for the interruption detection, defaults to 0.1s.
inference_timeout : float, optional
The timeout for the interruption detection, defaults to 1 second.
base_url : str, optional
The base URL for the interruption detection, defaults to the shared LIVEKIT_REMOTE_EOT_URL environment variable.
api_key : str, optional
The API key for the interruption detection, defaults to the LIVEKIT_INFERENCE_API_KEY environment variable.
api_secret : str, optional
The API secret for the interruption detection, defaults to the LIVEKIT_INFERENCE_API_SECRET environment variable.
http_session : aiohttp.ClientSession, optional
The HTTP session to use for the interruption detection.

Ancestors

Instance variables

prop label : str
Expand source code
@property
def label(self) -> str:
    return self._label
prop model : str
Expand source code
@property
def model(self) -> str:
    return "adaptive interruption"
prop provider : str
Expand source code
@property
def provider(self) -> str:
    return "livekit"
prop sample_rate : int
Expand source code
@property
def sample_rate(self) -> int:
    return self._sample_rate

Methods

def stream(self,
*,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> InterruptionHttpStream | InterruptionWebSocketStream
Expand source code
def stream(
    self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
) -> InterruptionHttpStream | InterruptionWebSocketStream:
    try:
        stream: InterruptionHttpStream | InterruptionWebSocketStream
        if self._opts.use_proxy:
            stream = InterruptionWebSocketStream(model=self, conn_options=conn_options)
        else:
            stream = InterruptionHttpStream(model=self, conn_options=conn_options)
    except Exception as e:
        self._emit_error(e, recoverable=False)
        raise
    self._streams.add(stream)
    return stream
def update_options(self,
*,
threshold: NotGivenOr[float] = NOT_GIVEN,
min_interruption_duration: NotGivenOr[float] = NOT_GIVEN) ‑> None
Expand source code
def update_options(
    self,
    *,
    threshold: NotGivenOr[float] = NOT_GIVEN,
    min_interruption_duration: NotGivenOr[float] = NOT_GIVEN,
) -> None:
    if is_given(threshold):
        self._opts.threshold = threshold
    if is_given(min_interruption_duration):
        self._opts.min_frames = math.ceil(min_interruption_duration * _FRAMES_PER_SECOND)

    for stream in self._streams:
        stream.update_options(
            threshold=threshold, min_interruption_duration=min_interruption_duration
        )

Inherited members

class InterruptionCacheEntry (*,
created_at: int = <factory>,
speech_input: npt.NDArray[np.int16] | None = None,
total_duration: float | None = None,
prediction_duration: float | None = None,
detection_delay: float | None = None,
probabilities: npt.NDArray[np.float32] | None = None,
is_interruption: bool | None = None)
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionCacheEntry:
    """Typed cache entry for interruption inference results."""

    created_at: int = field(default_factory=time.perf_counter_ns)
    """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation."""
    speech_input: npt.NDArray[np.int16] | None = None
    total_duration: float | None = None
    prediction_duration: float | None = None
    detection_delay: float | None = None
    probabilities: npt.NDArray[np.float32] | None = None
    is_interruption: bool | None = None

    def get_total_duration(self, default: float = 0.0) -> float:
        """RTT (Round Trip Time) time taken to perform the inference, in seconds."""
        return self.total_duration if self.total_duration is not None else default

    def get_prediction_duration(self, default: float = 0.0) -> float:
        """Time taken to perform the inference from the model side, in seconds."""
        return self.prediction_duration if self.prediction_duration is not None else default

    def get_detection_delay(self, default: float = 0.0) -> float:
        """Total time from the onset of the speech to the final prediction, in seconds."""
        return self.detection_delay if self.detection_delay is not None else default

    def get_probability(self, default: float = 0.0) -> float:
        """The conservative estimated probability of the interruption event."""
        return (
            _estimate_probability(self.probabilities) if self.probabilities is not None else default
        )

Typed cache entry for interruption inference results.

Instance variables

var created_at : int
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionCacheEntry:
    """Typed cache entry for interruption inference results."""

    created_at: int = field(default_factory=time.perf_counter_ns)
    """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation."""
    speech_input: npt.NDArray[np.int16] | None = None
    total_duration: float | None = None
    prediction_duration: float | None = None
    detection_delay: float | None = None
    probabilities: npt.NDArray[np.float32] | None = None
    is_interruption: bool | None = None

    def get_total_duration(self, default: float = 0.0) -> float:
        """RTT (Round Trip Time) time taken to perform the inference, in seconds."""
        return self.total_duration if self.total_duration is not None else default

    def get_prediction_duration(self, default: float = 0.0) -> float:
        """Time taken to perform the inference from the model side, in seconds."""
        return self.prediction_duration if self.prediction_duration is not None else default

    def get_detection_delay(self, default: float = 0.0) -> float:
        """Total time from the onset of the speech to the final prediction, in seconds."""
        return self.detection_delay if self.detection_delay is not None else default

    def get_probability(self, default: float = 0.0) -> float:
        """The conservative estimated probability of the interruption event."""
        return (
            _estimate_probability(self.probabilities) if self.probabilities is not None else default
        )

The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation.

var detection_delay : float | None
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionCacheEntry:
    """Typed cache entry for interruption inference results."""

    created_at: int = field(default_factory=time.perf_counter_ns)
    """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation."""
    speech_input: npt.NDArray[np.int16] | None = None
    total_duration: float | None = None
    prediction_duration: float | None = None
    detection_delay: float | None = None
    probabilities: npt.NDArray[np.float32] | None = None
    is_interruption: bool | None = None

    def get_total_duration(self, default: float = 0.0) -> float:
        """RTT (Round Trip Time) time taken to perform the inference, in seconds."""
        return self.total_duration if self.total_duration is not None else default

    def get_prediction_duration(self, default: float = 0.0) -> float:
        """Time taken to perform the inference from the model side, in seconds."""
        return self.prediction_duration if self.prediction_duration is not None else default

    def get_detection_delay(self, default: float = 0.0) -> float:
        """Total time from the onset of the speech to the final prediction, in seconds."""
        return self.detection_delay if self.detection_delay is not None else default

    def get_probability(self, default: float = 0.0) -> float:
        """The conservative estimated probability of the interruption event."""
        return (
            _estimate_probability(self.probabilities) if self.probabilities is not None else default
        )
var is_interruption : bool | None
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionCacheEntry:
    """Typed cache entry for interruption inference results."""

    created_at: int = field(default_factory=time.perf_counter_ns)
    """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation."""
    speech_input: npt.NDArray[np.int16] | None = None
    total_duration: float | None = None
    prediction_duration: float | None = None
    detection_delay: float | None = None
    probabilities: npt.NDArray[np.float32] | None = None
    is_interruption: bool | None = None

    def get_total_duration(self, default: float = 0.0) -> float:
        """RTT (Round Trip Time) time taken to perform the inference, in seconds."""
        return self.total_duration if self.total_duration is not None else default

    def get_prediction_duration(self, default: float = 0.0) -> float:
        """Time taken to perform the inference from the model side, in seconds."""
        return self.prediction_duration if self.prediction_duration is not None else default

    def get_detection_delay(self, default: float = 0.0) -> float:
        """Total time from the onset of the speech to the final prediction, in seconds."""
        return self.detection_delay if self.detection_delay is not None else default

    def get_probability(self, default: float = 0.0) -> float:
        """The conservative estimated probability of the interruption event."""
        return (
            _estimate_probability(self.probabilities) if self.probabilities is not None else default
        )
var prediction_duration : float | None
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionCacheEntry:
    """Typed cache entry for interruption inference results."""

    created_at: int = field(default_factory=time.perf_counter_ns)
    """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation."""
    speech_input: npt.NDArray[np.int16] | None = None
    total_duration: float | None = None
    prediction_duration: float | None = None
    detection_delay: float | None = None
    probabilities: npt.NDArray[np.float32] | None = None
    is_interruption: bool | None = None

    def get_total_duration(self, default: float = 0.0) -> float:
        """RTT (Round Trip Time) time taken to perform the inference, in seconds."""
        return self.total_duration if self.total_duration is not None else default

    def get_prediction_duration(self, default: float = 0.0) -> float:
        """Time taken to perform the inference from the model side, in seconds."""
        return self.prediction_duration if self.prediction_duration is not None else default

    def get_detection_delay(self, default: float = 0.0) -> float:
        """Total time from the onset of the speech to the final prediction, in seconds."""
        return self.detection_delay if self.detection_delay is not None else default

    def get_probability(self, default: float = 0.0) -> float:
        """The conservative estimated probability of the interruption event."""
        return (
            _estimate_probability(self.probabilities) if self.probabilities is not None else default
        )
var probabilities : numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.float32]] | None
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionCacheEntry:
    """Typed cache entry for interruption inference results."""

    created_at: int = field(default_factory=time.perf_counter_ns)
    """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation."""
    speech_input: npt.NDArray[np.int16] | None = None
    total_duration: float | None = None
    prediction_duration: float | None = None
    detection_delay: float | None = None
    probabilities: npt.NDArray[np.float32] | None = None
    is_interruption: bool | None = None

    def get_total_duration(self, default: float = 0.0) -> float:
        """RTT (Round Trip Time) time taken to perform the inference, in seconds."""
        return self.total_duration if self.total_duration is not None else default

    def get_prediction_duration(self, default: float = 0.0) -> float:
        """Time taken to perform the inference from the model side, in seconds."""
        return self.prediction_duration if self.prediction_duration is not None else default

    def get_detection_delay(self, default: float = 0.0) -> float:
        """Total time from the onset of the speech to the final prediction, in seconds."""
        return self.detection_delay if self.detection_delay is not None else default

    def get_probability(self, default: float = 0.0) -> float:
        """The conservative estimated probability of the interruption event."""
        return (
            _estimate_probability(self.probabilities) if self.probabilities is not None else default
        )
var speech_input : numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.int16]] | None
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionCacheEntry:
    """Typed cache entry for interruption inference results."""

    created_at: int = field(default_factory=time.perf_counter_ns)
    """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation."""
    speech_input: npt.NDArray[np.int16] | None = None
    total_duration: float | None = None
    prediction_duration: float | None = None
    detection_delay: float | None = None
    probabilities: npt.NDArray[np.float32] | None = None
    is_interruption: bool | None = None

    def get_total_duration(self, default: float = 0.0) -> float:
        """RTT (Round Trip Time) time taken to perform the inference, in seconds."""
        return self.total_duration if self.total_duration is not None else default

    def get_prediction_duration(self, default: float = 0.0) -> float:
        """Time taken to perform the inference from the model side, in seconds."""
        return self.prediction_duration if self.prediction_duration is not None else default

    def get_detection_delay(self, default: float = 0.0) -> float:
        """Total time from the onset of the speech to the final prediction, in seconds."""
        return self.detection_delay if self.detection_delay is not None else default

    def get_probability(self, default: float = 0.0) -> float:
        """The conservative estimated probability of the interruption event."""
        return (
            _estimate_probability(self.probabilities) if self.probabilities is not None else default
        )
var total_duration : float | None
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionCacheEntry:
    """Typed cache entry for interruption inference results."""

    created_at: int = field(default_factory=time.perf_counter_ns)
    """The timestamp when the cache entry was created, in nanoseconds. Used only for indexing and latency calculation."""
    speech_input: npt.NDArray[np.int16] | None = None
    total_duration: float | None = None
    prediction_duration: float | None = None
    detection_delay: float | None = None
    probabilities: npt.NDArray[np.float32] | None = None
    is_interruption: bool | None = None

    def get_total_duration(self, default: float = 0.0) -> float:
        """RTT (Round Trip Time) time taken to perform the inference, in seconds."""
        return self.total_duration if self.total_duration is not None else default

    def get_prediction_duration(self, default: float = 0.0) -> float:
        """Time taken to perform the inference from the model side, in seconds."""
        return self.prediction_duration if self.prediction_duration is not None else default

    def get_detection_delay(self, default: float = 0.0) -> float:
        """Total time from the onset of the speech to the final prediction, in seconds."""
        return self.detection_delay if self.detection_delay is not None else default

    def get_probability(self, default: float = 0.0) -> float:
        """The conservative estimated probability of the interruption event."""
        return (
            _estimate_probability(self.probabilities) if self.probabilities is not None else default
        )

Methods

def get_detection_delay(self, default: float = 0.0) ‑> float
Expand source code
def get_detection_delay(self, default: float = 0.0) -> float:
    """Total time from the onset of the speech to the final prediction, in seconds."""
    return self.detection_delay if self.detection_delay is not None else default

Total time from the onset of the speech to the final prediction, in seconds.

def get_prediction_duration(self, default: float = 0.0) ‑> float
Expand source code
def get_prediction_duration(self, default: float = 0.0) -> float:
    """Time taken to perform the inference from the model side, in seconds."""
    return self.prediction_duration if self.prediction_duration is not None else default

Time taken to perform the inference from the model side, in seconds.

def get_probability(self, default: float = 0.0) ‑> float
Expand source code
def get_probability(self, default: float = 0.0) -> float:
    """The conservative estimated probability of the interruption event."""
    return (
        _estimate_probability(self.probabilities) if self.probabilities is not None else default
    )

The conservative estimated probability of the interruption event.

def get_total_duration(self, default: float = 0.0) ‑> float
Expand source code
def get_total_duration(self, default: float = 0.0) -> float:
    """RTT (Round Trip Time) time taken to perform the inference, in seconds."""
    return self.total_duration if self.total_duration is not None else default

RTT (Round Trip Time) time taken to perform the inference, in seconds.

class InterruptionDetectionError (**data: Any)
Expand source code
class InterruptionDetectionError(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    type: Literal["interruption_detection_error"] = "interruption_detection_error"
    timestamp: float = Field(default_factory=time.time)
    label: str
    error: Exception = Field(..., exclude=True)
    recoverable: bool

Usage Documentation

Models

A base class for creating Pydantic models.

Attributes

__class_vars__
The names of the class variables defined on the model.
__private_attributes__
Metadata about the private attributes of the model.
__signature__
The synthesized __init__ [Signature][inspect.Signature] of the model.
__pydantic_complete__
Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__
The core schema of the model.
__pydantic_custom_init__
Whether the model has a custom __init__ function.
__pydantic_decorators__
Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
__pydantic_generic_metadata__
Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__
Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__
The name of the post-init method for the model, if defined.
__pydantic_root_model__
Whether the model is a [RootModel][pydantic.root_model.RootModel].
__pydantic_serializer__
The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__
The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_fields__
A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__
A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__
A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
__pydantic_fields_set__
The names of fields explicitly set during instantiation.
__pydantic_private__
Values of private attributes set on the model instance.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Ancestors

  • pydantic.main.BaseModel

Class variables

var error : Exception
var label : str
var model_config
var recoverable : bool
var timestamp : float
var type : Literal['interruption_detection_error']
class InterruptionHttpStream (*,
model: AdaptiveInterruptionDetector,
conn_options: APIConnectOptions)
Expand source code
class InterruptionHttpStream(InterruptionStreamBase):
    def __init__(
        self, *, model: AdaptiveInterruptionDetector, conn_options: APIConnectOptions
    ) -> None:
        super().__init__(model=model, conn_options=conn_options)

    def update_options(
        self,
        *,
        threshold: NotGivenOr[float] = NOT_GIVEN,
        min_interruption_duration: NotGivenOr[float] = NOT_GIVEN,
    ) -> None:
        if is_given(threshold):
            self._opts.threshold = threshold
        if is_given(min_interruption_duration):
            self._opts.min_frames = math.ceil(min_interruption_duration * _FRAMES_PER_SECOND)

    async def _run(self) -> None:
        async def _send_task(input_ch: aio.Chan[npt.NDArray[np.int16]]) -> None:
            async for data in input_ch:
                if (
                    overlap_started_at := self._overlap_started_at
                ) is None or not self._overlap_started:
                    continue

                # we don't increment the request counter for hosted agents
                resp: InterruptionResponse = await self.predict(data)
                created_at = resp.created_at
                self._cache[created_at] = entry = InterruptionCacheEntry(
                    created_at=created_at,
                    speech_input=data,
                    prediction_duration=resp.prediction_duration,
                    total_duration=(time.perf_counter_ns() - created_at) / 1e9,
                    detection_delay=time.time() - overlap_started_at,
                    probabilities=resp.probabilities,
                    is_interruption=resp.is_bargein,
                )
                if entry.is_interruption and self._overlap_started:
                    logger.debug("user interruption detected")
                    if self._user_speech_span:
                        self._update_user_speech_span(self._user_speech_span, entry)
                        self._user_speech_span = None
                    ev = OverlappingSpeechEvent.from_cache_entry(
                        entry=entry,
                        is_interruption=True,
                        started_at=overlap_started_at,
                        ended_at=time.time(),
                    )
                    self.send(ev)
                    self._overlap_started = False

        data_ch = aio.Chan[npt.NDArray[np.int16]]()
        tasks = [
            asyncio.create_task(self._forward_data(data_ch)),
            asyncio.create_task(_send_task(data_ch)),
        ]
        try:
            await asyncio.gather(*tasks)
        finally:
            await aio.cancel_and_wait(*tasks)

    @log_exceptions(logger=logger)
    async def predict(self, waveform: np.ndarray) -> InterruptionResponse:
        created_at = perf_counter_ns()
        try:
            async with self._session.post(
                url=f"{self._opts.base_url}/bargein?threshold={self._opts.threshold}&min_frames={int(self._opts.min_frames)}&created_at={int(created_at)}",
                headers={
                    "Content-Type": "application/octet-stream",
                    "Authorization": f"Bearer {create_access_token(self._opts.api_key, self._opts.api_secret)}",
                },
                data=waveform.tobytes(),
                timeout=aiohttp.ClientTimeout(total=self._opts.inference_timeout),
            ) as resp:
                try:
                    resp.raise_for_status()
                    data: dict[str, Any] = await resp.json()
                    result = InterruptionResponse.model_validate(
                        data
                        | {
                            "prediction_duration": (time.perf_counter_ns() - created_at) / 1e9,
                            "probabilities": np.array(
                                data.get("probabilities", []), dtype=np.float32
                            ),
                        }
                    )

                    logger.trace(
                        "interruption inference done",
                        extra={
                            "created_at": created_at,
                            "is_interruption": result.is_bargein,
                            "prediction_duration": result.prediction_duration,
                        },
                    )
                    return result
                except Exception as e:
                    msg = await resp.text()
                    status_code = (
                        e.status if isinstance(e, aiohttp.ClientResponseError) else resp.status
                    )
                    raise APIStatusError(
                        f"error during interruption prediction: {e}",
                        body=msg,
                        status_code=status_code,
                        retryable=False if status_code == 429 else None,
                    ) from e
        except asyncio.TimeoutError as e:
            raise APIStatusError(
                f"interruption inference timeout: {e}",
                status_code=408,
                retryable=False,
            ) from e
        except aiohttp.ClientError as e:
            raise APIConnectionError(f"interruption inference connection error: {e}") from e
        except APIError as e:
            raise e
        except Exception as e:
            raise APIError(f"error during interruption prediction: {e}") from e

Helper class that provides a standard way to create an ABC using inheritance.

Ancestors

Methods

async def predict(self, waveform: np.ndarray) ‑> InterruptionResponse
Expand source code
@log_exceptions(logger=logger)
async def predict(self, waveform: np.ndarray) -> InterruptionResponse:
    created_at = perf_counter_ns()
    try:
        async with self._session.post(
            url=f"{self._opts.base_url}/bargein?threshold={self._opts.threshold}&min_frames={int(self._opts.min_frames)}&created_at={int(created_at)}",
            headers={
                "Content-Type": "application/octet-stream",
                "Authorization": f"Bearer {create_access_token(self._opts.api_key, self._opts.api_secret)}",
            },
            data=waveform.tobytes(),
            timeout=aiohttp.ClientTimeout(total=self._opts.inference_timeout),
        ) as resp:
            try:
                resp.raise_for_status()
                data: dict[str, Any] = await resp.json()
                result = InterruptionResponse.model_validate(
                    data
                    | {
                        "prediction_duration": (time.perf_counter_ns() - created_at) / 1e9,
                        "probabilities": np.array(
                            data.get("probabilities", []), dtype=np.float32
                        ),
                    }
                )

                logger.trace(
                    "interruption inference done",
                    extra={
                        "created_at": created_at,
                        "is_interruption": result.is_bargein,
                        "prediction_duration": result.prediction_duration,
                    },
                )
                return result
            except Exception as e:
                msg = await resp.text()
                status_code = (
                    e.status if isinstance(e, aiohttp.ClientResponseError) else resp.status
                )
                raise APIStatusError(
                    f"error during interruption prediction: {e}",
                    body=msg,
                    status_code=status_code,
                    retryable=False if status_code == 429 else None,
                ) from e
    except asyncio.TimeoutError as e:
        raise APIStatusError(
            f"interruption inference timeout: {e}",
            status_code=408,
            retryable=False,
        ) from e
    except aiohttp.ClientError as e:
        raise APIConnectionError(f"interruption inference connection error: {e}") from e
    except APIError as e:
        raise e
    except Exception as e:
        raise APIError(f"error during interruption prediction: {e}") from e
def update_options(self,
*,
threshold: NotGivenOr[float] = NOT_GIVEN,
min_interruption_duration: NotGivenOr[float] = NOT_GIVEN) ‑> None
Expand source code
def update_options(
    self,
    *,
    threshold: NotGivenOr[float] = NOT_GIVEN,
    min_interruption_duration: NotGivenOr[float] = NOT_GIVEN,
) -> None:
    if is_given(threshold):
        self._opts.threshold = threshold
    if is_given(min_interruption_duration):
        self._opts.min_frames = math.ceil(min_interruption_duration * _FRAMES_PER_SECOND)

Inherited members

class InterruptionOptions (*,
sample_rate: int,
threshold: float,
min_frames: int,
max_audio_duration: float,
audio_prefix_duration: float,
detection_interval: float,
inference_timeout: float,
base_url: str,
api_key: str,
api_secret: str,
use_proxy: bool)
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionOptions:
    sample_rate: int
    """The sample rate of the audio frames, defaults to 16000Hz"""
    threshold: float
    """The threshold for the interruption detection, defaults to 0.5"""
    min_frames: int
    """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames"""
    max_audio_duration: float
    """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds"""
    audio_prefix_duration: float
    """The audio prefix duration for the interruption detection, defaults to 1.0 seconds"""
    detection_interval: float
    """The interval between detections, defaults to 0.1 seconds"""
    inference_timeout: float
    """The timeout for the interruption detection, defaults to 1 second"""
    base_url: str
    api_key: str
    api_secret: str
    use_proxy: bool
    """Whether to use the inference instead of the hosted API"""

InterruptionOptions(*, sample_rate: 'int', threshold: 'float', min_frames: 'int', max_audio_duration: 'float', audio_prefix_duration: 'float', detection_interval: 'float', inference_timeout: 'float', base_url: 'str', api_key: 'str', api_secret: 'str', use_proxy: 'bool')

Instance variables

var api_key : str
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionOptions:
    sample_rate: int
    """The sample rate of the audio frames, defaults to 16000Hz"""
    threshold: float
    """The threshold for the interruption detection, defaults to 0.5"""
    min_frames: int
    """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames"""
    max_audio_duration: float
    """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds"""
    audio_prefix_duration: float
    """The audio prefix duration for the interruption detection, defaults to 1.0 seconds"""
    detection_interval: float
    """The interval between detections, defaults to 0.1 seconds"""
    inference_timeout: float
    """The timeout for the interruption detection, defaults to 1 second"""
    base_url: str
    api_key: str
    api_secret: str
    use_proxy: bool
    """Whether to use the inference instead of the hosted API"""
var api_secret : str
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionOptions:
    sample_rate: int
    """The sample rate of the audio frames, defaults to 16000Hz"""
    threshold: float
    """The threshold for the interruption detection, defaults to 0.5"""
    min_frames: int
    """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames"""
    max_audio_duration: float
    """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds"""
    audio_prefix_duration: float
    """The audio prefix duration for the interruption detection, defaults to 1.0 seconds"""
    detection_interval: float
    """The interval between detections, defaults to 0.1 seconds"""
    inference_timeout: float
    """The timeout for the interruption detection, defaults to 1 second"""
    base_url: str
    api_key: str
    api_secret: str
    use_proxy: bool
    """Whether to use the inference instead of the hosted API"""
var audio_prefix_duration : float
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionOptions:
    sample_rate: int
    """The sample rate of the audio frames, defaults to 16000Hz"""
    threshold: float
    """The threshold for the interruption detection, defaults to 0.5"""
    min_frames: int
    """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames"""
    max_audio_duration: float
    """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds"""
    audio_prefix_duration: float
    """The audio prefix duration for the interruption detection, defaults to 1.0 seconds"""
    detection_interval: float
    """The interval between detections, defaults to 0.1 seconds"""
    inference_timeout: float
    """The timeout for the interruption detection, defaults to 1 second"""
    base_url: str
    api_key: str
    api_secret: str
    use_proxy: bool
    """Whether to use the inference instead of the hosted API"""

The audio prefix duration for the interruption detection, defaults to 1.0 seconds

var base_url : str
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionOptions:
    sample_rate: int
    """The sample rate of the audio frames, defaults to 16000Hz"""
    threshold: float
    """The threshold for the interruption detection, defaults to 0.5"""
    min_frames: int
    """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames"""
    max_audio_duration: float
    """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds"""
    audio_prefix_duration: float
    """The audio prefix duration for the interruption detection, defaults to 1.0 seconds"""
    detection_interval: float
    """The interval between detections, defaults to 0.1 seconds"""
    inference_timeout: float
    """The timeout for the interruption detection, defaults to 1 second"""
    base_url: str
    api_key: str
    api_secret: str
    use_proxy: bool
    """Whether to use the inference instead of the hosted API"""
var detection_interval : float
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionOptions:
    sample_rate: int
    """The sample rate of the audio frames, defaults to 16000Hz"""
    threshold: float
    """The threshold for the interruption detection, defaults to 0.5"""
    min_frames: int
    """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames"""
    max_audio_duration: float
    """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds"""
    audio_prefix_duration: float
    """The audio prefix duration for the interruption detection, defaults to 1.0 seconds"""
    detection_interval: float
    """The interval between detections, defaults to 0.1 seconds"""
    inference_timeout: float
    """The timeout for the interruption detection, defaults to 1 second"""
    base_url: str
    api_key: str
    api_secret: str
    use_proxy: bool
    """Whether to use the inference instead of the hosted API"""

The interval between detections, defaults to 0.1 seconds

var inference_timeout : float
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionOptions:
    sample_rate: int
    """The sample rate of the audio frames, defaults to 16000Hz"""
    threshold: float
    """The threshold for the interruption detection, defaults to 0.5"""
    min_frames: int
    """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames"""
    max_audio_duration: float
    """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds"""
    audio_prefix_duration: float
    """The audio prefix duration for the interruption detection, defaults to 1.0 seconds"""
    detection_interval: float
    """The interval between detections, defaults to 0.1 seconds"""
    inference_timeout: float
    """The timeout for the interruption detection, defaults to 1 second"""
    base_url: str
    api_key: str
    api_secret: str
    use_proxy: bool
    """Whether to use the inference instead of the hosted API"""

The timeout for the interruption detection, defaults to 1 second

var max_audio_duration : float
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionOptions:
    sample_rate: int
    """The sample rate of the audio frames, defaults to 16000Hz"""
    threshold: float
    """The threshold for the interruption detection, defaults to 0.5"""
    min_frames: int
    """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames"""
    max_audio_duration: float
    """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds"""
    audio_prefix_duration: float
    """The audio prefix duration for the interruption detection, defaults to 1.0 seconds"""
    detection_interval: float
    """The interval between detections, defaults to 0.1 seconds"""
    inference_timeout: float
    """The timeout for the interruption detection, defaults to 1 second"""
    base_url: str
    api_key: str
    api_secret: str
    use_proxy: bool
    """Whether to use the inference instead of the hosted API"""

The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds

var min_frames : int
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionOptions:
    sample_rate: int
    """The sample rate of the audio frames, defaults to 16000Hz"""
    threshold: float
    """The threshold for the interruption detection, defaults to 0.5"""
    min_frames: int
    """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames"""
    max_audio_duration: float
    """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds"""
    audio_prefix_duration: float
    """The audio prefix duration for the interruption detection, defaults to 1.0 seconds"""
    detection_interval: float
    """The interval between detections, defaults to 0.1 seconds"""
    inference_timeout: float
    """The timeout for the interruption detection, defaults to 1 second"""
    base_url: str
    api_key: str
    api_secret: str
    use_proxy: bool
    """Whether to use the inference instead of the hosted API"""

The minimum number of frames to detect a interruption, defaults to 50ms/2 frames

var sample_rate : int
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionOptions:
    sample_rate: int
    """The sample rate of the audio frames, defaults to 16000Hz"""
    threshold: float
    """The threshold for the interruption detection, defaults to 0.5"""
    min_frames: int
    """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames"""
    max_audio_duration: float
    """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds"""
    audio_prefix_duration: float
    """The audio prefix duration for the interruption detection, defaults to 1.0 seconds"""
    detection_interval: float
    """The interval between detections, defaults to 0.1 seconds"""
    inference_timeout: float
    """The timeout for the interruption detection, defaults to 1 second"""
    base_url: str
    api_key: str
    api_secret: str
    use_proxy: bool
    """Whether to use the inference instead of the hosted API"""

The sample rate of the audio frames, defaults to 16000Hz

var threshold : float
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionOptions:
    sample_rate: int
    """The sample rate of the audio frames, defaults to 16000Hz"""
    threshold: float
    """The threshold for the interruption detection, defaults to 0.5"""
    min_frames: int
    """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames"""
    max_audio_duration: float
    """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds"""
    audio_prefix_duration: float
    """The audio prefix duration for the interruption detection, defaults to 1.0 seconds"""
    detection_interval: float
    """The interval between detections, defaults to 0.1 seconds"""
    inference_timeout: float
    """The timeout for the interruption detection, defaults to 1 second"""
    base_url: str
    api_key: str
    api_secret: str
    use_proxy: bool
    """Whether to use the inference instead of the hosted API"""

The threshold for the interruption detection, defaults to 0.5

var use_proxy : bool
Expand source code
@dataclass(slots=True, kw_only=True)
class InterruptionOptions:
    sample_rate: int
    """The sample rate of the audio frames, defaults to 16000Hz"""
    threshold: float
    """The threshold for the interruption detection, defaults to 0.5"""
    min_frames: int
    """The minimum number of frames to detect a interruption, defaults to 50ms/2 frames"""
    max_audio_duration: float
    """The maximum audio duration for the interruption detection, including the audio prefix, defaults to 3 seconds"""
    audio_prefix_duration: float
    """The audio prefix duration for the interruption detection, defaults to 1.0 seconds"""
    detection_interval: float
    """The interval between detections, defaults to 0.1 seconds"""
    inference_timeout: float
    """The timeout for the interruption detection, defaults to 1 second"""
    base_url: str
    api_key: str
    api_secret: str
    use_proxy: bool
    """Whether to use the inference instead of the hosted API"""

Whether to use the inference instead of the hosted API

class InterruptionResponse (**data: Any)
Expand source code
class InterruptionResponse(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    created_at: int
    is_bargein: bool
    prediction_duration: float
    probabilities: npt.NDArray[np.float32] = Field(..., exclude=True)

Usage Documentation

Models

A base class for creating Pydantic models.

Attributes

__class_vars__
The names of the class variables defined on the model.
__private_attributes__
Metadata about the private attributes of the model.
__signature__
The synthesized __init__ [Signature][inspect.Signature] of the model.
__pydantic_complete__
Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__
The core schema of the model.
__pydantic_custom_init__
Whether the model has a custom __init__ function.
__pydantic_decorators__
Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
__pydantic_generic_metadata__
Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__
Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__
The name of the post-init method for the model, if defined.
__pydantic_root_model__
Whether the model is a [RootModel][pydantic.root_model.RootModel].
__pydantic_serializer__
The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__
The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_fields__
A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__
A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__
A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
__pydantic_fields_set__
The names of fields explicitly set during instantiation.
__pydantic_private__
Values of private attributes set on the model instance.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Ancestors

  • pydantic.main.BaseModel

Class variables

var created_at : int
var is_bargein : bool
var model_config
var prediction_duration : float
var probabilities : numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.float32]]
class InterruptionStreamBase (*,
model: AdaptiveInterruptionDetector,
conn_options: APIConnectOptions)
Expand source code
class InterruptionStreamBase(ABC):
    def __init__(
        self, *, model: AdaptiveInterruptionDetector, conn_options: APIConnectOptions
    ) -> None:
        self._model = model
        self._opts = model._opts
        self._session = model._ensure_session()

        self._input_ch = aio.Chan[InterruptionDataFrameType]()
        self._event_ch = aio.Chan[OverlappingSpeechEvent]()
        self._audio_buffer = AudioArrayBuffer(
            buffer_size=int(self._opts.max_audio_duration * self._opts.sample_rate),
            dtype=np.int16,
            sample_rate=self._opts.sample_rate,
        )
        self._cache = BoundedDict[int, InterruptionCacheEntry](maxsize=10)
        self._tee_aiter = aio.itertools.tee(self._event_ch, 2)
        self._event_aiter, monitor_aiter = self._tee_aiter
        self._metrics_task = asyncio.create_task(
            self._metrics_monitor_task(monitor_aiter), name="InterruptionStreamBase._metrics_task"
        )
        self._task = asyncio.create_task(self._main_task())
        self._task.add_done_callback(lambda _: self._event_ch.close())

        self._num_retries = 0
        self._conn_options = conn_options
        self._sample_rate = self._opts.sample_rate

        self._overlap_started_at: float | None = None
        self._user_speech_span: trace.Span | None = None
        self._agent_speech_started: bool = False
        self._overlap_started: bool = False
        self._overlap_count: int = 0
        self._accumulated_samples: int = 0
        self._num_requests = aio.AsyncAtomicCounter(initial=0)
        self._batch_size: int = int(self._opts.detection_interval * self._opts.sample_rate)
        self._prefix_size: int = int(self._opts.audio_prefix_duration * self._opts.sample_rate)

    @abstractmethod
    async def _run(self) -> None: ...

    @log_exceptions(logger=logger)
    async def _main_task(self) -> None:
        max_retries = self._conn_options.max_retry

        while self._num_retries <= max_retries:
            try:
                return await self._run()
            except APIError as e:
                if max_retries == 0 or not e.retryable:
                    self._emit_error(e, recoverable=False)
                    raise
                elif self._num_retries == max_retries:
                    self._emit_error(e, recoverable=False)
                    raise APIConnectionError(
                        f"failed to detect interruption after {self._num_retries} attempts",
                    ) from e
                else:
                    self._emit_error(e, recoverable=True)

                    retry_interval = self._conn_options._interval_for_retry(self._num_retries)
                    logger.warning(
                        f"failed to detect interruption, retrying in {retry_interval}s",
                        exc_info=e,
                        extra={
                            "model": self._model._label,
                            "attempt": self._num_retries,
                        },
                    )
                    await asyncio.sleep(retry_interval)

                self._num_retries += 1

            except Exception as e:
                self._emit_error(e, recoverable=False)
                raise

    def _emit_error(self, api_error: Exception, recoverable: bool) -> None:
        self._model._emit_error(api_error, recoverable)

    def push_frame(self, frame: InterruptionDataFrameType) -> None:
        """Push some audio frame to be analyzed"""
        self._check_input_not_ended()
        self._check_not_closed()
        self._input_ch.send_nowait(frame)

    def flush(self) -> None:
        """Mark the end of the current segment"""
        self._check_input_not_ended()
        self._check_not_closed()
        self._input_ch.send_nowait(_FlushSentinel())

    def end_input(self) -> None:
        """Mark the end of input, no more audio will be pushed"""
        self.flush()
        self._input_ch.close()

    async def aclose(self) -> None:
        """Close the stream immediately"""
        self._input_ch.close()
        await aio.cancel_and_wait(self._task)
        self._event_ch.close()
        try:
            await self._metrics_task
        finally:
            await self._tee_aiter.aclose()

    async def __anext__(self) -> OverlappingSpeechEvent:
        try:
            val = await self._event_aiter.__anext__()
        except StopAsyncIteration:
            if not self._task.cancelled() and (exc := self._task.exception()):
                raise exc  # noqa: B904

            raise StopAsyncIteration from None

        return val

    def __aiter__(self) -> AsyncIterator[OverlappingSpeechEvent]:
        return self

    def _check_not_closed(self) -> None:
        if self._event_ch.closed:
            cls = type(self)
            raise RuntimeError(f"{cls.__module__}.{cls.__name__} is closed")

    def _check_input_not_ended(self) -> None:
        if self._input_ch.closed:
            cls = type(self)
            raise RuntimeError(f"{cls.__module__}.{cls.__name__} input ended")

    @staticmethod
    def _update_user_speech_span(
        user_speech_span: trace.Span, entry: InterruptionCacheEntry
    ) -> None:
        user_speech_span.set_attribute(
            trace_types.ATTR_IS_INTERRUPTION, str(entry.is_interruption).lower()
        )
        user_speech_span.set_attribute(
            trace_types.ATTR_INTERRUPTION_PROBABILITY, entry.get_probability()
        )
        user_speech_span.set_attribute(
            trace_types.ATTR_INTERRUPTION_TOTAL_DURATION, entry.get_total_duration()
        )
        user_speech_span.set_attribute(
            trace_types.ATTR_INTERRUPTION_PREDICTION_DURATION, entry.get_prediction_duration()
        )
        user_speech_span.set_attribute(
            trace_types.ATTR_INTERRUPTION_DETECTION_DELAY, entry.get_detection_delay()
        )

    async def _forward_data(self, output_ch: aio.Chan[npt.NDArray[np.int16]]) -> None:
        """Preprocess the audio data and forward it to the output channel for inference."""

        async def _reset_state() -> None:
            self._agent_speech_started = False
            self._overlap_started = False
            self._overlap_count = 0
            self._accumulated_samples = 0
            await self._num_requests.set(0)

            self._audio_buffer.reset()
            self._cache.clear()
            self._user_speech_span = None

        async for input_frame in self._input_ch:
            match input_frame:
                case _FlushSentinel():
                    continue
                case _AgentSpeechStartedSentinel() | _AgentSpeechEndedSentinel():
                    await _reset_state()
                    self._agent_speech_started = isinstance(
                        input_frame, _AgentSpeechStartedSentinel
                    )
                    continue
                case _OverlapSpeechStartedSentinel() if self._agent_speech_started:
                    self._overlap_started_at = input_frame._started_at
                    self._user_speech_span = input_frame._user_speaking_span
                    self._overlap_started = True
                    self._accumulated_samples = 0
                    self._overlap_count += 1
                    # include the audio prefix in the window and
                    # only shift (remove leading silence) when the first overlap speech started
                    # otherwise, keep the existing data
                    if self._overlap_count == 1:
                        shift_size = max(
                            0,
                            len(self._audio_buffer)
                            - (
                                int(input_frame._speech_duration * self._sample_rate)
                                + self._prefix_size
                            ),
                        )
                        self._audio_buffer.shift(shift_size)
                    logger.trace(
                        "overlap speech started, starting interruption inference",
                        extra={
                            "overlap_count": self._overlap_count,
                        },
                    )
                    self._cache.clear()
                    continue
                case _OverlapSpeechEndedSentinel():
                    if self._overlap_started and self._overlap_started_at is not None:
                        logger.trace("overlap speech ended, stopping interruption inference")
                        self._user_speech_span = None
                        _, last_entry = self._cache.pop_if(
                            lambda entry: (
                                entry.total_duration is not None and entry.total_duration > 0
                            )
                        )
                        if last_entry is None:
                            logger.trace("no request made for overlap speech")
                        ev = OverlappingSpeechEvent.from_cache_entry(
                            entry=last_entry or _EMPTY_CACHE_ENTRY,
                            is_interruption=False,
                            started_at=self._overlap_started_at,
                            ended_at=input_frame._ended_at,
                        )
                        ev.num_requests = await self._num_requests.get_and_reset()
                        self.send(ev)

                    self._overlap_started = False
                    self._accumulated_samples = 0
                    self._overlap_started_at = None
                    # we don't clear the cache here since responses might be in flight
                case rtc.AudioFrame() if self._agent_speech_started:
                    samples_written = self._audio_buffer.push_frame(input_frame)
                    self._accumulated_samples += samples_written
                    if self._accumulated_samples >= self._batch_size and self._overlap_started:
                        output_ch.send_nowait(self._audio_buffer.read())
                        self._accumulated_samples = 0

        output_ch.close()

    def send(self, event: OverlappingSpeechEvent) -> None:
        self._event_ch.send_nowait(event)
        self._model.emit(event.type, event)

    @utils.log_exceptions(logger=logger)
    async def _metrics_monitor_task(
        self, event_aiter: AsyncIterable[OverlappingSpeechEvent]
    ) -> None:
        async for ev in event_aiter:
            metrics = InterruptionMetrics(
                timestamp=time.time(),
                total_duration=ev.total_duration,
                prediction_duration=ev.prediction_duration,
                detection_delay=ev.detection_delay,
                num_interruptions=1 if ev.is_interruption else 0,
                num_backchannels=1 if not ev.is_interruption else 0,
                num_requests=ev.num_requests,
                metadata=Metadata(
                    model_name=self._model.model, model_provider=self._model.provider
                ),
            )
            self._model.emit("metrics_collected", metrics)

Helper class that provides a standard way to create an ABC using inheritance.

Ancestors

  • abc.ABC

Subclasses

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    """Close the stream immediately"""
    self._input_ch.close()
    await aio.cancel_and_wait(self._task)
    self._event_ch.close()
    try:
        await self._metrics_task
    finally:
        await self._tee_aiter.aclose()

Close the stream immediately

def end_input(self) ‑> None
Expand source code
def end_input(self) -> None:
    """Mark the end of input, no more audio will be pushed"""
    self.flush()
    self._input_ch.close()

Mark the end of input, no more audio will be pushed

def flush(self) ‑> None
Expand source code
def flush(self) -> None:
    """Mark the end of the current segment"""
    self._check_input_not_ended()
    self._check_not_closed()
    self._input_ch.send_nowait(_FlushSentinel())

Mark the end of the current segment

def push_frame(self, frame: InterruptionDataFrameType) ‑> None
Expand source code
def push_frame(self, frame: InterruptionDataFrameType) -> None:
    """Push some audio frame to be analyzed"""
    self._check_input_not_ended()
    self._check_not_closed()
    self._input_ch.send_nowait(frame)

Push some audio frame to be analyzed

def send(self,
event: OverlappingSpeechEvent) ‑> None
Expand source code
def send(self, event: OverlappingSpeechEvent) -> None:
    self._event_ch.send_nowait(event)
    self._model.emit(event.type, event)
class InterruptionWSDetectedMessage (**data: Any)
Expand source code
class InterruptionWSDetectedMessage(BaseModel):
    type: Literal[InterruptionWSMessageType.INTERRUPTION_DETECTED] = (
        InterruptionWSMessageType.INTERRUPTION_DETECTED
    )
    created_at: int
    prediction_duration: float = Field(default=0.0)
    probabilities: list[float] = Field(default_factory=list)

Usage Documentation

Models

A base class for creating Pydantic models.

Attributes

__class_vars__
The names of the class variables defined on the model.
__private_attributes__
Metadata about the private attributes of the model.
__signature__
The synthesized __init__ [Signature][inspect.Signature] of the model.
__pydantic_complete__
Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__
The core schema of the model.
__pydantic_custom_init__
Whether the model has a custom __init__ function.
__pydantic_decorators__
Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
__pydantic_generic_metadata__
Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__
Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__
The name of the post-init method for the model, if defined.
__pydantic_root_model__
Whether the model is a [RootModel][pydantic.root_model.RootModel].
__pydantic_serializer__
The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__
The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_fields__
A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__
A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__
A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
__pydantic_fields_set__
The names of fields explicitly set during instantiation.
__pydantic_private__
Values of private attributes set on the model instance.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Ancestors

  • pydantic.main.BaseModel

Class variables

var created_at : int
var model_config
var prediction_duration : float
var probabilities : list[float]
var type : Literal[<InterruptionWSMessageType.INTERRUPTION_DETECTED: 'bargein_detected'>]
class InterruptionWSErrorMessage (**data: Any)
Expand source code
class InterruptionWSErrorMessage(BaseModel):
    type: Literal[InterruptionWSMessageType.ERROR] = InterruptionWSMessageType.ERROR
    message: str
    code: int
    session_id: str

Usage Documentation

Models

A base class for creating Pydantic models.

Attributes

__class_vars__
The names of the class variables defined on the model.
__private_attributes__
Metadata about the private attributes of the model.
__signature__
The synthesized __init__ [Signature][inspect.Signature] of the model.
__pydantic_complete__
Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__
The core schema of the model.
__pydantic_custom_init__
Whether the model has a custom __init__ function.
__pydantic_decorators__
Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
__pydantic_generic_metadata__
Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__
Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__
The name of the post-init method for the model, if defined.
__pydantic_root_model__
Whether the model is a [RootModel][pydantic.root_model.RootModel].
__pydantic_serializer__
The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__
The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_fields__
A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__
A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__
A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
__pydantic_fields_set__
The names of fields explicitly set during instantiation.
__pydantic_private__
Values of private attributes set on the model instance.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Ancestors

  • pydantic.main.BaseModel

Class variables

var code : int
var message : str
var model_config
var session_id : str
var type : Literal[<InterruptionWSMessageType.ERROR: 'error'>]
class InterruptionWSInferenceDoneMessage (**data: Any)
Expand source code
class InterruptionWSInferenceDoneMessage(BaseModel):
    type: Literal[InterruptionWSMessageType.INFERENCE_DONE] = (
        InterruptionWSMessageType.INFERENCE_DONE
    )
    created_at: int
    prediction_duration: float = Field(default=0.0)
    probabilities: list[float] = Field(default_factory=list)

Usage Documentation

Models

A base class for creating Pydantic models.

Attributes

__class_vars__
The names of the class variables defined on the model.
__private_attributes__
Metadata about the private attributes of the model.
__signature__
The synthesized __init__ [Signature][inspect.Signature] of the model.
__pydantic_complete__
Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__
The core schema of the model.
__pydantic_custom_init__
Whether the model has a custom __init__ function.
__pydantic_decorators__
Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
__pydantic_generic_metadata__
Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__
Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__
The name of the post-init method for the model, if defined.
__pydantic_root_model__
Whether the model is a [RootModel][pydantic.root_model.RootModel].
__pydantic_serializer__
The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__
The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_fields__
A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__
A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__
A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
__pydantic_fields_set__
The names of fields explicitly set during instantiation.
__pydantic_private__
Values of private attributes set on the model instance.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Ancestors

  • pydantic.main.BaseModel

Class variables

var created_at : int
var model_config
var prediction_duration : float
var probabilities : list[float]
var type : Literal[<InterruptionWSMessageType.INFERENCE_DONE: 'inference_done'>]
class InterruptionWSMessageType (*args, **kwds)
Expand source code
class InterruptionWSMessageType(str, Enum):
    SESSION_CREATE = "session.create"
    SESSION_CLOSE = "session.close"
    SESSION_CREATED = "session.created"
    SESSION_CLOSED = "session.closed"
    INTERRUPTION_DETECTED = "bargein_detected"
    INFERENCE_DONE = "inference_done"
    ERROR = "error"

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.

Ancestors

  • builtins.str
  • enum.Enum

Class variables

var ERROR
var INFERENCE_DONE
var INTERRUPTION_DETECTED
var SESSION_CLOSE
var SESSION_CLOSED
var SESSION_CREATE
var SESSION_CREATED
class InterruptionWSSessionCloseMessage (**data: Any)
Expand source code
class InterruptionWSSessionCloseMessage(BaseModel):
    type: Literal[InterruptionWSMessageType.SESSION_CLOSE] = InterruptionWSMessageType.SESSION_CLOSE

Usage Documentation

Models

A base class for creating Pydantic models.

Attributes

__class_vars__
The names of the class variables defined on the model.
__private_attributes__
Metadata about the private attributes of the model.
__signature__
The synthesized __init__ [Signature][inspect.Signature] of the model.
__pydantic_complete__
Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__
The core schema of the model.
__pydantic_custom_init__
Whether the model has a custom __init__ function.
__pydantic_decorators__
Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
__pydantic_generic_metadata__
Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__
Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__
The name of the post-init method for the model, if defined.
__pydantic_root_model__
Whether the model is a [RootModel][pydantic.root_model.RootModel].
__pydantic_serializer__
The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__
The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_fields__
A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__
A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__
A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
__pydantic_fields_set__
The names of fields explicitly set during instantiation.
__pydantic_private__
Values of private attributes set on the model instance.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Ancestors

  • pydantic.main.BaseModel

Class variables

var model_config
var type : Literal[<InterruptionWSMessageType.SESSION_CLOSE: 'session.close'>]
class InterruptionWSSessionClosedMessage (**data: Any)
Expand source code
class InterruptionWSSessionClosedMessage(BaseModel):
    type: Literal[InterruptionWSMessageType.SESSION_CLOSED] = (
        InterruptionWSMessageType.SESSION_CLOSED
    )

Usage Documentation

Models

A base class for creating Pydantic models.

Attributes

__class_vars__
The names of the class variables defined on the model.
__private_attributes__
Metadata about the private attributes of the model.
__signature__
The synthesized __init__ [Signature][inspect.Signature] of the model.
__pydantic_complete__
Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__
The core schema of the model.
__pydantic_custom_init__
Whether the model has a custom __init__ function.
__pydantic_decorators__
Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
__pydantic_generic_metadata__
Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__
Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__
The name of the post-init method for the model, if defined.
__pydantic_root_model__
Whether the model is a [RootModel][pydantic.root_model.RootModel].
__pydantic_serializer__
The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__
The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_fields__
A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__
A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__
A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
__pydantic_fields_set__
The names of fields explicitly set during instantiation.
__pydantic_private__
Values of private attributes set on the model instance.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Ancestors

  • pydantic.main.BaseModel

Class variables

var model_config
var type : Literal[<InterruptionWSMessageType.SESSION_CLOSED: 'session.closed'>]
class InterruptionWSSessionCreateMessage (**data: Any)
Expand source code
class InterruptionWSSessionCreateMessage(BaseModel):
    type: Literal[InterruptionWSMessageType.SESSION_CREATE] = (
        InterruptionWSMessageType.SESSION_CREATE
    )
    settings: InterruptionWSSessionCreateSettings

Usage Documentation

Models

A base class for creating Pydantic models.

Attributes

__class_vars__
The names of the class variables defined on the model.
__private_attributes__
Metadata about the private attributes of the model.
__signature__
The synthesized __init__ [Signature][inspect.Signature] of the model.
__pydantic_complete__
Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__
The core schema of the model.
__pydantic_custom_init__
Whether the model has a custom __init__ function.
__pydantic_decorators__
Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
__pydantic_generic_metadata__
Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__
Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__
The name of the post-init method for the model, if defined.
__pydantic_root_model__
Whether the model is a [RootModel][pydantic.root_model.RootModel].
__pydantic_serializer__
The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__
The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_fields__
A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__
A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__
A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
__pydantic_fields_set__
The names of fields explicitly set during instantiation.
__pydantic_private__
Values of private attributes set on the model instance.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Ancestors

  • pydantic.main.BaseModel

Class variables

var model_config
var settingsInterruptionWSSessionCreateSettings
var type : Literal[<InterruptionWSMessageType.SESSION_CREATE: 'session.create'>]
class InterruptionWSSessionCreateSettings (**data: Any)
Expand source code
class InterruptionWSSessionCreateSettings(BaseModel):
    sample_rate: int
    num_channels: int
    threshold: float
    min_frames: int
    encoding: Literal["s16le"]

Usage Documentation

Models

A base class for creating Pydantic models.

Attributes

__class_vars__
The names of the class variables defined on the model.
__private_attributes__
Metadata about the private attributes of the model.
__signature__
The synthesized __init__ [Signature][inspect.Signature] of the model.
__pydantic_complete__
Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__
The core schema of the model.
__pydantic_custom_init__
Whether the model has a custom __init__ function.
__pydantic_decorators__
Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
__pydantic_generic_metadata__
Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__
Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__
The name of the post-init method for the model, if defined.
__pydantic_root_model__
Whether the model is a [RootModel][pydantic.root_model.RootModel].
__pydantic_serializer__
The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__
The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_fields__
A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__
A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__
A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
__pydantic_fields_set__
The names of fields explicitly set during instantiation.
__pydantic_private__
Values of private attributes set on the model instance.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Ancestors

  • pydantic.main.BaseModel

Class variables

var encoding : Literal['s16le']
var min_frames : int
var model_config
var num_channels : int
var sample_rate : int
var threshold : float
class InterruptionWSSessionCreatedMessage (**data: Any)
Expand source code
class InterruptionWSSessionCreatedMessage(BaseModel):
    type: Literal[InterruptionWSMessageType.SESSION_CREATED] = (
        InterruptionWSMessageType.SESSION_CREATED
    )

Usage Documentation

Models

A base class for creating Pydantic models.

Attributes

__class_vars__
The names of the class variables defined on the model.
__private_attributes__
Metadata about the private attributes of the model.
__signature__
The synthesized __init__ [Signature][inspect.Signature] of the model.
__pydantic_complete__
Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__
The core schema of the model.
__pydantic_custom_init__
Whether the model has a custom __init__ function.
__pydantic_decorators__
Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
__pydantic_generic_metadata__
Metadata for generic models; contains data used for a similar purpose to args, origin, parameters in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__
Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__
The name of the post-init method for the model, if defined.
__pydantic_root_model__
Whether the model is a [RootModel][pydantic.root_model.RootModel].
__pydantic_serializer__
The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__
The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_fields__
A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__
A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__
A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
__pydantic_fields_set__
The names of fields explicitly set during instantiation.
__pydantic_private__
Values of private attributes set on the model instance.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Ancestors

  • pydantic.main.BaseModel

Class variables

var model_config
var type : Literal[<InterruptionWSMessageType.SESSION_CREATED: 'session.created'>]
class InterruptionWebSocketStream (*,
model: AdaptiveInterruptionDetector,
conn_options: APIConnectOptions)
Expand source code
class InterruptionWebSocketStream(InterruptionStreamBase):
    def __init__(
        self, *, model: AdaptiveInterruptionDetector, conn_options: APIConnectOptions
    ) -> None:
        super().__init__(model=model, conn_options=conn_options)
        self._request_id = str(shortuuid("interruption_request_"))
        self._reconnect_event = asyncio.Event()

    def update_options(
        self,
        *,
        threshold: NotGivenOr[float] = NOT_GIVEN,
        min_interruption_duration: NotGivenOr[float] = NOT_GIVEN,
    ) -> None:
        if is_given(threshold):
            self._opts.threshold = threshold
        if is_given(min_interruption_duration):
            self._opts.min_frames = math.ceil(min_interruption_duration * _FRAMES_PER_SECOND)
        self._reconnect_event.set()

    async def _run(self) -> None:
        closing_ws = False

        async def send_task(
            ws: aiohttp.ClientWebSocketResponse, input_ch: aio.Chan[npt.NDArray[np.int16]]
        ) -> None:
            nonlocal closing_ws
            timeout_ns = int(self._opts.inference_timeout * 1e9)

            async for audio_data in input_ch:
                now = perf_counter_ns()
                for _key, entry in self._cache.items():
                    if entry.total_duration is not None:
                        continue
                    if now - entry.created_at > timeout_ns:
                        raise APIStatusError(
                            f"interruption inference timed out after "
                            f"{(now - entry.created_at) / 1e9:.1f}s (ws)",
                            status_code=408,
                            retryable=False,
                        )
                    break  # oldest unanswered entry is still within timeout

                await self._num_requests.increment()
                created_at = perf_counter_ns()
                header = struct.pack("<Q", created_at)  # 8 bytes
                await ws.send_bytes(header + audio_data.tobytes())
                self._cache[created_at] = InterruptionCacheEntry(
                    created_at=created_at,
                    speech_input=audio_data,
                )

            closing_ws = True
            msg = InterruptionWSSessionCloseMessage(
                type=InterruptionWSMessageType.SESSION_CLOSE,
            )
            await ws.send_str(msg.model_dump_json())

        async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
            nonlocal closing_ws

            while True:
                ws_msg = await ws.receive()
                if ws_msg.type in (
                    aiohttp.WSMsgType.CLOSED,
                    aiohttp.WSMsgType.CLOSE,
                    aiohttp.WSMsgType.CLOSING,
                ):
                    if closing_ws or self._session.closed:
                        return
                    raise APIStatusError(
                        message=f"LiveKit Adaptive Interruption connection closed unexpectedly: {ws_msg.data}",
                        status_code=ws.close_code or -1,
                        body=f"{ws_msg.data=} {ws_msg.extra=}",
                    )

                if ws_msg.type != aiohttp.WSMsgType.TEXT:
                    logger.warning(
                        "unexpected LiveKit Adaptive Interruption message type %s", ws_msg.type
                    )
                    continue

                data = json.loads(ws_msg.data)
                msg: AnyInterruptionWSMessage = InterruptionWSMessage.validate_python(data)

                match msg:
                    case (
                        InterruptionWSSessionCreatedMessage() | InterruptionWSSessionClosedMessage()
                    ):
                        pass
                    case InterruptionWSDetectedMessage():
                        created_at = msg.created_at
                        if (
                            overlap_started_at := self._overlap_started_at
                        ) is not None and self._overlap_started:
                            entry = self._cache.set_or_update(
                                created_at,
                                lambda c=created_at: InterruptionCacheEntry(created_at=c),  # type: ignore[misc]
                                total_duration=(perf_counter_ns() - created_at) / 1e9,
                                probabilities=np.array(msg.probabilities, dtype=np.float32),
                                is_interruption=True,
                                prediction_duration=msg.prediction_duration,
                                detection_delay=time.time() - overlap_started_at,
                            )
                            if self._user_speech_span:
                                self._update_user_speech_span(self._user_speech_span, entry)
                                self._user_speech_span = None
                            logger.debug(
                                "interruption detected",
                                extra={
                                    "total_duration": entry.get_total_duration(),
                                    "prediction_duration": entry.get_prediction_duration(),
                                    "detection_delay": entry.get_detection_delay(),
                                    "probability": entry.get_probability(),
                                },
                            )
                            ev = OverlappingSpeechEvent.from_cache_entry(
                                entry=entry,
                                is_interruption=True,
                                started_at=overlap_started_at,
                                ended_at=time.time(),
                            )
                            ev.num_requests = await self._num_requests.get_and_reset()
                            self.send(ev)
                            self._overlap_started = False
                    case InterruptionWSInferenceDoneMessage():
                        created_at = msg.created_at
                        if (
                            overlap_started_at := self._overlap_started_at
                        ) is not None and self._overlap_started:
                            entry = self._cache.set_or_update(
                                created_at,
                                lambda c=created_at: InterruptionCacheEntry(created_at=c),  # type: ignore[misc]
                                total_duration=(perf_counter_ns() - created_at) / 1e9,
                                prediction_duration=msg.prediction_duration,
                                probabilities=np.array(msg.probabilities, dtype=np.float32),
                                is_interruption=False,
                                detection_delay=time.time() - overlap_started_at,
                            )
                            logger.trace(
                                "interruption inference done",
                                extra={
                                    "total_duration": entry.get_total_duration(),
                                    "prediction_duration": entry.get_prediction_duration(),
                                    "probability": entry.get_probability(),
                                },
                            )
                    case InterruptionWSErrorMessage():
                        raise APIStatusError(
                            f"LiveKit Adaptive Interruption returned error: {msg.code}",
                            body=msg.message,
                            status_code=msg.code,
                        )
                    case _:
                        logger.warning(
                            "received unexpected message from LiveKit Adaptive Interruption: %s",
                            data,
                        )

        ws: aiohttp.ClientWebSocketResponse | None = None

        while True:
            data_ch = aio.Chan[npt.NDArray[np.int16]]()
            try:
                closing_ws = False
                ws = await self._connect_ws()
                tasks = [
                    asyncio.create_task(self._forward_data(data_ch)),
                    asyncio.create_task(send_task(ws, data_ch)),
                    asyncio.create_task(recv_task(ws)),
                ]
                tasks_group = asyncio.gather(*tasks)
                wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())

                try:
                    done, _ = await asyncio.wait(
                        (tasks_group, wait_reconnect_task),
                        return_when=asyncio.FIRST_COMPLETED,
                    )

                    for task in done:
                        if task != wait_reconnect_task:
                            task.result()

                    if wait_reconnect_task not in done:
                        break

                    self._reconnect_event.clear()
                finally:
                    closing_ws = True
                    if ws is not None and not ws.closed:
                        await ws.close()
                        ws = None
                    await aio.gracefully_cancel(*tasks, wait_reconnect_task)
                    tasks_group.cancel()
                    try:
                        tasks_group.exception()
                    except asyncio.CancelledError:
                        pass
            finally:
                closing_ws = True
                if ws is not None and not ws.closed:
                    await ws.close()

    async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
        """Connect to the LiveKit Adaptive Interruption WebSocket."""
        settings = InterruptionWSSessionCreateSettings(
            sample_rate=self._opts.sample_rate,
            num_channels=1,
            threshold=self._opts.threshold,
            min_frames=self._model._opts.min_frames,
            encoding="s16le",
        )

        base_url = self._opts.base_url
        if base_url.startswith(("http://", "https://")):
            base_url = base_url.replace("http", "ws", 1)
        headers = {
            **get_inference_headers(),
            "Authorization": f"Bearer {create_access_token(self._opts.api_key, self._opts.api_secret)}",
        }
        try:
            ws = await asyncio.wait_for(
                self._session.ws_connect(f"{base_url}/bargein", headers=headers),
                self._conn_options.timeout,
            )
        except (
            aiohttp.ClientConnectorError,
            asyncio.TimeoutError,
            aiohttp.ClientResponseError,
        ) as e:
            if isinstance(e, aiohttp.ClientResponseError) and e.status == 429:
                raise APIStatusError(
                    "LiveKit Adaptive Interruption quota exceeded",
                    status_code=e.status,
                    retryable=False,
                ) from e
            elif isinstance(e, asyncio.TimeoutError):
                raise APIConnectionError(
                    "failed to connect to LiveKit Adaptive Interruption: timeout",
                    retryable=False,
                ) from e
            raise APIConnectionError("failed to connect to LiveKit Adaptive Interruption") from e

        try:
            msg = InterruptionWSSessionCreateMessage(
                type=InterruptionWSMessageType.SESSION_CREATE,
                settings=settings,
            )
            await ws.send_str(msg.model_dump_json())
        except Exception as e:
            await ws.close()
            raise APIConnectionError(
                "failed to send session.create message to LiveKit Adaptive Interruption"
            ) from e

        return ws

Helper class that provides a standard way to create an ABC using inheritance.

Ancestors

Methods

def update_options(self,
*,
threshold: NotGivenOr[float] = NOT_GIVEN,
min_interruption_duration: NotGivenOr[float] = NOT_GIVEN) ‑> None
Expand source code
def update_options(
    self,
    *,
    threshold: NotGivenOr[float] = NOT_GIVEN,
    min_interruption_duration: NotGivenOr[float] = NOT_GIVEN,
) -> None:
    if is_given(threshold):
        self._opts.threshold = threshold
    if is_given(min_interruption_duration):
        self._opts.min_frames = math.ceil(min_interruption_duration * _FRAMES_PER_SECOND)
    self._reconnect_event.set()

Inherited members

class OverlappingSpeechEvent (**data: Any)
Expand source code
class OverlappingSpeechEvent(BaseModel):
    """Represents an overlapping speech event detected during agent speech."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    type: Literal["overlapping_speech"] = "overlapping_speech"

    created_at: float = Field(default_factory=time.time)
    """Timestamp (in seconds) when the event was emitted."""

    detected_at: float = Field(default_factory=time.time)
    """Timestamp (in seconds) when the overlap was detected."""

    is_interruption: bool = False
    """Whether interruption is detected."""

    total_duration: float = 0.0
    """RTT (Round Trip Time) time taken to perform the inference, in seconds."""

    prediction_duration: float = 0.0
    """Time taken to perform the inference from the model side, in seconds."""

    detection_delay: float = 0.0
    """Total time from the onset of the speech to the final prediction, in seconds."""

    overlap_started_at: float | None = None
    """Timestamp (in seconds) when the overlap speech started. Useful for emitting held transcripts."""

    speech_input: npt.NDArray[np.int16] | None = None
    """The audio input that was used for the inference."""

    probabilities: npt.NDArray[np.float32] | None = None
    """The raw probabilities for the interruption detection."""

    probability: float = 0.0
    """The conservative estimated probability of the interruption event."""

    num_requests: int = 0
    """Number of requests sent for this event."""

    @model_serializer(mode="wrap")
    def serialize_model(self, handler: SerializerFunctionWrapHandler) -> Any:
        # remove numpy arrays from the model dump
        copy = self.model_copy(deep=True)
        data = copy.speech_input, copy.probabilities
        copy.speech_input, copy.probabilities = None, None
        try:
            serialized = handler(copy)
        finally:
            copy.speech_input, copy.probabilities = data
        return serialized

    @classmethod
    def from_cache_entry(
        cls,
        *,
        entry: InterruptionCacheEntry,
        is_interruption: bool,
        started_at: float | None = None,
        ended_at: float | None = None,
    ) -> OverlappingSpeechEvent:
        """Initialize the event from a cache entry.

        Args:
            entry: The cache entry to initialize the event from.
            is_interruption: Whether the interruption is detected.
            started_at: The timestamp when the overlap speech started.
            ended_at: The timestamp when the overlap speech ended.

        Returns:
            The initialized event.
        """
        return cls(
            type="overlapping_speech",
            detected_at=ended_at or time.time(),
            is_interruption=is_interruption,
            overlap_started_at=started_at,
            speech_input=entry.speech_input,
            probabilities=entry.probabilities,
            total_duration=entry.get_total_duration(),
            detection_delay=entry.get_detection_delay(),
            prediction_duration=entry.get_prediction_duration(),
            probability=entry.get_probability(),
        )

Represents an overlapping speech event detected during agent speech.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Ancestors

  • pydantic.main.BaseModel

Class variables

var created_at : float

Timestamp (in seconds) when the event was emitted.

var detected_at : float

Timestamp (in seconds) when the overlap was detected.

var detection_delay : float

Total time from the onset of the speech to the final prediction, in seconds.

var is_interruption : bool

Whether interruption is detected.

var model_config
var num_requests : int

Number of requests sent for this event.

var overlap_started_at : float | None

Timestamp (in seconds) when the overlap speech started. Useful for emitting held transcripts.

var prediction_duration : float

Time taken to perform the inference from the model side, in seconds.

var probabilities : numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.float32]] | None

The raw probabilities for the interruption detection.

var probability : float

The conservative estimated probability of the interruption event.

var speech_input : numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.int16]] | None

The audio input that was used for the inference.

var total_duration : float

RTT (Round Trip Time) time taken to perform the inference, in seconds.

var type : Literal['overlapping_speech']

Static methods

def from_cache_entry(*,
entry: InterruptionCacheEntry,
is_interruption: bool,
started_at: float | None = None,
ended_at: float | None = None) ‑> OverlappingSpeechEvent

Initialize the event from a cache entry.

Args

entry
The cache entry to initialize the event from.
is_interruption
Whether the interruption is detected.
started_at
The timestamp when the overlap speech started.
ended_at
The timestamp when the overlap speech ended.

Returns

The initialized event.

Methods

def serialize_model(self, handler: SerializerFunctionWrapHandler) ‑> Any
Expand source code
@model_serializer(mode="wrap")
def serialize_model(self, handler: SerializerFunctionWrapHandler) -> Any:
    # remove numpy arrays from the model dump
    copy = self.model_copy(deep=True)
    data = copy.speech_input, copy.probabilities
    copy.speech_input, copy.probabilities = None, None
    try:
        serialized = handler(copy)
    finally:
        copy.speech_input, copy.probabilities = data
    return serialized