using Microsoft.Azure.Cosmos;
using Microsoft.Azure.Cosmos.Fluent;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Newtonsoft.Json.Linq;
using System.Diagnostics;
using VectorSearchAiAssistant.Common.Models;
using VectorSearchAiAssistant.Common.Models.BusinessDomain;
using VectorSearchAiAssistant.Common.Models.Chat;
using VectorSearchAiAssistant.Service.Interfaces;
using VectorSearchAiAssistant.Service.Models.Chat;
using VectorSearchAiAssistant.Service.Models.ConfigurationOptions;
using VectorSearchAiAssistant.Service.Utils;
namespace VectorSearchAiAssistant.Service.Services
{
///
/// Service to access Azure Cosmos DB for NoSQL.
///
public class CosmosDbService : ICosmosDbService
{
private readonly Container _completions;
private readonly Container _customer;
private readonly Container _product;
private readonly Container _leases;
private readonly Database _database;
private readonly Dictionary _containers;
readonly Dictionary _memoryTypes;
private readonly IRAGService _ragService;
private readonly IAISearchService _aiSearchService;
private readonly CosmosDbSettings _settings;
private readonly ILogger _logger;
private List _changeFeedProcessors;
private bool _changeFeedsInitialized = false;
public bool IsInitialized => _changeFeedsInitialized;
public CosmosDbService(
IRAGService ragService,
IAISearchService AISearchService,
IOptions settings,
ILogger logger)
{
_ragService = ragService;
_aiSearchService = AISearchService;
_settings = settings.Value;
ArgumentException.ThrowIfNullOrEmpty(_settings.Endpoint);
ArgumentException.ThrowIfNullOrEmpty(_settings.Key);
ArgumentException.ThrowIfNullOrEmpty(_settings.Database);
ArgumentException.ThrowIfNullOrEmpty(_settings.Containers);
_logger = logger;
_logger.LogInformation("Initializing Cosmos DB service.");
if (!_settings.EnableTracing)
{
Type defaultTrace = Type.GetType("Microsoft.Azure.Cosmos.Core.Trace.DefaultTrace,Microsoft.Azure.Cosmos.Direct");
TraceSource traceSource = (TraceSource)defaultTrace.GetProperty("TraceSource").GetValue(null);
traceSource.Switch.Level = SourceLevels.All;
traceSource.Listeners.Clear();
}
CosmosSerializationOptions options = new()
{
PropertyNamingPolicy = CosmosPropertyNamingPolicy.CamelCase
};
CosmosClient client = new CosmosClientBuilder(_settings.Endpoint, _settings.Key)
.WithSerializerOptions(options)
.WithConnectionModeGateway()
.Build();
Database? database = client?.GetDatabase(_settings.Database);
_database = database ??
throw new ArgumentException("Unable to connect to existing Azure Cosmos DB database.");
//Dictionary of container references for all containers listed in config
_containers = new Dictionary();
List containers = _settings.Containers.Split(',').ToList();
foreach (string containerName in containers)
{
Container? container = database?.GetContainer(containerName.Trim()) ??
throw new ArgumentException("Unable to connect to existing Azure Cosmos DB container or database.");
_containers.Add(containerName.Trim(), container);
}
_completions = _containers["completions"];
_customer = _containers["customer"];
_product = _containers["product"];
_leases = database?.GetContainer(_settings.ChangeFeedLeaseContainer)
?? throw new ArgumentException($"Unable to connect to the {_settings.ChangeFeedLeaseContainer} container required to listen to the CosmosDB change feed.");
_memoryTypes = ModelRegistry.Models.ToDictionary(m => m.Key, m => m.Value.Type);
Task.Run(() => StartChangeFeedProcessors());
_logger.LogInformation("Cosmos DB service initialized.");
}
private async Task StartChangeFeedProcessors()
{
_logger.LogInformation("Initializing the Cognitive Search index...");
await _aiSearchService.Initialize(_memoryTypes.Values.ToList());
_logger.LogInformation("Initializing the change feed processors...");
_changeFeedProcessors = new List();
try
{
foreach (string monitoredContainerName in _settings.MonitoredContainers.Split(',').Select(s => s.Trim()))
{
var changeFeedProcessor = _containers[monitoredContainerName]
.GetChangeFeedProcessorBuilder($"{monitoredContainerName}ChangeFeed", GenericChangeFeedHandler)
.WithInstanceName($"{monitoredContainerName}ChangeInstance")
.WithErrorNotification(GenericChangeFeedErrorHandler)
.WithLeaseContainer(_leases)
.Build();
await changeFeedProcessor.StartAsync();
_changeFeedProcessors.Add(changeFeedProcessor);
_logger.LogInformation($"Initialized the change feed processor for the {monitoredContainerName} container.");
}
_changeFeedsInitialized = true;
_logger.LogInformation("Cosmos DB change feed processors initialized.");
}
catch (Exception ex)
{
_logger.LogError(ex, "Error initializing change feed processors.");
}
}
// This is an example of a dynamic change feed handler that can handle a range of preconfigured entities.
private async Task GenericChangeFeedHandler(
ChangeFeedProcessorContext context,
IReadOnlyCollection changes,
CancellationToken cancellationToken)
{
if (changes.Count == 0)
return;
var batchRef = Guid.NewGuid().ToString();
_logger.LogInformation($"Starting to generate embeddings for {changes.Count} entities (batch ref {batchRef}).");
// Using dynamic type as this container has two different entities
foreach (var item in changes)
{
try
{
if (cancellationToken.IsCancellationRequested)
break;
var jObject = item as JObject;
var typeMetadata = ModelRegistry.IdentifyType(jObject);
if (typeMetadata == null)
{
_logger.LogError($"Unsupported entity type in Cosmos DB change feed handler: {jObject}");
}
else
{
var entity = jObject.ToObject(typeMetadata.Type);
// Add the entity to the Cognitive Search content index
// The content index is used by the Cognitive Search memory source to run create memories from faceted queries
await _aiSearchService.IndexItem(entity);
// Add the entity to the Semantic Kernel memory used by the RAG service
// We want to keep the VectorSearchAiAssistant.SemanticKernel project isolated from any domain-specific
// references/dependencies, so we use a generic mechanism to get the name of the entity as well as to
// set the vector property on the entity.
await _ragService.AddMemory(
entity,
string.Join(" ", entity.GetPropertyValues(typeMetadata.NamingProperties)));
}
}
catch (Exception ex)
{
_logger.LogError(ex, $"Error processing an item in the change feed handler: {item}");
}
}
_logger.LogInformation($"Finished generating embeddings (batch ref {batchRef}).");
}
private async Task GenericChangeFeedErrorHandler(
string LeaseToken,
Exception exception)
{
if (exception is ChangeFeedProcessorUserException userException)
{
Console.WriteLine($"Lease {LeaseToken} processing failed with unhandled exception from user delegate {userException.InnerException}");
}
else
{
Console.WriteLine($"Lease {LeaseToken} failed with {exception}");
}
await Task.CompletedTask;
}
///
/// Gets a list of all current chat sessions.
///
/// List of distinct chat session items.
public async Task> GetSessionsAsync()
{
QueryDefinition query = new QueryDefinition("SELECT DISTINCT * FROM c WHERE c.type = @type")
.WithParameter("@type", nameof(Session));
FeedIterator response = _completions.GetItemQueryIterator(query);
List output = new();
while (response.HasMoreResults)
{
FeedResponse results = await response.ReadNextAsync();
output.AddRange(results);
}
return output;
}
///
/// Performs a point read to retrieve a single chat session item.
///
/// The chat session item.
public async Task GetSessionAsync(string id)
{
return await _completions.ReadItemAsync(
id: id,
partitionKey: new PartitionKey(id));
}
///
/// Gets a list of all current chat messages for a specified session identifier.
///
/// Chat session identifier used to filter messsages.
/// List of chat message items for the specified session.
public async Task> GetSessionMessagesAsync(string sessionId)
{
QueryDefinition query =
new QueryDefinition("SELECT * FROM c WHERE c.sessionId = @sessionId AND c.type = @type")
.WithParameter("@sessionId", sessionId)
.WithParameter("@type", nameof(Message));
FeedIterator results = _completions.GetItemQueryIterator(query);
List output = new();
while (results.HasMoreResults)
{
FeedResponse response = await results.ReadNextAsync();
output.AddRange(response);
}
return output;
}
///
/// Creates a new chat session.
///
/// Chat session item to create.
/// Newly created chat session item.
public async Task InsertSessionAsync(Session session)
{
PartitionKey partitionKey = new(session.SessionId);
return await _completions.CreateItemAsync(
item: session,
partitionKey: partitionKey
);
}
///
/// Creates a new chat message.
///
/// Chat message item to create.
/// Newly created chat message item.
public async Task InsertMessageAsync(Message message)
{
PartitionKey partitionKey = new(message.SessionId);
return await _completions.CreateItemAsync(
item: message,
partitionKey: partitionKey
);
}
///
/// Updates an existing chat message.
///
/// Chat message item to update.
/// Revised chat message item.
public async Task UpdateMessageAsync(Message message)
{
PartitionKey partitionKey = new(message.SessionId);
return await _completions.ReplaceItemAsync(
item: message,
id: message.Id,
partitionKey: partitionKey
);
}
///
/// Updates a message's rating through a patch operation.
///
/// The message id.
/// The message's partition key (session id).
/// The rating to replace.
/// Revised chat message item.
public async Task UpdateMessageRatingAsync(string id, string sessionId, bool? rating)
{
var response = await _completions.PatchItemAsync(
id: id,
partitionKey: new PartitionKey(sessionId),
patchOperations: new[]
{
PatchOperation.Set("/rating", rating),
}
);
return response.Resource;
}
///
/// Updates an existing chat session.
///
/// Chat session item to update.
/// Revised created chat session item.
public async Task UpdateSessionAsync(Session session)
{
PartitionKey partitionKey = new(session.SessionId);
return await _completions.ReplaceItemAsync(
item: session,
id: session.Id,
partitionKey: partitionKey
);
}
///
/// Updates a session's name through a patch operation.
///
/// The session id.
/// The session's new name.
/// Revised chat session item.
public async Task UpdateSessionNameAsync(string id, string name)
{
var response = await _completions.PatchItemAsync(
id: id,
partitionKey: new PartitionKey(id),
patchOperations: new[]
{
PatchOperation.Set("/name", name),
}
);
return response.Resource;
}
///
/// Batch create or update chat messages and session.
///
/// Chat message and session items to create or replace.
public async Task UpsertSessionBatchAsync(params dynamic[] messages)
{
if (messages.Select(m => m.SessionId).Distinct().Count() > 1)
{
throw new ArgumentException("All items must have the same partition key.");
}
PartitionKey partitionKey = new(messages.First().SessionId);
var batch = _completions.CreateTransactionalBatch(partitionKey);
foreach (var message in messages)
{
batch.UpsertItem(
item: message
);
}
await batch.ExecuteAsync();
}
///
/// Batch deletes an existing chat session and all related messages.
///
/// Chat session identifier used to flag messages and sessions for deletion.
public async Task DeleteSessionAndMessagesAsync(string sessionId)
{
PartitionKey partitionKey = new(sessionId);
// TODO: await container.DeleteAllItemsByPartitionKeyStreamAsync(partitionKey);
var query = new QueryDefinition("SELECT c.id FROM c WHERE c.sessionId = @sessionId")
.WithParameter("@sessionId", sessionId);
var response = _completions.GetItemQueryIterator(query);
var batch = _completions.CreateTransactionalBatch(partitionKey);
while (response.HasMoreResults)
{
var results = await response.ReadNextAsync();
foreach (var item in results)
{
batch.DeleteItem(
id: item.Id
);
}
}
await batch.ExecuteAsync();
}
///
/// Inserts a product into the product container.
///
/// Product item to create.
/// Newly created product item.
public async Task InsertProductAsync(Product product)
{
try
{
return await _product.CreateItemAsync(product);
}
catch (CosmosException ex)
{
// Ignore conflict errors.
if (ex.StatusCode == System.Net.HttpStatusCode.Conflict)
{
_logger.LogInformation("Product already added.");
}
else
{
_logger.LogError(ex.Message);
throw;
}
return product;
}
}
///
/// Inserts a customer into the customer container.
///
/// Customer item to create.
/// Newly created customer item.
public async Task InsertCustomerAsync(Customer customer)
{
try
{
return await _customer.CreateItemAsync(customer);
}
catch (CosmosException ex)
{
// Ignore conflict errors.
if (ex.StatusCode == System.Net.HttpStatusCode.Conflict)
{
_logger.LogInformation("Customer already added.");
}
else
{
_logger.LogError(ex.Message);
throw;
}
return customer;
}
}
///
/// Inserts a sales order into the customer container.
///
/// Sales order item to create.
/// Newly created sales order item.
public async Task InsertSalesOrderAsync(SalesOrder salesOrder)
{
try
{
return await _customer.CreateItemAsync(salesOrder);
}
catch (CosmosException ex)
{
// Ignore conflict errors.
if (ex.StatusCode == System.Net.HttpStatusCode.Conflict)
{
_logger.LogInformation("Sales order already added.");
}
else
{
_logger.LogError(ex.Message);
throw;
}
return salesOrder;
}
}
///
/// Deletes a product by its Id and category (its partition key).
///
/// The Id of the product to delete.
/// The category Id of the product to delete.
///
public async Task DeleteProductAsync(string productId, string categoryId)
{
try
{
// Delete from Cosmos product container
await _product.DeleteItemAsync(id: productId, partitionKey: new PartitionKey(categoryId));
}
catch (CosmosException ex)
{
if (ex.StatusCode == System.Net.HttpStatusCode.NotFound)
{
_logger.LogInformation("The product has already been removed.");
}
else
throw;
}
}
///
/// Reads all documents retrieved by Vector Search.
///
/// List string of JSON documents from vector search results
public async Task GetVectorSearchDocumentsAsync(List vectorDocuments)
{
List searchDocuments = new List();
foreach (var document in vectorDocuments)
{
try
{
var response = await _containers[document.containerName].ReadItemStreamAsync(
document.itemId, new PartitionKey(document.partitionKey));
if ((int) response.StatusCode < 200 || (int) response.StatusCode >= 400)
_logger.LogError(
$"Failed to retrieve an item for id '{document.itemId}' - status code '{response.StatusCode}");
if (response.Content == null)
{
_logger.LogInformation(
$"Null content received for document '{document.itemId}' - status code '{response.StatusCode}");
continue;
}
string item;
using (StreamReader sr = new StreamReader(response.Content))
item = await sr.ReadToEndAsync();
searchDocuments.Add(item);
}
catch (Exception ex)
{
_logger.LogError(ex.Message, ex);
}
}
var resultDocuments = string.Join(Environment.NewLine + "-", searchDocuments);
return resultDocuments;
}
public async Task GetCompletionPrompt(string sessionId, string completionPromptId)
{
return await _completions.ReadItemAsync(
id: completionPromptId,
partitionKey: new PartitionKey(sessionId));
}
}
}