Module livekit.plugins.turn_detector.base

Classes

class EOUModelBase (model_type: EOUModelType = 'en',
inference_executor: InferenceExecutor | None = None)
Expand source code
class EOUModelBase(ABC):
    def __init__(
        self,
        model_type: EOUModelType = "en",  # default to smaller, english-only model
        inference_executor: InferenceExecutor | None = None,
    ) -> None:
        self._model_type = model_type
        self._executor = (
            inference_executor or get_current_job_context().inference_executor
        )

        config_fname = _download_from_hf_hub(
            HG_MODEL,
            "languages.json",
            revision=MODEL_REVISIONS[self._model_type],
            local_files_only=True,
        )
        with open(config_fname, "r") as f:
            self._languages = json.load(f)

    @abstractmethod
    def _inference_method(self): ...

    def unlikely_threshold(self, language: str | None) -> float | None:
        if language is None:
            return None
        lang = language.lower()
        if lang in self._languages:
            return self._languages[lang]["threshold"]
        if "-" in lang:
            part = lang.split("-")[0]
            if part in self._languages:
                return self._languages[part]["threshold"]
        logger.warning(f"Language {language} not supported by EOU model")
        return None

    def supports_language(self, language: str | None) -> bool:
        return self.unlikely_threshold(language) is not None

    async def predict_eou(self, chat_ctx: llm.ChatContext) -> float:
        return await self.predict_end_of_turn(chat_ctx)

    # our EOU model inference should be fast, 3 seconds is more than enough
    async def predict_end_of_turn(
        self, chat_ctx: llm.ChatContext, *, timeout: float | None = 3
    ) -> float:
        messages = []

        for msg in chat_ctx.messages:
            if msg.role not in ("user", "assistant"):
                continue

            if isinstance(msg.content, str):
                messages.append(
                    {
                        "role": msg.role,
                        "content": msg.content,
                    }
                )
            elif isinstance(msg.content, list):
                for cnt in msg.content:
                    if isinstance(cnt, str):
                        messages.append(
                            {
                                "role": msg.role,
                                "content": cnt,
                            }
                        )
                        break

        messages = messages[-MAX_HISTORY_TURNS:]

        json_data = json.dumps({"chat_ctx": messages}).encode()

        result = await asyncio.wait_for(
            self._executor.do_inference(self._inference_method(), json_data),
            timeout=timeout,
        )

        assert result is not None, (
            "end_of_utterance prediction should always returns a result"
        )

        result_json = json.loads(result.decode())
        logger.debug(
            "eou prediction",
            extra=result_json,
        )
        return result_json["eou_probability"]

Helper class that provides a standard way to create an ABC using inheritance.

Ancestors

  • abc.ABC

Subclasses

Methods

async def predict_end_of_turn(self, chat_ctx: llm.ChatContext, *, timeout: float | None = 3) ‑> float
Expand source code
async def predict_end_of_turn(
    self, chat_ctx: llm.ChatContext, *, timeout: float | None = 3
) -> float:
    messages = []

    for msg in chat_ctx.messages:
        if msg.role not in ("user", "assistant"):
            continue

        if isinstance(msg.content, str):
            messages.append(
                {
                    "role": msg.role,
                    "content": msg.content,
                }
            )
        elif isinstance(msg.content, list):
            for cnt in msg.content:
                if isinstance(cnt, str):
                    messages.append(
                        {
                            "role": msg.role,
                            "content": cnt,
                        }
                    )
                    break

    messages = messages[-MAX_HISTORY_TURNS:]

    json_data = json.dumps({"chat_ctx": messages}).encode()

    result = await asyncio.wait_for(
        self._executor.do_inference(self._inference_method(), json_data),
        timeout=timeout,
    )

    assert result is not None, (
        "end_of_utterance prediction should always returns a result"
    )

    result_json = json.loads(result.decode())
    logger.debug(
        "eou prediction",
        extra=result_json,
    )
    return result_json["eou_probability"]
async def predict_eou(self, chat_ctx: llm.ChatContext) ‑> float
Expand source code
async def predict_eou(self, chat_ctx: llm.ChatContext) -> float:
    return await self.predict_end_of_turn(chat_ctx)
def supports_language(self, language: str | None) ‑> bool
Expand source code
def supports_language(self, language: str | None) -> bool:
    return self.unlikely_threshold(language) is not None
def unlikely_threshold(self, language: str | None) ‑> float | None
Expand source code
def unlikely_threshold(self, language: str | None) -> float | None:
    if language is None:
        return None
    lang = language.lower()
    if lang in self._languages:
        return self._languages[lang]["threshold"]
    if "-" in lang:
        part = lang.split("-")[0]
        if part in self._languages:
            return self._languages[part]["threshold"]
    logger.warning(f"Language {language} not supported by EOU model")
    return None