diff --git a/src/tai/ai_client.py b/src/tai/ai_client.py index a50457f..c411fd3 100644 --- a/src/tai/ai_client.py +++ b/src/tai/ai_client.py @@ -10,6 +10,7 @@ from openai import OpenAI DEFAULT_AI_HOST = "http://localhost:11434/v1" DEFAULT_MODEL = "gemma3:4b" +DEFAULT_EMBED_MODEL = "nomic-embed-text" @dataclass(slots=True) @@ -21,6 +22,7 @@ class AIConfig: api_key: str = "ollama" # Ollama ignores this; required by the openai client timeout_seconds: float = 120.0 max_tokens: int = 4096 + embed_model: str = DEFAULT_EMBED_MODEL extra_headers: dict[str, str] = field(default_factory=dict) @@ -106,3 +108,11 @@ class AIClient: def summary(self) -> str: """Human-readable description of the AI config.""" return f"host={self._config.host} model={self._config.model}" + + def embed(self, text: str) -> list[float]: + """Embed *text* using the configured embedding model via the OpenAI-compatible endpoint.""" + response = self._client.embeddings.create( + model=self._config.embed_model, + input=text, + ) + return list(response.data[0].embedding)