230 lines
7.0 KiB
Python
230 lines
7.0 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
OpenClaw 向量记忆系统 - 核心引擎
|
||
基于 硅基流动 BGE-M3 + Chroma + SQLite
|
||
"""
|
||
|
||
import chromadb
|
||
from chromadb.config import Settings
|
||
from openai import OpenAI
|
||
import sqlite3
|
||
import json
|
||
import os
|
||
from datetime import datetime
|
||
import uuid
|
||
|
||
|
||
class VectorMemorySystem:
|
||
"""向量记忆系统核心类"""
|
||
|
||
def __init__(self, persist_dir: str = "./data/memory", api_key: str = None):
|
||
"""
|
||
初始化向量记忆系统
|
||
|
||
Args:
|
||
persist_dir: 数据存储目录
|
||
api_key: 硅基流动 API Key
|
||
"""
|
||
|
||
# 1. 初始化硅基流动客户端
|
||
if not api_key:
|
||
api_key = os.getenv("SILICONFLOW_API_KEY")
|
||
if not api_key:
|
||
raise ValueError("请设置 SILICONFLOW_API_KEY 环境变量")
|
||
|
||
self.client = OpenAI(
|
||
api_key=api_key,
|
||
base_url="https://api.siliconflow.cn/v1"
|
||
)
|
||
|
||
# 2. 初始化 Chroma 向量库
|
||
self.chroma = chromadb.Client(Settings(
|
||
persist_directory=persist_dir,
|
||
anonymized_telemetry=False
|
||
))
|
||
self.collection = self.chroma.get_or_create_collection(
|
||
name="openclaw_memory",
|
||
metadata={"description": "OpenClaw long-term memory"}
|
||
)
|
||
|
||
# 3. 初始化 SQLite(用于持久化)
|
||
self.db_path = f"{persist_dir}/memory.db"
|
||
self._init_sqlite()
|
||
|
||
print(f"✅ 向量记忆系统初始化完成")
|
||
print(f" - 数据目录: {persist_dir}")
|
||
print(f" - 向量模型: BAAI/bge-m3")
|
||
|
||
def _init_sqlite(self):
|
||
"""初始化 SQLite 数据库"""
|
||
self.conn = sqlite3.connect(self.db_path)
|
||
self.conn.execute("""
|
||
CREATE TABLE IF NOT EXISTS memories (
|
||
id TEXT PRIMARY KEY,
|
||
content TEXT NOT NULL,
|
||
metadata TEXT,
|
||
importance INTEGER DEFAULT 3,
|
||
tier TEXT DEFAULT 'hot',
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
)
|
||
""")
|
||
self.conn.execute("""
|
||
CREATE INDEX IF NOT EXISTS idx_importance ON memories(importance)
|
||
""")
|
||
self.conn.execute("""
|
||
CREATE INDEX IF NOT EXISTS idx_created_at ON memories(created_at)
|
||
""")
|
||
self.conn.commit()
|
||
|
||
def _get_embedding(self, text: str) -> list:
|
||
"""调用 BGE-M3 获取向量"""
|
||
response = self.client.embeddings.create(
|
||
model="BAAI/bge-m3",
|
||
input=text
|
||
)
|
||
return response.data[0].embedding
|
||
|
||
def add_memory(self, content: str, metadata: dict = None, importance: int = 3) -> str:
|
||
"""
|
||
添加记忆(同时写入向量库 + SQLite)
|
||
|
||
Args:
|
||
content: 记忆内容
|
||
metadata: 元数据(字典)
|
||
importance: 重要性评分 1-5
|
||
|
||
Returns:
|
||
memory_id: 记忆唯一标识
|
||
"""
|
||
memory_id = str(uuid.uuid4())
|
||
|
||
# 1. 生成向量并存储到 Chroma
|
||
embedding = self._get_embedding(content)
|
||
self.collection.add(
|
||
ids=[memory_id],
|
||
embeddings=[embedding],
|
||
documents=[content],
|
||
metadatas=[metadata or {"importance": importance}]
|
||
)
|
||
|
||
# 2. 写入 SQLite 持久化
|
||
self.conn.execute(
|
||
"""INSERT INTO memories (id, content, metadata, importance)
|
||
VALUES (?, ?, ?, ?)""",
|
||
(memory_id, content, json.dumps(metadata or {}), importance)
|
||
)
|
||
self.conn.commit()
|
||
|
||
print(f"✅ 添加记忆: {content[:30]}...")
|
||
return memory_id
|
||
|
||
def search(self, query: str, top_k: int = 5) -> list:
|
||
"""
|
||
语义搜索
|
||
|
||
Args:
|
||
query: 查询文本
|
||
top_k: 返回数量
|
||
|
||
Returns:
|
||
记忆列表(包含 content, distance, metadata)
|
||
"""
|
||
# 1. 查询向量
|
||
query_embedding = self._get_embedding(query)
|
||
|
||
# 2. 向量相似度搜索
|
||
results = self.collection.query(
|
||
query_embeddings=[query_embedding],
|
||
n_results=top_k
|
||
)
|
||
|
||
# 3. 格式化返回
|
||
memories = []
|
||
for i, doc in enumerate(results['documents'][0]):
|
||
memories.append({
|
||
'id': results['ids'][0][i],
|
||
'content': doc,
|
||
'distance': results['distances'][0][i],
|
||
'similarity': 1 - results['distances'][0][i],
|
||
'metadata': results['metadatas'][0][i]
|
||
})
|
||
|
||
return memories
|
||
|
||
def get_recent(self, limit: int = 10) -> list:
|
||
"""获取最近添加的记忆"""
|
||
cursor = self.conn.execute("""
|
||
SELECT id, content, metadata, importance, created_at
|
||
FROM memories
|
||
ORDER BY created_at DESC
|
||
LIMIT ?
|
||
""", (limit,))
|
||
|
||
return [{
|
||
'id': row[0],
|
||
'content': row[1],
|
||
'metadata': json.loads(row[2]) if row[2] else {},
|
||
'importance': row[3],
|
||
'created_at': row[4]
|
||
} for row in cursor.fetchall()]
|
||
|
||
def count(self) -> int:
|
||
"""获取记忆总数"""
|
||
cursor = self.conn.execute("SELECT COUNT(*) FROM memories")
|
||
return cursor.fetchone()[0]
|
||
|
||
def delete(self, memory_id: str):
|
||
"""删除指定记忆"""
|
||
self.collection.delete(ids=[memory_id])
|
||
self.conn.execute("DELETE FROM memories WHERE id=?", (memory_id,))
|
||
self.conn.commit()
|
||
print(f"🗑️ 删除记忆: {memory_id}")
|
||
|
||
|
||
# ============ 便捷函数 ============
|
||
|
||
def get_memory_system(api_key: str = None) -> VectorMemorySystem:
|
||
"""获取记忆系统单例(支持传入 API Key)"""
|
||
global _memory_system
|
||
|
||
if '_memory_system' not in globals() or _memory_system is None:
|
||
_memory_system = VectorMemorySystem(api_key=api_key)
|
||
|
||
return _memory_system
|
||
|
||
|
||
# 测试代码
|
||
if __name__ == "__main__":
|
||
import sys
|
||
|
||
# 检查 API Key
|
||
api_key = os.getenv("SILICONFLOW_API_KEY")
|
||
if not api_key and len(sys.argv) > 1:
|
||
api_key = sys.argv[1]
|
||
|
||
if not api_key:
|
||
print("❌ 请设置 SILICONFLOW_API_KEY 环境变量或作为参数传入")
|
||
print(" export SILICONFLOW_API_KEY='sk-xxx'")
|
||
sys.exit(1)
|
||
|
||
# 初始化
|
||
vm = VectorMemorySystem(api_key=api_key)
|
||
|
||
# 测试添加
|
||
print("\n📝 测试添加记忆...")
|
||
vm.add_memory(
|
||
content="今天学习了向量数据库+语义搜索的方案,采用硅基流动BGE-M3模型",
|
||
metadata={"tags": ["学习", "向量"], "source": "test"},
|
||
importance=4
|
||
)
|
||
|
||
# 测试搜索
|
||
print("\n🔍 测试语义搜索...")
|
||
results = vm.search("AI 记忆系统")
|
||
for r in results:
|
||
print(f" - {r['content'][:50]}... (相似度: {r['similarity']:.2%})")
|
||
|
||
# 统计
|
||
print(f"\n📊 当前记忆总数: {vm.count()}")
|