Module livekit.plugins.baseten
Classes
class LLM (*,
model: str | LLMModels = 'meta-llama/Llama-4-Maverick-17B-128E-Instruct',
api_key: NotGivenOr[str] = NOT_GIVEN,
user: NotGivenOr[str] = NOT_GIVEN,
safety_identifier: NotGivenOr[str] = NOT_GIVEN,
prompt_cache_key: NotGivenOr[str] = NOT_GIVEN,
temperature: NotGivenOr[float] = NOT_GIVEN,
top_p: NotGivenOr[float] = NOT_GIVEN,
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
reasoning_effort: NotGivenOr[ReasoningEffort] = NOT_GIVEN,
base_url: NotGivenOr[str] = 'https://inference.baseten.co/v1',
client: openai.AsyncClient | None = None,
timeout: httpx.Timeout | None = None)-
Expand source code
class LLM(OpenAILLM): def __init__( self, *, model: str | LLMModels = "meta-llama/Llama-4-Maverick-17B-128E-Instruct", api_key: NotGivenOr[str] = NOT_GIVEN, user: NotGivenOr[str] = NOT_GIVEN, safety_identifier: NotGivenOr[str] = NOT_GIVEN, prompt_cache_key: NotGivenOr[str] = NOT_GIVEN, temperature: NotGivenOr[float] = NOT_GIVEN, top_p: NotGivenOr[float] = NOT_GIVEN, parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN, tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN, reasoning_effort: NotGivenOr[ReasoningEffort] = NOT_GIVEN, base_url: NotGivenOr[str] = "https://inference.baseten.co/v1", client: openai.AsyncClient | None = None, timeout: httpx.Timeout | None = None, ): """ Create a new instance of Baseten LLM. ``api_key`` must be set to your Baseten API key, either using the argument or by setting the ``BASETEN_API_KEY`` environmental variable. """ api_key = api_key if is_given(api_key) else os.environ.get("BASETEN_API_KEY", "") if not api_key: raise ValueError( "BASETEN_API_KEY is required, either as argument or set BASETEN_API_KEY environmental variable" # noqa: E501 ) if not is_given(reasoning_effort): if model == "openai/gpt-oss-120b": reasoning_effort = "low" super().__init__( model=model, api_key=api_key, base_url=base_url, client=client, user=user, safety_identifier=safety_identifier, prompt_cache_key=prompt_cache_key, temperature=temperature, top_p=top_p, parallel_tool_calls=parallel_tool_calls, tool_choice=tool_choice, timeout=timeout, reasoning_effort=reasoning_effort, ) @property def model(self) -> str: return self._opts.model @property def provider(self) -> str: return "Baseten"Helper class that provides a standard way to create an ABC using inheritance.
Create a new instance of Baseten LLM.
api_keymust be set to your Baseten API key, either using the argument or by setting theBASETEN_API_KEYenvironmental variable.Ancestors
- livekit.plugins.openai.llm.LLM
- livekit.agents.llm.llm.LLM
- abc.ABC
- EventEmitter
- typing.Generic
Instance variables
prop model : str-
Expand source code
@property def model(self) -> str: return self._opts.modelGet the model name/identifier for this LLM instance.
Returns
The model name if available, "unknown" otherwise.
Note
Plugins should override this property to provide their model information.
prop provider : str-
Expand source code
@property def provider(self) -> str: return "Baseten"Get the provider name/identifier for this LLM instance.
Returns
The provider name if available, "unknown" otherwise.
Note
Plugins should override this property to provide their provider information.
Inherited members
class STT (*,
api_key: str | None = None,
model_endpoint: str | None = None,
model_id: str | None = None,
chain_id: str | None = None,
sample_rate: int = 16000,
encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
buffer_size_seconds: float = 0.032,
language: str = 'en',
enable_partial_transcripts: bool = True,
partial_transcript_interval_s: float = 1.0,
final_transcript_max_duration_s: int = 30,
show_word_timestamps: bool = True,
vad_threshold: float = 0.5,
vad_min_silence_duration_ms: int = 300,
vad_speech_pad_ms: int = 30,
http_session: aiohttp.ClientSession | None = None)-
Expand source code
class STT(stt.STT): _TRUSS_URL_TEMPLATE = "wss://model-{model_id}.api.baseten.co/environments/production/websocket" _CHAIN_URL_TEMPLATE = "wss://chain-{chain_id}.api.baseten.co/environments/production/websocket" def __init__( self, *, api_key: str | None = None, model_endpoint: str | None = None, model_id: str | None = None, chain_id: str | None = None, sample_rate: int = 16000, encoding: NotGivenOr[STTEncoding] = NOT_GIVEN, buffer_size_seconds: float = 0.032, language: str = "en", enable_partial_transcripts: bool = True, partial_transcript_interval_s: float = 1.0, final_transcript_max_duration_s: int = 30, show_word_timestamps: bool = True, vad_threshold: float = 0.5, vad_min_silence_duration_ms: int = 300, vad_speech_pad_ms: int = 30, http_session: aiohttp.ClientSession | None = None, ): """Baseten Speech-to-Text provider. Connects to a Baseten Whisper Streaming WebSocket model for real-time transcription. Works with both **truss** and **chain** deployments. There are three ways to specify the endpoint (in priority order): 1. ``model_endpoint`` – pass the full WebSocket URL directly. 2. ``model_id`` – auto-constructs a **truss** endpoint URL:: wss://model-{model_id}.api.baseten.co/environments/production/websocket 3. ``chain_id`` – auto-constructs a **chain** endpoint URL:: wss://chain-{chain_id}.api.baseten.co/environments/production/websocket If none of the above are provided, the ``BASETEN_MODEL_ENDPOINT`` environment variable is used as a fallback. Args: api_key: Baseten API key. Falls back to the ``BASETEN_API_KEY`` env var. model_endpoint: Full WebSocket URL of the deployed model. Takes priority over ``model_id`` and ``chain_id``. model_id: Baseten **truss** model ID. The plugin builds the endpoint URL automatically. Ignored when ``model_endpoint`` is given. chain_id: Baseten **chain** ID. The plugin builds the endpoint URL automatically. Ignored when ``model_endpoint`` is given. sample_rate: Audio sample rate in Hz (default ``16000``). encoding: Audio encoding – ``pcm_s16le`` (default) or ``pcm_mulaw``. buffer_size_seconds: Audio buffer size in seconds. language: BCP-47 language code (default ``en``). Use ``auto`` for automatic language detection. enable_partial_transcripts: Emit interim transcripts while the speaker is still talking. Defaults to ``True``. partial_transcript_interval_s: Interval (seconds) between partial transcript updates. final_transcript_max_duration_s: Maximum seconds of audio before the server forces a final transcript. show_word_timestamps: Include word-level timestamps in results. vad_threshold: Server-side VAD threshold (0.0–1.0). vad_min_silence_duration_ms: Minimum silence (ms) to end an utterance. vad_speech_pad_ms: Padding (ms) around detected speech. http_session: Optional :class:`aiohttp.ClientSession` to reuse. """ super().__init__( capabilities=stt.STTCapabilities( streaming=True, interim_results=True, aligned_transcript="word", offline_recognize=False, ), ) api_key = api_key or os.environ.get("BASETEN_API_KEY") if not api_key: raise ValueError( "Baseten API key is required. " "Pass one in via the `api_key` parameter, " "or set it as the `BASETEN_API_KEY` environment variable" ) self._api_key = api_key # Resolve the WebSocket endpoint URL. # Priority: model_endpoint > model_id > chain_id > env var endpoint: str | None = None if model_endpoint: endpoint = model_endpoint elif model_id: endpoint = self._TRUSS_URL_TEMPLATE.format(model_id=model_id) elif chain_id: endpoint = self._CHAIN_URL_TEMPLATE.format(chain_id=chain_id) else: endpoint = os.environ.get("BASETEN_MODEL_ENDPOINT") if not endpoint: raise ValueError( "A Baseten endpoint is required. Provide one of: " "model_endpoint, model_id, or chain_id. " "Alternatively, set the BASETEN_MODEL_ENDPOINT environment variable." ) self._model_endpoint = endpoint self._opts = STTOptions( sample_rate=sample_rate, buffer_size_seconds=buffer_size_seconds, language=LanguageCode(language), enable_partial_transcripts=enable_partial_transcripts, partial_transcript_interval_s=partial_transcript_interval_s, final_transcript_max_duration_s=final_transcript_max_duration_s, show_word_timestamps=show_word_timestamps, vad_threshold=vad_threshold, vad_min_silence_duration_ms=vad_min_silence_duration_ms, vad_speech_pad_ms=vad_speech_pad_ms, ) if is_given(encoding): self._opts.encoding = encoding self._session = http_session self._streams = weakref.WeakSet[SpeechStream]() @property def model(self) -> str: return "unknown" @property def provider(self) -> str: return "Baseten" @property def session(self) -> aiohttp.ClientSession: if not self._session: self._session = utils.http_context.http_session() return self._session async def _recognize_impl( self, buffer: AudioBuffer, *, language: NotGivenOr[str] = NOT_GIVEN, conn_options: APIConnectOptions, ) -> stt.SpeechEvent: raise NotImplementedError("Not implemented") def stream( self, *, language: NotGivenOr[str] = NOT_GIVEN, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> SpeechStream: config = dataclasses.replace(self._opts) stream = SpeechStream( stt=self, conn_options=conn_options, opts=config, api_key=self._api_key, model_endpoint=self._model_endpoint, http_session=self.session, ) self._streams.add(stream) return stream def update_options( self, *, vad_threshold: NotGivenOr[float] = NOT_GIVEN, vad_min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN, vad_speech_pad_ms: NotGivenOr[int] = NOT_GIVEN, language: NotGivenOr[str] = NOT_GIVEN, buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(vad_threshold): self._opts.vad_threshold = vad_threshold if is_given(vad_min_silence_duration_ms): self._opts.vad_min_silence_duration_ms = vad_min_silence_duration_ms if is_given(vad_speech_pad_ms): self._opts.vad_speech_pad_ms = vad_speech_pad_ms if is_given(language): self._opts.language = LanguageCode(language) if is_given(buffer_size_seconds): self._opts.buffer_size_seconds = buffer_size_seconds for stream in self._streams: stream.update_options( vad_threshold=vad_threshold, vad_min_silence_duration_ms=vad_min_silence_duration_ms, vad_speech_pad_ms=vad_speech_pad_ms, language=language, buffer_size_seconds=buffer_size_seconds, )Helper class that provides a standard way to create an ABC using inheritance.
Baseten Speech-to-Text provider.
Connects to a Baseten Whisper Streaming WebSocket model for real-time transcription. Works with both truss and chain deployments.
There are three ways to specify the endpoint (in priority order):
model_endpoint– pass the full WebSocket URL directly.-
model_id– auto-constructs a truss endpoint URL::wss://model-{model_id}.api.baseten.co/environments/production/websocket
-
chain_id– auto-constructs a chain endpoint URL::wss://chain-{chain_id}.api.baseten.co/environments/production/websocket
If none of the above are provided, the
BASETEN_MODEL_ENDPOINTenvironment variable is used as a fallback.Args
api_key- Baseten API key.
Falls back to the
BASETEN_API_KEYenv var. model_endpoint- Full WebSocket URL of the deployed model.
Takes
priority over
model_idandchain_id. model_id- Baseten truss model ID.
The plugin builds the endpoint
URL automatically.
Ignored when
model_endpointis given. chain_id- Baseten chain ID.
The plugin builds the endpoint URL
automatically.
Ignored when
model_endpointis given. sample_rate- Audio sample rate in Hz (default
16000). encoding- Audio encoding –
pcm_s16le(default) orpcm_mulaw. buffer_size_seconds- Audio buffer size in seconds.
language- BCP-47 language code (default
en). Useautofor automatic language detection. enable_partial_transcripts- Emit interim transcripts while the speaker
is still talking.
Defaults to
True. partial_transcript_interval_s- Interval (seconds) between partial transcript updates.
final_transcript_max_duration_s- Maximum seconds of audio before the server forces a final transcript.
show_word_timestamps- Include word-level timestamps in results.
vad_threshold- Server-side VAD threshold (0.0–1.0).
vad_min_silence_duration_ms- Minimum silence (ms) to end an utterance.
vad_speech_pad_ms- Padding (ms) around detected speech.
http_session- Optional :class:
aiohttp.ClientSessionto reuse.
Ancestors
- livekit.agents.stt.stt.STT
- abc.ABC
- EventEmitter
- typing.Generic
Instance variables
prop model : str-
Expand source code
@property def model(self) -> str: return "unknown"Get the model name/identifier for this STT instance.
Returns
The model name if available, "unknown" otherwise.
Note
Plugins should override this property to provide their model information.
prop provider : str-
Expand source code
@property def provider(self) -> str: return "Baseten"Get the provider name/identifier for this STT instance.
Returns
The provider name if available, "unknown" otherwise.
Note
Plugins should override this property to provide their provider information.
prop session : aiohttp.ClientSession-
Expand source code
@property def session(self) -> aiohttp.ClientSession: if not self._session: self._session = utils.http_context.http_session() return self._session
Methods
def stream(self,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.plugins.baseten.stt.SpeechStream-
Expand source code
def stream( self, *, language: NotGivenOr[str] = NOT_GIVEN, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> SpeechStream: config = dataclasses.replace(self._opts) stream = SpeechStream( stt=self, conn_options=conn_options, opts=config, api_key=self._api_key, model_endpoint=self._model_endpoint, http_session=self.session, ) self._streams.add(stream) return stream def update_options(self,
*,
vad_threshold: NotGivenOr[float] = NOT_GIVEN,
vad_min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN,
vad_speech_pad_ms: NotGivenOr[int] = NOT_GIVEN,
language: NotGivenOr[str] = NOT_GIVEN,
buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN) ‑> None-
Expand source code
def update_options( self, *, vad_threshold: NotGivenOr[float] = NOT_GIVEN, vad_min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN, vad_speech_pad_ms: NotGivenOr[int] = NOT_GIVEN, language: NotGivenOr[str] = NOT_GIVEN, buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(vad_threshold): self._opts.vad_threshold = vad_threshold if is_given(vad_min_silence_duration_ms): self._opts.vad_min_silence_duration_ms = vad_min_silence_duration_ms if is_given(vad_speech_pad_ms): self._opts.vad_speech_pad_ms = vad_speech_pad_ms if is_given(language): self._opts.language = LanguageCode(language) if is_given(buffer_size_seconds): self._opts.buffer_size_seconds = buffer_size_seconds for stream in self._streams: stream.update_options( vad_threshold=vad_threshold, vad_min_silence_duration_ms=vad_min_silence_duration_ms, vad_speech_pad_ms=vad_speech_pad_ms, language=language, buffer_size_seconds=buffer_size_seconds, )
Inherited members
class SpeechStream (*,
stt: STT,
opts: STTOptions,
conn_options: APIConnectOptions,
api_key: str,
model_endpoint: str,
http_session: aiohttp.ClientSession)-
Expand source code
class SpeechStream(stt.SpeechStream): """A streaming speech-to-text session connected to Baseten via WebSocket.""" # Used to close websocket _CLOSE_MSG: str = json.dumps({"terminate_session": True}) def __init__( self, *, stt: STT, opts: STTOptions, conn_options: APIConnectOptions, api_key: str, model_endpoint: str, http_session: aiohttp.ClientSession, ) -> None: super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate) self._opts = opts self._api_key = api_key self._model_endpoint = model_endpoint self._session = http_session self._speech_duration: float = 0 # keep a list of final transcripts to combine them inside the END_OF_SPEECH event self._final_events: list[SpeechEvent] = [] self._reconnect_event = asyncio.Event() def update_options( self, *, vad_threshold: NotGivenOr[float] = NOT_GIVEN, vad_min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN, vad_speech_pad_ms: NotGivenOr[int] = NOT_GIVEN, language: NotGivenOr[str] = NOT_GIVEN, buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(vad_threshold): self._opts.vad_threshold = vad_threshold if is_given(vad_min_silence_duration_ms): self._opts.vad_min_silence_duration_ms = vad_min_silence_duration_ms if is_given(vad_speech_pad_ms): self._opts.vad_speech_pad_ms = vad_speech_pad_ms if is_given(language): self._opts.language = LanguageCode(language) if is_given(buffer_size_seconds): self._opts.buffer_size_seconds = buffer_size_seconds self._reconnect_event.set() async def _run(self) -> None: """ Run a single websocket connection to Baseten and make sure to reconnect when something went wrong. """ closing_ws = False async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: samples_per_buffer = 512 audio_bstream = utils.audio.AudioByteStream( sample_rate=self._opts.sample_rate, num_channels=1, samples_per_channel=samples_per_buffer, ) async for data in self._input_ch: if isinstance(data, self._FlushSentinel): frames = audio_bstream.flush() else: frames = audio_bstream.write(data.data.tobytes()) for frame in frames: if len(frame.data) % 2 != 0: logger.warning("Frame data size not aligned to float32 (multiple of 4)") int16_array = np.frombuffer(frame.data, dtype=np.int16) await ws.send_bytes(int16_array.tobytes()) async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None: nonlocal closing_ws while True: try: msg = await asyncio.wait_for(ws.receive(), timeout=5) except asyncio.TimeoutError: if closing_ws: break continue if msg.type in ( aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, ): if closing_ws: return raise APIStatusError( "Baseten connection closed unexpectedly", status_code=ws.close_code or -1, body=f"{msg.data=} {msg.extra=}", ) if msg.type != aiohttp.WSMsgType.TEXT: logger.error("Unexpected Baseten message type: %s", msg.type) continue try: data = json.loads(msg.data) # Skip non-transcription messages (e.g. error, status) msg_type = data.get("type") if msg_type and msg_type not in ("transcription",): logger.debug("Ignoring message type: %s", msg_type) continue is_final = data.get("is_final", True) segments = data.get("segments", []) # Build transcript text: prefer top-level "transcript" if present, # otherwise concatenate segment texts (Baseten standard format). text = ( data.get("transcript") or " ".join(seg.get("text", "") for seg in segments).strip() ) confidence = data.get("confidence", 0.0) # Build timed words – prefer word-level timestamps when available, # fall back to segment-level timing. timed_words: list[TimedString] = [] for segment in segments: word_timestamps = segment.get("word_timestamps", []) if word_timestamps: for w in word_timestamps: timed_words.append( TimedString( text=w.get("word", ""), start_time=( w.get("start_time", 0.0) + self.start_time_offset ), end_time=(w.get("end_time", 0.0) + self.start_time_offset), start_time_offset=self.start_time_offset, ) ) else: timed_words.append( TimedString( text=segment.get("text", ""), start_time=( segment.get("start_time", 0.0) + self.start_time_offset ), end_time=( segment.get("end_time", 0.0) + self.start_time_offset ), start_time_offset=self.start_time_offset, ) ) start_time = ( segments[0].get("start_time", 0.0) if segments else 0.0 ) + self.start_time_offset end_time = ( segments[-1].get("end_time", 0.0) if segments else 0.0 ) + self.start_time_offset if not is_final: if text: event = stt.SpeechEvent( type=stt.SpeechEventType.INTERIM_TRANSCRIPT, alternatives=[ stt.SpeechData( language=LanguageCode(""), text=text, confidence=confidence, start_time=start_time, end_time=end_time, words=timed_words, ) ], ) self._event_ch.send_nowait(event) elif is_final: language = LanguageCode(data.get("language_code", self._opts.language)) if text: event = stt.SpeechEvent( type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=[ stt.SpeechData( language=language, text=text, confidence=confidence, start_time=start_time, end_time=end_time, words=timed_words, ) ], ) self._final_events.append(event) self._event_ch.send_nowait(event) except Exception: logger.exception("Failed to process message from Baseten") ws: aiohttp.ClientWebSocketResponse | None = None while True: try: ws = await self._connect_ws() tasks = [ asyncio.create_task(send_task(ws)), asyncio.create_task(recv_task(ws)), ] wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) try: done, _ = await asyncio.wait( (asyncio.gather(*tasks), wait_reconnect_task), return_when=asyncio.FIRST_COMPLETED, ) for task in done: if task != wait_reconnect_task: task.result() if wait_reconnect_task not in done: break self._reconnect_event.clear() finally: await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task) finally: if ws is not None: await ws.close() async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: """Open a WebSocket and send the ``StreamingWhisperInput`` metadata message. The metadata schema must match the Baseten server's ``StreamingWhisperInput`` Pydantic model exactly (which uses ``extra="forbid"``). Field names are: - ``whisper_params`` – Whisper model parameters (language, word timestamps, …) - ``streaming_params`` – encoding, sample rate, partial transcript settings - ``streaming_vad_config`` – server-side Silero VAD configuration """ headers = { "Authorization": f"Api-Key {self._api_key}", } ws = await self._session.ws_connect(self._model_endpoint, headers=headers, ssl=ssl_context) # Build metadata matching Baseten's StreamingWhisperInput schema. # See: https://docs.baseten.co/reference/inference-api/predict-endpoints/streaming-transcription-api metadata = { "whisper_params": { "audio_language": self._opts.language, "show_word_timestamps": self._opts.show_word_timestamps, }, "streaming_params": { "encoding": self._opts.encoding, "sample_rate": self._opts.sample_rate, "enable_partial_transcripts": self._opts.enable_partial_transcripts, "partial_transcript_interval_s": self._opts.partial_transcript_interval_s, "final_transcript_max_duration_s": self._opts.final_transcript_max_duration_s, }, "streaming_vad_config": { "threshold": self._opts.vad_threshold, "min_silence_duration_ms": self._opts.vad_min_silence_duration_ms, "speech_pad_ms": self._opts.vad_speech_pad_ms, }, } await ws.send_str(json.dumps(metadata)) return wsA streaming speech-to-text session connected to Baseten via WebSocket.
Args: sample_rate : int or None, optional The desired sample rate for the audio input. If specified, the audio input will be automatically resampled to match the given sample rate before being processed for Speech-to-Text. If not provided (None), the input will retain its original sample rate.
Ancestors
- livekit.agents.stt.stt.RecognizeStream
- abc.ABC
Methods
def update_options(self,
*,
vad_threshold: NotGivenOr[float] = NOT_GIVEN,
vad_min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN,
vad_speech_pad_ms: NotGivenOr[int] = NOT_GIVEN,
language: NotGivenOr[str] = NOT_GIVEN,
buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN) ‑> None-
Expand source code
def update_options( self, *, vad_threshold: NotGivenOr[float] = NOT_GIVEN, vad_min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN, vad_speech_pad_ms: NotGivenOr[int] = NOT_GIVEN, language: NotGivenOr[str] = NOT_GIVEN, buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(vad_threshold): self._opts.vad_threshold = vad_threshold if is_given(vad_min_silence_duration_ms): self._opts.vad_min_silence_duration_ms = vad_min_silence_duration_ms if is_given(vad_speech_pad_ms): self._opts.vad_speech_pad_ms = vad_speech_pad_ms if is_given(language): self._opts.language = LanguageCode(language) if is_given(buffer_size_seconds): self._opts.buffer_size_seconds = buffer_size_seconds self._reconnect_event.set()
class SynthesizeStream (*,
tts: TTS,
conn_options: APIConnectOptions)-
Expand source code
class SynthesizeStream(tts.SynthesizeStream): def __init__( self, *, tts: TTS, conn_options: APIConnectOptions, ) -> None: super().__init__(tts=tts, conn_options=conn_options) self._tts: TTS = tts self._opts = replace(tts._opts) async def _run(self, output_emitter: tts.AudioEmitter) -> None: request_id = utils.shortuuid() output_emitter.initialize( request_id=request_id, sample_rate=24000, num_channels=1, mime_type="audio/pcm", stream=True, ) async def _send_task(ws: aiohttp.ClientWebSocketResponse) -> None: async for data in self._input_ch: if isinstance(data, self._FlushSentinel): continue self._mark_started() await ws.send_str(data) await ws.send_str(_END_SENTINEL) async def _recv_task(ws: aiohttp.ClientWebSocketResponse) -> None: output_emitter.start_segment(segment_id=request_id) async for msg in ws: if msg.type == aiohttp.WSMsgType.BINARY: output_emitter.push(msg.data) elif msg.type in ( aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, ): break elif msg.type == aiohttp.WSMsgType.ERROR: raise APIConnectionError() output_emitter.end_input() try: async with self._tts._ensure_session().ws_connect( self._tts._model_endpoint, headers={"Authorization": f"Api-Key {self._tts._api_key}"}, ssl=ssl_context, ) as ws: await ws.send_json( { "voice": self._opts.voice, "max_tokens": self._opts.max_tokens, "buffer_size": self._opts.buffer_size, } ) tasks = [ asyncio.create_task(_send_task(ws)), asyncio.create_task(_recv_task(ws)), ] try: await asyncio.gather(*tasks) finally: await utils.aio.gracefully_cancel(*tasks) except asyncio.TimeoutError: raise APITimeoutError() from None except aiohttp.ClientResponseError as e: raise APIStatusError( message=e.message, status_code=e.status, request_id=None, body=None ) from None except (APIConnectionError, APIStatusError, APITimeoutError): raise except Exception as e: raise APIConnectionError() from eHelper class that provides a standard way to create an ABC using inheritance.
Ancestors
- livekit.agents.tts.tts.SynthesizeStream
- abc.ABC
class TTS (*,
api_key: str | None = None,
model_endpoint: str | None = None,
voice: str = 'tara',
language: str = 'en',
temperature: float = 0.6,
max_tokens: int = 2000,
buffer_size: int = 10,
http_session: aiohttp.ClientSession | None = None)-
Expand source code
class TTS(tts.TTS): def __init__( self, *, api_key: str | None = None, model_endpoint: str | None = None, voice: str = "tara", language: str = "en", temperature: float = 0.6, max_tokens: int = 2000, buffer_size: int = 10, http_session: aiohttp.ClientSession | None = None, ) -> None: """ Initialize the Baseten TTS. Args: api_key: Baseten API key, or ``BASETEN_API_KEY`` env var. model_endpoint: Baseten model endpoint, or ``BASETEN_MODEL_ENDPOINT`` env var. Pass a ``wss://`` URL for streaming or an ``https://`` URL for non-streaming. voice: Speaker voice. Defaults to "tara". language: Language code. Defaults to "en". temperature: Sampling temperature. Defaults to 0.6. max_tokens: Maximum tokens for generation. Defaults to 2000. buffer_size: Number of words per chunk for streaming. Defaults to 10. http_session: Optional aiohttp session to reuse. """ api_key = api_key or os.environ.get("BASETEN_API_KEY") if not api_key: raise ValueError( "Baseten API key is required. " "Pass one in via the `api_key` parameter, " "or set it as the `BASETEN_API_KEY` environment variable" ) model_endpoint = model_endpoint or os.environ.get("BASETEN_MODEL_ENDPOINT") if not model_endpoint: raise ValueError( "model_endpoint is required. " "Provide it via the constructor or BASETEN_MODEL_ENDPOINT env var." ) is_ws = model_endpoint.startswith(("wss://", "ws://")) super().__init__( capabilities=tts.TTSCapabilities(streaming=is_ws), sample_rate=24000, num_channels=1, ) self._api_key = api_key self._model_endpoint = model_endpoint self._opts = _TTSOptions( voice=voice, language=language, temperature=temperature, max_tokens=max_tokens, buffer_size=buffer_size, ) self._session = http_session self._streams = weakref.WeakSet[SynthesizeStream]() @property def model(self) -> str: return "unknown" @property def provider(self) -> str: return "Baseten" def _ensure_session(self) -> aiohttp.ClientSession: if not self._session: self._session = utils.http_context.http_session() return self._session def update_options( self, *, voice: NotGivenOr[str] = NOT_GIVEN, language: NotGivenOr[str] = NOT_GIVEN, temperature: NotGivenOr[float] = NOT_GIVEN, max_tokens: NotGivenOr[int] = NOT_GIVEN, buffer_size: NotGivenOr[int] = NOT_GIVEN, ) -> None: if is_given(voice): self._opts.voice = voice if is_given(language): self._opts.language = language if is_given(temperature): self._opts.temperature = temperature if is_given(max_tokens): self._opts.max_tokens = max_tokens if is_given(buffer_size): self._opts.buffer_size = buffer_size def synthesize( self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS ) -> ChunkedStream: return ChunkedStream( tts=self, input_text=text, conn_options=conn_options, ) def stream( self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS ) -> SynthesizeStream: stream = SynthesizeStream( tts=self, conn_options=conn_options, ) self._streams.add(stream) return stream async def aclose(self) -> None: for stream in list(self._streams): await stream.aclose() self._streams.clear()Helper class that provides a standard way to create an ABC using inheritance.
Initialize the Baseten TTS.
Args
api_key- Baseten API key, or
BASETEN_API_KEYenv var. model_endpoint- Baseten model endpoint, or
BASETEN_MODEL_ENDPOINTenv var. Pass awss://URL for streaming or anhttps://URL for non-streaming. voice- Speaker voice. Defaults to "tara".
language- Language code. Defaults to "en".
temperature- Sampling temperature. Defaults to 0.6.
max_tokens- Maximum tokens for generation. Defaults to 2000.
buffer_size- Number of words per chunk for streaming. Defaults to 10.
http_session- Optional aiohttp session to reuse.
Ancestors
- livekit.agents.tts.tts.TTS
- abc.ABC
- EventEmitter
- typing.Generic
Instance variables
prop model : str-
Expand source code
@property def model(self) -> str: return "unknown"Get the model name/identifier for this TTS instance.
Returns
The model name if available, "unknown" otherwise.
Note
Plugins should override this property to provide their model information.
prop provider : str-
Expand source code
@property def provider(self) -> str: return "Baseten"Get the provider name/identifier for this TTS instance.
Returns
The provider name if available, "unknown" otherwise.
Note
Plugins should override this property to provide their provider information.
Methods
async def aclose(self) ‑> None-
Expand source code
async def aclose(self) -> None: for stream in list(self._streams): await stream.aclose() self._streams.clear() def stream(self,
*,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.plugins.baseten.tts.SynthesizeStream-
Expand source code
def stream( self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS ) -> SynthesizeStream: stream = SynthesizeStream( tts=self, conn_options=conn_options, ) self._streams.add(stream) return stream def synthesize(self,
text: str,
*,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.plugins.baseten.tts.ChunkedStream-
Expand source code
def synthesize( self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS ) -> ChunkedStream: return ChunkedStream( tts=self, input_text=text, conn_options=conn_options, ) def update_options(self,
*,
voice: NotGivenOr[str] = NOT_GIVEN,
language: NotGivenOr[str] = NOT_GIVEN,
temperature: NotGivenOr[float] = NOT_GIVEN,
max_tokens: NotGivenOr[int] = NOT_GIVEN,
buffer_size: NotGivenOr[int] = NOT_GIVEN) ‑> None-
Expand source code
def update_options( self, *, voice: NotGivenOr[str] = NOT_GIVEN, language: NotGivenOr[str] = NOT_GIVEN, temperature: NotGivenOr[float] = NOT_GIVEN, max_tokens: NotGivenOr[int] = NOT_GIVEN, buffer_size: NotGivenOr[int] = NOT_GIVEN, ) -> None: if is_given(voice): self._opts.voice = voice if is_given(language): self._opts.language = language if is_given(temperature): self._opts.temperature = temperature if is_given(max_tokens): self._opts.max_tokens = max_tokens if is_given(buffer_size): self._opts.buffer_size = buffer_size
Inherited members