Runtime/Tokenizers/PostProcessors/PostProcessors.cs (82 lines of code) (raw):
using UnityEngine;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;
using System.Text.RegularExpressions;
using System.Linq;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
namespace HuggingFace.SharpTransformers.PostProcessors
{
public abstract class PostProcessor
{
public JObject Config;
public PostProcessor(JObject config)
{
Config = config;
}
/// <summary>
/// Factory method to create a PostProcessor object from a configuration object.
/// </summary>
/// <param name="config">Configuration object representing a PostProcessor.</param>
/// <returns>A PostProcessor object created from the given configuration.</returns>
/// <exception cref="Exception"></exception>
public static PostProcessor FromConfig(JObject config)
{
if (config == null)
{
return null;
}
string configType = config["type"].ToString();
switch (configType)
{
case "TemplateProcessing":
return new TemplateProcessing(config);
default:
throw new Exception("Unknown PostProcessor type");
}
}
/// <summary>
/// Method to be implemented in subclass to apply post-processing on the given tokens.
/// </summary>
/// <param name="tokens">The input tokens to be post-processed.</param>
/// <returns>The post-processed tokens.</returns>
public virtual List<string> PostProcess(List<string> tokens, List<string> tokensPair = null)
{
throw new Exception("PostProcess should be implemented in subclass.");
}
/// <summary>
/// Alias for PostProcess
/// </summary>
/// <param name="tokens">The text or array of texts to post-process.</param>
/// <returns>An array of post-processed tokens.</returns>
public virtual List<string> Call(List<string> tokens, List<string> tokensPair = null)
{
return PostProcess(tokens);
}
}
/// <summary>
/// Post processor that replaces special tokens in a template with actual tokens.
/// </summary>
public class TemplateProcessing : PostProcessor
{
public JObject Config;
// The template for a single sequence of tokens.
public JArray Single;
//public List<SingleItem> Single;
// The template for a pair of sequences of tokens.
public JArray Pair;
//public List<PairItem> Pair;
/// <summary>
/// Creates a new instance of TemplateProcessing
/// </summary>
/// <param name="config"></param>
public TemplateProcessing(JObject config) : base(config)
{
Config = config;
Single = (JArray)config["single"];
Pair = (JArray)config["pair"];
}
// The function's purpose is to replace special tokens and sequence identifiers with actual tokens.
public override List<string> PostProcess(List<string> tokens, List<string> tokensPair = null)
{
// Check the type of sequence (based on if tokensPair is provided or not)
// If tokensPair is null => assign Single to Type
// Else assign Pair to Type
JArray Type = tokensPair == null ? Single : Pair;
// Create an empty List<string> to store the resulting tokens after processing
List<string> ToReturn = new List<string>();
// The function iterates over each item in the Type List
foreach (JToken item in Type)
{
JObject itemJson = (JObject)item;
// If the curent item has a property called "Special Token"
// it means that this item is a special token.
if (itemJson.ContainsKey("SpecialToken"))
{
// We extracts the id of the special token and adds it to the toReturn List.
// We need to parse the JSON string and extract the id here
string specialTokenId = (string)itemJson["SpecialToken"]["id"];
ToReturn.Add(specialTokenId);
}
// If the current item has a property called "Sequence" it means that this item
// represents a sequence identifier (like 'A' or 'B')
else if (itemJson.ContainsKey("Sequence"))
{
string sequenceId = (string)itemJson["Sequence"]["id"];
if (sequenceId == "A")
{
// Add the elements of another collection to the list
// Equivalent to merge in JS
// Merge sequence tokens
ToReturn.AddRange(tokens);
}
else if (sequenceId == "B")
{
// Merge tokens_pair
ToReturn.AddRange(tokensPair);
}
}
}
return ToReturn;
}
}
}