Runtime/Tokenizers/Decoders/Decoders.cs (194 lines of code) (raw):

using System; using System.Collections; using System.Collections.Generic; using System.Text; using System.Text.RegularExpressions; using System.Linq; using UnityEngine; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using HuggingFace.SharpTransformers.NormalizersUtils; namespace HuggingFace.SharpTransformers.Decoders { class Decoder_ { public JObject Config; public List<string> AddedTokens; public string EndOfWordSuffix; public string TrimOffsets; /// <summary> /// Creates an instance of `Decoder`. /// </summary> /// <param name="config">The configuration object.</param> public Decoder_(JObject config) { Config = config; AddedTokens = new List<string>(); EndOfWordSuffix = null; TrimOffsets = null; // config.trim_offsets } /// <summary> /// Creates a decoder instance based on the provided configuration. /// </summary> /// <param name="config">The configuration object.</param> /// <returns> A decoder instance.</returns> /// <exception cref="Exception">If an unknown decoder type is provided.</exception> public static Decoder_ FromConfig(JObject config) { string configType = config["type"].ToString(); switch (configType) { /*case "WordPiece": return new WordPieceDecoder(config);*/ case "Sequence": return new DecoderSequence(config); default: throw new Exception("Unknown Decoder type"); } } /// <summary> /// Apply the decoder to a list of tokens. /// </summary> /// <param name="tokens">The list of tokens.</param> /// <returns>The decoded list of tokens.</returns> /// <exception cref="Exception">If the `decode_chain` method is not implemented in the subclass.</exception> public virtual List<string> DecodeChain(List<string> tokens) { throw new Exception("`decode_chain` should be implemented in subclass."); } /// <summary> /// Method to be implemented in subclass to apply post-processing on the given tokens. /// </summary> /// <param name="tokens">The list of tokens.</param> /// <returns> The decoded string.</returns> public virtual List<string> Decode(List<string> tokens) { return tokens; //string.Join("", DecodeChain(tokens)); } /// <summary> /// Alias for Decode method /// </summary> /// <param name="tokens">The list of tokens.</param> /// <returns>The decoded string.</returns> public List<string> Call(List<string> tokens) { return Decode(tokens); } } /// <summary> /// A decoder that decodes a list of WordPiece tokens into a single string. /// </summary> /*class WordPieceDecoder : Decoder { public JObject Config; public string Prefix; public bool Cleanup; /// <summary> /// Creates a new instance of WordPieceDecoder. /// </summary> /// <param name="config">The configuration object.</param> public WordPieceDecoder(JObject config, string prefix) : base(config) { Config = config; Prefix = prefix; // Whether to cleanup the decoded string. Cleanup = (bool)config["cleanup"]; /*Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms AddedTokens = new List<string>(); EndOfWordSuffix = null; TrimOffsets = null; // config.trim_offsets } /// <summary> /// DecodeChain /// </summary> /// <param name="tokens"></param> /// <returns></returns> public List<string> DecodeChain(List<string> tokens) { List<string> decodedTokens = new List<string>(); for (int i = 0; i < tokens.Count; i++) { string token = tokens[i]; if (i != 0) { if (token.StartsWith(this.config.prefix)) // You will need to replace 'config.prefix' with the appropriate value { // NOTE: .Replace() is intended; only replace first occurrence token = token.Replace(this.config.prefix, ""); } else { token = " " + token; } } if (this.Cleanup) { token = CleanUpTokenization(token); } decodedTokens.Add(token); } return decodedTokens; } }*/ class ReplaceDecoder : Decoder_ { public ReplaceDecoder(JObject config) : base(config) { } public override List<string> DecodeChain(List<string> tokens) { string pattern = Utils.createPattern(this.Config["pattern"]); if (pattern == null) { return tokens; } // Iterate through each token in tokens array. Replacing all occurences of specified pattern // with the content defined the configuration and returning a new array with modified tokens. // We use LINQ's Select method to transform each element of the tokens list. return tokens.Select(token => token.Replace(pattern, this.Config["content"].Value<string>())).ToList(); } } /// <summary> /// Fuse all tokens into one big string /// Usually it's the last decoding step but this /// decoder exist in case some decoders need to happen after that step /// </summary> class FuseDecoder : Decoder_ { public FuseDecoder(JObject config) : base(config) { } public override List<string> DecodeChain(List<string> tokens) { return new List<string> { string.Concat(tokens) }; } } /// <summary> /// Handle tokens that represent bytes in hexadecimal format /// and convert them back into their corresponding string representations /// </summary> class ByteFallback : Decoder_ { public ByteFallback(JObject config) : base(config) { } private readonly UTF8Encoding uTF8Encoding = new UTF8Encoding(); public override List<string> DecodeChain(List<string> tokens) { var newTokens = new List<string>(); var previousByteTokens = new List<byte>(); foreach (var token in tokens) { byte? bytes = null; // We check if a token is in the <0xXX> format (where XX is a hexadecimal byte) and try to parse it to a byte if (token.Length == 6 && token.StartsWith("<0x") && token.EndsWith(">")) { if (byte.TryParse(token.Substring(3, 2), System.Globalization.NumberStyles.HexNumber, null, out byte byteValue)) { bytes = byteValue; } } // If successful we add it to previousByteTokens if (bytes != null) { previousByteTokens.Add(bytes.Value); } else { // If a token is not in the byte format, we check if there are any bytes in previousByteTokens, decode them into a string, add it to newTokens, and clear previousByteTokens. if (previousByteTokens.Count > 0) { var decodedString = uTF8Encoding.GetString(previousByteTokens.ToArray()); newTokens.Add(decodedString); previousByteTokens.Clear(); } newTokens.Add(token); } } if (previousByteTokens.Count > 0) { var decodedString = uTF8Encoding.GetString(previousByteTokens.ToArray()); newTokens.Add(decodedString); previousByteTokens.Clear(); } return newTokens; } } /// <summary> /// Strip character from the beginning and end of tokens. /// </summary> class StripDecoder : Decoder_ { private readonly string content; private readonly int start; private readonly int stop; public StripDecoder(JObject config) : base(config) { this.content = this.Config["content"].Value<string>(); this.start = this.Config["start"].Value<int>(); this.stop = this.Config["stop"].Value<int>(); } public override List<string> Decode(List<string> tokens) { return tokens.Select(token => { int startCut = 0; for (int i = 0; i < start; ++i) { if (token[i] == content[0]) { startCut = i + 1; continue; } else { break; } } int stopCut = token.Length; for (int i = 0; i < stop; ++i) { int index = token.Length - i - 1; if (token[index] == content[0]) { stopCut = index; continue; } else { break; } } return token.Substring(startCut, stopCut - startCut); }).ToList(); } } /// <summary> /// Apply a sequence of decoders /// </summary> class DecoderSequence : Decoder_ { public List<Decoder_> Decoders; public JObject Config; public DecoderSequence(JObject config) : base(config) { if (config == null) { throw new ArgumentNullException(nameof(config)); } if (config["decoders"] == null) { throw new ArgumentException("No decoders in Sequence"); } Decoders = new List<Decoder_>(); foreach (JObject decoderConfig in config["decoders"]) { var decoder = Decoder_.FromConfig(decoderConfig); if (decoder != null) { Decoders.Add(decoder); } } } /// <summary> /// Use this method to apply each decoder to the tokens /// </summary> /// <param name="tokens"></param> /// <returns></returns> public List<string> DecodeChain(List<string> tokens) { return this.Decoders.Aggregate(tokens, (currentTokens, decoder) => decoder.DecodeChain(currentTokens)); } } }