feat(rag): implement Tier 1 in-memory RAG for interactive follow-ups
Some checks failed
CI / test (push) Failing after 15s
Some checks failed
CI / test (push) Failing after 15s
- Add embed() to AIClient using Ollama nomic-embed-text via /v1/embeddings - Add DEFAULT_EMBED_MODEL and embed_model field to AIConfig - New rag_retriever.py: chunk_report(), EmbeddedChunk, retrieve() (pure-Python cosine) - prompt_builder: add build_message_with_chunks() for RAG-aware follow-up prompts - cli: add --no-rag flag, embed report chunks after collection, retrieve top-5 per question - Graceful fallback to full-context if embedding model unavailable - 16 new tests in test_rag_retriever.py (67 total, all passing) - Add chromadb>=0.5 as optional [rag] dep in pyproject.toml - README: add step 3 (pull nomic-embed-text), update Suggested Tooling table
This commit is contained in:
@@ -18,7 +18,13 @@ from tai.collectors import CollectionReport, collect_from_plan
|
|||||||
from tai.input_parser import InputValidationError, build_request
|
from tai.input_parser import InputValidationError, build_request
|
||||||
from tai.models import TroubleshootRequest
|
from tai.models import TroubleshootRequest
|
||||||
from tai.plan import plan_from_request
|
from tai.plan import plan_from_request
|
||||||
from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
|
from tai.prompt_builder import (
|
||||||
|
build_followup_message,
|
||||||
|
build_message_with_chunks,
|
||||||
|
build_system_prompt,
|
||||||
|
build_user_message,
|
||||||
|
)
|
||||||
|
from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve
|
||||||
from tai.session_log import SessionLogger
|
from tai.session_log import SessionLogger
|
||||||
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
|
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
|
||||||
|
|
||||||
@@ -97,6 +103,13 @@ def run(
|
|||||||
help="Optional JSONL file path to log AI and session output.",
|
help="Optional JSONL file path to log AI and session output.",
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
|
no_rag: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--no-rag",
|
||||||
|
help="Disable RAG; send full diagnostics to AI instead of retrieved chunks.",
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start an interactive troubleshooting session scaffold."""
|
"""Start an interactive troubleshooting session scaffold."""
|
||||||
try:
|
try:
|
||||||
@@ -147,6 +160,7 @@ def run(
|
|||||||
analyze=analyze,
|
analyze=analyze,
|
||||||
interactive=interactive,
|
interactive=interactive,
|
||||||
ai_config=ai_config,
|
ai_config=ai_config,
|
||||||
|
no_rag=no_rag,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -169,6 +183,7 @@ async def _async_main(
|
|||||||
analyze: bool,
|
analyze: bool,
|
||||||
interactive: bool,
|
interactive: bool,
|
||||||
ai_config: AIConfig,
|
ai_config: AIConfig,
|
||||||
|
no_rag: bool,
|
||||||
logger: SessionLogger | None,
|
logger: SessionLogger | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Open a single SSH session and run probe / collection / analysis through it."""
|
"""Open a single SSH session and run probe / collection / analysis through it."""
|
||||||
@@ -219,7 +234,7 @@ async def _async_main(
|
|||||||
_run_analysis(ai_config, req.issue, report, logger=logger)
|
_run_analysis(ai_config, req.issue, report, logger=logger)
|
||||||
|
|
||||||
if interactive:
|
if interactive:
|
||||||
await _interactive_loop(session, req, ai_config, report, logger=logger)
|
await _interactive_loop(session, req, ai_config, report, no_rag=no_rag, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
async def _interactive_loop(
|
async def _interactive_loop(
|
||||||
@@ -227,6 +242,8 @@ async def _interactive_loop(
|
|||||||
req: TroubleshootRequest,
|
req: TroubleshootRequest,
|
||||||
ai_config: AIConfig,
|
ai_config: AIConfig,
|
||||||
report: CollectionReport | None,
|
report: CollectionReport | None,
|
||||||
|
*,
|
||||||
|
no_rag: bool = False,
|
||||||
logger: SessionLogger | None,
|
logger: SessionLogger | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run a follow-up loop for collecting and conversational analysis."""
|
"""Run a follow-up loop for collecting and conversational analysis."""
|
||||||
@@ -241,6 +258,13 @@ async def _interactive_loop(
|
|||||||
)
|
)
|
||||||
|
|
||||||
prior_questions: list[str] = []
|
prior_questions: list[str] = []
|
||||||
|
embedded_chunks: list[EmbeddedChunk] | None = None
|
||||||
|
ai_embed = AIClient(ai_config)
|
||||||
|
|
||||||
|
if not no_rag and report is not None:
|
||||||
|
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed)
|
||||||
|
if embedded_chunks is not None:
|
||||||
|
console.print(f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -280,6 +304,12 @@ async def _interactive_loop(
|
|||||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||||
report = await collect_from_plan(session, plan)
|
report = await collect_from_plan(session, plan)
|
||||||
_handle_collection_report(report)
|
_handle_collection_report(report)
|
||||||
|
if not no_rag:
|
||||||
|
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed)
|
||||||
|
if embedded_chunks is not None:
|
||||||
|
console.print(
|
||||||
|
f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]"
|
||||||
|
)
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.log_event(
|
logger.log_event(
|
||||||
"collection_summary",
|
"collection_summary",
|
||||||
@@ -306,6 +336,7 @@ async def _interactive_loop(
|
|||||||
report,
|
report,
|
||||||
"Provide an updated diagnosis from the current diagnostics.",
|
"Provide an updated diagnosis from the current diagnostics.",
|
||||||
prior_questions,
|
prior_questions,
|
||||||
|
embedded_chunks=embedded_chunks,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
prior_questions.append("/analyze")
|
prior_questions.append("/analyze")
|
||||||
@@ -318,6 +349,12 @@ async def _interactive_loop(
|
|||||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||||
report = await collect_from_plan(session, plan)
|
report = await collect_from_plan(session, plan)
|
||||||
_handle_collection_report(report)
|
_handle_collection_report(report)
|
||||||
|
if not no_rag:
|
||||||
|
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed)
|
||||||
|
if embedded_chunks is not None:
|
||||||
|
console.print(
|
||||||
|
f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]"
|
||||||
|
)
|
||||||
|
|
||||||
if report is None:
|
if report is None:
|
||||||
console.print("[red]No diagnostics available to analyze.[/red]")
|
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||||
@@ -329,6 +366,7 @@ async def _interactive_loop(
|
|||||||
report,
|
report,
|
||||||
command,
|
command,
|
||||||
prior_questions,
|
prior_questions,
|
||||||
|
embedded_chunks=embedded_chunks,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
prior_questions.append(command)
|
prior_questions.append(command)
|
||||||
@@ -336,6 +374,24 @@ async def _interactive_loop(
|
|||||||
logger.log_event("interactive_followup", {"question": command})
|
logger.log_event("interactive_followup", {"question": command})
|
||||||
|
|
||||||
|
|
||||||
|
def _try_embed_report(
|
||||||
|
report: CollectionReport, ai: AIClient
|
||||||
|
) -> list[EmbeddedChunk] | None:
|
||||||
|
"""Embed all diagnostic chunks from *report*; returns None on any failure.
|
||||||
|
|
||||||
|
Failures are expected when the embedding model is not yet pulled or the
|
||||||
|
AI backend is unavailable — in those cases the caller falls back to
|
||||||
|
sending the full report as context.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
chunks = chunk_report(report)
|
||||||
|
if not chunks:
|
||||||
|
return None
|
||||||
|
return [EmbeddedChunk(chunk=c, embedding=ai.embed(c.content)) for c in chunks]
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _handle_probe_result(result: SSHCommandResult) -> None:
|
def _handle_probe_result(result: SSHCommandResult) -> None:
|
||||||
"""Handle and render probe output for success or failure."""
|
"""Handle and render probe output for success or failure."""
|
||||||
console.print("[dim]▶ SSH probe:[/dim] uname -a")
|
console.print("[dim]▶ SSH probe:[/dim] uname -a")
|
||||||
@@ -417,14 +473,32 @@ def _run_followup_analysis(
|
|||||||
question: str,
|
question: str,
|
||||||
prior_questions: list[str],
|
prior_questions: list[str],
|
||||||
*,
|
*,
|
||||||
|
embedded_chunks: list[EmbeddedChunk] | None = None,
|
||||||
logger: SessionLogger | None,
|
logger: SessionLogger | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run grounded follow-up analysis re-anchored to current diagnostics."""
|
"""Run grounded follow-up analysis re-anchored to current diagnostics.
|
||||||
|
|
||||||
|
When *embedded_chunks* is provided the question is embedded and the top-5
|
||||||
|
most relevant chunks are used instead of the full report, reducing token
|
||||||
|
usage. Falls back to full-context on any embedding failure.
|
||||||
|
"""
|
||||||
console.print()
|
console.print()
|
||||||
console.print(Rule("[bold cyan]AI Response[/bold cyan]", style="cyan"))
|
console.print(Rule("[bold cyan]AI Response[/bold cyan]", style="cyan"))
|
||||||
console.print()
|
console.print()
|
||||||
ai = AIClient(ai_config)
|
ai = AIClient(ai_config)
|
||||||
system_prompt = build_system_prompt()
|
system_prompt = build_system_prompt()
|
||||||
|
|
||||||
|
user_message: str
|
||||||
|
if embedded_chunks is not None:
|
||||||
|
try:
|
||||||
|
q_embedding = ai.embed(question)
|
||||||
|
retrieved = retrieve(q_embedding, embedded_chunks, top_k=5)
|
||||||
|
user_message = build_message_with_chunks(
|
||||||
|
issue, report.host, retrieved, question, prior_questions
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
user_message = build_followup_message(issue, report, question, prior_questions)
|
||||||
|
else:
|
||||||
user_message = build_followup_message(issue, report, question, prior_questions)
|
user_message = build_followup_message(issue, report, question, prior_questions)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from tai.collectors import CollectionReport
|
from tai.collectors import CollectionReport
|
||||||
|
from tai.rag_retriever import Chunk
|
||||||
|
|
||||||
_SYSTEM_PROMPT = """\
|
_SYSTEM_PROMPT = """\
|
||||||
You are an expert Linux systems administrator and troubleshooting assistant.
|
You are an expert Linux systems administrator and troubleshooting assistant.
|
||||||
@@ -99,3 +100,41 @@ def build_followup_message(
|
|||||||
"If evidence is insufficient, explicitly say so."
|
"If evidence is insufficient, explicitly say so."
|
||||||
)
|
)
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def build_message_with_chunks(
|
||||||
|
issue: str,
|
||||||
|
host: str,
|
||||||
|
chunks: list[Chunk],
|
||||||
|
question: str,
|
||||||
|
prior_questions: list[str],
|
||||||
|
) -> str:
|
||||||
|
"""Build a follow-up message using only semantically retrieved diagnostic chunks.
|
||||||
|
|
||||||
|
Used by the RAG path: instead of sending the full report, only the top-k
|
||||||
|
most relevant chunks are included, reducing token usage and focusing the AI.
|
||||||
|
"""
|
||||||
|
lines: list[str] = []
|
||||||
|
lines.append(f"## Issue reported\n\n{issue}\n")
|
||||||
|
lines.append(f"## Target host\n\n{host}\n")
|
||||||
|
lines.append("## Most relevant diagnostics (retrieved by semantic similarity)\n")
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
lines.append(f"### {chunk.name}\n")
|
||||||
|
lines.append(chunk.content)
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
lines.append("## Follow-up")
|
||||||
|
|
||||||
|
if prior_questions:
|
||||||
|
lines.append("\nRecent user follow-up questions:")
|
||||||
|
for idx, q in enumerate(prior_questions[-5:], start=1):
|
||||||
|
lines.append(f"{idx}. {q}")
|
||||||
|
|
||||||
|
lines.append("\nCurrent follow-up question:")
|
||||||
|
lines.append(question)
|
||||||
|
lines.append(
|
||||||
|
"\nAnswer strictly from the retrieved diagnostics above. "
|
||||||
|
"If evidence is insufficient, explicitly say so."
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|||||||
87
src/tai/rag_retriever.py
Normal file
87
src/tai/rag_retriever.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""In-memory RAG retriever for diagnostic report chunks (Tier 1).
|
||||||
|
|
||||||
|
Chunks one CollectionReport item per Chunk, embeds via AIClient, then
|
||||||
|
ranks chunks against a question using pure-Python cosine similarity.
|
||||||
|
No external vector store required — everything lives in process memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from tai.collectors import CollectionReport
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class Chunk:
|
||||||
|
"""A single retrievable piece of diagnostic content."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class EmbeddedChunk:
|
||||||
|
"""A Chunk paired with its embedding vector."""
|
||||||
|
|
||||||
|
chunk: Chunk
|
||||||
|
embedding: list[float]
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_report(report: CollectionReport) -> list[Chunk]:
|
||||||
|
"""Split a CollectionReport into one Chunk per diagnostic item.
|
||||||
|
|
||||||
|
Items that SSH could not execute at all (exit 255, no output) are dropped —
|
||||||
|
they carry no diagnostic signal.
|
||||||
|
"""
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
for item in report.items:
|
||||||
|
result = item.result
|
||||||
|
if result.exit_code == 255 and not result.stdout and not result.stderr:
|
||||||
|
continue
|
||||||
|
|
||||||
|
parts: list[str] = [
|
||||||
|
f"Command: {result.command}",
|
||||||
|
f"Exit code: {result.exit_code}",
|
||||||
|
]
|
||||||
|
if result.stdout:
|
||||||
|
parts.append(f"stdout:\n{result.stdout.strip()}")
|
||||||
|
if result.stderr:
|
||||||
|
parts.append(f"stderr:\n{result.stderr.strip()}")
|
||||||
|
if not result.stdout and not result.stderr:
|
||||||
|
parts.append("(no output)")
|
||||||
|
|
||||||
|
chunks.append(Chunk(name=item.name, content="\n".join(parts)))
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||||
|
"""Return cosine similarity in [-1, 1] using pure Python (no numpy)."""
|
||||||
|
dot = sum(x * y for x, y in zip(a, b, strict=False))
|
||||||
|
norm_a = math.sqrt(sum(x * x for x in a))
|
||||||
|
norm_b = math.sqrt(sum(x * x for x in b))
|
||||||
|
if norm_a == 0.0 or norm_b == 0.0:
|
||||||
|
return 0.0
|
||||||
|
return dot / (norm_a * norm_b)
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve(
|
||||||
|
question_embedding: list[float],
|
||||||
|
embedded_chunks: list[EmbeddedChunk],
|
||||||
|
*,
|
||||||
|
top_k: int = 5,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""Return the *top_k* chunks most similar to *question_embedding*.
|
||||||
|
|
||||||
|
Chunks are ranked by cosine similarity in descending order.
|
||||||
|
If *embedded_chunks* is empty or *top_k* is zero, returns an empty list.
|
||||||
|
"""
|
||||||
|
if not embedded_chunks or top_k <= 0:
|
||||||
|
return []
|
||||||
|
scored: list[tuple[float, Chunk]] = [
|
||||||
|
(_cosine_similarity(question_embedding, ec.embedding), ec.chunk)
|
||||||
|
for ec in embedded_chunks
|
||||||
|
]
|
||||||
|
scored.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
return [chunk for _, chunk in scored[:top_k]]
|
||||||
173
tests/test_rag_retriever.py
Normal file
173
tests/test_rag_retriever.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""Tests for rag_retriever — pure-Python, no network calls."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from tai.collectors import CollectedItem, CollectionReport
|
||||||
|
from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve
|
||||||
|
from tai.ssh_client import SSHCommandResult
|
||||||
|
|
||||||
|
|
||||||
|
def _report(*items: tuple[str, str, int]) -> CollectionReport:
|
||||||
|
"""Build a CollectionReport from (name, stdout, exit_code) tuples."""
|
||||||
|
return CollectionReport(
|
||||||
|
host="test-host",
|
||||||
|
items=[
|
||||||
|
CollectedItem(
|
||||||
|
name=name,
|
||||||
|
result=SSHCommandResult(
|
||||||
|
command=f"cmd-{name}",
|
||||||
|
exit_code=code,
|
||||||
|
stdout=stdout,
|
||||||
|
stderr="",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for name, stdout, code in items
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# chunk_report
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_report_creates_one_chunk_per_item() -> None:
|
||||||
|
report = _report(("kernel", "Linux test 6.1", 0), ("journal", "Started nginx.", 0))
|
||||||
|
chunks = chunk_report(report)
|
||||||
|
assert len(chunks) == 2
|
||||||
|
assert chunks[0].name == "kernel"
|
||||||
|
assert chunks[1].name == "journal"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_report_includes_stdout_in_content() -> None:
|
||||||
|
report = _report(("kernel", "Linux test 6.1", 0))
|
||||||
|
chunks = chunk_report(report)
|
||||||
|
assert "Linux test 6.1" in chunks[0].content
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_report_includes_exit_code_in_content() -> None:
|
||||||
|
report = _report(("fail", "error output", 1))
|
||||||
|
chunks = chunk_report(report)
|
||||||
|
assert "Exit code: 1" in chunks[0].content
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_report_skips_ssh_unreachable_items() -> None:
|
||||||
|
"""Items with exit 255 and no output represent SSH failures and are dropped."""
|
||||||
|
report = CollectionReport(
|
||||||
|
host="test-host",
|
||||||
|
items=[
|
||||||
|
CollectedItem(
|
||||||
|
name="unreachable",
|
||||||
|
result=SSHCommandResult(
|
||||||
|
command="some-cmd", exit_code=255, stdout="", stderr=""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
CollectedItem(
|
||||||
|
name="ok",
|
||||||
|
result=SSHCommandResult(
|
||||||
|
command="uname -a", exit_code=0, stdout="Linux", stderr=""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
chunks = chunk_report(report)
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert chunks[0].name == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_report_keeps_exit_255_with_output() -> None:
|
||||||
|
"""Exit 255 with stderr present is a real failure — keep it."""
|
||||||
|
report = CollectionReport(
|
||||||
|
host="test-host",
|
||||||
|
items=[
|
||||||
|
CollectedItem(
|
||||||
|
name="partial",
|
||||||
|
result=SSHCommandResult(
|
||||||
|
command="some-cmd",
|
||||||
|
exit_code=255,
|
||||||
|
stdout="",
|
||||||
|
stderr="Permission denied",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
chunks = chunk_report(report)
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert "Permission denied" in chunks[0].content
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_report_notes_no_output() -> None:
|
||||||
|
report = CollectionReport(
|
||||||
|
host="test-host",
|
||||||
|
items=[
|
||||||
|
CollectedItem(
|
||||||
|
name="silent",
|
||||||
|
result=SSHCommandResult(command="cmd", exit_code=0, stdout="", stderr=""),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
chunks = chunk_report(report)
|
||||||
|
assert "(no output)" in chunks[0].content
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _cosine_similarity
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_similarity_identical_vectors() -> None:
|
||||||
|
v = [1.0, 0.0, 0.0]
|
||||||
|
assert abs(_cosine_similarity(v, v) - 1.0) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_similarity_orthogonal_vectors() -> None:
|
||||||
|
a = [1.0, 0.0]
|
||||||
|
b = [0.0, 1.0]
|
||||||
|
assert abs(_cosine_similarity(a, b)) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_similarity_opposite_vectors() -> None:
|
||||||
|
a = [1.0, 0.0]
|
||||||
|
b = [-1.0, 0.0]
|
||||||
|
assert abs(_cosine_similarity(a, b) - (-1.0)) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_similarity_zero_vector_returns_zero() -> None:
|
||||||
|
assert _cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# retrieve
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _embedded(name: str, vec: list[float]) -> EmbeddedChunk:
|
||||||
|
return EmbeddedChunk(chunk=Chunk(name=name, content=f"content of {name}"), embedding=vec)
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_returns_top_k_by_similarity() -> None:
|
||||||
|
chunks = [
|
||||||
|
_embedded("close", [1.0, 0.0]), # most similar
|
||||||
|
_embedded("mid", [0.7, 0.7]),
|
||||||
|
_embedded("far", [0.0, 1.0]), # orthogonal to query
|
||||||
|
]
|
||||||
|
query = [1.0, 0.0]
|
||||||
|
result = retrieve(query, chunks, top_k=2)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0].name == "close"
|
||||||
|
assert result[1].name == "mid"
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_respects_top_k_larger_than_pool() -> None:
|
||||||
|
chunks = [_embedded("only", [1.0, 0.0])]
|
||||||
|
result = retrieve([1.0, 0.0], chunks, top_k=10)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_empty_pool_returns_empty() -> None:
|
||||||
|
assert retrieve([1.0, 0.0], [], top_k=5) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_top_k_zero_returns_empty() -> None:
|
||||||
|
chunks = [_embedded("x", [1.0, 0.0])]
|
||||||
|
assert retrieve([1.0, 0.0], chunks, top_k=0) == []
|
||||||
Reference in New Issue
Block a user