Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/Renci.SshNet/Common/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,14 @@ async Task<T> WaitCore()
}
}

extension(Array)
{
internal static int MaxLength
{
get { return 0X7FFFFFC7; }
}
}

extension(Task t)
{
internal bool IsCompletedSuccessfully
Expand Down
92 changes: 71 additions & 21 deletions src/Renci.SshNet/Common/SshDataStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,57 @@ public bool IsEndOfData
}
}

#if !NET
private void Write(ReadOnlySpan<byte> buffer)
// Because this type derives from MemoryStream, the base Write(ReadOnlySpan) chooses
// to rent an array, copy the data in and delegate to Write(byte[], int, int) for
// backwards compatibility.
// With a bit of extra ceremony, we can instead allow the various Write methods here
// to write directly into the underlying buffer without the need for any intermediate
// arrays (rented or otherwise).

#if NET9_0_OR_GREATER
/// <inheritdoc/>
public override void Write(ReadOnlySpan<byte> buffer)
{
var sharedBuffer = System.Buffers.ArrayPool<byte>.Shared.Rent(buffer.Length);
Write(buffer, buffer.Length, static (span, buffer) => buffer.CopyTo(span));
}
#endif

private delegate void WriteAction<in TArg>(Span<byte> span, TArg arg)
#if NET9_0_OR_GREATER
where TArg : allows ref struct
#endif
;

buffer.CopyTo(sharedBuffer);
private void Write<TArg>(TArg arg, int numBytesToWrite, WriteAction<TArg> writeAction)
#if NET9_0_OR_GREATER
where TArg : allows ref struct
#endif
{
var endPosition = Position + numBytesToWrite;

Write(sharedBuffer, 0, buffer.Length);
if (Capacity < endPosition)
{
var newCapacity = Math.Max(endPosition, Math.Min(2 * (uint)Capacity, Array.MaxLength));
Capacity = checked((int)newCapacity);
}

System.Buffers.ArrayPool<byte>.Shared.Return(sharedBuffer);
if (endPosition > Length)
{
SetLength(endPosition);
}

writeAction(GetRemainingBuffer().AsSpan(0, numBytesToWrite), arg);

Position = endPosition;
}
#endif

/// <summary>
/// Writes an <see cref="uint"/> to the SSH data stream.
/// </summary>
/// <param name="value"><see cref="uint"/> data to write.</param>
public void Write(uint value)
{
Span<byte> bytes = stackalloc byte[4];
BinaryPrimitives.WriteUInt32BigEndian(bytes, value);
Write(bytes);
Write(value, 4, static (span, value) => BinaryPrimitives.WriteUInt32BigEndian(span, value));
}

/// <summary>
Expand All @@ -89,9 +118,7 @@ public void Write(uint value)
/// <param name="value"><see cref="ulong"/> data to write.</param>
public void Write(ulong value)
{
Span<byte> bytes = stackalloc byte[8];
BinaryPrimitives.WriteUInt64BigEndian(bytes, value);
Write(bytes);
Write(value, 8, static (span, value) => BinaryPrimitives.WriteUInt64BigEndian(span, value));
}

/// <summary>
Expand All @@ -100,9 +127,22 @@ public void Write(ulong value)
/// <param name="data">The <see cref="BigInteger" /> to write.</param>
public void Write(BigInteger data)
{
#if NET
var byteCount = data.GetByteCount();

Write((data, byteCount), 4 + byteCount, static (span, args) =>
{
BinaryPrimitives.WriteUInt32BigEndian(span, (uint)args.byteCount);

var success = args.data.TryWriteBytes(span.Slice(4), out var bytesWritten, isBigEndian: true);

Debug.Assert(success && bytesWritten == span.Length - 4);
});
#else
var bytes = data.ToByteArray(isBigEndian: true);

WriteBinary(bytes, 0, bytes.Length);
#endif
}

/// <summary>
Expand All @@ -129,16 +169,26 @@ public void Write(string s, Encoding encoding)
ArgumentNullException.ThrowIfNull(s);
ArgumentNullException.ThrowIfNull(encoding);

var byteCount = encoding.GetByteCount(s);
#if NET
ReadOnlySpan<char> value = s;
var count = encoding.GetByteCount(value);
var bytes = count <= 256 ? stackalloc byte[count] : new byte[count];
encoding.GetBytes(value, bytes);
Write((uint)count);
Write(bytes);
Write((s, byteCount, encoding), 4 + byteCount, static (span, args) =>
{
BinaryPrimitives.WriteUInt32BigEndian(span, (uint)args.byteCount);

var bytesWritten = args.encoding.GetBytes(args.s, span.Slice(4));

Debug.Assert(bytesWritten == span.Length - 4);
});
#else
var bytes = encoding.GetBytes(s);
WriteBinary(bytes, 0, bytes.Length);
var rentedBuffer = System.Buffers.ArrayPool<byte>.Shared.Rent(byteCount);

var bytesWritten = encoding.GetBytes(s, 0, s.Length, rentedBuffer, 0);

Debug.Assert(bytesWritten == byteCount);

WriteBinary(rentedBuffer, 0, bytesWritten);

System.Buffers.ArrayPool<byte>.Shared.Return(rentedBuffer);
#endif
}

Expand Down