Module livekit.plugins.nvidia

Classes

class STT (*,
model: str = 'parakeet-1.1b-en-US-asr-streaming-silero-vad-sortformer',
function_id: str = '1598d209-5e27-4d3c-8079-4751568b1081',
punctuate: bool = True,
language_code: str = 'en-US',
sample_rate: int = 16000,
server: str = 'grpc.nvcf.nvidia.com:443',
use_ssl: bool = True,
api_key: str | livekit.agents.types.NotGiven = NOT_GIVEN)
Expand source code
class STT(stt.STT):
    def __init__(
        self,
        *,
        model: str = "parakeet-1.1b-en-US-asr-streaming-silero-vad-sortformer",
        function_id: str = "1598d209-5e27-4d3c-8079-4751568b1081",
        punctuate: bool = True,
        language_code: str = "en-US",
        sample_rate: int = 16000,
        server: str = "grpc.nvcf.nvidia.com:443",
        use_ssl: bool = True,
        api_key: NotGivenOr[str] = NOT_GIVEN,
    ):
        super().__init__(
            capabilities=stt.STTCapabilities(
                streaming=True,
                interim_results=True,
            ),
        )

        if is_given(api_key):
            self.nvidia_api_key = api_key
        else:
            self.nvidia_api_key = os.getenv("NVIDIA_API_KEY")
            if use_ssl and not self.nvidia_api_key:
                raise ValueError(
                    "NVIDIA_API_KEY is not set while using SSL. Either pass api_key parameter, set NVIDIA_API_KEY environment variable "
                    + "or disable SSL and use a locally hosted Riva NIM service."
                )

        logger.info(f"Initializing NVIDIA STT with model: {model}, server: {server}")
        logger.debug(
            f"Function ID: {function_id}, Language: {language_code}, Sample rate: {sample_rate}"
        )

        self._opts = STTOptions(
            model=model,
            function_id=function_id,
            punctuate=punctuate,
            language_code=language_code,
            sample_rate=sample_rate,
            server=server,
            use_ssl=use_ssl,
        )

    def _recognize_impl(
        self,
        buffer: AudioBuffer,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
    ) -> stt.SpeechEvent:
        raise NotImplementedError("Not implemented")

    def stream(
        self,
        *,
        language: NotGivenOr[str] = NOT_GIVEN,
        conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
    ) -> stt.RecognizeStream:
        effective_language = language if is_given(language) else self._opts.language_code
        return SpeechStream(stt=self, conn_options=conn_options, language=effective_language)

    def log_asr_models(self, asr_service: riva.client.ASRService) -> dict:
        config_response = asr_service.stub.GetRivaSpeechRecognitionConfig(
            riva.client.RivaSpeechRecognitionConfigRequest()
        )

        asr_models = {}
        for model_config in config_response.model_config:
            if model_config.parameters.get("type") == "online":
                language_code = model_config.parameters["language_code"]
                model = {"model": [model_config.model_name]}
                if language_code in asr_models:
                    asr_models[language_code].append(model)
                else:
                    asr_models[language_code] = [model]

        asr_models = dict(sorted(asr_models.items()))
        return asr_models

Helper class that provides a standard way to create an ABC using inheritance.

Ancestors

  • livekit.agents.stt.stt.STT
  • abc.ABC
  • EventEmitter
  • typing.Generic

Methods

def log_asr_models(self, asr_service: riva.client.asr.ASRService) ‑> dict
Expand source code
def log_asr_models(self, asr_service: riva.client.ASRService) -> dict:
    config_response = asr_service.stub.GetRivaSpeechRecognitionConfig(
        riva.client.RivaSpeechRecognitionConfigRequest()
    )

    asr_models = {}
    for model_config in config_response.model_config:
        if model_config.parameters.get("type") == "online":
            language_code = model_config.parameters["language_code"]
            model = {"model": [model_config.model_name]}
            if language_code in asr_models:
                asr_models[language_code].append(model)
            else:
                asr_models[language_code] = [model]

    asr_models = dict(sorted(asr_models.items()))
    return asr_models
def stream(self,
*,
language: str | livekit.agents.types.NotGiven = NOT_GIVEN,
conn_options: livekit.agents.types.APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.agents.stt.stt.RecognizeStream
Expand source code
def stream(
    self,
    *,
    language: NotGivenOr[str] = NOT_GIVEN,
    conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> stt.RecognizeStream:
    effective_language = language if is_given(language) else self._opts.language_code
    return SpeechStream(stt=self, conn_options=conn_options, language=effective_language)

Inherited members

class SpeechStream (*,
stt: livekit.plugins.nvidia.stt.STT,
conn_options: livekit.agents.types.APIConnectOptions,
language: str)
Expand source code
class SpeechStream(stt.SpeechStream):
    def __init__(self, *, stt: STT, conn_options: APIConnectOptions, language: str):
        super().__init__(stt=stt, conn_options=conn_options, sample_rate=stt._opts.sample_rate)
        self._stt = stt
        self._language = language

        self._audio_queue = queue.Queue()
        self._shutdown_event = threading.Event()
        self._recognition_thread = None

        self._speaking = False
        self._request_id = ""

        self._auth = auth.create_riva_auth(
            api_key=self._stt.nvidia_api_key,
            function_id=self._stt._opts.function_id,
            server=stt._opts.server,
            use_ssl=stt._opts.use_ssl,
        )
        self._asr_service = riva.client.ASRService(self._auth)

        self._event_loop = asyncio.get_running_loop()
        self._done_fut = asyncio.Future()

    async def _run(self) -> None:
        config = self._create_streaming_config()

        self._recognition_thread = threading.Thread(
            target=self._recognition_worker,
            args=(config,),
            name="nvidia-asr-recognition",
            daemon=True,
        )
        self._recognition_thread.start()

        try:
            await self._collect_audio()

        finally:
            self._audio_queue.put(None)
            await self._done_fut

    def _create_streaming_config(self) -> riva.client.StreamingRecognitionConfig:
        return riva.client.StreamingRecognitionConfig(
            config=riva.client.RecognitionConfig(
                encoding=riva.client.AudioEncoding.LINEAR_PCM,
                language_code=self._language,
                model=self._stt._opts.model,
                max_alternatives=1,
                enable_automatic_punctuation=self._stt._opts.punctuate,
                sample_rate_hertz=self._stt._opts.sample_rate,
                audio_channel_count=1,
            ),
            interim_results=True,
        )

    async def _collect_audio(self) -> None:
        async for data in self._input_ch:
            if isinstance(data, rtc.AudioFrame):
                audio_bytes = data.data.tobytes()
                if audio_bytes:
                    self._audio_queue.put(audio_bytes)
            elif isinstance(data, self._FlushSentinel):
                break

    def _recognition_worker(self, config: riva.client.StreamingRecognitionConfig) -> None:
        try:
            audio_generator = self._audio_chunk_generator()

            response_generator = self._asr_service.streaming_response_generator(
                audio_generator, config
            )

            for response in response_generator:
                self._handle_response(response)

        except Exception:
            logger.exception("Error in NVIDIA recognition thread")
        finally:
            self._event_loop.call_soon_threadsafe(self._done_fut.set_result, None)

    def _audio_chunk_generator(self) -> Generator[bytes, None, None]:
        """
        The nvidia riva SDK requires a generator for realtime STT - so we have to
        wrap the
        """
        while True:
            audio_chunk = self._audio_queue.get()

            if not audio_chunk:
                break

            yield audio_chunk

    def _handle_response(self, response) -> None:
        try:
            if not hasattr(response, "results") or not response.results:
                return

            for result in response.results:
                if not hasattr(result, "alternatives") or not result.alternatives:
                    continue

                alternative = result.alternatives[0]
                transcript = getattr(alternative, "transcript", "")
                is_final = getattr(result, "is_final", False)

                if not transcript.strip():
                    continue

                self._request_id = f"nvidia-{id(response)}"

                if not self._speaking and transcript.strip():
                    self._speaking = True
                    self._event_loop.call_soon_threadsafe(
                        self._event_ch.send_nowait,
                        stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH),
                    )

                speech_data = self._convert_to_speech_data(alternative)

                if is_final:
                    self._event_loop.call_soon_threadsafe(
                        self._event_ch.send_nowait,
                        stt.SpeechEvent(
                            type=stt.SpeechEventType.FINAL_TRANSCRIPT,
                            request_id=self._request_id,
                            alternatives=[speech_data],
                        ),
                    )

                    if self._speaking:
                        self._event_loop.call_soon_threadsafe(
                            self._event_ch.send_nowait,
                            stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH),
                        )
                else:
                    self._event_loop.call_soon_threadsafe(
                        self._event_ch.send_nowait,
                        stt.SpeechEvent(
                            type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
                            request_id=self._request_id,
                            alternatives=[speech_data],
                        ),
                    )

        except Exception:
            logger.exception("Error handling response")

    def _convert_to_speech_data(self, alternative) -> stt.SpeechData:
        transcript = getattr(alternative, "transcript", "")
        confidence = getattr(alternative, "confidence", 0.0)
        words = getattr(alternative, "words", [])

        start_time = 0.0
        end_time = 0.0
        if words:
            start_time = getattr(words[0], "start_time", 0) / 1000.0
            end_time = getattr(words[-1], "end_time", 0) / 1000.0

        return stt.SpeechData(
            language=self._language,
            start_time=start_time,
            end_time=end_time,
            confidence=confidence,
            text=transcript,
        )

Helper class that provides a standard way to create an ABC using inheritance.

Args: sample_rate : int or None, optional The desired sample rate for the audio input. If specified, the audio input will be automatically resampled to match the given sample rate before being processed for Speech-to-Text. If not provided (None), the input will retain its original sample rate.

Ancestors

  • livekit.agents.stt.stt.RecognizeStream
  • abc.ABC
class SynthesizeStream (*,
tts: livekit.plugins.nvidia.tts.TTS,
conn_options: livekit.agents.types.APIConnectOptions,
opts: livekit.plugins.nvidia.tts.TTSOptions)
Expand source code
class SynthesizeStream(tts.SynthesizeStream):
    def __init__(self, *, tts: TTS, conn_options: APIConnectOptions, opts: TTSOptions):
        super().__init__(tts=tts, conn_options=conn_options)
        self._opts = opts
        self._context_id = utils.shortuuid()
        self._sent_tokenizer_stream = self._opts.word_tokenizer.stream()
        self._token_q = queue.Queue()
        self._event_loop = asyncio.get_running_loop()

    async def _run(self, output_emitter: tts.AudioEmitter) -> None:
        output_emitter.initialize(
            request_id=self._context_id,
            sample_rate=self._opts.sample_rate,
            num_channels=1,
            stream=True,
            mime_type="audio/pcm",
        )
        output_emitter.start_segment(segment_id=self._context_id)

        done_fut = asyncio.Future()

        async def _input_task() -> None:
            async for data in self._input_ch:
                if isinstance(data, self._FlushSentinel):
                    self._sent_tokenizer_stream.flush()
                    continue
                self._sent_tokenizer_stream.push_text(data)
            self._sent_tokenizer_stream.end_input()

        async def _process_segments() -> None:
            async for word_stream in self._sent_tokenizer_stream:
                self._token_q.put(word_stream)
            self._token_q.put(None)

        def _synthesize_worker() -> None:
            try:
                service = self._tts._ensure_session()
                while True:
                    token = self._token_q.get()

                    if not token:
                        break

                    try:
                        responses = service.synthesize_online(
                            token.token,
                            self._opts.voice,
                            self._opts.language_code,
                            sample_rate_hz=self._opts.sample_rate,
                            encoding=AudioEncoding.LINEAR_PCM,
                        )
                        for response in responses:
                            self._event_loop.call_soon_threadsafe(
                                output_emitter.push, response.audio
                            )

                    except Exception as e:
                        logger.error(f"Error in synthesis: {e}")
                        continue
            finally:
                self._event_loop.call_soon_threadsafe(done_fut.set_result, None)

        synthesize_thread = threading.Thread(
            target=_synthesize_worker,
            name="nvidia-tts-synthesize",
            daemon=True,
        )
        synthesize_thread.start()

        tasks = [
            asyncio.create_task(_input_task()),
            asyncio.create_task(_process_segments()),
        ]

        try:
            await asyncio.gather(*tasks)
        finally:
            self._token_q.put(None)
            await done_fut
            output_emitter.end_segment()

Helper class that provides a standard way to create an ABC using inheritance.

Ancestors

  • livekit.agents.tts.tts.SynthesizeStream
  • abc.ABC
class TTS (*,
server: str = 'grpc.nvcf.nvidia.com:443',
voice: str = 'Magpie-Multilingual.EN-US.Leo',
function_id: str = '877104f7-e885-42b9-8de8-f6e4c6303969',
language_code: str = 'en-US',
use_ssl: bool = True,
api_key: str | None = None)
Expand source code
class TTS(tts.TTS):
    def __init__(
        self,
        *,
        server: str = "grpc.nvcf.nvidia.com:443",
        voice: str = "Magpie-Multilingual.EN-US.Leo",
        function_id: str = "877104f7-e885-42b9-8de8-f6e4c6303969",
        language_code: str = "en-US",
        use_ssl: bool = True,
        api_key: str | None = None,
    ):
        super().__init__(
            capabilities=tts.TTSCapabilities(streaming=True),
            sample_rate=16000,
            num_channels=1,
        )

        if api_key:
            self.nvidia_api_key = api_key
        else:
            self.nvidia_api_key = os.getenv("NVIDIA_API_KEY")
            if use_ssl and not self.nvidia_api_key:
                raise ValueError(
                    "NVIDIA_API_KEY is not set while using SSL. Either pass api_key parameter, set NVIDIA_API_KEY environment variable "
                    + "or disable SSL and use a locally hosted Riva NIM service."
                )

        self._opts = TTSOptions(
            voice=voice,
            function_id=function_id,
            server=server,
            sample_rate=16000,
            use_ssl=use_ssl,
            language_code=language_code,
            word_tokenizer=tokenize.blingfire.SentenceTokenizer(),
        )
        self._tts_service = None

    def _ensure_session(self) -> riva.client.SpeechSynthesisService:
        if not self._tts_service:
            riva_auth = auth.create_riva_auth(
                api_key=self.nvidia_api_key,
                function_id=self._opts.function_id,
                server=self._opts.server,
                use_ssl=self._opts.use_ssl,
            )
            self._tts_service = riva.client.SpeechSynthesisService(riva_auth)
        return self._tts_service

    def list_voices(self) -> dict:
        service = self._ensure_session()
        config_response = service.stub.GetRivaSynthesisConfig(
            riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
        )
        tts_models = {}
        for model_config in config_response.model_config:
            language_code = model_config.parameters.get("language_code", "unknown")
            voice_name = model_config.parameters.get("voice_name", "unknown")
            subvoices_str = model_config.parameters.get("subvoices", "")

            if subvoices_str:
                subvoices = [voice.split(":")[0] for voice in subvoices_str.split(",")]
                full_voice_names = [voice_name + "." + subvoice for subvoice in subvoices]
            else:
                full_voice_names = [voice_name]

            if language_code in tts_models:
                tts_models[language_code]["voices"].extend(full_voice_names)
            else:
                tts_models[language_code] = {"voices": full_voice_names}

        tts_models = dict(sorted(tts_models.items()))
        return tts_models

    def synthesize(
        self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
    ) -> tts.ChunkedStream:
        raise NotImplementedError("Chunked synthesis is not supported for NVIDIA TTS")

    def stream(
        self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
    ) -> tts.SynthesizeStream:
        return SynthesizeStream(tts=self, conn_options=conn_options, opts=self._opts)

Helper class that provides a standard way to create an ABC using inheritance.

Ancestors

  • livekit.agents.tts.tts.TTS
  • abc.ABC
  • EventEmitter
  • typing.Generic

Methods

def list_voices(self) ‑> dict
Expand source code
def list_voices(self) -> dict:
    service = self._ensure_session()
    config_response = service.stub.GetRivaSynthesisConfig(
        riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
    )
    tts_models = {}
    for model_config in config_response.model_config:
        language_code = model_config.parameters.get("language_code", "unknown")
        voice_name = model_config.parameters.get("voice_name", "unknown")
        subvoices_str = model_config.parameters.get("subvoices", "")

        if subvoices_str:
            subvoices = [voice.split(":")[0] for voice in subvoices_str.split(",")]
            full_voice_names = [voice_name + "." + subvoice for subvoice in subvoices]
        else:
            full_voice_names = [voice_name]

        if language_code in tts_models:
            tts_models[language_code]["voices"].extend(full_voice_names)
        else:
            tts_models[language_code] = {"voices": full_voice_names}

    tts_models = dict(sorted(tts_models.items()))
    return tts_models
def stream(self,
*,
conn_options: livekit.agents.types.APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.agents.tts.tts.SynthesizeStream
Expand source code
def stream(
    self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
) -> tts.SynthesizeStream:
    return SynthesizeStream(tts=self, conn_options=conn_options, opts=self._opts)
def synthesize(self,
text: str,
*,
conn_options: livekit.agents.types.APIConnectOptions = APIConnectOptions(max_retry=3, retry_interval=2.0, timeout=10.0)) ‑> livekit.agents.tts.tts.ChunkedStream
Expand source code
def synthesize(
    self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
) -> tts.ChunkedStream:
    raise NotImplementedError("Chunked synthesis is not supported for NVIDIA TTS")

Inherited members