Module livekit.plugins.baseten

Classes

class LLM (*,
model: str | LLMModels = 'meta-llama/Llama-4-Maverick-17B-128E-Instruct',
api_key: NotGivenOr[str] = NOT_GIVEN,
user: NotGivenOr[str] = NOT_GIVEN,
safety_identifier: NotGivenOr[str] = NOT_GIVEN,
prompt_cache_key: NotGivenOr[str] = NOT_GIVEN,
temperature: NotGivenOr[float] = NOT_GIVEN,
top_p: NotGivenOr[float] = NOT_GIVEN,
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
reasoning_effort: NotGivenOr[ReasoningEffort] = NOT_GIVEN,
base_url: NotGivenOr[str] = 'https://inference.baseten.co/v1',
client: openai.AsyncClient | None = None,
timeout: httpx.Timeout | None = None)
Expand source code
class LLM(OpenAILLM):
    def __init__(
        self,
        *,
        model: str | LLMModels = "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
        api_key: NotGivenOr[str] = NOT_GIVEN,
        user: NotGivenOr[str] = NOT_GIVEN,
        safety_identifier: NotGivenOr[str] = NOT_GIVEN,
        prompt_cache_key: NotGivenOr[str] = NOT_GIVEN,
        temperature: NotGivenOr[float] = NOT_GIVEN,
        top_p: NotGivenOr[float] = NOT_GIVEN,
        parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
        tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
        reasoning_effort: NotGivenOr[ReasoningEffort] = NOT_GIVEN,
        base_url: NotGivenOr[str] = "https://inference.baseten.co/v1",
        client: openai.AsyncClient | None = None,
        timeout: httpx.Timeout | None = None,
    ):
        """
        Create a new instance of Baseten LLM.

        ``api_key`` must be set to your Baseten API key, either using the argument or by setting
        the ``BASETEN_API_KEY`` environmental variable.
        """
        api_key = api_key if is_given(api_key) else os.environ.get("BASETEN_API_KEY", "")
        if not api_key:
            raise ValueError(
                "BASETEN_API_KEY is required, either as argument or set BASETEN_API_KEY environmental variable"  # noqa: E501
            )

        if not is_given(reasoning_effort):
            if model == "openai/gpt-oss-120b":
                reasoning_effort = "low"

        super().__init__(
            model=model,
            api_key=api_key,
            base_url=base_url,
            client=client,
            user=user,
            safety_identifier=safety_identifier,
            prompt_cache_key=prompt_cache_key,
            temperature=temperature,
            top_p=top_p,
            parallel_tool_calls=parallel_tool_calls,
            tool_choice=tool_choice,
            timeout=timeout,
            reasoning_effort=reasoning_effort,
        )

    @property
    def model(self) -> str:
        return self._opts.model

    @property
    def provider(self) -> str:
        return "Baseten"

Helper class that provides a standard way to create an ABC using inheritance.

Create a new instance of Baseten LLM.

api_key must be set to your Baseten API key, either using the argument or by setting the BASETEN_API_KEY environmental variable.

Ancestors

  • livekit.plugins.openai.llm.LLM
  • livekit.agents.llm.llm.LLM
  • abc.ABC
  • EventEmitter
  • typing.Generic

Instance variables

prop model : str
Expand source code
@property
def model(self) -> str:
    return self._opts.model

Get the model name/identifier for this LLM instance.

Returns

The model name if available, "unknown" otherwise.

Note

Plugins should override this property to provide their model information.

prop provider : str
Expand source code
@property
def provider(self) -> str:
    return "Baseten"

Get the provider name/identifier for this LLM instance.

Returns

The provider name if available, "unknown" otherwise.

Note

Plugins should override this property to provide their provider information.

Inherited members

class STT (*,
api_key: str | None = None,
model_endpoint: str | None = None,
model_id: str | None = None,
chain_id: str | None = None,
sample_rate: int = 16000,
encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
buffer_size_seconds: float = 0.032,
language: str = 'en',
enable_partial_transcripts: bool = True,
partial_transcript_interval_s: float = 1.0,
final_transcript_max_duration_s: int = 30,
show_word_timestamps: bool = True,
vad_threshold: float = 0.5,
vad_min_silence_duration_ms: int = 300,
vad_speech_pad_ms: int = 30,
http_session: aiohttp.ClientSession | None = None)
Expand source code
class STT(stt.STT):
    _TRUSS_URL_TEMPLATE = "wss://model-{model_id}.api.baseten.co/environments/production/websocket"
    _CHAIN_URL_TEMPLATE = "wss://chain-{chain_id}.api.baseten.co/environments/production/websocket"

    def __init__(
        self,
        *,
        api_key: str | None = None,
        model_endpoint: str | None = None,
        model_id: str | None = None,
        chain_id: str | None = None,
        sample_rate: int = 16000,
        encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
        buffer_size_seconds: float = 0.032,
        language: str = "en",
        enable_partial_transcripts: bool = True,
        partial_transcript_interval_s: float = 1.0,
        final_transcript_max_duration_s: int = 30,
        show_word_timestamps: bool = True,
        vad_threshold: float = 0.5,
        vad_min_silence_duration_ms: int = 300,
        vad_speech_pad_ms: int = 30,
        http_session: aiohttp.ClientSession | None = None,
    ):
        """Baseten Speech-to-Text provider.

        Connects to a Baseten Whisper Streaming WebSocket model for real-time
        transcription.  Works with both **truss** and **chain** deployments.

        There are three ways to specify the endpoint (in priority order):

        1. ``model_endpoint`` – pass the full WebSocket URL directly.
        2. ``model_id`` – auto-constructs a **truss** endpoint URL::

               wss://model-{model_id}.api.baseten.co/environments/production/websocket

        3. ``chain_id`` – auto-constructs a **chain** endpoint URL::

               wss://chain-{chain_id}.api.baseten.co/environments/production/websocket

        If none of the above are provided, the ``BASETEN_MODEL_ENDPOINT`` environment
        variable is used as a fallback.

        Args:
            api_key: Baseten API key.  Falls back to the ``BASETEN_API_KEY`` env var.
            model_endpoint: Full WebSocket URL of the deployed model.  Takes
                priority over ``model_id`` and ``chain_id``.
            model_id: Baseten **truss** model ID.  The plugin builds the endpoint
                URL automatically.  Ignored when ``model_endpoint`` is given.
            chain_id: Baseten **chain** ID.  The plugin builds the endpoint URL
                automatically.  Ignored when ``model_endpoint`` is given.
            sample_rate: Audio sample rate in Hz (default ``16000``).
            encoding: Audio encoding – ``pcm_s16le`` (default) or ``pcm_mulaw``.
            buffer_size_seconds: Audio buffer size in seconds.
            language: BCP-47 language code (default ``en``).  Use ``auto`` for
                automatic language detection.
            enable_partial_transcripts: Emit interim transcripts while the speaker
                is still talking.  Defaults to ``True``.
            partial_transcript_interval_s: Interval (seconds) between partial
                transcript updates.
            final_transcript_max_duration_s: Maximum seconds of audio before the
                server forces a final transcript.
            show_word_timestamps: Include word-level timestamps in results.
            vad_threshold: Server-side VAD threshold (0.0–1.0).
            vad_min_silence_duration_ms: Minimum silence (ms) to end an utterance.
            vad_speech_pad_ms: Padding (ms) around detected speech.
            http_session: Optional :class:`aiohttp.ClientSession` to reuse.
        """
        super().__init__(
            capabilities=stt.STTCapabilities(
                streaming=True,
                interim_results=True,
                aligned_transcript="word",
                offline_recognize=False,
            ),
        )

        api_key = api_key or os.environ.get("BASETEN_API_KEY")

        if not api_key:
            raise ValueError(
                "Baseten API key is required. "
                "Pass one in via the `api_key` parameter, "
                "or set it as the `BASETEN_API_KEY` environment variable"
            )

        self._api_key = api_key

        # Resolve the WebSocket endpoint URL.
        # Priority: model_endpoint > model_id > chain_id > env var
        endpoint: str | None = None
        if model_endpoint:
            endpoint = model_endpoint
        elif model_id:
            endpoint = self._TRUSS_URL_TEMPLATE.format(model_id=model_id)
        elif chain_id:
            endpoint = self._CHAIN_URL_TEMPLATE.format(chain_id=chain_id)
        else:
            endpoint = os.environ.get("BASETEN_MODEL_ENDPOINT")

        if not endpoint:
            raise ValueError(
                "A Baseten endpoint is required.  Provide one of: "
                "model_endpoint, model_id, or chain_id.  "
                "Alternatively, set the BASETEN_MODEL_ENDPOINT environment variable."
            )

        self._model_endpoint = endpoint

        self._opts = STTOptions(
            sample_rate=sample_rate,
            buffer_size_seconds=buffer_size_seconds,
            language=LanguageCode(language),
            enable_partial_transcripts=enable_partial_transcripts,
            partial_transcript_interval_s=partial_transcript_interval_s,
            final_transcript_max_duration_s=final_transcript_max_duration_s,
            show_word_timestamps=show_word_timestamps,
            vad_threshold=vad_threshold,
            vad_min_silence_duration_ms=vad_min_silence_duration_ms,
            vad_speech_pad_ms=vad_speech_pad_ms,
        )

        if is_given(encoding):
            self._opts.encoding = encoding

        self._session = http_session
        self._streams = weakref.WeakSet[SpeechStream]()

    @property
    def model(self) -> str:
        return "unknown"

    @property
    def provider(self) -> str:
        return "Baseten"

    @property
    def session(self) -> aiohttp.ClientSession:
        if not self._session:
            self._session = utils.http_context.http_session()
        return self._session

    async def _recognize_impl(
        self,
        buffer: AudioBuffer,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        conn_options: APIConnectOptions,
    ) -> stt.SpeechEvent:
        raise NotImplementedError("Not implemented")

    def stream(
        self,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
    ) -> SpeechStream:
        config = dataclasses.replace(self._opts)
        stream = SpeechStream(
            stt=self,
            conn_options=conn_options,
            opts=config,
            api_key=self._api_key,
            model_endpoint=self._model_endpoint,
            http_session=self.session,
        )
        self._streams.add(stream)
        return stream

    def update_options(
        self,
        *,
        vad_threshold: NotGivenOr[float] = NOT_GIVEN,
        vad_min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN,
        vad_speech_pad_ms: NotGivenOr[int] = NOT_GIVEN,
        language: NotGivenOr[str] = NOT_GIVEN,
        buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN,
    ) -> None:
        if is_given(vad_threshold):
            self._opts.vad_threshold = vad_threshold
        if is_given(vad_min_silence_duration_ms):
            self._opts.vad_min_silence_duration_ms = vad_min_silence_duration_ms
        if is_given(vad_speech_pad_ms):
            self._opts.vad_speech_pad_ms = vad_speech_pad_ms
        if is_given(language):
            self._opts.language = LanguageCode(language)
        if is_given(buffer_size_seconds):
            self._opts.buffer_size_seconds = buffer_size_seconds

        for stream in self._streams:
            stream.update_options(
                vad_threshold=vad_threshold,
                vad_min_silence_duration_ms=vad_min_silence_duration_ms,
                vad_speech_pad_ms=vad_speech_pad_ms,
                language=language,
                buffer_size_seconds=buffer_size_seconds,
            )

Helper class that provides a standard way to create an ABC using inheritance.

Baseten Speech-to-Text provider.

Connects to a Baseten Whisper Streaming WebSocket model for real-time transcription. Works with both truss and chain deployments.

There are three ways to specify the endpoint (in priority order):

  1. model_endpoint – pass the full WebSocket URL directly.
  2. model_id – auto-constructs a truss endpoint URL::

    wss://model-{model_id}.api.baseten.co/environments/production/websocket

  3. chain_id – auto-constructs a chain endpoint URL::

    wss://chain-{chain_id}.api.baseten.co/environments/production/websocket

If none of the above are provided, the BASETEN_MODEL_ENDPOINT environment variable is used as a fallback.

Args

api_key
Baseten API key. Falls back to the BASETEN_API_KEY env var.
model_endpoint
Full WebSocket URL of the deployed model. Takes priority over model_id and chain_id.
model_id
Baseten truss model ID. The plugin builds the endpoint URL automatically. Ignored when model_endpoint is given.
chain_id
Baseten chain ID. The plugin builds the endpoint URL automatically. Ignored when model_endpoint is given.
sample_rate
Audio sample rate in Hz (default 16000).
encoding
Audio encoding – pcm_s16le (default) or pcm_mulaw.
buffer_size_seconds
Audio buffer size in seconds.
language
BCP-47 language code (default en). Use auto for automatic language detection.
enable_partial_transcripts
Emit interim transcripts while the speaker is still talking. Defaults to True.
partial_transcript_interval_s
Interval (seconds) between partial transcript updates.
final_transcript_max_duration_s
Maximum seconds of audio before the server forces a final transcript.
show_word_timestamps
Include word-level timestamps in results.
vad_threshold
Server-side VAD threshold (0.0–1.0).
vad_min_silence_duration_ms
Minimum silence (ms) to end an utterance.
vad_speech_pad_ms
Padding (ms) around detected speech.
http_session
Optional :class:aiohttp.ClientSession to reuse.

Ancestors

  • livekit.agents.stt.stt.STT
  • abc.ABC
  • EventEmitter
  • typing.Generic

Instance variables

prop model : str
Expand source code
@property
def model(self) -> str:
    return "unknown"

Get the model name/identifier for this STT instance.

Returns

The model name if available, "unknown" otherwise.

Note

Plugins should override this property to provide their model information.

prop provider : str
Expand source code
@property
def provider(self) -> str:
    return "Baseten"

Get the provider name/identifier for this STT instance.

Returns

The provider name if available, "unknown" otherwise.

Note

Plugins should override this property to provide their provider information.

prop session : aiohttp.ClientSession
Expand source code
@property
def session(self) -> aiohttp.ClientSession:
    if not self._session:
        self._session = utils.http_context.http_session()
    return self._session

Methods

def stream(self,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.plugins.baseten.stt.SpeechStream
Expand source code
def stream(
    self,
    *,
    language: NotGivenOr[str] = NOT_GIVEN,
    conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> SpeechStream:
    config = dataclasses.replace(self._opts)
    stream = SpeechStream(
        stt=self,
        conn_options=conn_options,
        opts=config,
        api_key=self._api_key,
        model_endpoint=self._model_endpoint,
        http_session=self.session,
    )
    self._streams.add(stream)
    return stream
def update_options(self,
*,
vad_threshold: NotGivenOr[float] = NOT_GIVEN,
vad_min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN,
vad_speech_pad_ms: NotGivenOr[int] = NOT_GIVEN,
language: NotGivenOr[str] = NOT_GIVEN,
buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN) ‑> None
Expand source code
def update_options(
    self,
    *,
    vad_threshold: NotGivenOr[float] = NOT_GIVEN,
    vad_min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN,
    vad_speech_pad_ms: NotGivenOr[int] = NOT_GIVEN,
    language: NotGivenOr[str] = NOT_GIVEN,
    buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN,
) -> None:
    if is_given(vad_threshold):
        self._opts.vad_threshold = vad_threshold
    if is_given(vad_min_silence_duration_ms):
        self._opts.vad_min_silence_duration_ms = vad_min_silence_duration_ms
    if is_given(vad_speech_pad_ms):
        self._opts.vad_speech_pad_ms = vad_speech_pad_ms
    if is_given(language):
        self._opts.language = LanguageCode(language)
    if is_given(buffer_size_seconds):
        self._opts.buffer_size_seconds = buffer_size_seconds

    for stream in self._streams:
        stream.update_options(
            vad_threshold=vad_threshold,
            vad_min_silence_duration_ms=vad_min_silence_duration_ms,
            vad_speech_pad_ms=vad_speech_pad_ms,
            language=language,
            buffer_size_seconds=buffer_size_seconds,
        )

Inherited members

class SpeechStream (*,
stt: STT,
opts: STTOptions,
conn_options: APIConnectOptions,
api_key: str,
model_endpoint: str,
http_session: aiohttp.ClientSession)
Expand source code
class SpeechStream(stt.SpeechStream):
    """A streaming speech-to-text session connected to Baseten via WebSocket."""

    # Used to close websocket
    _CLOSE_MSG: str = json.dumps({"terminate_session": True})

    def __init__(
        self,
        *,
        stt: STT,
        opts: STTOptions,
        conn_options: APIConnectOptions,
        api_key: str,
        model_endpoint: str,
        http_session: aiohttp.ClientSession,
    ) -> None:
        super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)

        self._opts = opts
        self._api_key = api_key
        self._model_endpoint = model_endpoint
        self._session = http_session
        self._speech_duration: float = 0

        # keep a list of final transcripts to combine them inside the END_OF_SPEECH event
        self._final_events: list[SpeechEvent] = []
        self._reconnect_event = asyncio.Event()

    def update_options(
        self,
        *,
        vad_threshold: NotGivenOr[float] = NOT_GIVEN,
        vad_min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN,
        vad_speech_pad_ms: NotGivenOr[int] = NOT_GIVEN,
        language: NotGivenOr[str] = NOT_GIVEN,
        buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN,
    ) -> None:
        if is_given(vad_threshold):
            self._opts.vad_threshold = vad_threshold
        if is_given(vad_min_silence_duration_ms):
            self._opts.vad_min_silence_duration_ms = vad_min_silence_duration_ms
        if is_given(vad_speech_pad_ms):
            self._opts.vad_speech_pad_ms = vad_speech_pad_ms
        if is_given(language):
            self._opts.language = LanguageCode(language)
        if is_given(buffer_size_seconds):
            self._opts.buffer_size_seconds = buffer_size_seconds

        self._reconnect_event.set()

    async def _run(self) -> None:
        """
        Run a single websocket connection to Baseten and make sure to reconnect
        when something went wrong.
        """

        closing_ws = False

        async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
            samples_per_buffer = 512

            audio_bstream = utils.audio.AudioByteStream(
                sample_rate=self._opts.sample_rate,
                num_channels=1,
                samples_per_channel=samples_per_buffer,
            )

            async for data in self._input_ch:
                if isinstance(data, self._FlushSentinel):
                    frames = audio_bstream.flush()
                else:
                    frames = audio_bstream.write(data.data.tobytes())

                for frame in frames:
                    if len(frame.data) % 2 != 0:
                        logger.warning("Frame data size not aligned to float32 (multiple of 4)")

                    int16_array = np.frombuffer(frame.data, dtype=np.int16)
                    await ws.send_bytes(int16_array.tobytes())

        async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
            nonlocal closing_ws
            while True:
                try:
                    msg = await asyncio.wait_for(ws.receive(), timeout=5)
                except asyncio.TimeoutError:
                    if closing_ws:
                        break
                    continue

                if msg.type in (
                    aiohttp.WSMsgType.CLOSED,
                    aiohttp.WSMsgType.CLOSE,
                    aiohttp.WSMsgType.CLOSING,
                ):
                    if closing_ws:
                        return
                    raise APIStatusError(
                        "Baseten connection closed unexpectedly",
                        status_code=ws.close_code or -1,
                        body=f"{msg.data=} {msg.extra=}",
                    )

                if msg.type != aiohttp.WSMsgType.TEXT:
                    logger.error("Unexpected Baseten message type: %s", msg.type)
                    continue

                try:
                    data = json.loads(msg.data)

                    # Skip non-transcription messages (e.g. error, status)
                    msg_type = data.get("type")
                    if msg_type and msg_type not in ("transcription",):
                        logger.debug("Ignoring message type: %s", msg_type)
                        continue

                    is_final = data.get("is_final", True)
                    segments = data.get("segments", [])

                    # Build transcript text: prefer top-level "transcript" if present,
                    # otherwise concatenate segment texts (Baseten standard format).
                    text = (
                        data.get("transcript")
                        or " ".join(seg.get("text", "") for seg in segments).strip()
                    )

                    confidence = data.get("confidence", 0.0)

                    # Build timed words – prefer word-level timestamps when available,
                    # fall back to segment-level timing.
                    timed_words: list[TimedString] = []
                    for segment in segments:
                        word_timestamps = segment.get("word_timestamps", [])
                        if word_timestamps:
                            for w in word_timestamps:
                                timed_words.append(
                                    TimedString(
                                        text=w.get("word", ""),
                                        start_time=(
                                            w.get("start_time", 0.0) + self.start_time_offset
                                        ),
                                        end_time=(w.get("end_time", 0.0) + self.start_time_offset),
                                        start_time_offset=self.start_time_offset,
                                    )
                                )
                        else:
                            timed_words.append(
                                TimedString(
                                    text=segment.get("text", ""),
                                    start_time=(
                                        segment.get("start_time", 0.0) + self.start_time_offset
                                    ),
                                    end_time=(
                                        segment.get("end_time", 0.0) + self.start_time_offset
                                    ),
                                    start_time_offset=self.start_time_offset,
                                )
                            )

                    start_time = (
                        segments[0].get("start_time", 0.0) if segments else 0.0
                    ) + self.start_time_offset
                    end_time = (
                        segments[-1].get("end_time", 0.0) if segments else 0.0
                    ) + self.start_time_offset

                    if not is_final:
                        if text:
                            event = stt.SpeechEvent(
                                type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
                                alternatives=[
                                    stt.SpeechData(
                                        language=LanguageCode(""),
                                        text=text,
                                        confidence=confidence,
                                        start_time=start_time,
                                        end_time=end_time,
                                        words=timed_words,
                                    )
                                ],
                            )
                            self._event_ch.send_nowait(event)

                    elif is_final:
                        language = LanguageCode(data.get("language_code", self._opts.language))

                        if text:
                            event = stt.SpeechEvent(
                                type=stt.SpeechEventType.FINAL_TRANSCRIPT,
                                alternatives=[
                                    stt.SpeechData(
                                        language=language,
                                        text=text,
                                        confidence=confidence,
                                        start_time=start_time,
                                        end_time=end_time,
                                        words=timed_words,
                                    )
                                ],
                            )
                            self._final_events.append(event)
                            self._event_ch.send_nowait(event)

                except Exception:
                    logger.exception("Failed to process message from Baseten")

        ws: aiohttp.ClientWebSocketResponse | None = None

        while True:
            try:
                ws = await self._connect_ws()
                tasks = [
                    asyncio.create_task(send_task(ws)),
                    asyncio.create_task(recv_task(ws)),
                ]
                wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())

                try:
                    done, _ = await asyncio.wait(
                        (asyncio.gather(*tasks), wait_reconnect_task),
                        return_when=asyncio.FIRST_COMPLETED,
                    )
                    for task in done:
                        if task != wait_reconnect_task:
                            task.result()

                    if wait_reconnect_task not in done:
                        break

                    self._reconnect_event.clear()
                finally:
                    await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task)
            finally:
                if ws is not None:
                    await ws.close()

    async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
        """Open a WebSocket and send the ``StreamingWhisperInput`` metadata message.

        The metadata schema must match the Baseten server's ``StreamingWhisperInput``
        Pydantic model exactly (which uses ``extra="forbid"``).  Field names are:

        - ``whisper_params``  – Whisper model parameters (language, word timestamps, …)
        - ``streaming_params`` – encoding, sample rate, partial transcript settings
        - ``streaming_vad_config`` – server-side Silero VAD configuration
        """
        headers = {
            "Authorization": f"Api-Key {self._api_key}",
        }

        ws = await self._session.ws_connect(self._model_endpoint, headers=headers, ssl=ssl_context)

        # Build metadata matching Baseten's StreamingWhisperInput schema.
        # See: https://docs.baseten.co/reference/inference-api/predict-endpoints/streaming-transcription-api
        metadata = {
            "whisper_params": {
                "audio_language": self._opts.language,
                "show_word_timestamps": self._opts.show_word_timestamps,
            },
            "streaming_params": {
                "encoding": self._opts.encoding,
                "sample_rate": self._opts.sample_rate,
                "enable_partial_transcripts": self._opts.enable_partial_transcripts,
                "partial_transcript_interval_s": self._opts.partial_transcript_interval_s,
                "final_transcript_max_duration_s": self._opts.final_transcript_max_duration_s,
            },
            "streaming_vad_config": {
                "threshold": self._opts.vad_threshold,
                "min_silence_duration_ms": self._opts.vad_min_silence_duration_ms,
                "speech_pad_ms": self._opts.vad_speech_pad_ms,
            },
        }

        await ws.send_str(json.dumps(metadata))
        return ws

A streaming speech-to-text session connected to Baseten via WebSocket.

Args: sample_rate : int or None, optional The desired sample rate for the audio input. If specified, the audio input will be automatically resampled to match the given sample rate before being processed for Speech-to-Text. If not provided (None), the input will retain its original sample rate.

Ancestors

  • livekit.agents.stt.stt.RecognizeStream
  • abc.ABC

Methods

def update_options(self,
*,
vad_threshold: NotGivenOr[float] = NOT_GIVEN,
vad_min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN,
vad_speech_pad_ms: NotGivenOr[int] = NOT_GIVEN,
language: NotGivenOr[str] = NOT_GIVEN,
buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN) ‑> None
Expand source code
def update_options(
    self,
    *,
    vad_threshold: NotGivenOr[float] = NOT_GIVEN,
    vad_min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN,
    vad_speech_pad_ms: NotGivenOr[int] = NOT_GIVEN,
    language: NotGivenOr[str] = NOT_GIVEN,
    buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN,
) -> None:
    if is_given(vad_threshold):
        self._opts.vad_threshold = vad_threshold
    if is_given(vad_min_silence_duration_ms):
        self._opts.vad_min_silence_duration_ms = vad_min_silence_duration_ms
    if is_given(vad_speech_pad_ms):
        self._opts.vad_speech_pad_ms = vad_speech_pad_ms
    if is_given(language):
        self._opts.language = LanguageCode(language)
    if is_given(buffer_size_seconds):
        self._opts.buffer_size_seconds = buffer_size_seconds

    self._reconnect_event.set()
class SynthesizeStream (*,
tts: TTS,
conn_options: APIConnectOptions)
Expand source code
class SynthesizeStream(tts.SynthesizeStream):
    def __init__(
        self,
        *,
        tts: TTS,
        conn_options: APIConnectOptions,
    ) -> None:
        super().__init__(tts=tts, conn_options=conn_options)
        self._tts: TTS = tts
        self._opts = replace(tts._opts)

    async def _run(self, output_emitter: tts.AudioEmitter) -> None:
        request_id = utils.shortuuid()
        output_emitter.initialize(
            request_id=request_id,
            sample_rate=24000,
            num_channels=1,
            mime_type="audio/pcm",
            stream=True,
        )

        async def _send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
            async for data in self._input_ch:
                if isinstance(data, self._FlushSentinel):
                    continue
                self._mark_started()
                await ws.send_str(data)
            await ws.send_str(_END_SENTINEL)

        async def _recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
            output_emitter.start_segment(segment_id=request_id)
            async for msg in ws:
                if msg.type == aiohttp.WSMsgType.BINARY:
                    output_emitter.push(msg.data)
                elif msg.type in (
                    aiohttp.WSMsgType.CLOSE,
                    aiohttp.WSMsgType.CLOSED,
                    aiohttp.WSMsgType.CLOSING,
                ):
                    break
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    raise APIConnectionError()
            output_emitter.end_input()

        try:
            async with self._tts._ensure_session().ws_connect(
                self._tts._model_endpoint,
                headers={"Authorization": f"Api-Key {self._tts._api_key}"},
                ssl=ssl_context,
            ) as ws:
                await ws.send_json(
                    {
                        "voice": self._opts.voice,
                        "max_tokens": self._opts.max_tokens,
                        "buffer_size": self._opts.buffer_size,
                    }
                )

                tasks = [
                    asyncio.create_task(_send_task(ws)),
                    asyncio.create_task(_recv_task(ws)),
                ]
                try:
                    await asyncio.gather(*tasks)
                finally:
                    await utils.aio.gracefully_cancel(*tasks)
        except asyncio.TimeoutError:
            raise APITimeoutError() from None
        except aiohttp.ClientResponseError as e:
            raise APIStatusError(
                message=e.message, status_code=e.status, request_id=None, body=None
            ) from None
        except (APIConnectionError, APIStatusError, APITimeoutError):
            raise
        except Exception as e:
            raise APIConnectionError() from e

Helper class that provides a standard way to create an ABC using inheritance.

Ancestors

  • livekit.agents.tts.tts.SynthesizeStream
  • abc.ABC
class TTS (*,
api_key: str | None = None,
model_endpoint: str | None = None,
voice: str = 'tara',
language: str = 'en',
temperature: float = 0.6,
max_tokens: int = 2000,
buffer_size: int = 10,
http_session: aiohttp.ClientSession | None = None)
Expand source code
class TTS(tts.TTS):
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model_endpoint: str | None = None,
        voice: str = "tara",
        language: str = "en",
        temperature: float = 0.6,
        max_tokens: int = 2000,
        buffer_size: int = 10,
        http_session: aiohttp.ClientSession | None = None,
    ) -> None:
        """
        Initialize the Baseten TTS.

        Args:
            api_key: Baseten API key, or ``BASETEN_API_KEY`` env var.
            model_endpoint: Baseten model endpoint, or ``BASETEN_MODEL_ENDPOINT`` env var.
                Pass a ``wss://`` URL for streaming or an ``https://`` URL for non-streaming.
            voice: Speaker voice. Defaults to "tara".
            language: Language code. Defaults to "en".
            temperature: Sampling temperature. Defaults to 0.6.
            max_tokens: Maximum tokens for generation. Defaults to 2000.
            buffer_size: Number of words per chunk for streaming. Defaults to 10.
            http_session: Optional aiohttp session to reuse.
        """
        api_key = api_key or os.environ.get("BASETEN_API_KEY")

        if not api_key:
            raise ValueError(
                "Baseten API key is required. "
                "Pass one in via the `api_key` parameter, "
                "or set it as the `BASETEN_API_KEY` environment variable"
            )

        model_endpoint = model_endpoint or os.environ.get("BASETEN_MODEL_ENDPOINT")

        if not model_endpoint:
            raise ValueError(
                "model_endpoint is required. "
                "Provide it via the constructor or BASETEN_MODEL_ENDPOINT env var."
            )

        is_ws = model_endpoint.startswith(("wss://", "ws://"))

        super().__init__(
            capabilities=tts.TTSCapabilities(streaming=is_ws),
            sample_rate=24000,
            num_channels=1,
        )

        self._api_key = api_key
        self._model_endpoint = model_endpoint

        self._opts = _TTSOptions(
            voice=voice,
            language=language,
            temperature=temperature,
            max_tokens=max_tokens,
            buffer_size=buffer_size,
        )
        self._session = http_session
        self._streams = weakref.WeakSet[SynthesizeStream]()

    @property
    def model(self) -> str:
        return "unknown"

    @property
    def provider(self) -> str:
        return "Baseten"

    def _ensure_session(self) -> aiohttp.ClientSession:
        if not self._session:
            self._session = utils.http_context.http_session()

        return self._session

    def update_options(
        self,
        *,
        voice: NotGivenOr[str] = NOT_GIVEN,
        language: NotGivenOr[str] = NOT_GIVEN,
        temperature: NotGivenOr[float] = NOT_GIVEN,
        max_tokens: NotGivenOr[int] = NOT_GIVEN,
        buffer_size: NotGivenOr[int] = NOT_GIVEN,
    ) -> None:
        if is_given(voice):
            self._opts.voice = voice
        if is_given(language):
            self._opts.language = language
        if is_given(temperature):
            self._opts.temperature = temperature
        if is_given(max_tokens):
            self._opts.max_tokens = max_tokens
        if is_given(buffer_size):
            self._opts.buffer_size = buffer_size

    def synthesize(
        self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
    ) -> ChunkedStream:
        return ChunkedStream(
            tts=self,
            input_text=text,
            conn_options=conn_options,
        )

    def stream(
        self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
    ) -> SynthesizeStream:
        stream = SynthesizeStream(
            tts=self,
            conn_options=conn_options,
        )
        self._streams.add(stream)
        return stream

    async def aclose(self) -> None:
        for stream in list(self._streams):
            await stream.aclose()
        self._streams.clear()

Helper class that provides a standard way to create an ABC using inheritance.

Initialize the Baseten TTS.

Args

api_key
Baseten API key, or BASETEN_API_KEY env var.
model_endpoint
Baseten model endpoint, or BASETEN_MODEL_ENDPOINT env var. Pass a wss:// URL for streaming or an https:// URL for non-streaming.
voice
Speaker voice. Defaults to "tara".
language
Language code. Defaults to "en".
temperature
Sampling temperature. Defaults to 0.6.
max_tokens
Maximum tokens for generation. Defaults to 2000.
buffer_size
Number of words per chunk for streaming. Defaults to 10.
http_session
Optional aiohttp session to reuse.

Ancestors

  • livekit.agents.tts.tts.TTS
  • abc.ABC
  • EventEmitter
  • typing.Generic

Instance variables

prop model : str
Expand source code
@property
def model(self) -> str:
    return "unknown"

Get the model name/identifier for this TTS instance.

Returns

The model name if available, "unknown" otherwise.

Note

Plugins should override this property to provide their model information.

prop provider : str
Expand source code
@property
def provider(self) -> str:
    return "Baseten"

Get the provider name/identifier for this TTS instance.

Returns

The provider name if available, "unknown" otherwise.

Note

Plugins should override this property to provide their provider information.

Methods

async def aclose(self) ‑> None
Expand source code
async def aclose(self) -> None:
    for stream in list(self._streams):
        await stream.aclose()
    self._streams.clear()
def stream(self,
*,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.plugins.baseten.tts.SynthesizeStream
Expand source code
def stream(
    self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
) -> SynthesizeStream:
    stream = SynthesizeStream(
        tts=self,
        conn_options=conn_options,
    )
    self._streams.add(stream)
    return stream
def synthesize(self,
text: str,
*,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.plugins.baseten.tts.ChunkedStream
Expand source code
def synthesize(
    self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
) -> ChunkedStream:
    return ChunkedStream(
        tts=self,
        input_text=text,
        conn_options=conn_options,
    )
def update_options(self,
*,
voice: NotGivenOr[str] = NOT_GIVEN,
language: NotGivenOr[str] = NOT_GIVEN,
temperature: NotGivenOr[float] = NOT_GIVEN,
max_tokens: NotGivenOr[int] = NOT_GIVEN,
buffer_size: NotGivenOr[int] = NOT_GIVEN) ‑> None
Expand source code
def update_options(
    self,
    *,
    voice: NotGivenOr[str] = NOT_GIVEN,
    language: NotGivenOr[str] = NOT_GIVEN,
    temperature: NotGivenOr[float] = NOT_GIVEN,
    max_tokens: NotGivenOr[int] = NOT_GIVEN,
    buffer_size: NotGivenOr[int] = NOT_GIVEN,
) -> None:
    if is_given(voice):
        self._opts.voice = voice
    if is_given(language):
        self._opts.language = language
    if is_given(temperature):
        self._opts.temperature = temperature
    if is_given(max_tokens):
        self._opts.max_tokens = max_tokens
    if is_given(buffer_size):
        self._opts.buffer_size = buffer_size

Inherited members