sources/Google.Solutions.Iap/Protocol/SshRelayStream.cs (322 lines of code) (raw):
//
// Copyright 2022 Google LLC
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//
using Google.Solutions.Common.Diagnostics;
using Google.Solutions.Common.Threading;
using Google.Solutions.Iap.Net;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace Google.Solutions.Iap.Protocol
{
/// <summary>
/// NetworkStream for reading/writing from a SshRelaySession.
/// </summary>
internal class SshRelayStream : SingleReaderSingleWriterStream
{
private readonly SshRelaySession session;
//
// Queue of un-ack'ed messages that might require re-sending.
//
private readonly AsyncLock unacknoledgedQueueLock = new AsyncLock();
private readonly Queue<UnacknoledgedWrite> unacknoledgedQueue = new Queue<UnacknoledgedWrite>();
//---------------------------------------------------------------------
// Privates
//---------------------------------------------------------------------
private void TraceVerbose(string message)
{
if (IapTraceSource.Log.Switch.ShouldTrace(TraceEventType.Verbose))
{
IapTraceSource.Log.TraceVerbose($"{this.session}: {message}");
}
}
private async Task ResendUnacknoledgedDataAsync(
INetworkStream stream,
CancellationToken cancellationToken)
{
using (await this.unacknoledgedQueueLock
.AcquireAsync(cancellationToken)
.ConfigureAwait(false))
{
while (this.unacknoledgedQueue.Any())
{
var write = this.unacknoledgedQueue.Dequeue();
if (write.ExpectedAck > this.session.State.LastAckReceived)
{
//
// We never got an ACK for this one, resend.
//
TraceVerbose($"Resending DATA #{write.SequenceNumber}...");
await stream
.WriteAsync(
write.Data,
0,
write.Data.Length,
cancellationToken)
.ConfigureAwait(false);
}
else
{
//
// This was ACKed, so don't resend.
//
}
}
}
}
//---------------------------------------------------------------------
// Internal - for testing only.
//---------------------------------------------------------------------
internal ulong ExpectedAck
{
get
{
using (this.unacknoledgedQueueLock.Acquire())
{
return this.unacknoledgedQueue.Any()
? this.unacknoledgedQueue.Last().ExpectedAck
: 0;
}
}
}
internal int UnacknoledgedMessageCount
{
get
{
using (this.unacknoledgedQueueLock.Acquire())
{
return this.unacknoledgedQueue.Count;
}
}
}
//---------------------------------------------------------------------
// Publics
//---------------------------------------------------------------------
public SshRelayStream(ISshRelayTarget endpoint)
{
this.session = new SshRelaySession(endpoint);
}
/// <summary>
/// Maximum amount of data (in byte) that can be written at once.
/// </summary>
public const int MaxWriteSize = (int)SshRelayFormat.Data.MaxPayloadLength;
/// <summary>
/// Minimum amount of data (in byte) that can be read at once.
/// </summary>
public const int MinReadSize = (int)SshRelayFormat.Data.MaxPayloadLength;
public async Task ProbeConnectionAsync(TimeSpan timeout)
{
try
{
using (var cts = new CancellationTokenSource())
{
//
// If access to the instance is allowed, but the instance
// simply does not listen on this port, the connect or read
// will hang. Therefore, apply a timeout.
//
cts.CancelAfter(timeout);
await this.session.IoAsync(
stream =>
{
//
// If we get here, then we've successfully established
// a connection.
//
return Task.FromResult(0u);
},
(s, t) => Task.CompletedTask,
true,
cts.Token);
await CloseAsync(cts.Token)
.ConfigureAwait(false);
}
}
catch (OperationCanceledException)
{
throw new NetworkStreamClosedException(
"The server did not respond within the allotted time");
}
}
//---------------------------------------------------------------------
// Overrides.
//---------------------------------------------------------------------
public string? Sid => this.session.Sid;
protected override async Task<int> ProtectedReadAsync(
byte[] buffer,
int offset,
int count,
CancellationToken cancellationToken)
{
if (count < MinReadSize)
{
throw new IndexOutOfRangeException(
$"Read buffer too small ({count}), must be at least {MinReadSize}");
}
var message = new byte[Math.Max(
SshRelayFormat.MinMessageSize,
SshRelayFormat.Data.HeaderLength + count)];
return (int)await this.session.IoAsync(
async stream =>
{
while (true)
{
var bytesRead = await stream
.ReadAsync(
message,
0,
message.Length,
cancellationToken)
.ConfigureAwait(false);
if (bytesRead == 0)
{
return 0;
}
else if (bytesRead < SshRelayFormat.Tag.Length)
{
throw new SshRelayProtocolViolationException(
"The server sent an incomplete message");
}
SshRelayFormat.Tag.Decode(message, out var tag);
switch (tag)
{
case SshRelayMessageTag.DATA:
{
var bytesDecoded = SshRelayFormat.Data.Decode(
message,
buffer,
(uint)offset,
(uint)count,
out var dataLength);
Debug.Assert(dataLength < bytesDecoded);
Debug.Assert(bytesDecoded == bytesRead);
TraceVerbose($"Received DATA message ({dataLength} bytes)");
this.session.State.AddBytesReceived(dataLength);
return dataLength;
}
case SshRelayMessageTag.ACK:
{
var bytesDecoded = SshRelayFormat.Ack.Decode(message, out var ack);
Debug.Assert(bytesDecoded == bytesRead);
if (ack == 0)
{
throw new SshRelayProtocolViolationException(
"The server sent an invalid zero-ack");
}
else if (ack > (ulong)this.session.State.BytesSent)
{
throw new SshRelayProtocolViolationException(
"The server sent a mismatched ack");
}
this.session.State.LastAckReceived = ack;
using (await this.unacknoledgedQueueLock
.AcquireAsync(cancellationToken)
.ConfigureAwait(false))
{
//
// The server might be acknolodging multiple messages at once.
//
while (this.unacknoledgedQueue.Count > 0 &&
this.unacknoledgedQueue.Peek().ExpectedAck <= ack)
{
this.unacknoledgedQueue.Dequeue();
}
}
TraceVerbose($"Received ACK #{ack}");
break;
}
case SshRelayMessageTag.LONG_CLOSE:
default:
//
// Unknown tag, ignore.
//
TraceVerbose($"Received unknown message: {tag}");
break;
}
}
},
ResendUnacknoledgedDataAsync,
false, // Normal closes are ok.
cancellationToken);
}
protected override async Task ProtectedWriteAsync(
byte[] buffer,
int offset,
int count,
CancellationToken cancellationToken)
{
if (count > MaxWriteSize)
{
throw new IndexOutOfRangeException(
$"Write buffer too large ({count}), must be at most {MaxWriteSize}");
}
await this.session.IoAsync(
async stream =>
{
//
// Take care of outstanding ACKs.
//
var bytesToAck = this.session.State.BytesReceived;
if (this.session.State.LastAckSent < bytesToAck)
{
var ackBuffer = new byte[SshRelayFormat.Ack.MessageLength];
SshRelayFormat.Ack.Encode(ackBuffer, bytesToAck);
TraceVerbose($"Sending ACK #{bytesToAck}...");
await stream
.WriteAsync(
ackBuffer,
0,
ackBuffer.Length,
cancellationToken)
.ConfigureAwait(false);
this.session.State.LastAckSent = bytesToAck;
}
//
// Send data.
//
var sequenceNumber = this.session.State.BytesSent;
var message = new byte[SshRelayFormat.Data.HeaderLength + count];
SshRelayFormat.Data.Encode(message, buffer, (uint)offset, (uint)count);
TraceVerbose($"Sending DATA #{sequenceNumber}...");
//
// Update bytesSent before we write the data to the wire,
// otherwise we might see an ACK before bytesSent even reflects
// that the data has been sent.
//
this.session.State.AddBytesSent((uint)count);
await stream
.WriteAsync(
message,
0,
message.Length,
cancellationToken)
.ConfigureAwait(false);
//
// We should get an ACK for this message.
//
using (await this.unacknoledgedQueueLock
.AcquireAsync(cancellationToken)
.ConfigureAwait(false))
{
this.unacknoledgedQueue.Enqueue(new UnacknoledgedWrite(
message,
sequenceNumber,
sequenceNumber + (ulong)count));
}
return 0;
},
ResendUnacknoledgedDataAsync,
true, // Normal closes are unexpected.
cancellationToken);
}
public override async Task ProtectedCloseAsync(CancellationToken cancellationToken)
{
await this.session
.DisconnectAsync(cancellationToken)
.ConfigureAwait(false);
}
public override string ToString()
{
return this.session.ToString();
}
protected override void Dispose(bool disposing)
{
base.Dispose(disposing);
this.session.Dispose();
this.unacknoledgedQueueLock.Dispose();
}
//---------------------------------------------------------------------
// Helper structs.
//---------------------------------------------------------------------
private readonly struct UnacknoledgedWrite
{
public readonly byte[] Data;
public readonly ulong SequenceNumber;
public readonly ulong ExpectedAck;
public UnacknoledgedWrite(
byte[] data,
ulong sequenceNumber,
ulong expectedAck)
{
this.Data = data;
this.SequenceNumber = sequenceNumber;
this.ExpectedAck = expectedAck;
}
}
}
public abstract class SshRelayException : NetworkStreamClosedException
{
public SshRelayException(string message) : base(message)
{
}
}
public class SshRelayConnectException : SshRelayException
{
public SshRelayConnectException(string message) : base(message)
{
}
}
public class SshRelayReconnectException : SshRelayException
{
public SshRelayReconnectException(string message) : base(message)
{
}
}
public class SshRelayProtocolViolationException : SshRelayException
{
public SshRelayProtocolViolationException(string message) : base(message)
{
}
}
public class SshRelayDeniedException : SshRelayException
{
public SshRelayDeniedException(string message) : base(message)
{
}
}
public class SshRelayBackendNotFoundException : SshRelayException
{
public SshRelayBackendNotFoundException(string message) : base(message)
{
}
}
}