src/extensions/onnx_extension/OnnxGenAIChatCompletionsStreamingClass.cs (69 lines of code) (raw):
using System.Text;
using Microsoft.ML.OnnxRuntimeGenAI;
namespace Azure.AI.Details.Common.CLI.Extensions.ONNX;
public class ContentMessage
{
public ContentMessage()
{
Role = string.Empty;
Content = string.Empty;
}
public string Role { get; set; }
public string Content { get; set; }
}
public class OnnxGenAIChatCompletionsStreamingClass
{
public OnnxGenAIChatCompletionsStreamingClass(string modelDirectory, string systemPrompt, string? chatHistoryFile = null)
{
_modelDirectory = modelDirectory;
_systemPrompt = systemPrompt;
_messages = new List<ContentMessage>();
if (chatHistoryFile != null)
{
_messages.ReadChatHistoryFromFile(chatHistoryFile);
}
else
{
ClearConversation();
}
_model = new Model(modelDirectory);
_tokenizer = new Tokenizer(_model);
}
public void ClearConversation()
{
_messages.Clear();
_messages.Add(new ContentMessage { Role = "system", Content = _systemPrompt });
}
public List<ContentMessage> Messages { get => _messages; }
public string GetChatCompletionStreaming(string userPrompt, Action<string>? callback = null)
{
_messages.Add(new ContentMessage { Role = "user", Content = userPrompt });
var responseContent = string.Empty;
using var tokens = _tokenizer.Encode(string.Join("\n", _messages
.Select(m => $"<|{m.Role}|>\n{m.Content}\n<|end|>"))
+ "<|assistant|>\n");
using var generatorParams = new GeneratorParams(_model);
generatorParams.SetSearchOption("max_length", 2048);
generatorParams.SetInputSequences(tokens);
using var generator = new Generator(_model, generatorParams);
var sb = new StringBuilder();
while (!generator.IsDone())
{
generator.ComputeLogits();
generator.GenerateNextToken();
var outputTokens = generator.GetSequence(0);
var newToken = outputTokens.Slice(outputTokens.Length - 1, 1);
var output = _tokenizer.Decode(newToken);
sb.Append(output);
callback?.Invoke(output);
}
responseContent = sb.ToString();
_messages.Add(new ContentMessage { Role = "assistant", Content = responseContent });
return responseContent;
}
private string _modelDirectory;
private string _systemPrompt;
private Model _model;
private Tokenizer _tokenizer;
private List<ContentMessage> _messages;
}