sources/Google.Solutions.Ssh/SshConnection.cs (240 lines of code) (raw):
//
// Copyright 2022 Google LLC
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//
using Google.Solutions.Platform.IO;
using Google.Solutions.Ssh.Native;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.Net;
using System.Threading.Tasks;
namespace Google.Solutions.Ssh
{
public class SshConnection : SshWorkerThread
{
private readonly Queue<ISendOperation> sendQueue = new Queue<ISendOperation>();
/// <summary>
/// Task that is completed after the SSH connection has
/// been established.
/// </summary>
/// <remarks>
/// Force continuations to run asycnhronously so that they
/// don't block the worker thread.
/// </remarks>
private readonly TaskCompletionSource<int> connectionCompleted
= new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
//
// List of open channels. Only accessed on worker thread,
// so no locking required.
//
private readonly LinkedList<SshChannelBase> channels
= new LinkedList<SshChannelBase>();
/// <summary>
/// Create a connection.
/// </summary>
/// <remarks>
/// Keyboard handler callbacks are delivered on the designated
/// synchronization context.
/// </remarks>
public SshConnection(
IPEndPoint endpoint,
ISshCredential credential,
IKeyboardInteractiveHandler keyboardHandler)
: base(endpoint, credential, keyboardHandler)
{
}
//---------------------------------------------------------------------
// Overrides.
//---------------------------------------------------------------------
private protected override void OnConnected()
{
//
// Complete task on callback context.
//
this.connectionCompleted.SetResult(0);
}
private protected override void OnConnectionError(Exception exception)
{
if (!this.connectionCompleted.Task.IsCompleted)
{
//
// Complete task (with asynchronous continuation).
//
this.connectionCompleted.SetException(exception);
}
}
private protected override void OnReadyToSend(Libssh2AuthenticatedSession session)
{
lock (this.sendQueue)
{
Debug.Assert(this.IsRunningOnWorkerThread);
Debug.Assert(this.sendQueue.Count > 0);
var packet = this.sendQueue.Peek();
//
// NB. The operation can throw an exception. It is
// important that we let this exception escape because
// it might be simply an EAGAIN situation. If the
// exception is bad, we'll receive an OnSendError
// callback later.
//
packet.Run(session);
//
// Sending succeeded - complete packet.
//
// N.B. Force continuations onto callback
// continuation context so that we don't block the
// current worker thread.
//
this.sendQueue.Dequeue();
packet.OnCompleted();
if (this.sendQueue.Count == 0)
{
//
// Do not ask us for more data, we do not have any
// right now.
//
NotifyReadyToSend(false);
}
}
}
private protected override void OnSendError(Exception exception)
{
lock (this.sendQueue)
{
Debug.Assert(this.sendQueue.Count > 0);
this.sendQueue.Dequeue().OnFailed(exception);
}
}
private protected override void OnReadyToReceive(Libssh2AuthenticatedSession session)
{
foreach (var channel in this.channels)
{
channel.OnReceive();
}
}
private protected override void OnReceiveError(Exception exception)
{
foreach (var channel in this.channels)
{
channel.OnReceiveError(exception);
}
}
private protected override void OnBeforeCloseSession()
{
Debug.Assert(this.IsRunningOnWorkerThread);
//
// Cancel outstanding operations to release waiters.
//
lock (this.sendQueue)
{
foreach (var operation in this.sendQueue)
{
operation.OnCancelled();
}
}
//
// Close all open channels.
//
foreach (var channel in this.channels)
{
channel.Dispose();
}
}
//---------------------------------------------------------------------
// Helper methods for channels.
//---------------------------------------------------------------------
/// <summary>
/// Execute a callback on the worker thread when the connection
/// is ready to send data.
/// </summary>
internal Task<TResult> RunAsync<TResult>(
Func<Libssh2AuthenticatedSession, TResult> callback,
bool terminateConnectionOnError = true)
{
if (!this.IsConnected)
{
//
// NB. SshConnectionClosedException must be a subclass of
// OperationCanceledException so that callers interpret
// this as a kind of cancellation.
//
var e = new SshConnectionClosedException();
Debug.Assert(e is OperationCanceledException);
throw e;
}
lock (this.sendQueue)
{
var packet = new SendOperation<TResult>(
callback,
terminateConnectionOnError);
this.sendQueue.Enqueue(packet);
//
// Nofify that we have data and expect a Send()
// callback.
//
NotifyReadyToSend(true);
//
// Return a task - it'll be completed once we've
// actually sent the data.
//
return packet.Task;
}
}
/// <summary>
/// Execute a callback on the worker thread when the connection
/// is ready to send data.
/// </summary>
internal Task RunAsync(
Action<Libssh2AuthenticatedSession> callback,
bool terminateConnectionOnError = true)
{
return RunAsync<object?>(
s =>
{
callback(s);
return null;
},
terminateConnectionOnError);
}
//---------------------------------------------------------------------
// Publics.
//---------------------------------------------------------------------
public Task ConnectAsync()
{
StartConnection();
return this.connectionCompleted.Task;
}
public Task<SshShellChannel> OpenShellAsync(
PseudoTerminalSize initialSize,
string terminalType,
CultureInfo? locale)
{
IEnumerable<EnvironmentVariable>? environmentVariables = null;
if (locale != null)
{
//
// Format language so that Linux understands it.
//
var languageFormatted = locale.Name.Replace('-', '_');
environmentVariables = new[]
{
//
// Try to pass locale - but do not fail the connection if
// the server rejects it.
//
new EnvironmentVariable(
"LC_ALL",
$"{languageFormatted}.UTF-8",
false)
};
}
return RunAsync(
session =>
{
Debug.Assert(this.IsRunningOnWorkerThread);
using (session.Session.AsBlocking())
{
var nativeChannel = session.OpenShellChannel(
LIBSSH2_CHANNEL_EXTENDED_DATA.MERGE,
terminalType,
initialSize.Width,
initialSize.Height,
environmentVariables);
var channel = new SshShellChannel(
this,
nativeChannel);
this.channels.AddLast(channel);
return channel;
}
},
false);
}
public Task<SftpChannel> OpenFileSystemAsync()
{
return RunAsync(
session =>
{
Debug.Assert(this.IsRunningOnWorkerThread);
using (session.Session.AsBlocking())
{
var channel = new SftpChannel(
this,
session.OpenSftpChannel());
this.channels.AddLast(channel);
return channel;
}
});
}
//---------------------------------------------------------------------
// Inner classes.
//---------------------------------------------------------------------
private interface ISendOperation
{
/// <summary>
/// Run the operation.
/// </summary>
/// <remarks>
/// Any exception thrown by this method causes
/// the connection to be terminated.
/// </remarks>
void Run(Libssh2AuthenticatedSession session);
/// <summary>
/// Mark the operation as successful.
/// </summary>
void OnCompleted();
/// <summary>
/// Mark the operation as failed.
/// </summary>
void OnFailed(Exception e);
/// <summary>
/// Cancel operation.
/// </summary>
void OnCancelled();
}
protected internal class SendOperation<TResult> : ISendOperation
{
private readonly TaskCompletionSource<TResult> completionSource;
private readonly Func<Libssh2AuthenticatedSession, TResult> operation;
private readonly bool terminateConnectionOnError;
private TResult result = default!;
private Exception? exception;
internal SendOperation(
Func<Libssh2AuthenticatedSession, TResult> operation,
bool terminateConnectionOnError)
{
this.operation = operation;
this.terminateConnectionOnError = terminateConnectionOnError;
//
// Force continuations to run asycnhronously so that they
// don't block the worker thread.
//
this.completionSource = new TaskCompletionSource<TResult>(
TaskCreationOptions.RunContinuationsAsynchronously);
}
void ISendOperation.Run(Libssh2AuthenticatedSession session)
{
try
{
this.result = this.operation(session);
}
catch (Exception e) when (!this.terminateConnectionOnError)
{
//
// Swallow exception here so that we keep the
// connection alive.
//
this.exception = e;
}
}
void ISendOperation.OnCompleted()
{
//
// Propagate exception or result to awaiters.
//
if (this.exception != null)
{
this.completionSource.TrySetException(this.exception);
}
else
{
this.completionSource.TrySetResult(this.result);
}
}
void ISendOperation.OnFailed(Exception e)
{
this.completionSource.TrySetException(e);
}
void ISendOperation.OnCancelled()
{
this.completionSource.TrySetCanceled();
}
/// <summary>
/// Task to await completion.
/// </summary>
public Task<TResult> Task
{
get => this.completionSource.Task;
}
}
}
}