src/Common/Utilities/GeneralUtilities.cs (424 lines of code) (raw):

// ---------------------------------------------------------------------------------- // // Copyright Microsoft Corporation // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ---------------------------------------------------------------------------------- using Hyak.Common; using Microsoft.Azure.Commands.Common.Authentication; using Microsoft.Azure.Commands.Common.Authentication.Abstractions; using Microsoft.WindowsAzure.Commands.Common; using Newtonsoft.Json; using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Reflection; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Text.RegularExpressions; using System.Xml.Linq; namespace Microsoft.WindowsAzure.Commands.Utilities.Common { public static class GeneralUtilities { private static Assembly assembly = Assembly.GetExecutingAssembly(); private static List<string> AuthorizationHeaderNames = new List<string>() { "Authorization" }; // this is only used to determine cutoff for streams (not xml or json). private const int StreamCutOffSize = 10 * 1024; //10KB private static bool TryFindCertificatesInStore(string thumbprint, System.Security.Cryptography.X509Certificates.StoreLocation location, out X509Certificate2Collection certificates) { X509Certificate2Collection found = null; DiskDataStore.X509StoreWrapper(StoreName.My, location, (store) => { store.Open(OpenFlags.ReadOnly); found = store.Certificates.Find(X509FindType.FindByThumbprint, thumbprint, false); }); certificates = found; return certificates != null && certificates.Count > 0; } public static X509Certificate2 GetCertificateFromStore(string thumbprint) { if (string.IsNullOrWhiteSpace(thumbprint)) { throw new ArgumentNullException("certificate thumbprint"); } X509Certificate2Collection certificates; if (TryFindCertificatesInStore(thumbprint, StoreLocation.CurrentUser, out certificates) || TryFindCertificatesInStore(thumbprint, StoreLocation.LocalMachine, out certificates)) { return certificates[0]; } else { throw new ArgumentException(string.Format( "Certificate {0} was not found in the certificate store. Please ensure the referenced " + "certificate exists in the the LocalMachine\\My or CurrentUser\\My store", thumbprint)); } } /// <summary> /// Compares two strings with handling special case that base string can be empty. /// </summary> /// <param name="leftHandSide">The base string.</param> /// <param name="rightHandSide">The comparer string.</param> /// <returns>True if equals or leftHandSide is null/empty, false otherwise.</returns> public static bool TryEquals(string leftHandSide, string rightHandSide) { if (string.IsNullOrEmpty(leftHandSide) || leftHandSide.Equals(rightHandSide, StringComparison.OrdinalIgnoreCase)) { return true; } return false; } public static string GetConfiguration(string configurationPath) { var configuration = string.Join(string.Empty, File.ReadAllLines(configurationPath)); return configuration; } /// <summary> /// Get the value for a given key in a dictionary or return a default /// value if the key isn't present in the dictionary. /// </summary> /// <typeparam name="K">The type of the key.</typeparam> /// <typeparam name="V">The type of the value.</typeparam> /// <param name="dictionary">The dictionary.</param> /// <param name="key">The key.</param> /// <param name="defaultValue">A default value</param> /// <returns>The corresponding value or default value.</returns> public static V GetValueOrDefault<K, V>(this IDictionary<K, V> dictionary, K key, V defaultValue) { Debug.Assert(dictionary != null, "dictionary cannot be null!"); V value; if (!dictionary.TryGetValue(key, out value)) { value = defaultValue; } return value; } /// <summary> /// Returns a non-null sequence by either passing back the original /// sequence or creating a new empty sequence if the original was null. /// </summary> /// <typeparam name="T">Type of elements in the sequence.</typeparam> /// <param name="sequence">The sequence.</param> /// <returns>A non-null sequence.</returns> public static IEnumerable<T> NonNull<T>(this IEnumerable<T> sequence) { return (sequence != null) ? sequence : Enumerable.Empty<T>(); } /// <summary> /// Perform an action on each element of a sequence. /// </summary> /// <typeparam name="T">Type of elements in the sequence.</typeparam> /// <param name="sequence">The sequence.</param> /// <param name="action">The action to perform.</param> public static void ForEach<T>(this IEnumerable<T> sequence, Action<T> action) { Debug.Assert(sequence != null, "sequence cannot be null!"); Debug.Assert(action != null, "action cannot be null!"); foreach (T element in sequence) { action(element); } } /// <summary> /// Append an element to the end of an array. /// </summary> /// <typeparam name="T">Type of the arrays.</typeparam> /// <param name="left">The left array.</param> /// <param name="right">The right array.</param> /// <returns>The concatenated arrays.</returns> public static T[] Append<T>(T[] left, T right) { if (left == null) { return right != null ? new T[] { right } : new T[] { }; } else if (right == null) { return left; } else { return Enumerable.Concat(left, new T[] { right }).ToArray(); } } public static TResult MaxOrDefault<T, TResult>(this IEnumerable<T> sequence, Func<T, TResult> selector, TResult defaultValue) { return (sequence != null) ? sequence.Max(selector) : defaultValue; } /// <summary> /// Extends the array with one element. /// </summary> /// <typeparam name="T">The array type</typeparam> /// <param name="collection">The array holding elements</param> /// <param name="item">The item to add</param> /// <returns>New array with added item</returns> public static T[] ExtendArray<T>(IEnumerable<T> collection, T item) { if (collection == null) { collection = new T[0]; } List<T> list = new List<T>(collection); list.Add(item); return list.ToArray<T>(); } /// <summary> /// Extends the array with another array /// </summary> /// <typeparam name="T">The array type</typeparam> /// <param name="collection">The array holding elements</param> /// <param name="items">The items to add</param> /// <returns>New array with added items</returns> public static T[] ExtendArray<T>(IEnumerable<T> collection, IEnumerable<T> items) { if (collection == null) { collection = new T[0]; } if (items == null) { items = new T[0]; } return collection.Concat<T>(items).ToArray<T>(); } /// <summary> /// Initializes given object if its set to null. /// </summary> /// <typeparam name="T">The object type</typeparam> /// <param name="obj">The object to initialize</param> /// <returns>Initialized object</returns> public static T InitializeIfNull<T>(T obj) where T : new() { if (obj == null) { return new T(); } return obj; } public static string EnsureTrailingSlash(string url) { UriBuilder address = new UriBuilder(url); if (!address.Path.EndsWith("/", StringComparison.Ordinal)) { address.Path += "/"; } return address.Uri.AbsoluteUri; } public static string GetHttpResponseLog(string statusCode, IDictionary<string, IEnumerable<string>> headers, string body, IList<Regex> matchers = null) { StringBuilder httpResponseLog = new StringBuilder(); httpResponseLog.AppendLine($"============================ HTTP RESPONSE ============================{Environment.NewLine}"); httpResponseLog.AppendLine($"Status Code:{Environment.NewLine}{statusCode}{Environment.NewLine}"); httpResponseLog.AppendLine($"Headers:{ Environment.NewLine}{ MessageHeadersToString(headers)}"); httpResponseLog.AppendLine($"Body:{Environment.NewLine}{TransformBody(body, matchers)}{Environment.NewLine}"); return httpResponseLog.ToString(); } public static string GetHttpResponseLog(string statusCode, HttpHeaders headers, string body, IList<Regex> matchers = null) { return GetHttpResponseLog(statusCode, ConvertHttpHeadersToWebHeaderCollection(headers), body, matchers); } public static string TransformBody(string inBody) { IList<Regex> matchers = new List<Regex>(); Regex matcher = new Regex("(\\s*\"access_token\"\\s*:\\s*)\"[^\"]+\""); matchers.Add(matcher); return TransformBody(inBody, matchers); } public static string TransformBody(string inBody, IList<Regex> matchers) { if (matchers != null) { foreach (Regex match in matchers) { inBody = match.Replace(inBody, "$1\"<redacted>\""); } } return inBody; } public static string GetHttpRequestLog( string method, string requestUri, IDictionary<string, IEnumerable<string>> headers, string body, IList<Regex> matchers = null) { StringBuilder httpRequestLog = new StringBuilder(); httpRequestLog.AppendLine(string.Format("============================ HTTP REQUEST ============================{0}", Environment.NewLine)); httpRequestLog.AppendLine(string.Format("HTTP Method:{0}{1}{0}", Environment.NewLine, method)); httpRequestLog.AppendLine(string.Format("Absolute Uri:{0}{1}{0}", Environment.NewLine, requestUri)); httpRequestLog.AppendLine(string.Format("Headers:{0}{1}", Environment.NewLine, MessageHeadersToString(headers))); httpRequestLog.AppendLine(string.Format("Body:{0}{1}{0}", Environment.NewLine, TransformBody(body, matchers))); return httpRequestLog.ToString(); } public static string GetHttpRequestLog(string method, string requestUri, HttpHeaders headers, string body, IList<Regex> matchers = null) { return GetHttpRequestLog(method, requestUri, ConvertHttpHeadersToWebHeaderCollection(headers), body, matchers); } public static string GetLog(HttpResponseMessage response, IList<Regex> matchers = null) { if (response == null) { return string.Empty; } string body = response.Content == null ? string.Empty : FormatString(response.Content.ReadAsStringAsync().Result); return GetHttpResponseLog( response.StatusCode.ToString(), response.Headers, body, matchers); } public static string GetLog(HttpResponseMessage response) { return GetLog(response, null); } public static string GetLog(HttpRequestMessage request, IList<Regex> matchers = null) { if (request == null) { return string.Empty; } string body = request.Content == null ? string.Empty : FormatString(request.Content.ReadAsStringAsync().Result); return GetHttpRequestLog( request.Method.ToString(), request.RequestUri.ToString(), (HttpHeaders)request.Headers, body, matchers); } public static string GetLog(HttpRequestMessage request) { return GetLog(request, null); } public static string FormatString(string content) { if (CloudException.IsXml(content)) { return TryFormatXml(content); } else if (CloudException.IsJson(content)) { return TryFormatJson(content); } else { return content.Length <= GeneralUtilities.StreamCutOffSize ? content : content.Substring(0, StreamCutOffSize) + "\r\nDATA TRUNCATED DUE TO SIZE\r\n"; } } private static string TryFormatJson(string str) { try { object parsedJson = JsonConvert.DeserializeObject(str); return JsonConvert.SerializeObject(parsedJson, Newtonsoft.Json.Formatting.Indented); } catch { // can't parse JSON, return the original string return str; } } private static string TryFormatXml(string content) { try { XDocument doc = XDocument.Parse(content); return doc.ToString(); } catch (Exception) { return content; } } private static IDictionary<string, IEnumerable<string>> ConvertHttpHeadersToWebHeaderCollection(HttpHeaders headers) { IDictionary<string, IEnumerable<string>> webHeaders = new Dictionary<string, IEnumerable<string>>(); foreach (KeyValuePair<string, IEnumerable<string>> pair in headers) { if (AuthorizationHeaderNames.Any(h => h.Equals(pair.Key, StringComparison.OrdinalIgnoreCase))) { // Skip adding the authorization header continue; } webHeaders.Add(pair.Key, pair.Value); } return webHeaders; } private static string MessageHeadersToString(IDictionary<string, IEnumerable<string>> headers) { string[] keys = headers.Keys.ToArray(); StringBuilder result = new StringBuilder(); foreach (string key in keys) { result.AppendLine(string.Format( "{0,-30}: {1}", key, ConversionUtilities.ArrayToString(headers[key].ToArray(), ","))); } return result.ToString(); } /// <summary> /// Creates https endpoint from the given endpoint. /// </summary> /// <param name="endpointUri">The endpoint uri.</param> /// <returns>The https endpoint uri.</returns> public static Uri CreateHttpsEndpoint(string endpointUri) { UriBuilder builder = new UriBuilder(endpointUri) { Scheme = "https" }; string endpoint = builder.Uri.GetComponents( UriComponents.AbsoluteUri & ~UriComponents.Port, UriFormat.UriEscaped); return new Uri(endpoint); } /// <summary> /// Pad a string using the given separator string /// </summary> /// <param name="amount">The number of repetitions of the separator</param> /// <param name="separator">The separator string to use</param> /// <returns>A string containing the given number of repetitions of the separator string</returns> public static string GenerateSeparator(int amount, string separator) { StringBuilder result = new StringBuilder(); while (amount-- != 0) result.Append(separator); return result.ToString(); } /// <summary> /// Ensure the default profile directory exists /// </summary> public static void EnsureDefaultProfileDirectoryExists() { if (!AzureSession.Instance.DataStore.DirectoryExists(AzureSession.Instance.ProfileDirectory)) { AzureSession.Instance.DataStore.CreateDirectory(AzureSession.Instance.ProfileDirectory); } } /// <summary> /// Checks if collection has more than one element /// </summary> /// <typeparam name="T">Type of the collection.</typeparam> /// <param name="collection">Collection.</param> /// <returns></returns> public static bool HasMoreThanOneElement<T>(ICollection<T> collection) { return collection != null && collection.Count > 1; } /// <summary> /// Checks if collection has only one element /// </summary> /// <typeparam name="T">Type of the collection.</typeparam> /// <param name="collection">Collection.</param> /// <returns></returns> public static bool HasSingleElement<T>(ICollection<T> collection) { return collection != null && collection.Count == 1; } /// <summary> /// Clear the current storage account from the context - guarantees that only one storage account will be active /// at a time. /// </summary> /// <param name="clearSMContext">Whether to clear the service management context.</param> public static void ClearCurrentStorageAccount(bool clearSMContext = false) { if (AzureRmProfileProvider.Instance != null) { var RMProfile = AzureRmProfileProvider.Instance.Profile; if (RMProfile != null && RMProfile.DefaultContext != null && RMProfile.DefaultContext.Subscription != null && RMProfile.DefaultContext.Subscription.IsPropertySet(AzureSubscription.Property.StorageAccount)) { RMProfile.DefaultContext.Subscription.SetProperty(AzureSubscription.Property.StorageAccount, null); } } #if !NETSTANDARD if (clearSMContext && AzureSMProfileProvider.Instance != null) { var SMProfile = AzureSMProfileProvider.Instance.Profile; if (SMProfile != null && SMProfile.DefaultContext != null && SMProfile.DefaultContext.Subscription != null && SMProfile.DefaultContext.Subscription.IsPropertySet(AzureSubscription.Property.StorageAccount)) { SMProfile.DefaultContext.Subscription.SetProperty(AzureSubscription.Property.StorageAccount, null); } } #endif } /// <summary> /// Execute a process and check for a clean exit to determine if the process exists. /// </summary> /// <param name="programName">Name of the program to start.</param> /// <param name="args">Command line argumentes provided to the program.</param> /// <param name="waitTime">Time to wait for the process to close.</param> /// <param name="criterion">Function to evaluate the process response to determine success. The default implementation returns true if the exit code equals 0.</param> /// <returns></returns> public static bool Probe(string programName, string args = "", int waitTime = 3000, Func<ProcessExitInfo, bool> criterion = null) { try { var process = new Process { StartInfo = new ProcessStartInfo { FileName = programName, Arguments = args, RedirectStandardOutput = true, RedirectStandardError = true, UseShellExecute = false } }; var stdout = new List<string>(); var stderr = new List<string>(); process.OutputDataReceived += (s, e) => stdout.Add(e.Data); process.ErrorDataReceived += (s, e) => stderr.Add(e.Data); process.Start(); process.BeginErrorReadLine(); process.BeginOutputReadLine(); process.WaitForExit(waitTime); var exitInfo = new ProcessExitInfo { ExitCode = process.ExitCode, StdOut = stdout, StdErr = stderr }; var exitCode = process.ExitCode; return criterion == null ? exitInfo.ExitCode == 0 : criterion(exitInfo); } catch (InvalidOperationException) { // The excutable failed to execute prior wait time expiring. return false; } catch (SystemException) { // The excutable doesn't exist on path. Rather than handling Win32 exception, chose to handle a less platform specific sys exception. return false; } } /// <summary> /// Process exit information /// </summary> public class ProcessExitInfo { /// <summary> /// Exit code of a process /// </summary> public int ExitCode { get; set; } /// <summary> /// List of all lines from STDOUT /// </summary> public IList<string> StdOut { get; set; } /// <summary> /// List of all lines from STDERR /// </summary> public IList<string> StdErr { get; set; } } public static string DownloadFile(string uri) { string contents = null; using (WebClient webClient = new WebClient()) { try { contents = webClient.DownloadString(new Uri(uri)); } catch { // Ignore the exception and return empty contents } } return contents; } } }