using Microsoft.Azure.Cosmos; using Microsoft.Azure.Cosmos.Fluent; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Newtonsoft.Json.Linq; using System.Diagnostics; using VectorSearchAiAssistant.Common.Models; using VectorSearchAiAssistant.Common.Models.BusinessDomain; using VectorSearchAiAssistant.Common.Models.Chat; using VectorSearchAiAssistant.Service.Interfaces; using VectorSearchAiAssistant.Service.Models.Chat; using VectorSearchAiAssistant.Service.Models.ConfigurationOptions; using VectorSearchAiAssistant.Service.Utils; namespace VectorSearchAiAssistant.Service.Services { /// /// Service to access Azure Cosmos DB for NoSQL. /// public class CosmosDbService : ICosmosDbService { private readonly Container _completions; private readonly Container _customer; private readonly Container _product; private readonly Container _leases; private readonly Database _database; private readonly Dictionary _containers; readonly Dictionary _memoryTypes; private readonly IRAGService _ragService; private readonly IAISearchService _aiSearchService; private readonly CosmosDbSettings _settings; private readonly ILogger _logger; private List _changeFeedProcessors; private bool _changeFeedsInitialized = false; public bool IsInitialized => _changeFeedsInitialized; public CosmosDbService( IRAGService ragService, IAISearchService AISearchService, IOptions settings, ILogger logger) { _ragService = ragService; _aiSearchService = AISearchService; _settings = settings.Value; ArgumentException.ThrowIfNullOrEmpty(_settings.Endpoint); ArgumentException.ThrowIfNullOrEmpty(_settings.Key); ArgumentException.ThrowIfNullOrEmpty(_settings.Database); ArgumentException.ThrowIfNullOrEmpty(_settings.Containers); _logger = logger; _logger.LogInformation("Initializing Cosmos DB service."); if (!_settings.EnableTracing) { Type defaultTrace = Type.GetType("Microsoft.Azure.Cosmos.Core.Trace.DefaultTrace,Microsoft.Azure.Cosmos.Direct"); TraceSource traceSource = (TraceSource)defaultTrace.GetProperty("TraceSource").GetValue(null); traceSource.Switch.Level = SourceLevels.All; traceSource.Listeners.Clear(); } CosmosSerializationOptions options = new() { PropertyNamingPolicy = CosmosPropertyNamingPolicy.CamelCase }; CosmosClient client = new CosmosClientBuilder(_settings.Endpoint, _settings.Key) .WithSerializerOptions(options) .WithConnectionModeGateway() .Build(); Database? database = client?.GetDatabase(_settings.Database); _database = database ?? throw new ArgumentException("Unable to connect to existing Azure Cosmos DB database."); //Dictionary of container references for all containers listed in config _containers = new Dictionary(); List containers = _settings.Containers.Split(',').ToList(); foreach (string containerName in containers) { Container? container = database?.GetContainer(containerName.Trim()) ?? throw new ArgumentException("Unable to connect to existing Azure Cosmos DB container or database."); _containers.Add(containerName.Trim(), container); } _completions = _containers["completions"]; _customer = _containers["customer"]; _product = _containers["product"]; _leases = database?.GetContainer(_settings.ChangeFeedLeaseContainer) ?? throw new ArgumentException($"Unable to connect to the {_settings.ChangeFeedLeaseContainer} container required to listen to the CosmosDB change feed."); _memoryTypes = ModelRegistry.Models.ToDictionary(m => m.Key, m => m.Value.Type); Task.Run(() => StartChangeFeedProcessors()); _logger.LogInformation("Cosmos DB service initialized."); } private async Task StartChangeFeedProcessors() { _logger.LogInformation("Initializing the Cognitive Search index..."); await _aiSearchService.Initialize(_memoryTypes.Values.ToList()); _logger.LogInformation("Initializing the change feed processors..."); _changeFeedProcessors = new List(); try { foreach (string monitoredContainerName in _settings.MonitoredContainers.Split(',').Select(s => s.Trim())) { var changeFeedProcessor = _containers[monitoredContainerName] .GetChangeFeedProcessorBuilder($"{monitoredContainerName}ChangeFeed", GenericChangeFeedHandler) .WithInstanceName($"{monitoredContainerName}ChangeInstance") .WithErrorNotification(GenericChangeFeedErrorHandler) .WithLeaseContainer(_leases) .Build(); await changeFeedProcessor.StartAsync(); _changeFeedProcessors.Add(changeFeedProcessor); _logger.LogInformation($"Initialized the change feed processor for the {monitoredContainerName} container."); } _changeFeedsInitialized = true; _logger.LogInformation("Cosmos DB change feed processors initialized."); } catch (Exception ex) { _logger.LogError(ex, "Error initializing change feed processors."); } } // This is an example of a dynamic change feed handler that can handle a range of preconfigured entities. private async Task GenericChangeFeedHandler( ChangeFeedProcessorContext context, IReadOnlyCollection changes, CancellationToken cancellationToken) { if (changes.Count == 0) return; var batchRef = Guid.NewGuid().ToString(); _logger.LogInformation($"Starting to generate embeddings for {changes.Count} entities (batch ref {batchRef})."); // Using dynamic type as this container has two different entities foreach (var item in changes) { try { if (cancellationToken.IsCancellationRequested) break; var jObject = item as JObject; var typeMetadata = ModelRegistry.IdentifyType(jObject); if (typeMetadata == null) { _logger.LogError($"Unsupported entity type in Cosmos DB change feed handler: {jObject}"); } else { var entity = jObject.ToObject(typeMetadata.Type); // Add the entity to the Cognitive Search content index // The content index is used by the Cognitive Search memory source to run create memories from faceted queries await _aiSearchService.IndexItem(entity); // Add the entity to the Semantic Kernel memory used by the RAG service // We want to keep the VectorSearchAiAssistant.SemanticKernel project isolated from any domain-specific // references/dependencies, so we use a generic mechanism to get the name of the entity as well as to // set the vector property on the entity. await _ragService.AddMemory( entity, string.Join(" ", entity.GetPropertyValues(typeMetadata.NamingProperties))); } } catch (Exception ex) { _logger.LogError(ex, $"Error processing an item in the change feed handler: {item}"); } } _logger.LogInformation($"Finished generating embeddings (batch ref {batchRef})."); } private async Task GenericChangeFeedErrorHandler( string LeaseToken, Exception exception) { if (exception is ChangeFeedProcessorUserException userException) { Console.WriteLine($"Lease {LeaseToken} processing failed with unhandled exception from user delegate {userException.InnerException}"); } else { Console.WriteLine($"Lease {LeaseToken} failed with {exception}"); } await Task.CompletedTask; } /// /// Gets a list of all current chat sessions. /// /// List of distinct chat session items. public async Task> GetSessionsAsync() { QueryDefinition query = new QueryDefinition("SELECT DISTINCT * FROM c WHERE c.type = @type") .WithParameter("@type", nameof(Session)); FeedIterator response = _completions.GetItemQueryIterator(query); List output = new(); while (response.HasMoreResults) { FeedResponse results = await response.ReadNextAsync(); output.AddRange(results); } return output; } /// /// Performs a point read to retrieve a single chat session item. /// /// The chat session item. public async Task GetSessionAsync(string id) { return await _completions.ReadItemAsync( id: id, partitionKey: new PartitionKey(id)); } /// /// Gets a list of all current chat messages for a specified session identifier. /// /// Chat session identifier used to filter messsages. /// List of chat message items for the specified session. public async Task> GetSessionMessagesAsync(string sessionId) { QueryDefinition query = new QueryDefinition("SELECT * FROM c WHERE c.sessionId = @sessionId AND c.type = @type") .WithParameter("@sessionId", sessionId) .WithParameter("@type", nameof(Message)); FeedIterator results = _completions.GetItemQueryIterator(query); List output = new(); while (results.HasMoreResults) { FeedResponse response = await results.ReadNextAsync(); output.AddRange(response); } return output; } /// /// Creates a new chat session. /// /// Chat session item to create. /// Newly created chat session item. public async Task InsertSessionAsync(Session session) { PartitionKey partitionKey = new(session.SessionId); return await _completions.CreateItemAsync( item: session, partitionKey: partitionKey ); } /// /// Creates a new chat message. /// /// Chat message item to create. /// Newly created chat message item. public async Task InsertMessageAsync(Message message) { PartitionKey partitionKey = new(message.SessionId); return await _completions.CreateItemAsync( item: message, partitionKey: partitionKey ); } /// /// Updates an existing chat message. /// /// Chat message item to update. /// Revised chat message item. public async Task UpdateMessageAsync(Message message) { PartitionKey partitionKey = new(message.SessionId); return await _completions.ReplaceItemAsync( item: message, id: message.Id, partitionKey: partitionKey ); } /// /// Updates a message's rating through a patch operation. /// /// The message id. /// The message's partition key (session id). /// The rating to replace. /// Revised chat message item. public async Task UpdateMessageRatingAsync(string id, string sessionId, bool? rating) { var response = await _completions.PatchItemAsync( id: id, partitionKey: new PartitionKey(sessionId), patchOperations: new[] { PatchOperation.Set("/rating", rating), } ); return response.Resource; } /// /// Updates an existing chat session. /// /// Chat session item to update. /// Revised created chat session item. public async Task UpdateSessionAsync(Session session) { PartitionKey partitionKey = new(session.SessionId); return await _completions.ReplaceItemAsync( item: session, id: session.Id, partitionKey: partitionKey ); } /// /// Updates a session's name through a patch operation. /// /// The session id. /// The session's new name. /// Revised chat session item. public async Task UpdateSessionNameAsync(string id, string name) { var response = await _completions.PatchItemAsync( id: id, partitionKey: new PartitionKey(id), patchOperations: new[] { PatchOperation.Set("/name", name), } ); return response.Resource; } /// /// Batch create or update chat messages and session. /// /// Chat message and session items to create or replace. public async Task UpsertSessionBatchAsync(params dynamic[] messages) { if (messages.Select(m => m.SessionId).Distinct().Count() > 1) { throw new ArgumentException("All items must have the same partition key."); } PartitionKey partitionKey = new(messages.First().SessionId); var batch = _completions.CreateTransactionalBatch(partitionKey); foreach (var message in messages) { batch.UpsertItem( item: message ); } await batch.ExecuteAsync(); } /// /// Batch deletes an existing chat session and all related messages. /// /// Chat session identifier used to flag messages and sessions for deletion. public async Task DeleteSessionAndMessagesAsync(string sessionId) { PartitionKey partitionKey = new(sessionId); // TODO: await container.DeleteAllItemsByPartitionKeyStreamAsync(partitionKey); var query = new QueryDefinition("SELECT c.id FROM c WHERE c.sessionId = @sessionId") .WithParameter("@sessionId", sessionId); var response = _completions.GetItemQueryIterator(query); var batch = _completions.CreateTransactionalBatch(partitionKey); while (response.HasMoreResults) { var results = await response.ReadNextAsync(); foreach (var item in results) { batch.DeleteItem( id: item.Id ); } } await batch.ExecuteAsync(); } /// /// Inserts a product into the product container. /// /// Product item to create. /// Newly created product item. public async Task InsertProductAsync(Product product) { try { return await _product.CreateItemAsync(product); } catch (CosmosException ex) { // Ignore conflict errors. if (ex.StatusCode == System.Net.HttpStatusCode.Conflict) { _logger.LogInformation("Product already added."); } else { _logger.LogError(ex.Message); throw; } return product; } } /// /// Inserts a customer into the customer container. /// /// Customer item to create. /// Newly created customer item. public async Task InsertCustomerAsync(Customer customer) { try { return await _customer.CreateItemAsync(customer); } catch (CosmosException ex) { // Ignore conflict errors. if (ex.StatusCode == System.Net.HttpStatusCode.Conflict) { _logger.LogInformation("Customer already added."); } else { _logger.LogError(ex.Message); throw; } return customer; } } /// /// Inserts a sales order into the customer container. /// /// Sales order item to create. /// Newly created sales order item. public async Task InsertSalesOrderAsync(SalesOrder salesOrder) { try { return await _customer.CreateItemAsync(salesOrder); } catch (CosmosException ex) { // Ignore conflict errors. if (ex.StatusCode == System.Net.HttpStatusCode.Conflict) { _logger.LogInformation("Sales order already added."); } else { _logger.LogError(ex.Message); throw; } return salesOrder; } } /// /// Deletes a product by its Id and category (its partition key). /// /// The Id of the product to delete. /// The category Id of the product to delete. /// public async Task DeleteProductAsync(string productId, string categoryId) { try { // Delete from Cosmos product container await _product.DeleteItemAsync(id: productId, partitionKey: new PartitionKey(categoryId)); } catch (CosmosException ex) { if (ex.StatusCode == System.Net.HttpStatusCode.NotFound) { _logger.LogInformation("The product has already been removed."); } else throw; } } /// /// Reads all documents retrieved by Vector Search. /// /// List string of JSON documents from vector search results public async Task GetVectorSearchDocumentsAsync(List vectorDocuments) { List searchDocuments = new List(); foreach (var document in vectorDocuments) { try { var response = await _containers[document.containerName].ReadItemStreamAsync( document.itemId, new PartitionKey(document.partitionKey)); if ((int) response.StatusCode < 200 || (int) response.StatusCode >= 400) _logger.LogError( $"Failed to retrieve an item for id '{document.itemId}' - status code '{response.StatusCode}"); if (response.Content == null) { _logger.LogInformation( $"Null content received for document '{document.itemId}' - status code '{response.StatusCode}"); continue; } string item; using (StreamReader sr = new StreamReader(response.Content)) item = await sr.ReadToEndAsync(); searchDocuments.Add(item); } catch (Exception ex) { _logger.LogError(ex.Message, ex); } } var resultDocuments = string.Join(Environment.NewLine + "-", searchDocuments); return resultDocuments; } public async Task GetCompletionPrompt(string sessionId, string completionPromptId) { return await _completions.ReadItemAsync( id: completionPromptId, partitionKey: new PartitionKey(sessionId)); } } }