1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
|
import os
import chromadb
import markdown
import requests
class EmbeddingBot:
def __init__(self, api_url="https://api.siliconflow.cn/v1/embeddings",
api_key="", model_name="BAAI/bge-m3"):
"""初始化嵌入 API 和向量数据库
Args:
api_url: 嵌入 API 地址
api_key: 可选的 API 密钥
model_name: 使用的嵌入模型名称
"""
self.api_url = api_url
self.api_key = api_key
self.model_name = model_name
# 初始化 ChromaDB (持久化模式)
db_path = os.path.join("db", "files")
os.makedirs(db_path, exist_ok=True)
from chromadb.config import Settings
self.chroma_client = chromadb.PersistentClient(
path=db_path,
settings=Settings(anonymized_telemetry=False)
)
self.collection = self.chroma_client.get_or_create_collection(name="knowledge_base",
metadata={"dimension": 1024})
def extract_text(self, file_path):
"""提取文件内容"""
text = ""
if file_path.endswith('.md'):
with open(file_path, 'r', encoding='utf-8') as f:
text = markdown.markdown(f.read())
return text
def process_file(self, file_path, metadata=None):
"""处理文件并存储向量"""
text = self.extract_text(file_path)
if not text:
return False
# 向量化文本
embeddings = self.embedding([text])
if embeddings:
# print(f"文件向量化成功,维度:{len(embeddings[0])}")
# 合并元数据
doc_metadata = {"file_path": file_path}
if metadata:
doc_metadata.update(metadata)
# 存储到 ChromaDB
self.collection.add(
documents=[text],
embeddings=[embeddings[0]],
metadatas=[doc_metadata],
ids=[os.path.basename(file_path)]
)
return True
return False
def process_files(self, file_paths, metadata=None):
"""批量处理文件"""
results = []
for file_path in file_paths:
results.append(self.process_file(file_path, metadata))
return results
def delete_embeddings(self, file_path, delete_file=False):
"""删除文件对应的向量数据"""
file_id = os.path.basename(file_path)
if delete_file and os.path.exists(file_path):
os.remove(file_path)
self.collection.delete(ids=[file_id])
def query(self, query_text, n_results=5, include_metadata=True):
"""查询知识库
Args:
query_text: 查询文本
n_results: 返回结果数量
include_metadata: 是否包含元数据
Returns:
dict: 包含查询结果的字典,格式为:
{
'documents': 匹配的文档列表,
'metadatas': 匹配的元数据列表 (可选),
'scores': 匹配分数列表
}
"""
try:
# 先将查询文本向量化
query_embedding = self.embedding([query_text])
# 使用向量进行查询
results = self.collection.query(
query_embeddings=query_embedding,
n_results=n_results,
include=["documents", "metadatas", "distances"] if include_metadata else ["documents", "distances"]
)
# 转换距离为相似度分数 (0-1)
if 'distances' in results:
results['scores'] = [1 - (d / 2) for d in results['distances'][0]]
del results['distances']
return results
except Exception as e:
print(f"查询失败:{str(e)}")
return {'error': str(e)}
def embedding(self, input_data):
"""通过 API 生成 embedding 向量"""
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(
self.api_url,
json={"input": input_data, "model": self.model_name, "dimensions": 1024},
headers=headers,
timeout=30
)
response.raise_for_status()
embedding_data = response.json()
# 提取 embedding 向量数组
embeddings = embedding_data['data'][0]['embedding']
# 验证维度
actual_dim = len(embeddings)
if actual_dim != 1024:
raise ValueError(f"维度不匹配:预期 1024 维,实际{actual_dim}维")
return [embeddings]
|