提交 f6af0f99 编写于 作者: J John Wang

feat: use embeddings table query instead of not implement of milvus vector

上级 5eddcaae
......@@ -73,7 +73,3 @@ class InvalidMetadataError(BaseHTTPException):
code = 400
class CurrentVectorStoreNotSupportHitTestingError(BaseHTTPException):
error_code = 'current_vector_store_not_support_hit_testing'
description = "The current vector store does not support hit testing."
code = 400
......@@ -8,14 +8,12 @@ import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \
ProviderModelCurrentlyNotSupportError
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError, \
CurrentVectorStoreNotSupportHitTestingError
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import TimestampField
from services.dataset_service import DatasetService
from services.errors.dataset import VectorStoreNotSupportHitTestingError
from services.hit_testing_service import HitTestingService
document_fields = {
......@@ -103,8 +101,6 @@ class HitTestingApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except VectorStoreNotSupportHitTestingError:
raise CurrentVectorStoreNotSupportHitTestingError()
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))
......
......@@ -74,10 +74,3 @@ class VectorStore:
raise Exception("Vector store client is not initialized.")
return self._client
def support_hit_testing(self):
if isinstance(self._client, MilvusVectorStoreClient):
# search API not return vector data
return False
return True
\ No newline at end of file
......@@ -3,7 +3,3 @@ from services.errors.base import BaseServiceError
class DatasetNameDuplicateError(BaseServiceError):
pass
class VectorStoreNotSupportHitTestingError(BaseServiceError):
pass
......@@ -11,19 +11,14 @@ from sklearn.manifold import TSNE
from core.docstore.empty_docstore import EmptyDocumentStore
from core.index.vector_index import VectorIndex
from extensions.ext_database import db
from extensions.ext_vector_store import vector_store
from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery
from services.errors.dataset import VectorStoreNotSupportHitTestingError
from models.dataset import Dataset, DocumentSegment, DatasetQuery, Embedding
from services.errors.index import IndexNotInitializedError
class HitTestingService:
@classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
if not vector_store.support_hit_testing():
raise VectorStoreNotSupportHitTestingError()
index = VectorIndex(dataset=dataset).query_index
if not index:
......@@ -74,6 +69,11 @@ class HitTestingService:
for node in nodes:
if node.node.embedding:
embeddings.append(node.node.embedding)
else:
embedding = db.session.query(Embedding).filter_by(hash=node.node.doc_hash).first()
if embedding:
node.node.embedding = embedding.get_embedding()
embeddings.append(node.node.embedding)
tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册