"""Tests for session_store with mocked ChromaDB.""" from __future__ import annotations from pathlib import Path from unittest.mock import MagicMock, patch from tai.session_store import PastSession, SessionStore, _build_embed_text def _make_chromadb_mock() -> MagicMock: collection = MagicMock() collection.count.return_value = 0 client = MagicMock() client.get_or_create_collection.return_value = collection chroma_mod = MagicMock() chroma_mod.PersistentClient.return_value = client return chroma_mod def _make_ai_mock(embedding: list[float] | None = None) -> MagicMock: ai = MagicMock() ai.embed.return_value = embedding or [0.1, 0.2, 0.3] return ai def test_build_embed_text_contains_host_issue_and_summary() -> None: text = _build_embed_text(host="web01", issue="sssd broken", summary="Unit missing") assert "host: web01" in text assert "issue: sssd broken" in text assert "Unit missing" in text def test_index_session_upserts_with_metadata(tmp_path: Path) -> None: chroma_mock = _make_chromadb_mock() collection = chroma_mock.PersistentClient.return_value.get_or_create_collection.return_value ai = _make_ai_mock() with patch.dict("sys.modules", {"chromadb": chroma_mock}): store = SessionStore(tmp_path / "store") session_id = store.index_session("web01", "sssd broken", "summary text", ai) assert session_id collection.upsert.assert_called_once() args = collection.upsert.call_args.kwargs assert args["metadatas"][0]["host"] == "web01" assert args["metadatas"][0]["issue"] == "sssd broken" def test_query_returns_empty_when_no_docs(tmp_path: Path) -> None: chroma_mock = _make_chromadb_mock() ai = _make_ai_mock() with patch.dict("sys.modules", {"chromadb": chroma_mock}): store = SessionStore(tmp_path / "store") results = store.query("why sssd", "web01", ai) assert results == [] def test_query_returns_past_sessions(tmp_path: Path) -> None: chroma_mock = _make_chromadb_mock() collection = chroma_mock.PersistentClient.return_value.get_or_create_collection.return_value collection.count.return_value = 1 collection.query.return_value = { "ids": [["20260506T120000Z"]], "documents": [["Root cause: package missing"]], "metadatas": [[{"host": "web01", "issue": "sssd broken"}]], } ai = _make_ai_mock() with patch.dict("sys.modules", {"chromadb": chroma_mock}): store = SessionStore(tmp_path / "store") results = store.query("sssd issue", "web01", ai) assert len(results) == 1 assert isinstance(results[0], PastSession) assert results[0].host == "web01" assert "package missing" in results[0].summary def test_list_recent_returns_sessions_sorted_desc(tmp_path: Path) -> None: chroma_mock = _make_chromadb_mock() collection = chroma_mock.PersistentClient.return_value.get_or_create_collection.return_value collection.count.return_value = 3 collection.get.return_value = { "ids": ["20260506T120000Z", "20260507T120000Z", "20260505T120000Z"], "documents": ["older", "newer", "oldest"], "metadatas": [ {"host": "web01", "issue": "i1"}, {"host": "web01", "issue": "i2"}, {"host": "db01", "issue": "i3"}, ], } with patch.dict("sys.modules", {"chromadb": chroma_mock}): store = SessionStore(tmp_path / "store") results = store.list_recent(limit=2) assert len(results) == 2 assert results[0].session_id == "20260507T120000Z" assert results[1].session_id == "20260506T120000Z" def test_search_keyword_filters_by_term_and_host(tmp_path: Path) -> None: chroma_mock = _make_chromadb_mock() collection = chroma_mock.PersistentClient.return_value.get_or_create_collection.return_value collection.count.return_value = 3 collection.get.return_value = { "ids": ["20260505T120000Z", "20260506T120000Z", "20260507T120000Z"], "documents": [ "Root cause: nginx config typo", "Root cause: package missing", "Root cause: nginx port conflict", ], "metadatas": [ {"host": "web01", "issue": "nginx fails"}, {"host": "web01", "issue": "sssd fails"}, {"host": "db01", "issue": "nginx start failed"}, ], } with patch.dict("sys.modules", {"chromadb": chroma_mock}): store = SessionStore(tmp_path / "store") results = store.search_keyword("nginx", host="web01", limit=5) assert len(results) == 1 assert results[0].host == "web01" assert "nginx" in results[0].issue.lower()