sources/Google.Solutions.Iap/Protocol/SshRelaySession.cs (330 lines of code) (raw):

// // Copyright 2022 Google LLC // // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. // using Google.Solutions.Common.Diagnostics; using Google.Solutions.Common.Threading; using Google.Solutions.Iap.Net; using System; using System.Diagnostics; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; namespace Google.Solutions.Iap.Protocol { /// <summary> /// Target for an SSH Relay Connection. /// </summary> public interface ISshRelayTarget { bool IsMutualTlsEnabled { get; } Task<INetworkStream> ConnectAsync(CancellationToken token); Task<INetworkStream> ReconnectAsync( string sid, ulong lastByteConsumedByClient, CancellationToken token); } /// <summary> /// SSH Relay session. A session uses one WebSocket connection /// at a time. If that connection breaks, the session attempts /// to reconnect. /// </summary> internal sealed class SshRelaySession { private const uint MaxReconnects = 2; public ISshRelayTarget Endpoint { get; } // // Current connection, guarded by the a lock. // private INetworkStream? connection = null; private readonly AsyncLock connectLock = new AsyncLock(); internal SessionState State { get; } = new SessionState(); /// <summary> /// Unique identifier of session. Available after initial /// connection has been established. /// </summary> public string? Sid { get; private set; } private void TraceVerbose(string message) { if (IapTraceSource.Log.Switch.ShouldTrace(TraceEventType.Verbose)) { IapTraceSource.Log.TraceVerbose($"{this}: {message}"); } } private async Task<INetworkStream> GetConnectionAsync( Func<INetworkStream, CancellationToken, Task> resendUnacknoledgedDataAction, CancellationToken cancellationToken) { // // This method might be called concurrently by a // writer and a reader, so we have to synchronize. // using (await this.connectLock .AcquireAsync(cancellationToken) .ConfigureAwait(false)) { if (this.connection != null) { // // We're still connected. // return this.connection; } else if (this.State.LastAckReceived == 0) { // // Initial connect. // TraceVerbose($"Establishing new connection"); var connection = await this.Endpoint .ConnectAsync(cancellationToken) .ConfigureAwait(false); // // To complete connection establishment, we have to receive // a CONNECT_SUCCESS_SID message. // var message = new byte[SshRelayFormat.MaxMessageSize]; string? connectionSid = null; while (true) { var bytesRead = await connection .ReadAsync( message, 0, message.Length, cancellationToken) .ConfigureAwait(false); SshRelayFormat.Tag.Decode(message, out var tag); if (bytesRead == 0) { throw new WebSocketStreamClosedByServerException( WebSocketCloseStatus.NormalClosure, "The connection was closed by the server"); } else if (bytesRead < SshRelayFormat.Tag.Length) { throw new SshRelayProtocolViolationException( "The server sent an incomplete message"); } switch (tag) { case SshRelayMessageTag.CONNECT_SUCCESS_SID: { var bytesDecoded = SshRelayFormat.ConnectSuccessSid.Decode( message, out var sid); connectionSid = sid; Debug.Assert(bytesDecoded == bytesRead); // // If the previous connection broke before we received // the first ACK, then there might be data to be resend. // await resendUnacknoledgedDataAction( connection, cancellationToken) .ConfigureAwait(false); this.Sid = connectionSid; this.connection = connection; TraceVerbose($"Received CONNECT_SUCCESS_SID, connected"); return connection; } case SshRelayMessageTag.LONG_CLOSE: default: // // Unknown tag, ignore. // TraceVerbose($"Received unknown message: {tag}"); break; } } } else { // // Reconnect + sync ack's + resend data. // Debug.Assert(this.Sid != null); TraceVerbose($"Attempting reconnect with ack={this.State.LastAckReceived}"); var connection = await this.Endpoint .ReconnectAsync( this.Sid!, this.State.BytesReceived, cancellationToken) .ConfigureAwait(false); // // To complete connection establishment, we have to receive // a RECONNECT_SUCCESS_ACK message. // var message = new byte[SshRelayFormat.MaxMessageSize]; while (true) { var bytesRead = await connection .ReadAsync( message, 0, message.Length, cancellationToken) .ConfigureAwait(false); SshRelayFormat.Tag.Decode(message, out var tag); if (bytesRead == 0) { throw new WebSocketStreamClosedByServerException( WebSocketCloseStatus.NormalClosure, "The connection was closed by the server"); } else if (bytesRead < SshRelayFormat.Tag.Length) { throw new SshRelayProtocolViolationException( "The server sent an incomplete message"); } switch (tag) { case SshRelayMessageTag.RECONNECT_SUCCESS_ACK: { var bytesDecoded = SshRelayFormat.ReconnectAck.Decode( message, out var ack); this.State.LastAckReceived = ack; Debug.Assert(bytesDecoded == bytesRead); // // Resend all data since the ACK that we just received. // await resendUnacknoledgedDataAction( connection, cancellationToken) .ConfigureAwait(false); this.connection = connection; TraceVerbose("Received RECONNECT_SUCCESS_ACK, reconnected"); return connection; } case SshRelayMessageTag.CONNECT_SUCCESS_SID: { // // We shouldn't be receiving this message after // a reconnect. // throw new SshRelayProtocolViolationException( "The server sent an unexpected CONNECT_SUCCESS_SID " + "message in response to a reconnect"); } case SshRelayMessageTag.LONG_CLOSE: default: // // Unknown tag, ignore. // TraceVerbose($"Received unknown message: {tag}"); break; } } } } } //--------------------------------------------------------------------- // Publics. //--------------------------------------------------------------------- public SshRelaySession(ISshRelayTarget endpoint) { this.Endpoint = endpoint; } internal async Task DisconnectAsync(CancellationToken cancellationToken) { using (await this.connectLock .AcquireAsync(cancellationToken) .ConfigureAwait(false)) { // // Drop this connection. // if (this.connection != null) { try { await this.connection .CloseAsync(cancellationToken) .ConfigureAwait(false); } catch (Exception e) { IapTraceSource.Log.TraceError(e); } try { this.connection.Dispose(); } catch (Exception e) { IapTraceSource.Log.TraceError(e); } this.connection = null; } TraceVerbose("Disconnected"); } } public async Task<uint> IoAsync( Func<INetworkStream, Task<uint>> ioAction, Func<INetworkStream, CancellationToken, Task> resendUnacknoledgedDataAction, bool treatNormalCloseAsError, CancellationToken cancellationToken) { var attempt = 0; while (true) { try { var connection = await GetConnectionAsync( resendUnacknoledgedDataAction, cancellationToken) .ConfigureAwait(false); Debug.Assert(connection != null); return await ioAction(connection!).ConfigureAwait(false); } catch (WebSocketStreamClosedByClientException) { throw; } catch (WebSocketStreamClosedByServerException e) { IapTraceSource.Log.TraceError(e); switch ((SshRelayCloseCode)e.CloseStatus) { case SshRelayCloseCode.NORMAL: case SshRelayCloseCode.DESTINATION_READ_FAILED: case SshRelayCloseCode.DESTINATION_WRITE_FAILED: // // NB. We get a DESTINATION_*_FAILED if the // backend closed the connection (as opposed // to the relay). // if (treatNormalCloseAsError) { throw; } else { // // Server closed the connection normally. // return 0; } case SshRelayCloseCode.NOT_AUTHORIZED: throw new SshRelayDeniedException( $"The server denied access: " + e.CloseStatusDescription); case SshRelayCloseCode.FAILED_TO_REWIND: case SshRelayCloseCode.SID_UNKNOWN: case SshRelayCloseCode.SID_IN_USE: throw new SshRelayReconnectException( "The server closed the connection unexpectedly and " + "reestablishing the connection failed: " + e.CloseStatusDescription); case SshRelayCloseCode.FAILED_TO_CONNECT_TO_BACKEND: throw new SshRelayConnectException( "The server could not connect to the backend: " + e.CloseStatusDescription); case SshRelayCloseCode.LOOKUP_FAILED: case SshRelayCloseCode.LOOKUP_FAILED_RECONNECT: throw new SshRelayBackendNotFoundException( "The backend could not be found"); default: { if (attempt++ >= MaxReconnects) { TraceVerbose($"Failed to reconnect after {attempt} attempts"); throw; } else { // // Try again. // await DisconnectAsync(cancellationToken) .ConfigureAwait(true); TraceVerbose("Attempting to reconnect"); break; } } } } } } public void Dispose() { this.connectLock.Dispose(); this.connection?.Dispose(); } public override string ToString() { var sidToken = this.Sid != null ? this.Sid.Substring(0, Math.Min(this.Sid.Length, 10)) : "(unknown)"; return $"[SshRelaySession {sidToken} {this.State}]"; } //--------------------------------------------------------------------- // Inner classes. //--------------------------------------------------------------------- internal class SessionState { // // Counters for keeping track of the connection state. // The values can be read from any thread, but written // only by the current reader or writer thread. // private long lastAckReceived = 0; private long lastAckSent = 0; private long bytesReceived = 0; private long bytesSent = 0; public ulong LastAckReceived { get => (ulong)Thread.VolatileRead(ref this.lastAckReceived); set => this.lastAckReceived = (long)value; } public ulong LastAckSent { get => (ulong)Thread.VolatileRead(ref this.lastAckSent); set => this.lastAckSent = (long)value; } public ulong BytesReceived { get => (ulong)Thread.VolatileRead(ref this.bytesReceived); } public ulong BytesSent { get => (ulong)Thread.VolatileRead(ref this.bytesSent); } public void AddBytesReceived(uint delta) { Debug.Assert(delta > 0); Interlocked.Add(ref this.bytesReceived, delta); } public void AddBytesSent(uint delta) { Debug.Assert(delta > 0); Interlocked.Add(ref this.bytesSent, delta); } public override string ToString() { return $"AR: {this.LastAckReceived} AS: {this.LastAckSent} " + $"TR: {this.BytesReceived} TX: {this.BytesSent}"; } } } }