LinuxCommunicator/LinuxCommunicator.cs (444 lines of code) (raw):

using System; using System.Collections.Concurrent; using System.Globalization; using System.IO; using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Formatting; using System.Net.Security; using System.Security.Principal; using System.Threading; using System.Threading.Tasks; using System.Xml.Linq; using Microsoft.Hpc.Activation; using Microsoft.Hpc.Communicators.LinuxCommunicator.HostsFile; using Microsoft.Hpc.Communicators.LinuxCommunicator.Monitoring; using Microsoft.Hpc.Scheduler.Communicator; using Microsoft.Hpc.Scheduler.Properties; namespace Microsoft.Hpc.Communicators.LinuxCommunicator { public class LinuxCommunicator : IManagedResourceCommunicator, IDisposable { private const string HttpsResourceUriFormat = "https://{0}:40002/api/{1}/{2}"; private const string HttpResourceUriFormat = "http://{0}:40000/api/{1}/{2}"; private const string CallbackUriHeaderName = "CallbackUri"; private const string HpcFullKeyName = @"HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\HPC"; private const string ClusterNameKeyName = "ClusterName"; private const string ClusterAuthenticationKeyName = "ClusterAuthenticationKey"; private const string LinuxHttpsKeyName = "LinuxHttps"; private const int AutoRetrySendLimit = 3; private const int AutoRetryStartLimit = 3; public readonly string ClusterAuthenticationKey = (string)Microsoft.Win32.Registry.GetValue(HpcFullKeyName, ClusterAuthenticationKeyName, null); public readonly int IsHttps = (int)Microsoft.Win32.Registry.GetValue(HpcFullKeyName, LinuxHttpsKeyName, 0); private readonly string HeadNode = (string)Microsoft.Win32.Registry.GetValue(HpcFullKeyName, ClusterNameKeyName, null); private readonly int MonitoringPort = 9894; private readonly TimeSpan RequestTimeout = TimeSpan.FromSeconds(40); private readonly TimeSpan DelayBetweenRetry = TimeSpan.FromSeconds(3); private readonly TimeSpan RetryStartInterval = TimeSpan.FromSeconds(10); private WebServer server; private CancellationTokenSource cancellationTokenSource; private static LinuxCommunicator instance; private Lazy<string> headNodeFqdn; private ConcurrentDictionary<string, Guid> cachedNodeGuids = new ConcurrentDictionary<string, Guid>(); private string ResourceUriFormat { get { return this.IsHttps > 0 ? HttpsResourceUriFormat : HttpResourceUriFormat; } } public LinuxCommunicator() { if (instance != null) { throw new InvalidOperationException("An instance of LinuxCommunicator already exists."); } instance = this; this.headNodeFqdn = new Lazy<string>(() => Dns.GetHostEntryAsync(this.HeadNode).Result.HostName, LazyThreadSafetyMode.ExecutionAndPublication); } public event EventHandler<RegisterEventArgs> RegisterRequested; public MonitoringConfigManager MonitoringConfigManager { get; private set; } public HostsFileManager HostsManager { get; private set; } public void Dispose() { this.server?.Dispose(); this.MonitoringConfigManager?.Dispose(); this.HostsManager?.Dispose(); GC.SuppressFinalize(this); } public static LinuxCommunicator Instance { get { return instance; } } public string LinuxServiceHost { get; set; } public INodeCommTracing Tracer { get; private set; } public ISchedulerCallbacks SchedulerCallbacks { get; set; } public NodeLocation Location { get { return NodeLocation.Linux; } } public void SetTracer(INodeCommTracing tracer) { this.Tracer = tracer; } static bool IsAdmin(string userName, string password) { SafeToken token = Credentials.GetTokenFromCredentials(userName, password); WindowsIdentity identity = new WindowsIdentity(token.DangerousGetHandle()); if (identity.IsSystem) { return true; } WindowsPrincipal principal = new WindowsPrincipal(identity); if (principal.IsInRole(WindowsBuiltInRole.Administrator)) { return true; } return false; } public bool Initialize() { this.Tracer.TraceInfo("Initializing LinuxCommunicator."); this.MonitoringConfigManager = new MonitoringConfigManager(this.headNodeFqdn.Value); Task.Run(() => this.MonitoringConfigManager.Initialize()); this.HostsManager = new HostsFileManager(); ServicePointManager.ServerCertificateValidationCallback += (s, cert, chain, sslPolicyErrors) => { this.Tracer.TraceDetail("sslPolicyErrors {0}", sslPolicyErrors); sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateNameMismatch; return sslPolicyErrors == System.Net.Security.SslPolicyErrors.None; }; this.server = new WebServer(this.IsHttps > 0); if (this.HeadNode == null) { ArgumentNullException exp = new ArgumentNullException(ClusterNameKeyName); Tracer.TraceError("Failed to find registry value: {0}. {1}", ClusterNameKeyName, exp); throw exp; } this.MonitoringConfigManager.ConfigChanged += (s, e) => { var result = Parallel.ForEach(this.cachedNodeGuids, kvp => { this.SetMetricConfig(kvp.Key, kvp.Value, e.CurrentConfig); }); }; this.Tracer.TraceInfo("Initialized LinuxCommunicator."); return true; } public bool Start() { this.MonitoringConfigManager?.Start(); return this.Start(0); } private bool Start(int retryCount) { this.Tracer.TraceInfo("Starting LinuxCommunicator. RetryCount {0}", retryCount); if (retryCount >= AutoRetryStartLimit) { this.Tracer.TraceInfo("Exceeding the auto retry start limit {0}", AutoRetryStartLimit); return false; } this.cancellationTokenSource?.Dispose(); this.cancellationTokenSource = new CancellationTokenSource(); try { this.server?.Start().Wait(); } catch (AggregateException aggrEx) { if (aggrEx.InnerExceptions.Any(e => e is HttpListenerException)) { this.Tracer.TraceWarning("Failed to start http listener {0}", aggrEx); return this.Start(retryCount + 1); } throw; } catch (HttpListenerException ex) { this.Tracer.TraceWarning("Failed to start http listener {0}", ex); return this.Start(retryCount + 1); } return true; } public bool Stop() { this.Tracer?.TraceInfo("Stopping LinuxCommunicator."); this.server?.Stop(); this.MonitoringConfigManager?.Stop(); this.cancellationTokenSource?.Cancel(); this.cancellationTokenSource?.Dispose(); this.cancellationTokenSource = null; return true; } public void OnNodeStatusChange(string nodeName, bool reachable) { } public void EndJob(string nodeName, EndJobArg arg, NodeCommunicatorCallBack<EndJobArg> callback) { this.SendRequest("endjob", this.GetCallbackUri(nodeName, "taskcompleted"), nodeName, async (content, ex) => { Exception readEx = null; try { if (content != null && ex == null) { arg.JobInfo = await content.ReadAsAsync<ComputeClusterJobInformation>(); } } catch (Exception e) { this.Tracer.TraceError("Exception while read the task info {0}", e); readEx = e; } if (ex != null && readEx != null) { ex = new AggregateException(ex, readEx); } else { ex = ex ?? readEx; } callback(nodeName, arg, ex); }, arg); } public void EndTask(string nodeName, EndTaskArg arg, NodeCommunicatorCallBack<EndTaskArg> callback) { this.SendRequest("endtask", this.GetCallbackUri(nodeName, "taskcompleted"), nodeName, async (content, ex) => { Exception readEx = null; try { if (content != null && ex == null) { arg.TaskInfo = await content.ReadAsAsync<ComputeClusterTaskInformation>(); } } catch (Exception e) { this.Tracer.TraceError("Exception while read the task info {0}", e); readEx = e; } if (ex != null && readEx != null) { ex = new AggregateException(ex, readEx); } else { ex = ex ?? readEx; } callback(nodeName, arg, ex); }, arg); } public void StartJobAndTask(string nodeName, StartJobAndTaskArg arg, string userName, string password, ProcessStartInfo startInfo, NodeCommunicatorCallBack<StartJobAndTaskArg> callback) { if (IsAdmin(userName, password)) { startInfo.EnvironmentVariables["CCP_ISADMIN"] = "1"; } this.SendRequest("startjobandtask", this.GetCallbackUri(nodeName, "taskcompleted"), nodeName, async (content, ex) => { await Task.Yield(); callback(nodeName, arg, ex); }, Tuple.Create(arg, startInfo, userName, password)); } public void StartJobAndTaskSoftCardCred(string nodeName, StartJobAndTaskArg arg, string userName, string password, byte[] certificate, ProcessStartInfo startInfo, NodeCommunicatorCallBack<StartJobAndTaskArg> callback) { this.SendRequest("startjobandtask", this.GetCallbackUri(nodeName, "taskcompleted"), nodeName, async (content, ex) => { await Task.Yield(); callback(nodeName, arg, ex); }, Tuple.Create(arg, startInfo, userName, password, certificate)); } public void StartJobAndTaskExtendedData(string nodeName, StartJobAndTaskArg arg, string userName, string password, string extendedData, ProcessStartInfo startInfo, NodeCommunicatorCallBack<StartJobAndTaskArg> callback) { string privateKey = null, publicKey = null; if (extendedData != null) { try { XDocument xDoc = XDocument.Parse(extendedData); var privateKeyNode = xDoc.Descendants("PrivateKey").FirstOrDefault(); var publicKeyNode = xDoc.Descendants("PublicKey").FirstOrDefault(); if (privateKeyNode != null) { privateKey = privateKeyNode.Value; } if (publicKeyNode != null) { publicKey = publicKeyNode.Value; } } catch (Exception ex) { this.Tracer.TraceWarning("Error parsing extended data {0}, ex {1}", extendedData, ex); } } if (IsAdmin(userName, password)) { startInfo.EnvironmentVariables["CCP_ISADMIN"] = "1"; } this.SendRequest("startjobandtask", this.GetCallbackUri(nodeName, "taskcompleted"), nodeName, async (content, ex) => { await Task.Yield(); callback(nodeName, arg, ex); }, Tuple.Create(arg, startInfo, userName, password, privateKey, publicKey)); } public void StartTask(string nodeName, StartTaskArg arg, ProcessStartInfo startInfo, NodeCommunicatorCallBack<StartTaskArg> callback) { this.SendRequest("starttask", this.GetCallbackUri(nodeName, "taskcompleted"), nodeName, async (content, ex) => { await Task.Yield(); callback(nodeName, arg, ex); }, Tuple.Create(arg, startInfo)); } public void Ping(string nodeName) { this.SendRequest<NodeCommunicatorCallBackArg>("ping", this.GetCallbackUri(nodeName, "computenodereported"), nodeName, async (content, ex) => { await Task.Yield(); this.Tracer.TraceInfo("Compute node {0} pinged. Ex {1}", nodeName, ex); }, null); } public void SetMetricGuid(string nodeName, Guid nodeGuid) { this.cachedNodeGuids.AddOrUpdate(nodeName, nodeGuid, (s, g) => nodeGuid); this.SetMetricConfig(nodeName, nodeGuid, this.MonitoringConfigManager.MetricCountersConfig); } public void SetMetricConfig(string nodeName, Guid nodeGuid, MetricCountersConfig config) { var callbackUri = this.GetMetricCallbackUri(this.headNodeFqdn.Value, this.MonitoringPort, nodeGuid); this.SendRequest("metricconfig", callbackUri, nodeName, async (content, ex) => { await Task.Yield(); this.Tracer.TraceInfo("Compute node {0} metricconfig requested, callback {1}. Ex {2}", nodeGuid, callbackUri, ex); }, config); this.SendRequest("metric", callbackUri, nodeName, async (content, ex) => { await Task.Yield(); this.Tracer.TraceInfo("Compute node {0} metric requested, callback {1}. Ex {2}", nodeGuid, callbackUri, ex); }, config); } private async Task SendRequestInternal<T>(string action, string callbackUri, string nodeName, Func<HttpContent, Exception, Task> callback, T arg, int retryCount = 0) { this.Tracer.TraceDetail("Sending out request, action {0}, callback {1}, nodeName {2}", action, callbackUri, nodeName); var request = new HttpRequestMessage(HttpMethod.Post, this.GetResoureUri(nodeName, action)); request.Headers.Add(CallbackUriHeaderName, callbackUri); request.Headers.Add(MessageAuthenticationHandler.AuthenticationHeaderKey, this.ClusterAuthenticationKey); var formatter = new JsonMediaTypeFormatter(); request.Content = new ObjectContent<T>(arg, formatter); Exception ex = null; HttpContent content = null; bool retry = false; using (HttpClient client = new HttpClient()) { client.Timeout = this.RequestTimeout; HttpResponseMessage response = null; try { response = await client.SendAsync(request, this.cancellationTokenSource.Token); } catch (Exception e) { ex = e; } this.Tracer.TraceDetail("Sending out request task completed, action {0}, callback {1}, nodeName {2} ex {3}", action, callbackUri, nodeName, ex); if (ex == null) { try { content = response.Content; if (!response.IsSuccessStatusCode) { using (content) { if (response.StatusCode >= HttpStatusCode.InternalServerError) { throw new InvalidProgramException(await content.ReadAsStringAsync()); } else { response.EnsureSuccessStatusCode(); } } } } catch (Exception e) { ex = e; } } if (this.CanRetry(ex) && retryCount < AutoRetrySendLimit) { retry = true; } else { try { await callback(content, ex); } catch (Exception callbackEx) { this.Tracer.TraceError("Finished sending, callback error: action {0}, callback {1}, nodeName {2} retry count {3}, ex {4}", action, callbackUri, nodeName, retryCount, callbackEx); } } } if (retry) { await Task.Delay(DelayBetweenRetry); this.SendRequest(action, callbackUri, nodeName, callback, arg, retryCount + 1); } } private void SendRequest<T>(string action, string callbackUri, string nodeName, Func<HttpContent, Exception, Task> callback, T arg, int retryCount = 0) { this.SendRequestInternal(action, callbackUri, nodeName, callback, arg, retryCount).ContinueWith(t => { this.Tracer.TraceDetail("Finished sending, action {0}, callback {1}, nodeName {2} retry count {3}", action, callbackUri, nodeName, retryCount); }); } private bool CanRetry(Exception exception) { if (this.cancellationTokenSource.IsCancellationRequested) { return false; } if (exception is HttpRequestException || exception is WebException || exception is TaskCanceledException) { return true; } var aggregateEx = exception as AggregateException; if (aggregateEx != null) { aggregateEx = aggregateEx.Flatten(); return aggregateEx.InnerExceptions.Any(e => this.CanRetry(e)); } return false; } private Uri GetResoureUri(string nodeName, string action) { return new Uri(string.Format(this.ResourceUriFormat, nodeName, nodeName, action)); } private string GetMetricCallbackUri(string headNodeName, int port, Guid nodeGuid) { return string.Format("udp://{0}:{1}/api/{2}/metricreported", headNodeName, port, nodeGuid); } private string GetCallbackUri(string nodeName, string action) { return string.Format("{0}/api/{1}/{2}", string.Format(CultureInfo.InvariantCulture, this.server.LinuxCommunicatorUriTemplate, this.headNodeFqdn.Value), nodeName, action); } public void OnRegisterRequested(RegisterEventArgs registerEventArgs) { var registerRequested = this.RegisterRequested; if (registerRequested != null) { registerRequested(this, registerEventArgs); } } } }