text/text.cpp (243 lines of code) (raw):
#include <azure_ai_contentsafety_text.h>
#include <iostream>
#include <csignal>
#include "thread"
#include <filesystem>
#include <fstream>
#include <functional>
#include <cstdlib> // For _dupenv_s
#include <string> // For std::string
#include <stdexcept> // For std::invalid_argument, std::out_of_range
#include <map>
#include <sstream>
#include <Windows.h>
#include <conio.h>
using namespace Azure::AI::ContentSafety;
std::map<std::string, std::string> readConfig(const std::string &filename) {
std::map<std::string, std::string> config;
std::ifstream file(filename);
if (!file.is_open()) {
throw std::runtime_error("Could not open config file");
}
std::string line;
while (std::getline(file, line)) {
std::istringstream is_line(line);
std::string key;
if (std::getline(is_line, key, '=')) {
std::string value;
if (std::getline(is_line, value)) {
config[key] = value;
}
}
}
return config;
}
std::string getCategoryName(TextCategory category) {
switch (category) {
case TextCategory::Hate:
return "Hate";
case TextCategory::SelfHarm:
return "Self Harm";
case TextCategory::Sexual:
return "Sexual";
case TextCategory::Violence:
return "Violence";
default:
return "Unknown";
}
}
std::vector<char> readFile(const std::string& filename) {
std::vector<char> buffer;
#ifdef _WIN32
HANDLE hFile = CreateFile(filename.c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
if (hFile == INVALID_HANDLE_VALUE) {
std::cerr << "Could not open file" << std::endl;
return buffer;
}
DWORD fileSize = GetFileSize(hFile, NULL);
if (fileSize == INVALID_FILE_SIZE) {
std::cerr << "Could not get file size" << std::endl;
CloseHandle(hFile);
return buffer;
}
HANDLE hMapFile = CreateFileMapping(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
if (hMapFile == NULL) {
std::cerr << "Could not create file mapping object" << std::endl;
CloseHandle(hFile);
return buffer;
}
char* data = (char*)MapViewOfFile(hMapFile, FILE_MAP_READ, 0, 0, 0);
if (data == NULL) {
std::cerr << "Could not map view of file" << std::endl;
CloseHandle(hMapFile);
CloseHandle(hFile);
return buffer;
}
buffer.assign(data, data + fileSize);
UnmapViewOfFile(data);
CloseHandle(hMapFile);
CloseHandle(hFile);
#else
int fd = open(filename.c_str(), O_RDONLY);
if (fd == -1) {
std::cerr << "Could not open file" << std::endl;
return buffer;
}
struct stat sb;
if (fstat(fd, &sb) == -1) {
std::cerr << "Could not get file size" << std::endl;
close(fd);
return buffer;
}
char* data = static_cast<char*>(mmap(NULL, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0));
if (data == MAP_FAILED) {
std::cerr << "Could not map file" << std::endl;
close(fd);
return buffer;
}
buffer.assign(data, data + sb.st_size);
munmap(data, sb.st_size);
close(fd);
#endif
return buffer;
}
void Init(TextModelRuntime **aacs, std::map<std::string, std::string> config) {
std::string licenseText = config["licenseText"];
TextModelConfig aacsConfig;
aacsConfig.gpuEnabled = (config["gpuEnabled"] == "true");
aacsConfig.gpuDeviceId = std::stoi(config["gpuDeviceId"]);
aacsConfig.numThreads = std::stoi(config["numThreads"]);
aacsConfig.modelDirectory = config["modelDirectory"];
aacsConfig.modelName = config["modelName"];
aacsConfig.spmModelName = config["spmModelName"];
aacsConfig.logEnabled = (config["logEnabled"] == "true");
std::cout << "gpuEnabled: " << aacsConfig.gpuEnabled << std::endl;
std::cout << "gpuDeviceId: " << aacsConfig.gpuDeviceId << std::endl;
std::cout << "numThreads: " << aacsConfig.numThreads << std::endl;
std::cout << "modelDirectory: " << aacsConfig.modelDirectory << std::endl;
std::cout << "modelName: " << aacsConfig.modelName << std::endl;
std::cout << "spmModelName: " << aacsConfig.spmModelName << std::endl;
(*aacs) = new TextModelRuntime(licenseText.c_str(), aacsConfig);
try {
(*aacs)->Reload();
} catch (const std::exception &ex) {
std::cerr << "Exception caught: " << ex.what() << std::endl;
} catch (...) {
std::cerr << "Unknown exception caught" << std::endl;
}
}
void processInputText(TextModelRuntime* aacs, std::string inputText) {
AnalyzeTextOptions request;
std::cout << " Your input: " << inputText << std::endl;
request.text = inputText;
std::cout << " AnalyzeResult: " << std::endl;
auto severityThreshold = 3;
// Run inference
auto analyzeStart = std::chrono::high_resolution_clock::now();
auto result = aacs->AnalyzeText(request);
// Print the result to the console
for (const auto& categoryAnalysis : result->categoriesAnalysis) {
if (categoryAnalysis.severity > 0 && categoryAnalysis.severity < severityThreshold) {
std::cout << "\033[33m"; // Set the text color to yellow
}
else if (categoryAnalysis.severity >= severityThreshold) {
std::cout << "\033[31m"; // Set the text color to red
}
else {
std::cout << "\033[32m"; // Set the text color to green
}
std::cout << " Category: " << getCategoryName(categoryAnalysis.category) << ", Severity: "
<< static_cast<int>(categoryAnalysis.severity) << std::endl;
std::cout << "\033[0m"; // Reset the text color
}
auto analyzeEnd = std::chrono::high_resolution_clock::now();
auto analyzeTextDuration = std::chrono::duration_cast<std::chrono::milliseconds>(analyzeEnd - analyzeStart);
std::cout << "AnalyzeText duration: " << analyzeTextDuration.count() << " milliseconds" << std::endl;
std::cout << "--------------------------------------------------------------------------------------"
<< std::endl;
}
void processInputTextWithBlockList(TextModelRuntime* aacs, std::string inputText, std::vector<std::string> blocklist_names) {
AnalyzeTextOptions request;
std::cout << " Your input: " << inputText << std::endl;
request.text = inputText;
request.blocklistNames = blocklist_names;
std::cout << " AnalyzeResult: " << std::endl;
auto severityThreshold = 3;
// Run inference
auto analyzeStart = std::chrono::high_resolution_clock::now();
auto result = aacs->AnalyzeText(request);
// Print the result to the console
for (const auto& categoryAnalysis : result->categoriesAnalysis) {
if (categoryAnalysis.severity > 0 && categoryAnalysis.severity < severityThreshold) {
std::cout << "\033[33m"; // Set the text color to yellow
}
else if (categoryAnalysis.severity >= severityThreshold) {
std::cout << "\033[31m"; // Set the text color to red
}
else {
std::cout << "\033[32m"; // Set the text color to green
}
std::cout << " Category: " << getCategoryName(categoryAnalysis.category) << ", Severity: "
<< static_cast<int>(categoryAnalysis.severity) << std::endl;
std::cout << "\033[0m"; // Reset the text color
}
std::cout << " BlockList Matches : " << result->blocklistsMatched.size() << std::endl;
for (const auto& blockListItem : result->blocklistsMatched) {
std::cout << "\033[31m"; // Set the text color to red
std::cout << " BlockList Name: " << blockListItem.blocklistName << ", Text: "
<< blockListItem.blocklistItemText << std::endl;
std::cout << "\033[0m"; // Reset the text color
}
auto analyzeEnd = std::chrono::high_resolution_clock::now();
auto analyzeTextDuration = std::chrono::duration_cast<std::chrono::milliseconds>(analyzeEnd - analyzeStart);
std::cout << "AnalyzeText duration: " << analyzeTextDuration.count() << " milliseconds" << std::endl;
std::cout << "--------------------------------------------------------------------------------------"
<< std::endl;
}
void processInputFile(TextModelRuntime* aacs, const std::string& inputDirectory, const std::string& fileName) {
auto filePath = inputDirectory + "\\" + fileName;
std::ifstream file(filePath);
if (!file.is_open()) {
std::cerr << "processInputFile, Could not open file: " << filePath << std::endl;
return;
}
std::string line;
while (std::getline(file, line)) {
processInputText(aacs, line);
}
file.close();
}
void processInputFileWithBlockList(TextModelRuntime* aacs, const std::string& inputDirectory, const std::string& fileName) {
auto filePath = inputDirectory + "\\" + fileName;
std::ifstream file(filePath);
if (!file.is_open()) {
std::cerr << "processInputFileWithBlockList, Could not open file: " << filePath << std::endl;
return;
}
std::vector<std::string> blocklist_names;
for (const auto& entry : std::filesystem::directory_iterator(inputDirectory)) {
if (entry.is_regular_file()) {
if (entry.path().extension() == ".csv") {
std::vector<char> buffer = readFile(entry.path().string());
auto blockListName = entry.path().stem().string();
aacs->AddBlocklist(blockListName, buffer.data(), buffer.size());
blocklist_names.push_back(blockListName);
std::cout << "processInputFileWithBlockList, adding block list : " << blockListName << std::endl;
}
}
}
std::string line;
while (std::getline(file, line)) {
processInputTextWithBlockList(aacs, line, blocklist_names);
}
file.close();
}
void processSampleInputFiles(TextModelRuntime *aacs, const std::string& inputDirectory, const std::string& inputFileName, const std::string& inputWithBlockListFileName) {
processInputFile(aacs, inputDirectory, inputFileName);
processInputFileWithBlockList(aacs, inputDirectory, inputWithBlockListFileName);
}
int main(int argc, char *argv[]) {
std::map<std::string, std::string> config = readConfig("config.ini");
TextModelRuntime *aacs = NULL;
Init(&aacs, config);
processSampleInputFiles(aacs, config["inputTextDirectory"], config["inputTextFile"], config["inputWithBlockListTextFile"]);
std::cout << "Press any key to continue..";
_getch();
return 0;
}