From 50a8c24ead9d87cf8337af309b6185859fde99b4 Mon Sep 17 00:00:00 2001 From: Stepan Vladovskiy Date: Fri, 21 Mar 2025 15:40:29 -0300 Subject: [PATCH] feat(search.py): documnet for bulk indexing are categorized --- services/search.py | 324 ++++++++++++++++++++++++++++++--------------- 1 file changed, 215 insertions(+), 109 deletions(-) diff --git a/services/search.py b/services/search.py index aca6ba1e..a29ead9b 100644 --- a/services/search.py +++ b/services/search.py @@ -4,6 +4,7 @@ import logging import os import httpx import time +import random # Set up proper logging logger = logging.getLogger("search") @@ -11,7 +12,7 @@ logger.setLevel(logging.INFO) # Change to INFO to see more details # Configuration for search service SEARCH_ENABLED = bool(os.environ.get("SEARCH_ENABLED", "true").lower() in ["true", "1", "yes"]) -TXTAI_SERVICE_URL = os.environ.get("TXTAI_SERVICE_URL", "http://search-txtai.web.1:8000") +TXTAI_SERVICE_URL = os.environ.get("TXTAI_SERVICE_URL", "none") MAX_BATCH_SIZE = int(os.environ.get("SEARCH_MAX_BATCH_SIZE", "25")) @@ -87,7 +88,7 @@ class SearchService: logger.error(f"Indexing error for shout {shout.id}: {e}") async def bulk_index(self, shouts): - """Index multiple documents at once""" + """Index multiple documents at once with adaptive batch sizing""" if not self.available or not shouts: logger.warning(f"Bulk indexing skipped: available={self.available}, shouts_count={len(shouts) if shouts else 0}") return @@ -96,122 +97,227 @@ class SearchService: logger.info(f"Starting bulk indexing of {len(shouts)} documents") MAX_TEXT_LENGTH = 8000 # Maximum text length to send in a single request - batch_size = MAX_BATCH_SIZE + max_batch_size = MAX_BATCH_SIZE total_indexed = 0 total_skipped = 0 total_truncated = 0 - i = 0 - - for i in range(0, len(shouts), batch_size): - batch = shouts[i:i+batch_size] - logger.info(f"Processing batch {i//batch_size + 1} of {(len(shouts)-1)//batch_size + 1}, size {len(batch)}") - - documents = [] - for shout in batch: - try: - text_fields = [] - for field_name in ['title', 'subtitle', 'lead', 'body']: - field_value = getattr(shout, field_name, None) - if field_value and isinstance(field_value, str) and field_value.strip(): - text_fields.append(field_value.strip()) - - media = getattr(shout, 'media', None) - if media: - if isinstance(media, str): - try: - media_json = json.loads(media) - if isinstance(media_json, dict): - if 'title' in media_json: - text_fields.append(media_json['title']) - if 'body' in media_json: - text_fields.append(media_json['body']) - except json.JSONDecodeError: - text_fields.append(media) - elif isinstance(media, dict): - if 'title' in media: - text_fields.append(media['title']) - if 'body' in media: - text_fields.append(media['body']) - - text = " ".join(text_fields) - - if not text.strip(): - logger.debug(f"Skipping shout {shout.id}: no text content") - total_skipped += 1 - continue - - # Truncate text if it exceeds the maximum length - original_length = len(text) - if original_length > MAX_TEXT_LENGTH: - text = text[:MAX_TEXT_LENGTH] - logger.info(f"Truncated document {shout.id} from {original_length} to {MAX_TEXT_LENGTH} chars") - total_truncated += 1 - - documents.append({ - "id": str(shout.id), - "text": text - }) - total_indexed += 1 - - except Exception as e: - logger.error(f"Error processing shout {getattr(shout, 'id', 'unknown')} for indexing: {e}") - total_skipped += 1 - - if not documents: - logger.warning(f"No valid documents in batch {i//batch_size + 1}") - continue - + total_retries = 0 + + # Group documents by size to process smaller documents in larger batches + small_docs = [] + medium_docs = [] + large_docs = [] + + # First pass: prepare all documents and categorize by size + for shout in shouts: try: - if documents: - sample = documents[0] - logger.info(f"Sample document: id={sample['id']}, text_length={len(sample['text'])}") + text_fields = [] + for field_name in ['title', 'subtitle', 'lead', 'body']: + field_value = getattr(shout, field_name, None) + if field_value and isinstance(field_value, str) and field_value.strip(): + text_fields.append(field_value.strip()) - logger.info(f"Sending batch of {len(documents)} documents to search service") - response = await self.index_client.post( - "/bulk-index", - json=documents - ) - # Error Handling - if response.status_code == 422: - error_detail = response.json() - - # Create a truncated version of the error detail for logging - truncated_detail = error_detail.copy() if isinstance(error_detail, dict) else error_detail - - # If it's a validation error with details list - if isinstance(truncated_detail, dict) and 'detail' in truncated_detail and isinstance(truncated_detail['detail'], list): - for i, item in enumerate(truncated_detail['detail']): - # Handle case where input contains document text - if isinstance(item, dict) and 'input' in item: - if isinstance(item['input'], dict) and any(k in item['input'] for k in ['documents', 'text']): - # Check for documents list - if 'documents' in item['input'] and isinstance(item['input']['documents'], list): - for j, doc in enumerate(item['input']['documents']): - if 'text' in doc and isinstance(doc['text'], str) and len(doc['text']) > 100: - item['input']['documents'][j]['text'] = f"{doc['text'][:100]}... [truncated, total {len(doc['text'])} chars]" - - # Check for direct text field - if 'text' in item['input'] and isinstance(item['input']['text'], str) and len(item['input']['text']) > 100: - item['input']['text'] = f"{item['input']['text'][:100]}... [truncated, total {len(item['input']['text'])} chars]" - - logger.error(f"Validation error from search service: {truncated_detail}") - - # Try to identify problematic documents - for doc in documents: - if len(doc['text']) > 10000: # Adjust threshold as needed - logger.warning(f"Document {doc['id']} has very long text: {len(doc['text'])} chars") - - # Continue with next batch instead of failing completely + # Media field processing remains the same + media = getattr(shout, 'media', None) + if media: + # Your existing media processing logic + if isinstance(media, str): + try: + media_json = json.loads(media) + if isinstance(media_json, dict): + if 'title' in media_json: + text_fields.append(media_json['title']) + if 'body' in media_json: + text_fields.append(media_json['body']) + except json.JSONDecodeError: + text_fields.append(media) + elif isinstance(media, dict): + if 'title' in media: + text_fields.append(media['title']) + if 'body' in media: + text_fields.append(media['body']) + + text = " ".join(text_fields) + + if not text.strip(): + logger.debug(f"Skipping shout {shout.id}: no text content") + total_skipped += 1 continue - - response.raise_for_status() - result = response.json() - logger.info(f"Batch {i//batch_size + 1} indexed successfully: {result}") + + # Truncate text if it exceeds the maximum length + original_length = len(text) + if original_length > MAX_TEXT_LENGTH: + text = text[:MAX_TEXT_LENGTH] + logger.info(f"Truncated document {shout.id} from {original_length} to {MAX_TEXT_LENGTH} chars") + total_truncated += 1 + + document = { + "id": str(shout.id), + "text": text + } + + # Categorize by size + text_len = len(text) + if text_len > 5000: + large_docs.append(document) + elif text_len > 2000: + medium_docs.append(document) + else: + small_docs.append(document) + + total_indexed += 1 + except Exception as e: - logger.error(f"Bulk indexing error for batch {i//batch_size + 1}: {e}") + logger.error(f"Error processing shout {getattr(shout, 'id', 'unknown')} for indexing: {e}") + total_skipped += 1 + + # Process each category with appropriate batch sizes + logger.info(f"Documents categorized: {len(small_docs)} small, {len(medium_docs)} medium, {len(large_docs)} large") + + # Process small documents (larger batches) + if small_docs: + batch_size = min(max_batch_size, 25) + await self._process_document_batches(small_docs, batch_size, "small") + + # Process medium documents (medium batches) + if medium_docs: + batch_size = min(max_batch_size, 15) + await self._process_document_batches(medium_docs, batch_size, "medium") + + # Process large documents (small batches) + if large_docs: + batch_size = min(max_batch_size, 5) + await self._process_document_batches(large_docs, batch_size, "large") elapsed = time.time() - start_time - logger.info(f"Bulk indexing completed in {elapsed:.2f}s: {total_indexed} indexed, {total_skipped} skipped") + logger.info(f"Bulk indexing completed in {elapsed:.2f}s: {total_indexed} indexed, {total_skipped} skipped, {total_truncated} truncated, {total_retries} retries") + + async def _process_document_batches(self, documents, batch_size, size_category): + """Process document batches with retry logic""" + for i in range(0, len(documents), batch_size): + batch = documents[i:i+batch_size] + batch_id = f"{size_category}-{i//batch_size + 1}" + logger.info(f"Processing {size_category} batch {batch_id} of {len(batch)} documents") + + retry_count = 0 + max_retries = 3 + success = False + + # Process with retries + while not success and retry_count < max_retries: + try: + if batch: + sample = batch[0] + logger.info(f"Sample document in batch {batch_id}: id={sample['id']}, text_length={len(sample['text'])}") + + logger.info(f"Sending batch {batch_id} of {len(batch)} documents to search service (attempt {retry_count+1})") + response = await self.index_client.post( + "/bulk-index", + json=batch, + timeout=120.0 # Explicit longer timeout for large batches + ) + + # Handle 422 validation errors - these won't be fixed by retrying + if response.status_code == 422: + error_detail = response.json() + truncated_error = self._truncate_error_detail(error_detail) + logger.error(f"Validation error from search service for batch {batch_id}: {truncated_error}") + + # Individual document validation often won't benefit from splitting + break + + # Handle 500 server errors - these might be fixed by retrying with smaller batches + elif response.status_code == 500: + if retry_count < max_retries - 1: + retry_count += 1 + wait_time = (2 ** retry_count) + (random.random() * 0.5) # Exponential backoff with jitter + logger.warning(f"Server error for batch {batch_id}, retrying in {wait_time:.1f}s (attempt {retry_count+1}/{max_retries})") + await asyncio.sleep(wait_time) + continue + + # Final retry, split the batch + elif len(batch) > 1: + logger.warning(f"Splitting batch {batch_id} after repeated failures") + mid = len(batch) // 2 + await self._process_single_batch(batch[:mid], f"{batch_id}-A") + await self._process_single_batch(batch[mid:], f"{batch_id}-B") + break + else: + # Can't split a single document + logger.error(f"Failed to index document {batch[0]['id']} after {max_retries} attempts") + break + + # Normal success case + response.raise_for_status() + result = response.json() + logger.info(f"Batch {batch_id} indexed successfully: {result}") + success = True + + except Exception as e: + if retry_count < max_retries - 1: + retry_count += 1 + wait_time = (2 ** retry_count) + (random.random() * 0.5) + logger.warning(f"Error for batch {batch_id}, retrying in {wait_time:.1f}s: {str(e)[:200]}") + await asyncio.sleep(wait_time) + else: + # Last resort - try to split the batch + if len(batch) > 1: + logger.warning(f"Splitting batch {batch_id} after exception: {str(e)[:200]}") + mid = len(batch) // 2 + await self._process_single_batch(batch[:mid], f"{batch_id}-A") + await self._process_single_batch(batch[mid:], f"{batch_id}-B") + else: + logger.error(f"Failed to index document {batch[0]['id']} after {max_retries} attempts: {e}") + break + + async def _process_single_batch(self, documents, batch_id): + """Process a single batch with maximum reliability""" + try: + if not documents: + return + + logger.info(f"Processing sub-batch {batch_id} with {len(documents)} documents") + response = await self.index_client.post( + "/bulk-index", + json=documents, + timeout=90.0 + ) + response.raise_for_status() + result = response.json() + logger.info(f"Sub-batch {batch_id} indexed successfully: {result}") + except Exception as e: + logger.error(f"Error indexing sub-batch {batch_id}: {str(e)[:200]}") + + # For tiny batches, try one-by-one as last resort + if len(documents) > 1: + logger.info(f"Processing documents in sub-batch {batch_id} individually") + for i, doc in enumerate(documents): + try: + resp = await self.index_client.post("/index", json=doc, timeout=30.0) + resp.raise_for_status() + logger.info(f"Indexed document {doc['id']} individually") + except Exception as e2: + logger.error(f"Failed to index document {doc['id']} individually: {str(e2)[:100]}") + + def _truncate_error_detail(self, error_detail): + """Truncate error details for logging""" + truncated_detail = error_detail.copy() if isinstance(error_detail, dict) else error_detail + + if isinstance(truncated_detail, dict) and 'detail' in truncated_detail and isinstance(truncated_detail['detail'], list): + for i, item in enumerate(truncated_detail['detail']): + if isinstance(item, dict) and 'input' in item: + if isinstance(item['input'], dict) and any(k in item['input'] for k in ['documents', 'text']): + # Check for documents list + if 'documents' in item['input'] and isinstance(item['input']['documents'], list): + for j, doc in enumerate(item['input']['documents']): + if 'text' in doc and isinstance(doc['text'], str) and len(doc['text']) > 100: + item['input']['documents'][j]['text'] = f"{doc['text'][:100]}... [truncated, total {len(doc['text'])} chars]" + + # Check for direct text field + if 'text' in item['input'] and isinstance(item['input']['text'], str) and len(item['input']['text']) > 100: + item['input']['text'] = f"{item['input']['text'][:100]}... [truncated, total {len(item['input']['text'])} chars]" + + return truncated_detail async def search(self, text, limit, offset): """Search documents"""