Module livekit.plugins.gradium
Sub-modules
livekit.plugins.gradium.models
Classes
class STT (*,
api_key: str | None = None,
model_endpoint: str | None = 'wss://eu.api.gradium.ai/api/speech/asr',
model_name: str = 'default',
sample_rate: int = 24000,
encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
buffer_size_seconds: float = 0.08,
http_session: aiohttp.ClientSession | None = None,
vad_threshold: float = 0.9,
vad_bucket: int | None = 2,
vad_flush: bool = True,
temperature: float | None = None)-
Expand source code
class STT(stt.STT): def __init__( self, *, api_key: str | None = None, model_endpoint: str | None = "wss://eu.api.gradium.ai/api/speech/asr", model_name: str = "default", sample_rate: int = SUPPORTED_SAMPLE_RATE, encoding: NotGivenOr[STTEncoding] = NOT_GIVEN, buffer_size_seconds: float = 0.08, http_session: aiohttp.ClientSession | None = None, vad_threshold: float = 0.9, vad_bucket: int | None = 2, vad_flush: bool = True, temperature: float | None = None, ): super().__init__( capabilities=stt.STTCapabilities( streaming=True, interim_results=True, # only final transcripts ), ) api_key = api_key or os.environ.get("GRADIUM_API_KEY") if sample_rate != SUPPORTED_SAMPLE_RATE: raise ValueError(f"Only {SUPPORTED_SAMPLE_RATE}Hz sample rate is supported") if not api_key: raise ValueError( "Gradium API key is required. " "Pass one in via the `api_key` parameter, " "or set it as the `GRADIUM_API_KEY` environment variable" ) self._api_key = api_key model_endpoint = model_endpoint or os.environ.get("GRADIUM_MODEL_ENDPOINT") if not model_endpoint: raise ValueError( "The model endpoint is required, you can find it in the Gradium dashboard" ) self._model_endpoint = model_endpoint self._model_name = model_name self._opts = STTOptions( sample_rate=sample_rate, buffer_size_seconds=buffer_size_seconds, vad_threshold=vad_threshold, vad_bucket=vad_bucket, vad_flush=vad_flush, temperature=temperature, ) if is_given(encoding): self._opts.encoding = encoding self._session = http_session self._streams = weakref.WeakSet[SpeechStream]() @property def model(self) -> str: return "unknown" @property def provider(self) -> str: return "Gradium" @property def session(self) -> aiohttp.ClientSession: if not self._session: self._session = utils.http_context.http_session() return self._session async def _recognize_impl( self, buffer: AudioBuffer, *, language: NotGivenOr[str] = NOT_GIVEN, conn_options: APIConnectOptions, ) -> stt.SpeechEvent: raise NotImplementedError("Not implemented") def stream( self, *, language: NotGivenOr[str] = NOT_GIVEN, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> SpeechStream: config = dataclasses.replace(self._opts) stream = SpeechStream( stt=self, conn_options=conn_options, opts=config, api_key=self._api_key, model_endpoint=self._model_endpoint, model_name=self._model_name, http_session=self.session, ) self._streams.add(stream) return stream def update_options( self, *, buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(buffer_size_seconds): self._opts.buffer_size_seconds = buffer_size_seconds for stream in self._streams: stream.update_options( buffer_size_seconds=buffer_size_seconds, )Helper class that provides a standard way to create an ABC using inheritance.
Ancestors
- livekit.agents.stt.stt.STT
- abc.ABC
- EventEmitter
- typing.Generic
Instance variables
prop model : str-
Expand source code
@property def model(self) -> str: return "unknown"Get the model name/identifier for this STT instance.
Returns
The model name if available, "unknown" otherwise.
Note
Plugins should override this property to provide their model information.
prop provider : str-
Expand source code
@property def provider(self) -> str: return "Gradium"Get the provider name/identifier for this STT instance.
Returns
The provider name if available, "unknown" otherwise.
Note
Plugins should override this property to provide their provider information.
prop session : aiohttp.ClientSession-
Expand source code
@property def session(self) -> aiohttp.ClientSession: if not self._session: self._session = utils.http_context.http_session() return self._session
Methods
def stream(self,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.plugins.gradium.stt.SpeechStream-
Expand source code
def stream( self, *, language: NotGivenOr[str] = NOT_GIVEN, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> SpeechStream: config = dataclasses.replace(self._opts) stream = SpeechStream( stt=self, conn_options=conn_options, opts=config, api_key=self._api_key, model_endpoint=self._model_endpoint, model_name=self._model_name, http_session=self.session, ) self._streams.add(stream) return stream def update_options(self, *, buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN) ‑> None-
Expand source code
def update_options( self, *, buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(buffer_size_seconds): self._opts.buffer_size_seconds = buffer_size_seconds for stream in self._streams: stream.update_options( buffer_size_seconds=buffer_size_seconds, )
Inherited members
class SpeechStream (*,
stt: STT,
opts: STTOptions,
conn_options: APIConnectOptions,
api_key: str,
model_endpoint: str,
model_name: str,
http_session: aiohttp.ClientSession)-
Expand source code
class SpeechStream(stt.SpeechStream): # Used to close websocket _CLOSE_MSG: str = json.dumps({"terminate_session": True}) def __init__( self, *, stt: STT, opts: STTOptions, conn_options: APIConnectOptions, api_key: str, model_endpoint: str, model_name: str, http_session: aiohttp.ClientSession, ) -> None: super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate) self._opts = opts self._api_key = api_key self._model_endpoint = model_endpoint self._model_name = model_name self._session = http_session self._speech_duration: float = 0 self._reconnect_event = asyncio.Event() self._ready_msg: dict[str, Any] | None = None @property def delay_in_tokens(self) -> int: if self._ready_msg is not None: return int(self._ready_msg.get("delay_in_tokens", 6)) return 6 @property def frame_size(self) -> int: if self._ready_msg is not None: return int(self._ready_msg.get("frame_size", 1920)) return 1920 def update_options( self, *, buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(buffer_size_seconds): self._opts.buffer_size_seconds = buffer_size_seconds self._reconnect_event.set() async def _run(self) -> None: """ Run a single websocket connection to Gradium and make sure to reconnect when something went wrong. """ closing_ws = False async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: samples_per_buffer = 1920 audio_bstream = utils.audio.AudioByteStream( sample_rate=self._opts.sample_rate, num_channels=1, samples_per_channel=samples_per_buffer, ) async for data in self._input_ch: if isinstance(data, self._FlushSentinel): frames = audio_bstream.flush() else: frames = audio_bstream.write(data.data.tobytes()) for frame in frames: if len(frame.data) % 2 != 0: logger.warning("Frame data size not aligned to int16 (multiple of 2)") audio_data = base64.b64encode(frame.data.tobytes()).decode("utf-8") audio_msg = { "type": "audio", "audio": audio_data, } await ws.send_str(json.dumps(audio_msg)) async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None: nonlocal closing_ws buffered_text = [] speaking = False remaining_vad_steps = False while True: try: msg = await asyncio.wait_for(ws.receive(), timeout=5) except asyncio.TimeoutError: if closing_ws: break continue if msg.type in ( aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, ): if closing_ws: return raise APIStatusError("Gradium connection closed unexpectedly") if msg.type != aiohttp.WSMsgType.TEXT: logger.error("Unexpected Gradium message type: %s", msg.type) continue try: data = json.loads(msg.data) type_ = data.get("type", "") if type_ == "text": if speaking is False: speaking = True start_event = stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH) self._event_ch.send_nowait(start_event) buffered_text.append(data["text"]) event = stt.SpeechEvent( type=stt.SpeechEventType.INTERIM_TRANSCRIPT, alternatives=[ stt.SpeechData( text=data["text"], language=self._opts.language, start_time=data["start_s"], ) ], ) self._event_ch.send_nowait(event) elif type_ == "step": if not speaking: continue if vad_bucket := self._opts.vad_bucket: positive_vad = ( data["vad"][vad_bucket]["inactivity_prob"] > self._opts.vad_threshold ) if positive_vad: if remaining_vad_steps is None: remaining_vad_steps = self.delay_in_tokens if self._opts.vad_flush: samples_per_channel = self.frame_size * self.delay_in_tokens zeros = AudioFrame.create( sample_rate=self._opts.sample_rate, num_channels=1, samples_per_channel=samples_per_channel, ) await self._input_ch.send(zeros) else: remaining_vad_steps -= 1 if remaining_vad_steps <= 0: speaking = False remaining_vad_steps = None event = stt.SpeechEvent( type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=[ stt.SpeechData( text=" ".join(buffered_text), language=self._opts.language, ) ], ) self._event_ch.send_nowait(event) buffered_text = [] self._event_ch.send_nowait( stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH) ) else: remaining_vad_steps = None elif type_ == "ready": self._ready_msg = data elif type_ == "end_text": # This message provides the end timestamp of the previous word in the stop_s field. pass else: logger.warning(f"Unknown message type from Gradium {type_}") except Exception: logger.exception("Failed to process message from Gradium") ws: aiohttp.ClientWebSocketResponse | None = None while True: try: ws = await self._connect_ws() tasks = [ asyncio.create_task(send_task(ws)), asyncio.create_task(recv_task(ws)), ] wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) try: done, _ = await asyncio.wait( (asyncio.gather(*tasks), wait_reconnect_task), return_when=asyncio.FIRST_COMPLETED, ) for task in done: if task != wait_reconnect_task: task.result() if wait_reconnect_task not in done: break self._reconnect_event.clear() finally: await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task) finally: if ws is not None: await ws.close() async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: headers = {"x-api-key": self._api_key, "x-api-source": "livekit"} ws = await self._session.ws_connect(self._model_endpoint, headers=headers) # Build and send the setup payload as the first message setup_msg: dict[str, Any] = { "type": "setup", "model_name": self._model_name, "input_format": "pcm", } if self._opts.temperature is not None: setup_msg["json_config"] = {"temp": self._opts.temperature} await ws.send_str(json.dumps(setup_msg)) return wsHelper class that provides a standard way to create an ABC using inheritance.
Args: sample_rate : int or None, optional The desired sample rate for the audio input. If specified, the audio input will be automatically resampled to match the given sample rate before being processed for Speech-to-Text. If not provided (None), the input will retain its original sample rate.
Ancestors
- livekit.agents.stt.stt.RecognizeStream
- abc.ABC
Instance variables
prop delay_in_tokens : int-
Expand source code
@property def delay_in_tokens(self) -> int: if self._ready_msg is not None: return int(self._ready_msg.get("delay_in_tokens", 6)) return 6 prop frame_size : int-
Expand source code
@property def frame_size(self) -> int: if self._ready_msg is not None: return int(self._ready_msg.get("frame_size", 1920)) return 1920
Methods
def update_options(self, *, buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN) ‑> None-
Expand source code
def update_options( self, *, buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN, ) -> None: if is_given(buffer_size_seconds): self._opts.buffer_size_seconds = buffer_size_seconds self._reconnect_event.set()
class TTS (*,
api_key: str | None = None,
model_endpoint: str | None = 'wss://eu.api.gradium.ai/api/speech/tts',
model_name: str = 'default',
voice: str | None = None,
voice_id: str | None = 'YTpq7expH9539ERJ',
json_config: dict[str, Any] | None = None,
http_session: aiohttp.ClientSession | None = None,
word_tokenizer: tokenize.WordTokenizer | None = None)-
Expand source code
class TTS(tts.TTS): def __init__( self, *, api_key: str | None = None, model_endpoint: str | None = "wss://eu.api.gradium.ai/api/speech/tts", model_name: str = "default", voice: str | None = None, voice_id: str | None = "YTpq7expH9539ERJ", json_config: dict[str, Any] | None = None, http_session: aiohttp.ClientSession | None = None, word_tokenizer: tokenize.WordTokenizer | None = None, ) -> None: """ Initialize the Gradium TTS. Args: api_key (str): Gradium API key, or `GRADIUM_API_KEY` env var. model_endpoint (str): Gradium model endpoint, or `GRADIUM_MODEL_ENDPOINT` env var. model_name (str): Model name. voice (str): Speaker voice. voice_id (str): Speaker voice ID. word_tokenizer (tokenize.WordTokenizer): Tokenizer for processing text. Defaults to basic WordTokenizer. """ super().__init__( capabilities=tts.TTSCapabilities(streaming=True), sample_rate=SUPPORTED_SAMPLE_RATE, num_channels=1, ) api_key = api_key or os.environ.get("GRADIUM_API_KEY") if not api_key: raise ValueError( "Gradium API key is required. " "Pass one in via the `api_key` parameter, " "or set it as the `GRADIUM_API_KEY` environment variable" ) model_endpoint = model_endpoint or os.environ.get("GRADIUM_MODEL_ENDPOINT") if not model_endpoint: raise ValueError( "The model endpoint is required, you can find it in the Gradium dashboard" ) self._api_key = api_key self._model_endpoint = model_endpoint self._model_name = model_name if not word_tokenizer: word_tokenizer = tokenize.basic.WordTokenizer(ignore_punctuation=False) self._opts = _TTSOptions( voice=voice, voice_id=voice_id, word_tokenizer=word_tokenizer, json_config=json_config, ) self._session = http_session @property def model(self) -> str: return "unknown" @property def provider(self) -> str: return "Gradium" async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse: return await asyncio.wait_for( self._ensure_session().ws_connect( self._model_endpoint, headers={"x-api-key": self._api_key, "x-api-source": "livekit"}, ), timeout, ) async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None: await ws.close() def _ensure_session(self) -> aiohttp.ClientSession: if not self._session: self._session = utils.http_context.http_session() return self._session def update_options( self, *, voice: NotGivenOr[str] = NOT_GIVEN, json_config: NotGivenOr[dict[str, Any]] = NOT_GIVEN, ) -> None: if is_given(voice): self._opts.voice = voice if is_given(json_config): self._opts.json_config = json_config def stream( self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS ) -> SynthesizeStream: return SynthesizeStream(tts=self, conn_options=conn_options) def synthesize( self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS ) -> ChunkedStream: return ChunkedStream( tts=self, api_key=self._api_key, input_text=text, model_endpoint=self._model_endpoint, model_name=self._model_name, conn_options=conn_options, )Helper class that provides a standard way to create an ABC using inheritance.
Initialize the Gradium TTS.
Args
api_key:str- Gradium API key, or
GRADIUM_API_KEYenv var. model_endpoint:str- Gradium model endpoint, or
GRADIUM_MODEL_ENDPOINTenv var. model_name:str- Model name.
voice:str- Speaker voice.
voice_id:str- Speaker voice ID.
word_tokenizer:tokenize.WordTokenizer- Tokenizer for processing text. Defaults to basic WordTokenizer.
Ancestors
- livekit.agents.tts.tts.TTS
- abc.ABC
- EventEmitter
- typing.Generic
Instance variables
prop model : str-
Expand source code
@property def model(self) -> str: return "unknown"Get the model name/identifier for this TTS instance.
Returns
The model name if available, "unknown" otherwise.
Note
Plugins should override this property to provide their model information.
prop provider : str-
Expand source code
@property def provider(self) -> str: return "Gradium"Get the provider name/identifier for this TTS instance.
Returns
The provider name if available, "unknown" otherwise.
Note
Plugins should override this property to provide their provider information.
Methods
def stream(self,
*,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.plugins.gradium.tts.SynthesizeStream-
Expand source code
def stream( self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS ) -> SynthesizeStream: return SynthesizeStream(tts=self, conn_options=conn_options) def synthesize(self,
text: str,
*,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.plugins.gradium.tts.ChunkedStream-
Expand source code
def synthesize( self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS ) -> ChunkedStream: return ChunkedStream( tts=self, api_key=self._api_key, input_text=text, model_endpoint=self._model_endpoint, model_name=self._model_name, conn_options=conn_options, ) def update_options(self,
*,
voice: NotGivenOr[str] = NOT_GIVEN,
json_config: NotGivenOr[dict[str, Any]] = NOT_GIVEN) ‑> None-
Expand source code
def update_options( self, *, voice: NotGivenOr[str] = NOT_GIVEN, json_config: NotGivenOr[dict[str, Any]] = NOT_GIVEN, ) -> None: if is_given(voice): self._opts.voice = voice if is_given(json_config): self._opts.json_config = json_config
Inherited members