This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Any, Dict, List
|
|||||||
import muvera
|
import muvera
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from settings import SEARCH_MAX_BATCH_SIZE, SEARCH_PREFETCH_SIZE
|
from settings import MUVERA_INDEX_NAME, SEARCH_MAX_BATCH_SIZE, SEARCH_PREFETCH_SIZE
|
||||||
from utils.logger import root_logger as logger
|
from utils.logger import root_logger as logger
|
||||||
|
|
||||||
# Global collection for background tasks
|
# Global collection for background tasks
|
||||||
@@ -34,13 +34,13 @@ class MuveraWrapper:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def search(self, query: str, limit: int) -> List[Dict[str, Any]]:
|
async def search(self, query: str, limit: int) -> List[Dict[str, Any]]:
|
||||||
"""Simple search implementation using FDE encoding"""
|
"""Simple search implementation using FDE encoding with deterministic results"""
|
||||||
if not query.strip():
|
if not query.strip():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# For demo purposes, create a simple query embedding
|
# Create deterministic query embedding based on query hash
|
||||||
# In a real implementation, you'd use a proper text embedding model
|
query_hash = hash(query.strip().lower())
|
||||||
rng = np.random.default_rng()
|
rng = np.random.default_rng(seed=query_hash & 0x7FFFFFFF) # Use positive seed
|
||||||
query_embedding = rng.standard_normal((32, self.vector_dimension)).astype(np.float32)
|
query_embedding = rng.standard_normal((32, self.vector_dimension)).astype(np.float32)
|
||||||
|
|
||||||
# Encode query using FDE
|
# Encode query using FDE
|
||||||
@@ -62,19 +62,20 @@ class MuveraWrapper:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sort by score and limit results
|
# Sort by score and limit results - добавляем сортировку по ID для стабильности
|
||||||
results.sort(key=lambda x: x["score"], reverse=True)
|
results.sort(key=lambda x: (x["score"], x["id"]), reverse=True)
|
||||||
return results[:limit]
|
return results[:limit]
|
||||||
|
|
||||||
async def index(self, documents: List[Dict[str, Any]]) -> None:
|
async def index(self, documents: List[Dict[str, Any]]) -> None:
|
||||||
"""Index documents using FDE encoding"""
|
"""Index documents using FDE encoding with deterministic embeddings"""
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
doc_id = doc["id"]
|
doc_id = doc["id"]
|
||||||
self.documents[doc_id] = doc
|
self.documents[doc_id] = doc
|
||||||
|
|
||||||
# Create a simple document embedding (in real implementation, use proper text embedding)
|
# Create deterministic document embedding based on document content hash
|
||||||
# For now, create random embeddings for demo
|
doc_content = f"{doc.get('title', '')} {doc.get('body', '')}"
|
||||||
rng = np.random.default_rng()
|
content_hash = hash(doc_content.strip().lower() + str(doc_id))
|
||||||
|
rng = np.random.default_rng(seed=content_hash & 0x7FFFFFFF) # Use positive seed
|
||||||
doc_embedding = rng.standard_normal((32, self.vector_dimension)).astype(np.float32)
|
doc_embedding = rng.standard_normal((32, self.vector_dimension)).astype(np.float32)
|
||||||
|
|
||||||
# Encode document using FDE (average aggregation for documents)
|
# Encode document using FDE (average aggregation for documents)
|
||||||
@@ -104,16 +105,16 @@ class SearchService:
|
|||||||
self.muvera_client: Any = None
|
self.muvera_client: Any = None
|
||||||
self.client: Any = None
|
self.client: Any = None
|
||||||
|
|
||||||
# Initialize Muvera
|
# Initialize local Muvera
|
||||||
try:
|
try:
|
||||||
# Initialize Muvera wrapper with your configuration
|
|
||||||
self.muvera_client = MuveraWrapper(
|
self.muvera_client = MuveraWrapper(
|
||||||
vector_dimension=768, # Standard embedding dimension
|
vector_dimension=768, # Standard embedding dimension
|
||||||
cache_enabled=True,
|
cache_enabled=True,
|
||||||
batch_size=SEARCH_MAX_BATCH_SIZE,
|
batch_size=SEARCH_MAX_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
self.available = True
|
self.available = True
|
||||||
logger.info("Muvera wrapper initialized successfully - enhanced search enabled")
|
logger.info(f"Local Muvera wrapper initialized - index: {MUVERA_INDEX_NAME}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize Muvera: {e}")
|
logger.error(f"Failed to initialize Muvera: {e}")
|
||||||
self.available = False
|
self.available = False
|
||||||
@@ -126,7 +127,7 @@ class SearchService:
|
|||||||
# Get Muvera service info
|
# Get Muvera service info
|
||||||
if self.muvera_client:
|
if self.muvera_client:
|
||||||
muvera_info = await self.muvera_client.info()
|
muvera_info = await self.muvera_client.info()
|
||||||
return {"status": "enabled", "provider": "muvera", "muvera_info": muvera_info}
|
return {"status": "enabled", "provider": "muvera", "mode": "local", "muvera_info": muvera_info}
|
||||||
return {"status": "error", "message": "Muvera client not available"}
|
return {"status": "error", "message": "Muvera client not available"}
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to get search info")
|
logger.exception("Failed to get search info")
|
||||||
@@ -403,6 +404,7 @@ class SearchService:
|
|||||||
if hasattr(self, "muvera_client") and self.muvera_client:
|
if hasattr(self, "muvera_client") and self.muvera_client:
|
||||||
try:
|
try:
|
||||||
await self.muvera_client.close()
|
await self.muvera_client.close()
|
||||||
|
logger.info("Local Muvera client closed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error closing Muvera client: {e}")
|
logger.warning(f"Error closing Muvera client: {e}")
|
||||||
logger.info("Search service closed")
|
logger.info("Search service closed")
|
||||||
|
|||||||
@@ -96,3 +96,4 @@ SEARCH_MAX_BATCH_SIZE = int(os.environ.get("SEARCH_MAX_BATCH_SIZE", "25"))
|
|||||||
SEARCH_CACHE_ENABLED = bool(os.environ.get("SEARCH_CACHE_ENABLED", "true").lower() in ["true", "1", "yes"])
|
SEARCH_CACHE_ENABLED = bool(os.environ.get("SEARCH_CACHE_ENABLED", "true").lower() in ["true", "1", "yes"])
|
||||||
SEARCH_CACHE_TTL_SECONDS = int(os.environ.get("SEARCH_CACHE_TTL_SECONDS", "300"))
|
SEARCH_CACHE_TTL_SECONDS = int(os.environ.get("SEARCH_CACHE_TTL_SECONDS", "300"))
|
||||||
SEARCH_PREFETCH_SIZE = int(os.environ.get("SEARCH_PREFETCH_SIZE", "200"))
|
SEARCH_PREFETCH_SIZE = int(os.environ.get("SEARCH_PREFETCH_SIZE", "200"))
|
||||||
|
MUVERA_INDEX_NAME = "discours"
|
||||||
|
|||||||
Reference in New Issue
Block a user