Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,6 @@ jobs:

- name: Run ruff format check
run: uv run ruff format --check src/ tests/

- name: Run mypy type check
run: uv run mypy src/mnemebrain_core/ --ignore-missing-imports
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"pydantic>=2.0",
"kuzu>=0.8",
"numpy>=1.26,<2.0",
"httpx>=0.27",
]

[project.optional-dependencies]
Expand All @@ -37,6 +38,7 @@ dev = [
"httpx>=0.27",
"pytest-cov>=5.0",
"ruff>=0.9",
"mypy>=1.19",
"asgi-lifespan>=2.0",
]

Expand Down
66 changes: 46 additions & 20 deletions src/mnemebrain_core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import logging
from dataclasses import dataclass
from datetime import datetime, timezone
from uuid import UUID
Expand All @@ -18,6 +19,8 @@
from mnemebrain_core.providers.base import EmbeddingProvider, EvidenceInput
from mnemebrain_core.store import KuzuGraphStore

logger = logging.getLogger(__name__)


@dataclass
class BeliefResult:
Expand Down Expand Up @@ -55,10 +58,20 @@ def __init__(
self._embedder = embedding_provider
if self._embedder is None:
self._embedder = self._auto_detect_embedder()
if self._embedder is None:
logger.warning(
"No embedding provider available. Running in degraded mode: "
"believe/explain/search will use text matching instead of "
"semantic similarity. Install sentence-transformers, set "
"EMBEDDING_BASE_URL+EMBEDDING_MODEL, or set OPENAI_API_KEY "
"to enable embeddings."
)

@staticmethod
def _auto_detect_embedder() -> EmbeddingProvider | None:
"""Try available embedding providers in order of preference."""
import os

# 1. Local sentence-transformers (no API key needed)
try:
from mnemebrain_core.providers.embeddings.sentence_transformers import (
Expand All @@ -68,10 +81,21 @@ def _auto_detect_embedder() -> EmbeddingProvider | None:
return SentenceTransformerProvider()
except ImportError:
pass
# 2. OpenAI API (requires OPENAI_API_KEY)
try:
import os
# 2. OpenAI-compatible server (Ollama, LM Studio, vLLM, etc.)
base_url = os.environ.get("EMBEDDING_BASE_URL")
model = os.environ.get("EMBEDDING_MODEL")
if base_url and model:
from mnemebrain_core.providers.embeddings.openai_compatible import (
OpenAICompatibleProvider,
)

return OpenAICompatibleProvider(
base_url=base_url,
model=model,
api_key=os.environ.get("EMBEDDING_API_KEY"),
)
# 3. OpenAI API (requires OPENAI_API_KEY)
try:
if os.environ.get("OPENAI_API_KEY"):
from mnemebrain_core.providers.embeddings.openai import (
OpenAIEmbeddingProvider,
Expand All @@ -80,17 +104,7 @@ def _auto_detect_embedder() -> EmbeddingProvider | None:
return OpenAIEmbeddingProvider()
except ImportError:
pass
return None # Will fail at use-time with clear message

def _get_embedder(self) -> EmbeddingProvider:
if self._embedder is None:
raise ImportError(
"No embedding provider available. "
"Install with: pip install mnemebrain-lite[embeddings] "
"(local) or pip install mnemebrain-lite[openai] and set "
"OPENAI_API_KEY"
)
return self._embedder
return None # Degraded mode — text matching only

def believe(
self,
Expand All @@ -101,8 +115,14 @@ def believe(
source_agent: str = "",
) -> BeliefResult:
"""Store a new belief with evidence. Merges if similar belief exists."""
embedding = self._get_embedder().embed(claim)
existing = self._store.find_similar(embedding, threshold=0.92)
embedding: list[float]
if self._embedder is not None:
embedding = self._embedder.embed(claim)
existing = self._store.find_similar(embedding, threshold=0.92)
else:
embedding = []
exact = self._store.find_by_claim(claim)
existing = [(exact, 1.0)] if exact else []

if existing:
belief = existing[0][0]
Expand Down Expand Up @@ -175,8 +195,11 @@ def retract(self, evidence_id: UUID) -> list[BeliefResult]:

def explain(self, claim: str) -> ExplanationResult | None:
"""Return full justification chain for a belief."""
embedding = self._get_embedder().embed(claim)
matches = self._store.find_similar(embedding, threshold=0.8)
if self._embedder is not None:
embedding = self._embedder.embed(claim)
matches = self._store.find_similar(embedding, threshold=0.8)
else:
matches = []

if not matches:
exact = self._store.find_by_claim(claim)
Expand Down Expand Up @@ -240,8 +263,11 @@ def search(
"""
from mnemebrain_core.engine import apply_conflict_policy, rank_score

embedding = self._get_embedder().embed(query)
raw_matches = self._store.find_similar(embedding, threshold=0.3)
if self._embedder is not None:
embedding = self._embedder.embed(query)
raw_matches = self._store.find_similar(embedding, threshold=0.3)
else:
raw_matches = self._store.find_by_text(query, limit=limit)

scored = [
(
Expand Down
45 changes: 45 additions & 0 deletions src/mnemebrain_core/providers/embeddings/openai_compatible.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""OpenAI-compatible embedding provider — works with Ollama, LM Studio, vLLM, etc."""

from __future__ import annotations

import numpy as np

from mnemebrain_core.providers.base import EmbeddingProvider


class OpenAICompatibleProvider(EmbeddingProvider):
"""Embedding provider using any OpenAI-compatible /v1/embeddings endpoint."""

def __init__(
self,
base_url: str,
model: str,
api_key: str | None = None,
) -> None:
import httpx

headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
self._client = httpx.Client(base_url=base_url, headers=headers, timeout=30.0)
self._model = model

def embed(self, text: str) -> list[float]:
"""Embed text via the /embeddings endpoint."""
response = self._client.post(
"/embeddings",
json={"input": text, "model": self._model},
)
if response.status_code != 200:
raise RuntimeError(
f"Embedding request failed ({response.status_code}): {response.text}"
)
return response.json()["data"][0]["embedding"]

def similarity(self, a: list[float], b: list[float]) -> float:
"""Cosine similarity between two vectors."""
a_arr, b_arr = np.array(a), np.array(b)
norm_a, norm_b = np.linalg.norm(a_arr), np.linalg.norm(b_arr)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a_arr, b_arr) / (norm_a * norm_b))
72 changes: 58 additions & 14 deletions src/mnemebrain_core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import json
from typing import Any, cast
from uuid import UUID

import kuzu
Expand All @@ -22,6 +23,27 @@ def __init__(self, db_path: str, *, max_db_size: int = 0) -> None:
self._conn = kuzu.Connection(self._db)
self._init_schema()

def _query(
self, statement: str, parameters: dict[str, Any] | None = None
) -> kuzu.QueryResult:
"""Execute a single Cypher statement and return the QueryResult.

Kuzu's ``execute`` is typed as returning ``QueryResult | list[QueryResult]``
but only returns a list when multiple statements are passed. This helper
narrows the type for single-statement calls used throughout the store.
"""
kwargs: dict[str, Any] = {}
if parameters is not None:
kwargs["parameters"] = parameters
result = self._conn.execute(statement, **kwargs)
assert isinstance(result, kuzu.QueryResult)
return result

@staticmethod
def _next_row(result: kuzu.QueryResult) -> list[Any]:
"""Get the next row as a list (Kuzu types ``list | dict``)."""
return cast(list[Any], result.get_next())

def _init_schema(self) -> None:
"""Create node/rel tables if they don't exist."""
self._conn.execute(
Expand Down Expand Up @@ -69,38 +91,38 @@ def upsert(self, belief: Belief, embedding: list[float] | None = None) -> None:

def get(self, belief_id: UUID) -> Belief | None:
"""Retrieve a belief by ID."""
result = self._conn.execute(
result = self._query(
"MATCH (b:Belief {id: $id}) RETURN b.data",
parameters={"id": str(belief_id)},
)
if not result.has_next():
return None
row = result.get_next()
row = self._next_row(result)
belief_data = json.loads(row[0])

# Load evidence
ev_result = self._conn.execute(
ev_result = self._query(
"MATCH (b:Belief {id: $id})-[:HAS_EVIDENCE]->(e:EvidenceNode) "
"RETURN e.data",
parameters={"id": str(belief_id)},
)
evidence_list = []
while ev_result.has_next():
ev_row = ev_result.get_next()
ev_row = self._next_row(ev_result)
evidence_list.append(json.loads(ev_row[0]))

belief_data["evidence"] = evidence_list
return Belief.model_validate(belief_data)

def get_evidence(self, evidence_id: UUID) -> Evidence | None:
"""Retrieve a single evidence item by ID."""
result = self._conn.execute(
result = self._query(
"MATCH (e:EvidenceNode {id: $id}) RETURN e.data",
parameters={"id": str(evidence_id)},
)
if not result.has_next():
return None
row = result.get_next()
row = self._next_row(result)
return Evidence.model_validate(json.loads(row[0]))

def update_evidence(self, evidence: Evidence) -> None:
Expand All @@ -113,13 +135,13 @@ def update_evidence(self, evidence: Evidence) -> None:

def find_beliefs_using(self, evidence_id: UUID) -> list[Belief]:
"""Find all beliefs that reference a given evidence item."""
result = self._conn.execute(
result = self._query(
"MATCH (b:Belief)-[:HAS_EVIDENCE]->(e:EvidenceNode {id: $eid}) RETURN b.id",
parameters={"eid": str(evidence_id)},
)
beliefs = []
while result.has_next():
row = result.get_next()
row = self._next_row(result)
belief = self.get(UUID(row[0]))
if belief:
beliefs.append(belief)
Expand All @@ -129,7 +151,7 @@ def find_similar(
self, embedding: list[float], threshold: float = 0.92
) -> list[tuple[Belief, float]]:
"""Find beliefs with similar embeddings. Returns (belief, similarity) pairs."""
result = self._conn.execute(
result = self._query(
"MATCH (b:Belief) WHERE size(b.embedding) > 0 RETURN b.id, b.embedding"
)
matches: list[tuple[Belief, float]] = []
Expand All @@ -139,7 +161,7 @@ def find_similar(
return []

while result.has_next():
row = result.get_next()
row = self._next_row(result)
stored_emb = np.array(row[1])
if stored_emb.shape != query_vec.shape:
continue # skip embeddings from a different provider
Expand All @@ -156,10 +178,10 @@ def find_similar(

def list_beliefs(self) -> list[Belief]:
"""List all beliefs in the store."""
result = self._conn.execute("MATCH (b:Belief) RETURN b.id")
result = self._query("MATCH (b:Belief) RETURN b.id")
beliefs = []
while result.has_next():
row = result.get_next()
row = self._next_row(result)
belief = self.get(UUID(row[0]))
if belief:
beliefs.append(belief)
Expand Down Expand Up @@ -193,11 +215,33 @@ def list_beliefs_filtered(
total = len(filtered)
return filtered[offset : offset + limit], total

def find_by_text(self, query: str, limit: int = 10) -> list[tuple[Belief, float]]:
"""Find beliefs by case-insensitive substring match on claim text.

Returns (belief, score) pairs sorted by relevance score.
"""
result = self._query("MATCH (b:Belief) RETURN b.id, b.data")
matches: list[tuple[Belief, float]] = []
query_lower = query.lower()

while result.has_next():
row = self._next_row(result)
data = json.loads(row[1])
claim = data.get("claim", "")
if query_lower in claim.lower():
score = len(query) / len(claim) if claim else 0.0
belief = self.get(UUID(row[0]))
if belief:
matches.append((belief, score))

matches.sort(key=lambda x: x[1], reverse=True)
return matches[:limit]

def find_by_claim(self, claim: str) -> Belief | None:
"""Find a belief by exact claim match."""
result = self._conn.execute("MATCH (b:Belief) RETURN b.id, b.data")
result = self._query("MATCH (b:Belief) RETURN b.id, b.data")
while result.has_next():
row = result.get_next()
row = self._next_row(result)
data = json.loads(row[1])
if data.get("claim") == claim:
return self.get(UUID(row[0]))
Expand Down
Loading
Loading