Module livekit.plugins.turn_detector.base
Classes
class EOUModelBase (model_type: EOUModelType = 'en',
inference_executor: InferenceExecutor | None = None,
unlikely_threshold: float | None = None,
load_languages: bool = True)-
Expand source code
class EOUModelBase(ABC): def __init__( self, model_type: EOUModelType = "en", # default to smaller, english-only model inference_executor: InferenceExecutor | None = None, # if set, overrides the per-language threshold tuned for accuracy. # not recommended unless you're confident in the impact. unlikely_threshold: float | None = None, load_languages: bool = True, ) -> None: self._model_type = model_type self._executor = inference_executor or get_job_context().inference_executor self._unlikely_threshold = unlikely_threshold self._languages: dict[str, Any] = {} if load_languages: config_fname = _download_from_hf_hub( HG_MODEL, "languages.json", revision=MODEL_REVISIONS[self._model_type], local_files_only=True, ) with open(config_fname) as f: self._languages = json.load(f) @abstractmethod def _inference_method(self) -> str: ... async def unlikely_threshold(self, language: str | None) -> float | None: if language is None: return None lang = language.lower() # try the full language code first lang_data = self._languages.get(lang) # try the base language if the full language code is not found if lang_data is None and "-" in lang: base_lang = lang.split("-")[0] lang_data = self._languages.get(base_lang) if not lang_data: return None # if a custom threshold is provided, use it if self._unlikely_threshold is not None: return self._unlikely_threshold else: return lang_data["threshold"] # type: ignore async def supports_language(self, language: str | None) -> bool: return await self.unlikely_threshold(language) is not None # our EOU model inference should be fast, 3 seconds is more than enough async def predict_end_of_turn( self, chat_ctx: llm.ChatContext, *, timeout: float | None = 3, ) -> float: messages: list[dict[str, Any]] = [] for item in chat_ctx.items: if item.type != "message": continue if item.role not in ("user", "assistant"): continue text_content = item.text_content if text_content: messages.append( { "role": item.role, "content": text_content, } ) messages = messages[-MAX_HISTORY_TURNS:] json_data = json.dumps({"chat_ctx": messages}).encode() result = await asyncio.wait_for( self._executor.do_inference(self._inference_method(), json_data), timeout=timeout, ) assert result is not None, "end_of_utterance prediction should always returns a result" result_json: dict[str, Any] = json.loads(result.decode()) logger.debug( "eou prediction", extra=result_json, ) return result_json["eou_probability"] # type: ignore
Helper class that provides a standard way to create an ABC using inheritance.
Ancestors
- abc.ABC
Subclasses
Methods
async def predict_end_of_turn(self, chat_ctx: llm.ChatContext, *, timeout: float | None = 3) ‑> float
-
Expand source code
async def predict_end_of_turn( self, chat_ctx: llm.ChatContext, *, timeout: float | None = 3, ) -> float: messages: list[dict[str, Any]] = [] for item in chat_ctx.items: if item.type != "message": continue if item.role not in ("user", "assistant"): continue text_content = item.text_content if text_content: messages.append( { "role": item.role, "content": text_content, } ) messages = messages[-MAX_HISTORY_TURNS:] json_data = json.dumps({"chat_ctx": messages}).encode() result = await asyncio.wait_for( self._executor.do_inference(self._inference_method(), json_data), timeout=timeout, ) assert result is not None, "end_of_utterance prediction should always returns a result" result_json: dict[str, Any] = json.loads(result.decode()) logger.debug( "eou prediction", extra=result_json, ) return result_json["eou_probability"] # type: ignore
async def supports_language(self, language: str | None) ‑> bool
-
Expand source code
async def supports_language(self, language: str | None) -> bool: return await self.unlikely_threshold(language) is not None
async def unlikely_threshold(self, language: str | None) ‑> float | None
-
Expand source code
async def unlikely_threshold(self, language: str | None) -> float | None: if language is None: return None lang = language.lower() # try the full language code first lang_data = self._languages.get(lang) # try the base language if the full language code is not found if lang_data is None and "-" in lang: base_lang = lang.split("-")[0] lang_data = self._languages.get(base_lang) if not lang_data: return None # if a custom threshold is provided, use it if self._unlikely_threshold is not None: return self._unlikely_threshold else: return lang_data["threshold"] # type: ignore