sources/Google.Solutions.Ssh/SshWorkerThread.cs (241 lines of code) (raw):
//
// Copyright 2020 Google LLC
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//
using Google.Solutions.Common.Diagnostics;
using Google.Solutions.Common.Runtime;
using Google.Solutions.Common.Threading;
using Google.Solutions.Common.Util;
using Google.Solutions.Ssh.Native;
using System;
using System.ComponentModel;
using System.Diagnostics;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
namespace Google.Solutions.Ssh
{
public abstract class SshWorkerThread : IDisposable
{
private readonly IPEndPoint endpoint;
private readonly ISshCredential credential;
private readonly IKeyboardInteractiveHandler keyboardHandler;
private readonly Thread workerThread;
private readonly CancellationTokenSource workerCancellationSource;
private readonly WsaEventHandle readyToSend;
private bool disposed;
private static readonly RundownProtection workerThreadRundownProtection
= new RundownProtection();
//---------------------------------------------------------------------
// Properties.
//---------------------------------------------------------------------
/// <summary>
/// Interval to send keep alive messages in. Longer intervals cause
/// connection failures to be detected later.
/// </summary>
public TimeSpan KeepAliveInterval { get; set; } = TimeSpan.FromSeconds(5);
public TimeSpan SocketWaitInterval { get; set; } = TimeSpan.FromSeconds(2);
/// <summary>
/// Timeout for blocking operations (used during connection phase).
/// </summary>
public TimeSpan ConnectionTimeout { get; set; } = TimeSpan.FromSeconds(15);
public bool JoinWorkerThreadOnDispose { get; set; } = true;
public string? Banner { get; set; }
internal bool IsRunningOnWorkerThread
=> Environment.CurrentManagedThreadId == this.workerThread.ManagedThreadId;
//---------------------------------------------------------------------
// Ctor.
//---------------------------------------------------------------------
protected SshWorkerThread(
IPEndPoint endpoint,
ISshCredential credential,
IKeyboardInteractiveHandler keyboardHandler)
{
this.endpoint = endpoint.ExpectNotNull(nameof(endpoint));
this.credential = credential.ExpectNotNull(nameof(credential));
this.keyboardHandler = keyboardHandler.ExpectNotNull(nameof(keyboardHandler));
this.readyToSend = NativeMethods.WSACreateEvent();
this.workerCancellationSource = new CancellationTokenSource();
this.workerThread = new Thread(WorkerThreadProc)
{
Name = $"SSH worker for {credential.Username}@{this.endpoint}",
IsBackground = true
};
}
//---------------------------------------------------------------------
// Methods for subclasses.
//---------------------------------------------------------------------
protected void StartConnection()
{
if (this.workerThread.IsAlive)
{
throw new InvalidOperationException(
"Connect must only be called once");
}
this.workerThread.Start();
}
/// <summary>
/// Handle an error from an SSH channel.
///
/// Called on worker thread, method should not block for any
/// significant amount of time.
/// </summary>
private protected abstract void OnSendError(Exception exception);
/// <summary>
/// Handle an error from an SSH channel.
///
/// Called on worker thread, method should not block for any
/// significant amount of time.
/// </summary>
private protected abstract void OnReceiveError(Exception exception);
/// <summary>
/// Called once after SSH connection has been established successfully.
///
/// Called on worker thread, method should not block for any
/// significant amount of time.
/// </summary>
private protected abstract void OnConnected();
/// <summary>
/// Called once after SSH connection has failed.
///
/// Called on worker thread, method should not block for any
/// significant amount of time.
/// </summary>
private protected abstract void OnConnectionError(Exception exception);
/// <summary>
/// Perform any operation that sends data.
///
/// Called on worker thread, method should not block for any
/// significant amount of time.
/// </summary>
private protected abstract void OnReadyToSend(Libssh2AuthenticatedSession session);
/// <summary>
/// Perform any operation that sends data.
///
/// Called on worker thread, method should not block for any
/// significant amount of time.
/// </summary>
private protected abstract void OnReadyToReceive(Libssh2AuthenticatedSession session);
/// <summary>
/// Close channels and other resources before session is closed.
/// </summary>
private protected abstract void OnBeforeCloseSession();
protected bool IsConnected
=> this.workerThread.IsAlive &&
!this.workerCancellationSource.IsCancellationRequested;
/// <summary>
/// Notify that data is available for sending.
/// </summary>
protected void NotifyReadyToSend(bool ready)
{
if (ready)
{
NativeMethods.WSASetEvent(this.readyToSend);
}
else
{
NativeMethods.WSAResetEvent(this.readyToSend);
}
}
//---------------------------------------------------------------------
// Worker thread
//---------------------------------------------------------------------
[Flags]
private enum Operation
{
Sending,
Receiving
}
private void WorkerThreadProc()
{
//
// NB. libssh2 has limited support for multi-threading and in general,
// it's best to use a libssh2 session from a single thread only.
// Therefore, all libssh2 operations are performed by this one thead.
//
using (SshTraceSource.Log.TraceMethod().WithoutParameters())
{
try
{
using (workerThreadRundownProtection.Acquire())
using (var session = new Libssh2Session())
{
session.SetTraceHandler(
LIBSSH2_TRACE.SOCKET |
LIBSSH2_TRACE.ERROR |
LIBSSH2_TRACE.CONN |
LIBSSH2_TRACE.AUTH |
LIBSSH2_TRACE.KEX |
LIBSSH2_TRACE.SFTP,
SshTraceSource.Log.TraceVerbose);
if (!string.IsNullOrEmpty(this.Banner))
{
session.Banner = this.Banner!;
}
session.Timeout = this.ConnectionTimeout;
//
// Open connection and perform handshake using blocking I/O.
//
using (var connectedSession = session.Connect(this.endpoint))
using (var authenticatedSession = connectedSession.Authenticate(
this.credential,
this.keyboardHandler))
using (Disposable.Create(() => OnBeforeCloseSession()))
{
//
// Make sure the readyToSend handle remains valid throughout
// this thread's lifetime.
//
var readyToSendHandleSafeToUse = false;
this.readyToSend.DangerousAddRef(ref readyToSendHandleSafeToUse);
Debug.Assert(readyToSendHandleSafeToUse);
//
// With the channel established, switch to non-blocking I/O.
// Use a disposable scope to make sure that tearing down the
// connection is done using blocking I/O again.
//
using (session.AsNonBlocking())
using (Disposable.Create(() => this.readyToSend.DangerousRelease()))
using (var readyToReceive = NativeMethods.WSACreateEvent())
{
//
// Create an event that is signalled whenever there is data
// available to read on the socket.
//
// NB. This is a manual-reset event that must be reset by
// calling WSAEnumNetworkEvents.
//
if (NativeMethods.WSAEventSelect(
connectedSession.Socket.Handle,
readyToReceive,
NativeMethods.FD_READ) != 0)
{
throw new Win32Exception(
NativeMethods.WSAGetLastError(),
"WSAEventSelect failed");
}
//
// Looks good so far, consider the connection successful.
//
OnConnected();
//
// Set up keepalives. Because we use non-blocking I/O, we have to
// send keepalives by ourselves.
//
// NB. This method must not be called before the handshake has completed.
//
connectedSession.ConfigureKeepAlive(false, this.KeepAliveInterval);
var waitHandles = new[]
{
readyToReceive.DangerousGetHandle(),
this.readyToSend.DangerousGetHandle()
};
while (!this.workerCancellationSource.IsCancellationRequested)
{
var currentOperation = Operation.Receiving | Operation.Sending;
try
{
//
// In each iteration, wait for
// (data received on socket) OR (user data to send)
//
// NB. The timeout should not be lower than approx.
// one second, otherwise we spend too much time calling
// libssh2's keepalive function, which causes the terminal
// to become sluggish.
//
var waitResult = NativeMethods.WSAWaitForMultipleEvents(
(uint)waitHandles.Length,
waitHandles,
false,
(uint)this.SocketWaitInterval.TotalMilliseconds,
false);
if (waitResult == NativeMethods.WSA_WAIT_EVENT_0)
{
//
// Socket has data available.
//
currentOperation = Operation.Receiving;
//
// Reset the WSA event.
//
var wsaEvents = new NativeMethods.WSANETWORKEVENTS()
{
iErrorCode = new int[10]
};
if (NativeMethods.WSAEnumNetworkEvents(
connectedSession.Socket.Handle,
readyToReceive,
ref wsaEvents) != 0)
{
throw new Win32Exception(
NativeMethods.WSAGetLastError(),
"WSAEnumNetworkEvents failed");
}
//
// Perform whatever receiving operation we need to do.
//
// NB. We already reset the WSA event, so we must now read
// all data that's available.
//
OnReadyToReceive(authenticatedSession);
}
else if (waitResult == NativeMethods.WSA_WAIT_EVENT_0 + 1)
{
//
// User has data to send. Perform whatever send operation
// we need to do.
//
currentOperation = Operation.Sending;
OnReadyToSend(authenticatedSession);
}
else if (waitResult == NativeMethods.WSA_WAIT_TIMEOUT)
{
//
// Channel is idle - use the opportunity to send a
// keepalive. Libssh2 will ignore the call if no
// keepalive is due yet.
//
connectedSession.SendKeepAlive();
}
else if (waitResult == NativeMethods.WSA_WAIT_FAILED)
{
throw new Win32Exception(
NativeMethods.WSAGetLastError(),
"WSAWaitForMultipleEvents failed");
}
}
catch (Libssh2Exception e) when (e.ErrorCode == LIBSSH2_ERROR.EAGAIN)
{
//
// Retry operation.
//
}
catch (Exception e)
{
SshTraceSource.Log.TraceError(
"Socket I/O failed for {0}: {1}",
Thread.CurrentThread.Name,
e);
if ((currentOperation & Operation.Sending) != 0)
{
OnSendError(e);
}
else
{
//
// Consider it a receive error.
//
OnReceiveError(e);
}
//
// Bail out.
//
return;
}
} // while
} // nonblocking
}
}
}
catch (Exception e)
{
SshTraceSource.Log.TraceError(
"Connection failed for {0}: {1}",
Thread.CurrentThread.Name,
e);
OnConnectionError(e);
}
}
}
//---------------------------------------------------------------------
// Dispose.
//---------------------------------------------------------------------
protected virtual void Dispose(bool disposing)
{
// Stop worker thread.
this.workerCancellationSource.Cancel();
if (this.JoinWorkerThreadOnDispose)
{
Debug.Assert(
!this.IsRunningOnWorkerThread,
"Join on worker thread would cause deadlock");
this.workerThread.Join();
}
if (!this.disposed && disposing)
{
this.readyToSend.Dispose();
this.disposed = true;
}
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
/// <summary>
/// Wait for all worker threads to complete. Typically only needed
/// for test cases to prevent worker threads from being aborted.
/// </summary>
public static Task JoinAllWorkerThreadsAsync()
{
return workerThreadRundownProtection.WaitAsync();
}
}
}