Module livekit.plugins.gradium

Sub-modules

livekit.plugins.gradium.models

Classes

class STT (*,
api_key: str | None = None,
model_endpoint: str | None = 'wss://eu.api.gradium.ai/api/speech/asr',
model_name: str = 'default',
sample_rate: int = 24000,
encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
buffer_size_seconds: float = 0.08,
http_session: aiohttp.ClientSession | None = None,
vad_threshold: float = 0.9,
vad_bucket: int | None = 2,
vad_flush: bool = True,
temperature: float | None = None)
Expand source code
class STT(stt.STT):
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model_endpoint: str | None = "wss://eu.api.gradium.ai/api/speech/asr",
        model_name: str = "default",
        sample_rate: int = SUPPORTED_SAMPLE_RATE,
        encoding: NotGivenOr[STTEncoding] = NOT_GIVEN,
        buffer_size_seconds: float = 0.08,
        http_session: aiohttp.ClientSession | None = None,
        vad_threshold: float = 0.9,
        vad_bucket: int | None = 2,
        vad_flush: bool = True,
        temperature: float | None = None,
    ):
        super().__init__(
            capabilities=stt.STTCapabilities(
                streaming=True,
                interim_results=True,  # only final transcripts
            ),
        )

        api_key = api_key or os.environ.get("GRADIUM_API_KEY")

        if sample_rate != SUPPORTED_SAMPLE_RATE:
            raise ValueError(f"Only {SUPPORTED_SAMPLE_RATE}Hz sample rate is supported")

        if not api_key:
            raise ValueError(
                "Gradium API key is required. "
                "Pass one in via the `api_key` parameter, "
                "or set it as the `GRADIUM_API_KEY` environment variable"
            )

        self._api_key = api_key

        model_endpoint = model_endpoint or os.environ.get("GRADIUM_MODEL_ENDPOINT")

        if not model_endpoint:
            raise ValueError(
                "The model endpoint is required, you can find it in the Gradium dashboard"
            )

        self._model_endpoint = model_endpoint
        self._model_name = model_name

        self._opts = STTOptions(
            sample_rate=sample_rate,
            buffer_size_seconds=buffer_size_seconds,
            vad_threshold=vad_threshold,
            vad_bucket=vad_bucket,
            vad_flush=vad_flush,
            temperature=temperature,
        )

        if is_given(encoding):
            self._opts.encoding = encoding

        self._session = http_session
        self._streams = weakref.WeakSet[SpeechStream]()

    @property
    def model(self) -> str:
        return "unknown"

    @property
    def provider(self) -> str:
        return "Gradium"

    @property
    def session(self) -> aiohttp.ClientSession:
        if not self._session:
            self._session = utils.http_context.http_session()
        return self._session

    async def _recognize_impl(
        self,
        buffer: AudioBuffer,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        conn_options: APIConnectOptions,
    ) -> stt.SpeechEvent:
        raise NotImplementedError("Not implemented")

    def stream(
        self,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
    ) -> SpeechStream:
        config = dataclasses.replace(self._opts)
        stream = SpeechStream(
            stt=self,
            conn_options=conn_options,
            opts=config,
            api_key=self._api_key,
            model_endpoint=self._model_endpoint,
            model_name=self._model_name,
            http_session=self.session,
        )
        self._streams.add(stream)
        return stream

    def update_options(
        self,
        *,
        buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN,
    ) -> None:
        if is_given(buffer_size_seconds):
            self._opts.buffer_size_seconds = buffer_size_seconds

        for stream in self._streams:
            stream.update_options(
                buffer_size_seconds=buffer_size_seconds,
            )

Helper class that provides a standard way to create an ABC using inheritance.

Ancestors

  • livekit.agents.stt.stt.STT
  • abc.ABC
  • EventEmitter
  • typing.Generic

Instance variables

prop model : str
Expand source code
@property
def model(self) -> str:
    return "unknown"

Get the model name/identifier for this STT instance.

Returns

The model name if available, "unknown" otherwise.

Note

Plugins should override this property to provide their model information.

prop provider : str
Expand source code
@property
def provider(self) -> str:
    return "Gradium"

Get the provider name/identifier for this STT instance.

Returns

The provider name if available, "unknown" otherwise.

Note

Plugins should override this property to provide their provider information.

prop session : aiohttp.ClientSession
Expand source code
@property
def session(self) -> aiohttp.ClientSession:
    if not self._session:
        self._session = utils.http_context.http_session()
    return self._session

Methods

def stream(self,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.plugins.gradium.stt.SpeechStream
Expand source code
def stream(
    self,
    *,
    language: NotGivenOr[str] = NOT_GIVEN,
    conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> SpeechStream:
    config = dataclasses.replace(self._opts)
    stream = SpeechStream(
        stt=self,
        conn_options=conn_options,
        opts=config,
        api_key=self._api_key,
        model_endpoint=self._model_endpoint,
        model_name=self._model_name,
        http_session=self.session,
    )
    self._streams.add(stream)
    return stream
def update_options(self, *, buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN) ‑> None
Expand source code
def update_options(
    self,
    *,
    buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN,
) -> None:
    if is_given(buffer_size_seconds):
        self._opts.buffer_size_seconds = buffer_size_seconds

    for stream in self._streams:
        stream.update_options(
            buffer_size_seconds=buffer_size_seconds,
        )

Inherited members

class SpeechStream (*,
stt: STT,
opts: STTOptions,
conn_options: APIConnectOptions,
api_key: str,
model_endpoint: str,
model_name: str,
http_session: aiohttp.ClientSession)
Expand source code
class SpeechStream(stt.SpeechStream):
    # Used to close websocket
    _CLOSE_MSG: str = json.dumps({"terminate_session": True})

    def __init__(
        self,
        *,
        stt: STT,
        opts: STTOptions,
        conn_options: APIConnectOptions,
        api_key: str,
        model_endpoint: str,
        model_name: str,
        http_session: aiohttp.ClientSession,
    ) -> None:
        super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)

        self._opts = opts
        self._api_key = api_key
        self._model_endpoint = model_endpoint
        self._model_name = model_name
        self._session = http_session
        self._speech_duration: float = 0

        self._reconnect_event = asyncio.Event()
        self._ready_msg: dict[str, Any] | None = None

    @property
    def delay_in_tokens(self) -> int:
        if self._ready_msg is not None:
            return int(self._ready_msg.get("delay_in_tokens", 6))
        return 6

    @property
    def frame_size(self) -> int:
        if self._ready_msg is not None:
            return int(self._ready_msg.get("frame_size", 1920))
        return 1920

    def update_options(
        self,
        *,
        buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN,
    ) -> None:
        if is_given(buffer_size_seconds):
            self._opts.buffer_size_seconds = buffer_size_seconds

        self._reconnect_event.set()

    async def _run(self) -> None:
        """
        Run a single websocket connection to Gradium and make sure to reconnect
        when something went wrong.
        """

        closing_ws = False

        async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
            samples_per_buffer = 1920

            audio_bstream = utils.audio.AudioByteStream(
                sample_rate=self._opts.sample_rate,
                num_channels=1,
                samples_per_channel=samples_per_buffer,
            )

            async for data in self._input_ch:
                if isinstance(data, self._FlushSentinel):
                    frames = audio_bstream.flush()
                else:
                    frames = audio_bstream.write(data.data.tobytes())

                for frame in frames:
                    if len(frame.data) % 2 != 0:
                        logger.warning("Frame data size not aligned to int16 (multiple of 2)")

                    audio_data = base64.b64encode(frame.data.tobytes()).decode("utf-8")
                    audio_msg = {
                        "type": "audio",
                        "audio": audio_data,
                    }
                    await ws.send_str(json.dumps(audio_msg))

        async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
            nonlocal closing_ws
            buffered_text = []
            speaking = False
            remaining_vad_steps = False
            while True:
                try:
                    msg = await asyncio.wait_for(ws.receive(), timeout=5)
                except asyncio.TimeoutError:
                    if closing_ws:
                        break
                    continue

                if msg.type in (
                    aiohttp.WSMsgType.CLOSED,
                    aiohttp.WSMsgType.CLOSE,
                    aiohttp.WSMsgType.CLOSING,
                ):
                    if closing_ws:
                        return
                    raise APIStatusError("Gradium connection closed unexpectedly")

                if msg.type != aiohttp.WSMsgType.TEXT:
                    logger.error("Unexpected Gradium message type: %s", msg.type)
                    continue

                try:
                    data = json.loads(msg.data)

                    type_ = data.get("type", "")

                    if type_ == "text":
                        if speaking is False:
                            speaking = True
                            start_event = stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
                            self._event_ch.send_nowait(start_event)
                        buffered_text.append(data["text"])
                        event = stt.SpeechEvent(
                            type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
                            alternatives=[
                                stt.SpeechData(
                                    text=data["text"],
                                    language=self._opts.language,
                                    start_time=data["start_s"],
                                )
                            ],
                        )
                        self._event_ch.send_nowait(event)

                    elif type_ == "step":
                        if not speaking:
                            continue
                        if vad_bucket := self._opts.vad_bucket:
                            positive_vad = (
                                data["vad"][vad_bucket]["inactivity_prob"]
                                > self._opts.vad_threshold
                            )
                            if positive_vad:
                                if remaining_vad_steps is None:
                                    remaining_vad_steps = self.delay_in_tokens
                                    if self._opts.vad_flush:
                                        samples_per_channel = self.frame_size * self.delay_in_tokens
                                        zeros = AudioFrame.create(
                                            sample_rate=self._opts.sample_rate,
                                            num_channels=1,
                                            samples_per_channel=samples_per_channel,
                                        )
                                        await self._input_ch.send(zeros)
                                else:
                                    remaining_vad_steps -= 1
                                    if remaining_vad_steps <= 0:
                                        speaking = False
                                        remaining_vad_steps = None
                                        event = stt.SpeechEvent(
                                            type=stt.SpeechEventType.FINAL_TRANSCRIPT,
                                            alternatives=[
                                                stt.SpeechData(
                                                    text=" ".join(buffered_text),
                                                    language=self._opts.language,
                                                )
                                            ],
                                        )
                                        self._event_ch.send_nowait(event)

                                        buffered_text = []
                                        self._event_ch.send_nowait(
                                            stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
                                        )
                            else:
                                remaining_vad_steps = None

                    elif type_ == "ready":
                        self._ready_msg = data
                    elif type_ == "end_text":
                        # This message provides the end timestamp of the previous word in the stop_s field.
                        pass
                    else:
                        logger.warning(f"Unknown message type from Gradium {type_}")

                except Exception:
                    logger.exception("Failed to process message from Gradium")

        ws: aiohttp.ClientWebSocketResponse | None = None

        while True:
            try:
                ws = await self._connect_ws()
                tasks = [
                    asyncio.create_task(send_task(ws)),
                    asyncio.create_task(recv_task(ws)),
                ]
                wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())

                try:
                    done, _ = await asyncio.wait(
                        (asyncio.gather(*tasks), wait_reconnect_task),
                        return_when=asyncio.FIRST_COMPLETED,
                    )
                    for task in done:
                        if task != wait_reconnect_task:
                            task.result()

                    if wait_reconnect_task not in done:
                        break

                    self._reconnect_event.clear()
                finally:
                    await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task)
            finally:
                if ws is not None:
                    await ws.close()

    async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
        headers = {"x-api-key": self._api_key, "x-api-source": "livekit"}

        ws = await self._session.ws_connect(self._model_endpoint, headers=headers)

        # Build and send the setup payload as the first message
        setup_msg: dict[str, Any] = {
            "type": "setup",
            "model_name": self._model_name,
            "input_format": "pcm",
        }
        if self._opts.temperature is not None:
            setup_msg["json_config"] = {"temp": self._opts.temperature}

        await ws.send_str(json.dumps(setup_msg))
        return ws

Helper class that provides a standard way to create an ABC using inheritance.

Args: sample_rate : int or None, optional The desired sample rate for the audio input. If specified, the audio input will be automatically resampled to match the given sample rate before being processed for Speech-to-Text. If not provided (None), the input will retain its original sample rate.

Ancestors

  • livekit.agents.stt.stt.RecognizeStream
  • abc.ABC

Instance variables

prop delay_in_tokens : int
Expand source code
@property
def delay_in_tokens(self) -> int:
    if self._ready_msg is not None:
        return int(self._ready_msg.get("delay_in_tokens", 6))
    return 6
prop frame_size : int
Expand source code
@property
def frame_size(self) -> int:
    if self._ready_msg is not None:
        return int(self._ready_msg.get("frame_size", 1920))
    return 1920

Methods

def update_options(self, *, buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN) ‑> None
Expand source code
def update_options(
    self,
    *,
    buffer_size_seconds: NotGivenOr[float] = NOT_GIVEN,
) -> None:
    if is_given(buffer_size_seconds):
        self._opts.buffer_size_seconds = buffer_size_seconds

    self._reconnect_event.set()
class TTS (*,
api_key: str | None = None,
model_endpoint: str | None = 'wss://eu.api.gradium.ai/api/speech/tts',
model_name: str = 'default',
voice: str | None = None,
voice_id: str | None = 'YTpq7expH9539ERJ',
json_config: dict[str, Any] | None = None,
http_session: aiohttp.ClientSession | None = None,
word_tokenizer: tokenize.WordTokenizer | None = None)
Expand source code
class TTS(tts.TTS):
    def __init__(
        self,
        *,
        api_key: str | None = None,
        model_endpoint: str | None = "wss://eu.api.gradium.ai/api/speech/tts",
        model_name: str = "default",
        voice: str | None = None,
        voice_id: str | None = "YTpq7expH9539ERJ",
        json_config: dict[str, Any] | None = None,
        http_session: aiohttp.ClientSession | None = None,
        word_tokenizer: tokenize.WordTokenizer | None = None,
    ) -> None:
        """
        Initialize the Gradium TTS.

        Args:
            api_key (str): Gradium API key, or `GRADIUM_API_KEY` env var.
            model_endpoint (str): Gradium model endpoint, or `GRADIUM_MODEL_ENDPOINT` env var.
            model_name (str): Model name.
            voice (str): Speaker voice.
            voice_id (str): Speaker voice ID.
            word_tokenizer (tokenize.WordTokenizer): Tokenizer for processing text. Defaults to basic WordTokenizer.
        """
        super().__init__(
            capabilities=tts.TTSCapabilities(streaming=True),
            sample_rate=SUPPORTED_SAMPLE_RATE,
            num_channels=1,
        )

        api_key = api_key or os.environ.get("GRADIUM_API_KEY")

        if not api_key:
            raise ValueError(
                "Gradium API key is required. "
                "Pass one in via the `api_key` parameter, "
                "or set it as the `GRADIUM_API_KEY` environment variable"
            )

        model_endpoint = model_endpoint or os.environ.get("GRADIUM_MODEL_ENDPOINT")

        if not model_endpoint:
            raise ValueError(
                "The model endpoint is required, you can find it in the Gradium dashboard"
            )

        self._api_key = api_key
        self._model_endpoint = model_endpoint
        self._model_name = model_name

        if not word_tokenizer:
            word_tokenizer = tokenize.basic.WordTokenizer(ignore_punctuation=False)
        self._opts = _TTSOptions(
            voice=voice,
            voice_id=voice_id,
            word_tokenizer=word_tokenizer,
            json_config=json_config,
        )
        self._session = http_session

    @property
    def model(self) -> str:
        return "unknown"

    @property
    def provider(self) -> str:
        return "Gradium"

    async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
        return await asyncio.wait_for(
            self._ensure_session().ws_connect(
                self._model_endpoint,
                headers={"x-api-key": self._api_key, "x-api-source": "livekit"},
            ),
            timeout,
        )

    async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None:
        await ws.close()

    def _ensure_session(self) -> aiohttp.ClientSession:
        if not self._session:
            self._session = utils.http_context.http_session()

        return self._session

    def update_options(
        self,
        *,
        voice: NotGivenOr[str] = NOT_GIVEN,
        json_config: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
    ) -> None:
        if is_given(voice):
            self._opts.voice = voice
        if is_given(json_config):
            self._opts.json_config = json_config

    def stream(
        self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
    ) -> SynthesizeStream:
        return SynthesizeStream(tts=self, conn_options=conn_options)

    def synthesize(
        self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
    ) -> ChunkedStream:
        return ChunkedStream(
            tts=self,
            api_key=self._api_key,
            input_text=text,
            model_endpoint=self._model_endpoint,
            model_name=self._model_name,
            conn_options=conn_options,
        )

Helper class that provides a standard way to create an ABC using inheritance.

Initialize the Gradium TTS.

Args

api_key : str
Gradium API key, or GRADIUM_API_KEY env var.
model_endpoint : str
Gradium model endpoint, or GRADIUM_MODEL_ENDPOINT env var.
model_name : str
Model name.
voice : str
Speaker voice.
voice_id : str
Speaker voice ID.
word_tokenizer : tokenize.WordTokenizer
Tokenizer for processing text. Defaults to basic WordTokenizer.

Ancestors

  • livekit.agents.tts.tts.TTS
  • abc.ABC
  • EventEmitter
  • typing.Generic

Instance variables

prop model : str
Expand source code
@property
def model(self) -> str:
    return "unknown"

Get the model name/identifier for this TTS instance.

Returns

The model name if available, "unknown" otherwise.

Note

Plugins should override this property to provide their model information.

prop provider : str
Expand source code
@property
def provider(self) -> str:
    return "Gradium"

Get the provider name/identifier for this TTS instance.

Returns

The provider name if available, "unknown" otherwise.

Note

Plugins should override this property to provide their provider information.

Methods

def stream(self,
*,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.plugins.gradium.tts.SynthesizeStream
Expand source code
def stream(
    self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
) -> SynthesizeStream:
    return SynthesizeStream(tts=self, conn_options=conn_options)
def synthesize(self,
text: str,
*,
conn_options: APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.plugins.gradium.tts.ChunkedStream
Expand source code
def synthesize(
    self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
) -> ChunkedStream:
    return ChunkedStream(
        tts=self,
        api_key=self._api_key,
        input_text=text,
        model_endpoint=self._model_endpoint,
        model_name=self._model_name,
        conn_options=conn_options,
    )
def update_options(self,
*,
voice: NotGivenOr[str] = NOT_GIVEN,
json_config: NotGivenOr[dict[str, Any]] = NOT_GIVEN) ‑> None
Expand source code
def update_options(
    self,
    *,
    voice: NotGivenOr[str] = NOT_GIVEN,
    json_config: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> None:
    if is_given(voice):
        self._opts.voice = voice
    if is_given(json_config):
        self._opts.json_config = json_config

Inherited members