sources/Google.Solutions.Ssh/Native/Libssh2Session.cs (333 lines of code) (raw):

// // Copyright 2020 Google LLC // // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. // using Google.Solutions.Common.Diagnostics; using Google.Solutions.Common.Runtime; using Google.Solutions.Common.Util; using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Net; using System.Net.Sockets; using System.Runtime.InteropServices; using System.Text; namespace Google.Solutions.Ssh.Native { /// <summary> /// An (unconnected) Libssh2 session. /// </summary> internal class Libssh2Session : IDisposable { private const string BannerPrefix = "SSH-2.0-"; private const ushort MaxKexAttempts = 3; private Libssh2SessionHandle? sessionHandle; private bool disposed = false; private string? banner; private bool blocking; private TimeSpan timeout = TimeSpan.Zero; private NativeMethods.TraceHandler? traceHandlerDelegate; private LIBSSH2_TRACE traceMask = (LIBSSH2_TRACE)0; private readonly Dictionary<LIBSSH2_METHOD, string[]> preferredMethods = new Dictionary<LIBSSH2_METHOD, string[]>(); internal static readonly NativeMethods.Alloc Alloc; internal static readonly NativeMethods.Free Free; internal static readonly NativeMethods.Realloc Realloc; //--------------------------------------------------------------------- // Initialization. //--------------------------------------------------------------------- static Libssh2Session() { // // Store these delegates in fields to prevent them from being // garbage collected. Otherwise callbacks will suddenly // start hitting GC'ed memory. // Alloc = (size, context) => Marshal.AllocHGlobal(size); Realloc = (ptr, size, context) => Marshal.ReAllocHGlobal(ptr, size); Free = (ptr, context) => Marshal.FreeHGlobal(ptr); try { var result = (LIBSSH2_ERROR)NativeMethods.libssh2_init(0); if (result != LIBSSH2_ERROR.NONE) { throw new Libssh2Exception( result, "Failed to initialize libssh2"); } } catch (EntryPointNotFoundException) { throw new SshException("libssh2 DLL not found or could not be loaded"); } } internal Libssh2Session() { // // Use blocking I/O by default. // this.blocking = true; } /// <summary> /// Lazily initialize a Libssh2 session. /// </summary> private void InitializeSession(bool force) { using (SshTraceSource.Log.TraceMethod().WithParameters(force)) { if (this.sessionHandle != null && force) { // // Close existing session to force re-initialization. // this.sessionHandle.CheckCurrentThreadOwnsHandle(); this.sessionHandle.Dispose(); this.sessionHandle = null; } if (this.sessionHandle == null) { this.sessionHandle = NativeMethods.libssh2_session_init_ex( Alloc, Free, Realloc, IntPtr.Zero); if (this.traceHandlerDelegate != null) { // // NB. We must not pass an anonymous delegate to // libssh2_trace_sethandler as it might be garbage- // collected while still being referenced by native // code. // NativeMethods.libssh2_trace_sethandler( this.sessionHandle, IntPtr.Zero, this.traceHandlerDelegate); NativeMethods.libssh2_trace( this.sessionHandle, this.traceMask); } NativeMethods.libssh2_session_set_timeout( this.sessionHandle, (int)this.timeout.TotalMilliseconds); NativeMethods.libssh2_session_set_blocking( this.sessionHandle, this.blocking ? 1 : 0); if (this.Banner != null) { _ = NativeMethods.libssh2_session_banner_set( this.sessionHandle, BannerPrefix + this.Banner); } foreach (var preferredMethod in this.preferredMethods) { var prefs = string.Join(",", preferredMethod.Value); var result = (LIBSSH2_ERROR)NativeMethods.libssh2_session_method_pref( this.sessionHandle, preferredMethod.Key, prefs); if (result != LIBSSH2_ERROR.NONE) { throw CreateException(result); } } } Debug.Assert(this.sessionHandle != null); } } internal Libssh2SessionHandle Handle { get => this.sessionHandle ?? throw new InvalidOperationException("The session has not been initialized"); } //--------------------------------------------------------------------- // Publics. //--------------------------------------------------------------------- /// <summary> /// Custom banner to use during handshake. /// </summary> public string? Banner { get => this.banner; set { Precondition.Expect( value.All(c => c != '-' && !char.IsWhiteSpace(c)), "Banner must not contain whitespace or dashes"); this.banner = value; } } /// <summary> /// Set whether touse blocking I/O or Unix-style non-blocking I/O. /// </summary> public bool IsBlocking { get => this.blocking; private set { this.blocking = value; if (this.sessionHandle != null) { // // Apply to existing session. // this.sessionHandle.CheckCurrentThreadOwnsHandle(); NativeMethods.libssh2_session_set_blocking( this.sessionHandle, value ? 1 : 0); } } } /// <summary> /// Enable blocking I/O for a using block. /// </summary> public IDisposable AsBlocking() { this.IsBlocking = true; return Disposable.Create(() => this.IsBlocking = false); } /// <summary> /// Enable non-blocking I/O for a using block. /// </summary> public IDisposable AsNonBlocking() { this.IsBlocking = false; return Disposable.Create(() => this.IsBlocking = true); } /// <summary> /// Query the list of supported algorithms. /// </summary> /// <remarks>This forces the session to be initialized.</remarks> internal string[] GetSupportedAlgorithms(LIBSSH2_METHOD methodType) { // // Initialize session if that hasn't happened yet. // InitializeSession(false); this.sessionHandle!.CheckCurrentThreadOwnsHandle(); using (SshTraceSource.Log.TraceMethod().WithParameters(methodType)) { var count = NativeMethods.libssh2_session_supported_algs( this.sessionHandle, methodType, out var algorithmsPtrPtr); if (count > 0 && algorithmsPtrPtr != IntPtr.Zero) { var algorithmsPtrs = new IntPtr[count]; Marshal.Copy(algorithmsPtrPtr, algorithmsPtrs, 0, algorithmsPtrs.Length); var algorithms = algorithmsPtrs .Select(ptr => Marshal.PtrToStringAnsi(ptr)) .ToArray(); _ = NativeMethods.libssh2_free( this.sessionHandle, algorithmsPtrPtr); return algorithms; } else if (count < 0) { throw CreateException((LIBSSH2_ERROR)count); } else { return Array.Empty<string>(); } } } /// <summary> /// Set preferred methods. /// </summary> /// <remarks>Must be called before Connect()</remarks> internal void SetPreferredMethods( LIBSSH2_METHOD methodType, string[] methods) { Precondition.ExpectNotNullOrZeroSized(methods, nameof(methods)); Precondition.Expect( this.sessionHandle == null, "Method must be called before the session is initialized"); this.preferredMethods[methodType] = methods; } /// <summary> /// Timeout for blocking operations. /// </summary> public TimeSpan Timeout { get => this.timeout; set { this.timeout = value; // // Update existing session. // if (this.sessionHandle != null) { this.sessionHandle.CheckCurrentThreadOwnsHandle(); NativeMethods.libssh2_session_set_timeout( this.sessionHandle, (int)value.TotalMilliseconds); } } } /// <summary> /// Temporarily adjust timeout for the duration of a using block. /// </summary> public IDisposable WithTimeout(TimeSpan timeout) { var originalTimeout = this.Timeout; this.Timeout = timeout; return Disposable.Create(() => this.Timeout = originalTimeout); } /// <summary> /// Time to wait for user to react to keyboard/interactive prompts. /// </summary> public TimeSpan KeyboardInteractivePromptTimeout { get; set; } = TimeSpan.FromMinutes(1); /// <summary> /// Connect to remote server and initiate handshake. /// </summary> /// <remarks>Forces the native session to be initialzed.</remarks> public Libssh2ConnectedSession Connect(EndPoint remoteEndpoint) { using (SshTraceSource.Log.TraceMethod().WithParameters(remoteEndpoint)) { for (var kexAttempt = 0; ; kexAttempt++) { var socket = new Socket( AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp) { // // Flush input data immediately so that the user does not // experience a lag. // NoDelay = true }; // // Initialize session if that hasn't happened yet. // InitializeSession(kexAttempt > 0); this.sessionHandle!.CheckCurrentThreadOwnsHandle(); SshEventSource.Log.ConnectionHandshakeInitiated(remoteEndpoint.ToString()); socket.Connect(remoteEndpoint); var result = (LIBSSH2_ERROR)NativeMethods.libssh2_session_handshake( this.sessionHandle, socket.Handle); if (result == LIBSSH2_ERROR.KEY_EXCHANGE_FAILURE && kexAttempt < MaxKexAttempts) { // // When using the WinCNG backend, key exchanges can occasionally fail, // see https://github.com/libssh2/libssh2/issues/804. // // Retry a few times. // SshTraceSource.Log.TraceWarning( "KEX failed (attempt {0}/{1}), retrying...", kexAttempt, MaxKexAttempts); socket.Close(); } else if (result != LIBSSH2_ERROR.NONE) { // // Some other error occured, don't retry. // socket.Close(); throw CreateException(result); } else { SshEventSource.Log.ConnectionHandshakeCompleted(remoteEndpoint.ToString()); return new Libssh2ConnectedSession(this, socket); } } } } //--------------------------------------------------------------------- // Error. //--------------------------------------------------------------------- /// <summary> /// Query last error encountered. /// </summary> internal LIBSSH2_ERROR LastError { get { if (this.sessionHandle == null) { return LIBSSH2_ERROR.NONE; } else { this.sessionHandle.CheckCurrentThreadOwnsHandle(); return (LIBSSH2_ERROR) NativeMethods.libssh2_session_last_errno(this.sessionHandle); } } } internal Libssh2Exception CreateException(LIBSSH2_ERROR error) { if (this.sessionHandle != null) { var lastError = (LIBSSH2_ERROR)NativeMethods.libssh2_session_last_error( this.sessionHandle, out var errorMessage, out var errorMessageLength, 0); SshEventSource.Log.ConnectionErrorEncountered((int)error); if (lastError == error) { return new Libssh2Exception( error, Marshal.PtrToStringAnsi(errorMessage, errorMessageLength)); } } // // Fall back to using a generic error message. // return new Libssh2Exception( error, $"SSH operation failed: {error}"); } internal void SetTraceHandler( LIBSSH2_TRACE mask, Action<string> handler) { this.traceHandlerDelegate = (sessionPtr, contextPtr, dataPtr, length) => { Debug.Assert(contextPtr == IntPtr.Zero); var data = new byte[length.ToInt32()]; Marshal.Copy(dataPtr, data, 0, length.ToInt32()); handler(Encoding.ASCII.GetString(data)); }; this.traceMask = mask; } //--------------------------------------------------------------------- // IDisposable. //--------------------------------------------------------------------- public void Dispose() { Dispose(true); GC.SuppressFinalize(this); } protected virtual void Dispose(bool disposing) { if (this.disposed) { return; } if (disposing) { if (this.sessionHandle != null) { this.sessionHandle.CheckCurrentThreadOwnsHandle(); NativeMethods.libssh2_trace_sethandler( this.sessionHandle, IntPtr.Zero, null); this.sessionHandle.Dispose(); } this.disposed = true; } } } }