Module livekit.plugins.turn_detector.base
Classes
class EOUModelBase (model_type: EOUModelType = 'en',
inference_executor: InferenceExecutor | None = None)-
Expand source code
class EOUModelBase(ABC): def __init__( self, model_type: EOUModelType = "en", # default to smaller, english-only model inference_executor: InferenceExecutor | None = None, ) -> None: self._model_type = model_type self._executor = ( inference_executor or get_current_job_context().inference_executor ) config_fname = _download_from_hf_hub( HG_MODEL, "languages.json", revision=MODEL_REVISIONS[self._model_type], local_files_only=True, ) with open(config_fname, "r") as f: self._languages = json.load(f) @abstractmethod def _inference_method(self): ... def unlikely_threshold(self, language: str | None) -> float | None: if language is None: return None lang = language.lower() if lang in self._languages: return self._languages[lang]["threshold"] if "-" in lang: part = lang.split("-")[0] if part in self._languages: return self._languages[part]["threshold"] logger.warning(f"Language {language} not supported by EOU model") return None def supports_language(self, language: str | None) -> bool: return self.unlikely_threshold(language) is not None async def predict_eou(self, chat_ctx: llm.ChatContext) -> float: return await self.predict_end_of_turn(chat_ctx) # our EOU model inference should be fast, 3 seconds is more than enough async def predict_end_of_turn( self, chat_ctx: llm.ChatContext, *, timeout: float | None = 3 ) -> float: messages = [] for msg in chat_ctx.messages: if msg.role not in ("user", "assistant"): continue if isinstance(msg.content, str): messages.append( { "role": msg.role, "content": msg.content, } ) elif isinstance(msg.content, list): for cnt in msg.content: if isinstance(cnt, str): messages.append( { "role": msg.role, "content": cnt, } ) break messages = messages[-MAX_HISTORY_TURNS:] json_data = json.dumps({"chat_ctx": messages}).encode() result = await asyncio.wait_for( self._executor.do_inference(self._inference_method(), json_data), timeout=timeout, ) assert result is not None, ( "end_of_utterance prediction should always returns a result" ) result_json = json.loads(result.decode()) logger.debug( "eou prediction", extra=result_json, ) return result_json["eou_probability"]
Helper class that provides a standard way to create an ABC using inheritance.
Ancestors
- abc.ABC
Subclasses
Methods
async def predict_end_of_turn(self, chat_ctx: llm.ChatContext, *, timeout: float | None = 3) ‑> float
-
Expand source code
async def predict_end_of_turn( self, chat_ctx: llm.ChatContext, *, timeout: float | None = 3 ) -> float: messages = [] for msg in chat_ctx.messages: if msg.role not in ("user", "assistant"): continue if isinstance(msg.content, str): messages.append( { "role": msg.role, "content": msg.content, } ) elif isinstance(msg.content, list): for cnt in msg.content: if isinstance(cnt, str): messages.append( { "role": msg.role, "content": cnt, } ) break messages = messages[-MAX_HISTORY_TURNS:] json_data = json.dumps({"chat_ctx": messages}).encode() result = await asyncio.wait_for( self._executor.do_inference(self._inference_method(), json_data), timeout=timeout, ) assert result is not None, ( "end_of_utterance prediction should always returns a result" ) result_json = json.loads(result.decode()) logger.debug( "eou prediction", extra=result_json, ) return result_json["eou_probability"]
async def predict_eou(self, chat_ctx: llm.ChatContext) ‑> float
-
Expand source code
async def predict_eou(self, chat_ctx: llm.ChatContext) -> float: return await self.predict_end_of_turn(chat_ctx)
def supports_language(self, language: str | None) ‑> bool
-
Expand source code
def supports_language(self, language: str | None) -> bool: return self.unlikely_threshold(language) is not None
def unlikely_threshold(self, language: str | None) ‑> float | None
-
Expand source code
def unlikely_threshold(self, language: str | None) -> float | None: if language is None: return None lang = language.lower() if lang in self._languages: return self._languages[lang]["threshold"] if "-" in lang: part = lang.split("-")[0] if part in self._languages: return self._languages[part]["threshold"] logger.warning(f"Language {language} not supported by EOU model") return None