Module livekit.plugins.turn_detector.base

Classes

class EOUModelBase (model_type: EOUModelType = 'en',
inference_executor: InferenceExecutor | None = None,
unlikely_threshold: float | None = None,
load_languages: bool = True)
Expand source code
class EOUModelBase(ABC):
    def __init__(
        self,
        model_type: EOUModelType = "en",  # default to smaller, english-only model
        inference_executor: InferenceExecutor | None = None,
        # if set, overrides the per-language threshold tuned for accuracy.
        # not recommended unless you're confident in the impact.
        unlikely_threshold: float | None = None,
        load_languages: bool = True,
    ) -> None:
        self._model_type = model_type
        self._executor = inference_executor or get_job_context().inference_executor
        self._unlikely_threshold = unlikely_threshold
        self._languages: dict[str, Any] = {}

        if load_languages:
            config_fname = _download_from_hf_hub(
                HG_MODEL,
                "languages.json",
                revision=MODEL_REVISIONS[self._model_type],
                local_files_only=True,
            )
            with open(config_fname) as f:
                self._languages = json.load(f)

    @abstractmethod
    def _inference_method(self) -> str: ...

    async def unlikely_threshold(self, language: str | None) -> float | None:
        if language is None:
            return None

        lang = language.lower()
        # try the full language code first
        lang_data = self._languages.get(lang)

        # try the base language if the full language code is not found
        if lang_data is None and "-" in lang:
            base_lang = lang.split("-")[0]
            lang_data = self._languages.get(base_lang)

        if not lang_data:
            return None
        # if a custom threshold is provided, use it
        if self._unlikely_threshold is not None:
            return self._unlikely_threshold
        else:
            return lang_data["threshold"]  # type: ignore

    async def supports_language(self, language: str | None) -> bool:
        return await self.unlikely_threshold(language) is not None

    # our EOU model inference should be fast, 3 seconds is more than enough
    async def predict_end_of_turn(
        self,
        chat_ctx: llm.ChatContext,
        *,
        timeout: float | None = 3,
    ) -> float:
        messages: list[dict[str, Any]] = []
        for item in chat_ctx.items:
            if item.type != "message":
                continue

            if item.role not in ("user", "assistant"):
                continue

            text_content = item.text_content
            if text_content:
                messages.append(
                    {
                        "role": item.role,
                        "content": text_content,
                    }
                )

        messages = messages[-MAX_HISTORY_TURNS:]
        json_data = json.dumps({"chat_ctx": messages}).encode()

        result = await asyncio.wait_for(
            self._executor.do_inference(self._inference_method(), json_data),
            timeout=timeout,
        )

        assert result is not None, "end_of_utterance prediction should always returns a result"

        result_json: dict[str, Any] = json.loads(result.decode())
        logger.debug(
            "eou prediction",
            extra=result_json,
        )
        return result_json["eou_probability"]  # type: ignore

Helper class that provides a standard way to create an ABC using inheritance.

Ancestors

  • abc.ABC

Subclasses

Methods

async def predict_end_of_turn(self, chat_ctx: llm.ChatContext, *, timeout: float | None = 3) ‑> float
Expand source code
async def predict_end_of_turn(
    self,
    chat_ctx: llm.ChatContext,
    *,
    timeout: float | None = 3,
) -> float:
    messages: list[dict[str, Any]] = []
    for item in chat_ctx.items:
        if item.type != "message":
            continue

        if item.role not in ("user", "assistant"):
            continue

        text_content = item.text_content
        if text_content:
            messages.append(
                {
                    "role": item.role,
                    "content": text_content,
                }
            )

    messages = messages[-MAX_HISTORY_TURNS:]
    json_data = json.dumps({"chat_ctx": messages}).encode()

    result = await asyncio.wait_for(
        self._executor.do_inference(self._inference_method(), json_data),
        timeout=timeout,
    )

    assert result is not None, "end_of_utterance prediction should always returns a result"

    result_json: dict[str, Any] = json.loads(result.decode())
    logger.debug(
        "eou prediction",
        extra=result_json,
    )
    return result_json["eou_probability"]  # type: ignore
async def supports_language(self, language: str | None) ‑> bool
Expand source code
async def supports_language(self, language: str | None) -> bool:
    return await self.unlikely_threshold(language) is not None
async def unlikely_threshold(self, language: str | None) ‑> float | None
Expand source code
async def unlikely_threshold(self, language: str | None) -> float | None:
    if language is None:
        return None

    lang = language.lower()
    # try the full language code first
    lang_data = self._languages.get(lang)

    # try the base language if the full language code is not found
    if lang_data is None and "-" in lang:
        base_lang = lang.split("-")[0]
        lang_data = self._languages.get(base_lang)

    if not lang_data:
        return None
    # if a custom threshold is provided, use it
    if self._unlikely_threshold is not None:
        return self._unlikely_threshold
    else:
        return lang_data["threshold"]  # type: ignore