From ab2ae015f6b22be7a5bafc0e536c491d345557a5 Mon Sep 17 00:00:00 2001 From: stevenebutler Date: Sun, 23 Apr 2023 13:18:05 +1000 Subject: [PATCH] Rename TimeoutSec and add NetworkTimeoutSec TimeoutSec becomes ConnectTimeoutSec and has same behaviour as before. NetworkTimeoutSec specifies the maximum allowed time between reads from the network when streaming http response messages. The default for both items is to have no timeout. --- .../BasicHtmlWebResponseObject.Common.cs | 9 +- .../Common/InvokeRestMethodCommand.Common.cs | 16 +- .../Common/WebRequestPSCmdlet.Common.cs | 39 ++++- .../Common/WebResponseObject.Common.cs | 18 ++- .../InvokeWebRequestCommand.CoreClr.cs | 8 +- .../utility/WebCmdlet/StreamHelper.cs | 150 +++++++++++++----- .../utility/WebCmdlet/WebRequestSession.cs | 6 +- .../WebCmdlets.Tests.ps1 | 106 ++++++++++++- .../Controllers/DelayController.cs | 59 ++++--- 9 files changed, 318 insertions(+), 93 deletions(-) diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs index 8eaa95c224a..9413e7d7b03 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; @@ -23,8 +24,9 @@ public class BasicHtmlWebResponseObject : WebResponseObject /// Initializes a new instance of the class. /// /// The response. + /// Time permitted between reads or Timeout.InfiniteTimeSpan for no timeout. /// Cancellation token. - public BasicHtmlWebResponseObject(HttpResponseMessage response, CancellationToken cancellationToken) : this(response, null, cancellationToken) { } + public BasicHtmlWebResponseObject(HttpResponseMessage response, TimeSpan perReadTimeout, CancellationToken cancellationToken) : this(response, null, perReadTimeout, cancellationToken) { } /// /// Initializes a new instance of the class @@ -32,8 +34,9 @@ public BasicHtmlWebResponseObject(HttpResponseMessage response, CancellationToke /// /// The response. /// The content stream associated with the response. + /// Time permitted between reads or Timeout.InfiniteTimeSpan for no timeout. /// Cancellation token. - public BasicHtmlWebResponseObject(HttpResponseMessage response, Stream contentStream, CancellationToken cancellationToken) : base(response, contentStream, cancellationToken) + public BasicHtmlWebResponseObject(HttpResponseMessage response, Stream contentStream, TimeSpan perReadTimeout, CancellationToken cancellationToken) : base(response, contentStream, perReadTimeout, cancellationToken) { InitializeContent(cancellationToken); InitializeRawContent(response); @@ -153,7 +156,7 @@ protected void InitializeContent(CancellationToken cancellationToken) // Fill the Content buffer string characterSet = WebResponseHelper.GetCharacterSet(BaseResponse); - Content = StreamHelper.DecodeStream(RawContentStream, characterSet, out Encoding encoding, cancellationToken); + Content = StreamHelper.DecodeStream(RawContentStream, characterSet, out Encoding encoding, perReadMillisecondsTimeout, cancellationToken); Encoding = encoding; } else diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs index 3fe6e2a91ad..56963666cf6 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs @@ -77,11 +77,12 @@ internal override void ProcessResponse(HttpResponseMessage response) ArgumentNullException.ThrowIfNull(response); ArgumentNullException.ThrowIfNull(_cancelToken); + TimeSpan perReadTimeout = NetworkTimeout; Stream baseResponseStream = StreamHelper.GetResponseStream(response, _cancelToken.Token); if (ShouldWriteToPipeline) { - using var responseStream = new BufferingStreamReader(baseResponseStream, _cancelToken.Token); + using var responseStream = new BufferingStreamReader(baseResponseStream, perReadTimeout, _cancelToken.Token); // First see if it is an RSS / ATOM feed, in which case we can // stream it - unless the user has overridden it with a return type of "XML" @@ -96,8 +97,7 @@ internal override void ProcessResponse(HttpResponseMessage response) // Try to get the response encoding from the ContentType header. string charSet = WebResponseHelper.GetCharacterSet(response); - - string str = StreamHelper.DecodeStream(responseStream, charSet, out Encoding encoding, _cancelToken.Token); + string str = StreamHelper.DecodeStream(responseStream, charSet, out Encoding encoding, perReadTimeout, _cancelToken.Token); object obj = null; Exception ex = null; @@ -137,12 +137,12 @@ internal override void ProcessResponse(HttpResponseMessage response) } } else if (ShouldSaveToOutFile) - { + { string outFilePath = WebResponseHelper.GetOutFilePath(response, _qualifiedOutFile); WriteVerbose(string.Create(System.Globalization.CultureInfo.InvariantCulture, $"File Name: {Path.GetFileName(_qualifiedOutFile)}")); - StreamHelper.SaveStreamToFile(baseResponseStream, outFilePath, this, response.Content.Headers.ContentLength.GetValueOrDefault(), _cancelToken.Token); + StreamHelper.SaveStreamToFile(baseResponseStream, outFilePath, this, response.Content.Headers.ContentLength.GetValueOrDefault(), perReadTimeout, _cancelToken.Token); } if (!string.IsNullOrEmpty(StatusCodeVariable)) @@ -351,18 +351,20 @@ public enum RestReturnType internal class BufferingStreamReader : Stream { - internal BufferingStreamReader(Stream baseStream, CancellationToken cancellationToken) + internal BufferingStreamReader(Stream baseStream, TimeSpan perReadTimeout, CancellationToken cancellationToken) { _baseStream = baseStream; _streamBuffer = new MemoryStream(); _length = long.MaxValue; _copyBuffer = new byte[4096]; + _perReadTimeout = perReadTimeout; _cancellationToken = cancellationToken; } private readonly Stream _baseStream; private readonly MemoryStream _streamBuffer; private readonly byte[] _copyBuffer; + private readonly TimeSpan _perReadTimeout; private readonly CancellationToken _cancellationToken; public override bool CanRead => true; @@ -397,7 +399,7 @@ public override int Read(byte[] buffer, int offset, int count) // If we don't have enough data to fill this from memory, cache more. // We try to read 4096 bytes from base stream every time, so at most we // may cache 4095 bytes more than what is required by the Read operation. - int bytesRead = _baseStream.ReadAsync(_copyBuffer, 0, _copyBuffer.Length, _cancellationToken).GetAwaiter().GetResult(); + int bytesRead = _baseStream.ReadAsync(_copyBuffer.AsMemory(), _perReadTimeout, _cancellationToken).GetAwaiter().GetResult(); if (_streamBuffer.Position < _streamBuffer.Length) { diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs index 60ee931591a..416bac38988 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs @@ -266,11 +266,25 @@ public abstract class WebRequestPSCmdlet : PSCmdlet, IDisposable public virtual SwitchParameter DisableKeepAlive { get; set; } /// - /// Gets or sets the TimeOut property. + /// Gets or sets the ConnectTimeOutSec property. /// + /// + /// This property applies to sending the request and receiving the response headers only. + /// + [Alias("TimeoutSec")] [Parameter] [ValidateRange(0, int.MaxValue)] - public virtual int TimeoutSec { get; set; } + public virtual int ConnectTimeoutSec { get; set; } + + /// + /// Gets or sets the NetworkTimeoutSec property. + /// + /// + /// This property applies to receiving the response body. + /// + [Parameter] + [ValidateRange(0, int.MaxValue)] + public virtual int NetworkTimeoutSec { get; set; } /// /// Gets or sets the Headers property. @@ -497,6 +511,8 @@ public virtual string CustomMethod internal bool ShouldWriteToPipeline => !ShouldSaveToOutFile || PassThru; + internal TimeSpan NetworkTimeout => NetworkTimeoutSec > 0 ? TimeSpan.FromSeconds(NetworkTimeoutSec) : Timeout.InfiniteTimeSpan; + #endregion Helper Properties #region Abstract Methods @@ -570,7 +586,7 @@ protected override void ProcessRecord() string respVerboseMsg = contentLength is null ? string.Format(CultureInfo.CurrentCulture, WebCmdletStrings.WebResponseNoSizeVerboseMsg, contentType) : string.Format(CultureInfo.CurrentCulture, WebCmdletStrings.WebResponseVerboseMsg, contentLength, contentType); - + WriteVerbose(respVerboseMsg); bool _isSuccess = response.IsSuccessStatusCode; @@ -621,12 +637,19 @@ protected override void ProcessRecord() string detailMsg = string.Empty; try { - string error = StreamHelper.GetResponseString(response, _cancelToken.Token); + // We can't use ReadAsStringAsync because it doesn't have per read timeouts + TimeSpan perReadTimeout = NetworkTimeout; + string characterSet = WebResponseHelper.GetCharacterSet(response); + var responseStream = StreamHelper.GetResponseStream(response, _cancelToken.Token); + int initialCapacity = (int)Math.Min(contentLength ?? StreamHelper.DefaultReadBuffer, StreamHelper.DefaultReadBuffer); + var bufferedStream = new WebResponseContentMemoryStream(responseStream, initialCapacity, this, contentLength, perReadTimeout, _cancelToken.Token); + string error = StreamHelper.DecodeStream(bufferedStream, characterSet, out Encoding encoding, perReadTimeout, _cancelToken.Token); detailMsg = FormatErrorMessage(error, contentType); } - catch + catch (Exception ex) { // Catch all + er.ErrorDetails = new ErrorDetails(ex.ToString()); } if (!string.IsNullOrEmpty(detailMsg)) @@ -666,7 +689,7 @@ protected override void ProcessRecord() ThrowTerminatingError(er); } - finally + finally { _cancelToken?.Dispose(); _cancelToken = null; @@ -970,7 +993,7 @@ internal virtual void PrepareSession() } else { - webProxy.UseDefaultCredentials = ProxyUseDefaultCredentials; + webProxy.UseDefaultCredentials = ProxyUseDefaultCredentials; } // We don't want to update the WebSession unless the proxies are different @@ -1020,7 +1043,7 @@ internal virtual void PrepareSession() WebSession.RetryIntervalInSeconds = RetryIntervalSec; } - WebSession.TimeoutSec = TimeoutSec; + WebSession.ConnectTimeoutSec = ConnectTimeoutSec; } internal virtual HttpClient GetHttpClient(bool handleRedirect) diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs index 8bfdeed2da4..f63297e2a94 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs @@ -70,14 +70,24 @@ public class WebResponseObject #endregion Properties + #region Protected Fields + + /// + /// Time permitted between reads or Timeout.InfiniteTimeSpan for no timeout. + /// + protected TimeSpan perReadMillisecondsTimeout; + + #endregion Protected Fields + #region Constructors /// /// Initializes a new instance of the class. /// /// The Http response. + /// Time permitted between reads or Timeout.InfiniteTimeSpan for no timeout. /// The cancellation token. - public WebResponseObject(HttpResponseMessage response, CancellationToken cancellationToken) : this(response, null, cancellationToken) + public WebResponseObject(HttpResponseMessage response, TimeSpan perReadTimeout, CancellationToken cancellationToken) : this(response, null, perReadTimeout, cancellationToken) { } /// @@ -86,9 +96,11 @@ public WebResponseObject(HttpResponseMessage response, CancellationToken cancell /// /// Http response. /// The http content stream. + /// Time permitted between reads or Timeout.InfiniteTimeSpan for no timeout. /// The cancellation token. - public WebResponseObject(HttpResponseMessage response, Stream contentStream, CancellationToken cancellationToken) + public WebResponseObject(HttpResponseMessage response, Stream contentStream, TimeSpan perReadTimeout, CancellationToken cancellationToken) { + this.perReadMillisecondsTimeout = perReadTimeout; SetResponse(response, contentStream, cancellationToken); InitializeContent(); InitializeRawContent(response); @@ -151,7 +163,7 @@ private void SetResponse(HttpResponseMessage response, Stream contentStream, Can } int initialCapacity = (int)Math.Min(contentLength, StreamHelper.DefaultReadBuffer); - RawContentStream = new WebResponseContentMemoryStream(st, initialCapacity, cmdlet: null, response.Content.Headers.ContentLength.GetValueOrDefault(), cancellationToken); + RawContentStream = new WebResponseContentMemoryStream(st, initialCapacity, cmdlet: null, response.Content.Headers.ContentLength.GetValueOrDefault(), perReadMillisecondsTimeout, cancellationToken); } // Set the position of the content stream to the beginning diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs index 19ebc294c10..22973c646cb 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs @@ -5,6 +5,7 @@ using System.IO; using System.Management.Automation; using System.Net.Http; +using System.Threading; namespace Microsoft.PowerShell.Commands { @@ -33,7 +34,7 @@ public InvokeWebRequestCommand() : base() internal override void ProcessResponse(HttpResponseMessage response) { ArgumentNullException.ThrowIfNull(response); - + TimeSpan perReadTimeout = NetworkTimeout; Stream responseStream = StreamHelper.GetResponseStream(response, _cancelToken.Token); if (ShouldWriteToPipeline) { @@ -43,8 +44,9 @@ internal override void ProcessResponse(HttpResponseMessage response) StreamHelper.ChunkSize, this, response.Content.Headers.ContentLength.GetValueOrDefault(), + perReadTimeout, _cancelToken.Token); - WebResponseObject ro = WebResponseHelper.IsText(response) ? new BasicHtmlWebResponseObject(response, responseStream, _cancelToken.Token) : new WebResponseObject(response, responseStream, _cancelToken.Token); + WebResponseObject ro = WebResponseHelper.IsText(response) ? new BasicHtmlWebResponseObject(response, responseStream, perReadTimeout, _cancelToken.Token) : new WebResponseObject(response, responseStream, perReadTimeout, _cancelToken.Token); ro.RelationLink = _relationLink; WriteObject(ro); @@ -61,7 +63,7 @@ internal override void ProcessResponse(HttpResponseMessage response) WriteVerbose(string.Create(System.Globalization.CultureInfo.InvariantCulture, $"File Name: {Path.GetFileName(_qualifiedOutFile)}")); - StreamHelper.SaveStreamToFile(responseStream, outFilePath, this, response.Content.Headers.ContentLength.GetValueOrDefault(), _cancelToken.Token); + StreamHelper.SaveStreamToFile(responseStream, outFilePath, this, response.Content.Headers.ContentLength.GetValueOrDefault(), perReadTimeout, _cancelToken.Token); } } diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs index bec01ee8a1d..772a6b5ccf9 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Buffers; using System.IO; using System.Management.Automation; using System.Management.Automation.Internal; @@ -27,6 +28,7 @@ internal class WebResponseContentMemoryStream : MemoryStream private readonly Stream _originalStreamToProxy; private readonly Cmdlet _ownerCmdlet; private readonly CancellationToken _cancellationToken; + private readonly TimeSpan _perReadTimeout; private bool _isInitialized = false; #endregion Data @@ -39,13 +41,15 @@ internal class WebResponseContentMemoryStream : MemoryStream /// Presize the memory stream. /// Owner cmdlet if any. /// Expected download size in Bytes. + /// Time permitted between reads or Timeout.InfiniteTimeSpan for no timeout. /// Cancellation token. - internal WebResponseContentMemoryStream(Stream stream, int initialCapacity, Cmdlet cmdlet, long? contentLength, CancellationToken cancellationToken) : base(initialCapacity) + internal WebResponseContentMemoryStream(Stream stream, int initialCapacity, Cmdlet cmdlet, long? contentLength, TimeSpan perReadTimeout, CancellationToken cancellationToken) : base(initialCapacity) { this._contentLength = contentLength; _originalStreamToProxy = stream; _ownerCmdlet = cmdlet; _cancellationToken = cancellationToken; + _perReadTimeout = perReadTimeout; } #endregion Constructors @@ -226,7 +230,7 @@ private void Initialize(CancellationToken cancellationToken = default) } } - read = _originalStreamToProxy.ReadAsync(buffer, 0, buffer.Length, cancellationToken).GetAwaiter().GetResult(); + read = _originalStreamToProxy.ReadAsync(buffer.AsMemory(), _perReadTimeout, cancellationToken).GetAwaiter().GetResult(); if (read > 0) { @@ -253,6 +257,59 @@ private void Initialize(CancellationToken cancellationToken = default) } } + internal static class StreamTimeoutExtensions + { + internal static async Task ReadAsync(this Stream stream, Memory buffer, TimeSpan readTimeout, CancellationToken cancellationToken) + { + if (readTimeout == Timeout.InfiniteTimeSpan) + { + return await stream.ReadAsync(buffer, cancellationToken); + } + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cts.CancelAfter(readTimeout); + return await stream.ReadAsync(buffer, cts.Token).ConfigureAwait(false); + } + + internal static async Task CopyToAsync(this Stream source, Stream destination, TimeSpan perReadTimeout, CancellationToken cancellationToken) + { + if (perReadTimeout == Timeout.InfiniteTimeSpan) + { + // No timeout - use fast path + await source.CopyToAsync(destination, cancellationToken); + return; + } + + byte[] buffer = ArrayPool.Shared.Rent(StreamHelper.ChunkSize); + CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + try + { + while (true) + { + if (!cts.TryReset()) + { + cts.Dispose(); + cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + } + + cts.CancelAfter(perReadTimeout); + int bytesRead = await source.ReadAsync(buffer, cts.Token).ConfigureAwait(false); + if (bytesRead == 0) + { + break; + } + + await destination.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken).ConfigureAwait(false); + } + } + finally + { + cts.Dispose(); + ArrayPool.Shared.Return(buffer); + } + } + } + internal static class StreamHelper { #region Constants @@ -268,11 +325,11 @@ internal static class StreamHelper #region Static Methods - internal static void WriteToStream(Stream input, Stream output, PSCmdlet cmdlet, long? contentLength, CancellationToken cancellationToken) + internal static void WriteToStream(Stream input, Stream output, PSCmdlet cmdlet, long? contentLength, TimeSpan perReadTimeout, CancellationToken cancellationToken) { ArgumentNullException.ThrowIfNull(cmdlet); - Task copyTask = input.CopyToAsync(output, cancellationToken); + Task copyTask = input.CopyToAsync(output, perReadTimeout, cancellationToken); bool wroteProgress = false; ProgressRecord record = new( @@ -326,16 +383,17 @@ internal static void WriteToStream(Stream input, Stream output, PSCmdlet cmdlet, /// Output file name. /// Current cmdlet (Invoke-WebRequest or Invoke-RestMethod). /// Expected download size in Bytes. + /// Time permitted between reads or Timeout.InfiniteTimeSpan for no timeout. /// CancellationToken to track the cmdlet cancellation. - internal static void SaveStreamToFile(Stream stream, string filePath, PSCmdlet cmdlet, long? contentLength, CancellationToken cancellationToken) + internal static void SaveStreamToFile(Stream stream, string filePath, PSCmdlet cmdlet, long? contentLength, TimeSpan perReadTimeout, CancellationToken cancellationToken) { // If the web cmdlet should resume, append the file instead of overwriting. FileMode fileMode = cmdlet is WebRequestPSCmdlet webCmdlet && webCmdlet.ShouldResume ? FileMode.Append : FileMode.Create; using FileStream output = new(filePath, fileMode, FileAccess.Write, FileShare.Read); - WriteToStream(stream, output, cmdlet, contentLength, cancellationToken); + WriteToStream(stream, output, cmdlet, contentLength, perReadTimeout, cancellationToken); } - private static string StreamToString(Stream stream, Encoding encoding, CancellationToken cancellationToken) + private static string StreamToString(Stream stream, Encoding encoding, TimeSpan perReadTimeout, CancellationToken cancellationToken) { StringBuilder result = new(capacity: ChunkSize); Decoder decoder = encoding.GetDecoder(); @@ -346,51 +404,59 @@ private static string StreamToString(Stream stream, Encoding encoding, Cancellat useBufferSize = encoding.GetMaxCharCount(10); } - char[] chars = new char[useBufferSize]; - byte[] bytes = new byte[useBufferSize * 4]; - int bytesRead = 0; - do + char[] chars = ArrayPool.Shared.Rent(useBufferSize); + byte[] bytes = ArrayPool.Shared.Rent(useBufferSize * 4); + try { - // Read at most the number of bytes that will fit in the input buffer. The - // return value is the actual number of bytes read, or zero if no bytes remain. - bytesRead = stream.ReadAsync(bytes, 0, useBufferSize * 4, cancellationToken).GetAwaiter().GetResult(); + int bytesRead = 0; + do + { + // Read at most the number of bytes that will fit in the input buffer. The + // return value is the actual number of bytes read, or zero if no bytes remain. + bytesRead = stream.ReadAsync(bytes.AsMemory(), perReadTimeout, cancellationToken).GetAwaiter().GetResult(); - bool completed = false; - int byteIndex = 0; + bool completed = false; + int byteIndex = 0; - while (!completed) - { - // If this is the last input data, flush the decoder's internal buffer and state. - bool flush = bytesRead == 0; - decoder.Convert(bytes, byteIndex, bytesRead - byteIndex, chars, 0, useBufferSize, flush, out int bytesUsed, out int charsUsed, out completed); - - // The conversion produced the number of characters indicated by charsUsed. Write that number - // of characters to our result buffer - result.Append(chars, 0, charsUsed); - - // Increment byteIndex to the next block of bytes in the input buffer, if any, to convert. - byteIndex += bytesUsed; - - // The behavior of decoder.Convert changed start .NET 3.1-preview2. - // The change was made in https://github.com/dotnet/coreclr/pull/27229 - // The recommendation from .NET team is to not check for 'completed' if 'flush' is false. - // Break out of the loop if all bytes have been read. - if (!flush && bytesRead == byteIndex) + while (!completed) { - break; + // If this is the last input data, flush the decoder's internal buffer and state. + bool flush = bytesRead == 0; + decoder.Convert(bytes, byteIndex, bytesRead - byteIndex, chars, 0, useBufferSize, flush, out int bytesUsed, out int charsUsed, out completed); + + // The conversion produced the number of characters indicated by charsUsed. Write that number + // of characters to our result buffer + result.Append(chars, 0, charsUsed); + + // Increment byteIndex to the next block of bytes in the input buffer, if any, to convert. + byteIndex += bytesUsed; + + // The behavior of decoder.Convert changed start .NET 3.1-preview2. + // The change was made in https://github.com/dotnet/coreclr/pull/27229 + // The recommendation from .NET team is to not check for 'completed' if 'flush' is false. + // Break out of the loop if all bytes have been read. + if (!flush && bytesRead == byteIndex) + { + break; + } } } - } - while (bytesRead != 0); + while (bytesRead != 0); - return result.ToString(); + return result.ToString(); + } + finally + { + ArrayPool.Shared.Return(chars); + ArrayPool.Shared.Return(bytes); + } } - internal static string DecodeStream(Stream stream, string characterSet, out Encoding encoding, CancellationToken cancellationToken) + internal static string DecodeStream(Stream stream, string characterSet, out Encoding encoding, TimeSpan perReadTimeout, CancellationToken cancellationToken) { bool isDefaultEncoding = !TryGetEncoding(characterSet, out encoding); - string content = StreamToString(stream, encoding, cancellationToken); + string content = StreamToString(stream, encoding, perReadTimeout, cancellationToken); if (isDefaultEncoding) { // We only look within the first 1k characters as the meta element and @@ -413,7 +479,7 @@ internal static string DecodeStream(Stream stream, string characterSet, out Enco if (TryGetEncoding(characterSet, out Encoding localEncoding)) { stream.Seek(0, SeekOrigin.Begin); - content = StreamToString(stream, localEncoding, cancellationToken); + content = StreamToString(stream, localEncoding, perReadTimeout, cancellationToken); encoding = localEncoding; } } @@ -457,8 +523,6 @@ internal static byte[] EncodeToBytes(string str, Encoding encoding) return encoding.GetBytes(str); } - internal static string GetResponseString(HttpResponseMessage response, CancellationToken cancellationToken) => response.Content.ReadAsStringAsync(cancellationToken).GetAwaiter().GetResult(); - internal static Stream GetResponseStream(HttpResponseMessage response, CancellationToken cancellationToken) => response.Content.ReadAsStreamAsync(cancellationToken).GetAwaiter().GetResult(); #endregion Static Methods diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/WebRequestSession.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/WebRequestSession.cs index d4a9a5cc48b..0d2ec20089e 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/WebRequestSession.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/WebRequestSession.cs @@ -32,7 +32,7 @@ public class WebRequestSession : IDisposable private bool _skipCertificateCheck; private bool _noProxy; private bool _disposed; - private int _timeoutSec; + private int _connectTimeoutSec; /// /// Contains true if an existing HttpClient had to be disposed and recreated since the WebSession was last used. @@ -142,7 +142,7 @@ public WebRequestSession() internal bool SkipCertificateCheck { set => SetStructVar(ref _skipCertificateCheck, value); } - internal int TimeoutSec { set => SetStructVar(ref _timeoutSec, value); } + internal int ConnectTimeoutSec { set => SetStructVar(ref _connectTimeoutSec, value); } internal bool NoProxy { @@ -240,7 +240,7 @@ private HttpClient CreateHttpClient() // Check timeout setting (in seconds instead of milliseconds as in HttpWebRequest) return new HttpClient(handler) { - Timeout = _timeoutSec is 0 ? TimeSpan.FromMilliseconds(Timeout.Infinite) : TimeSpan.FromSeconds(_timeoutSec) + Timeout = _connectTimeoutSec is 0 ? TimeSpan.FromMilliseconds(Timeout.Infinite) : TimeSpan.FromSeconds(_connectTimeoutSec) }; } diff --git a/test/powershell/Modules/Microsoft.PowerShell.Utility/WebCmdlets.Tests.ps1 b/test/powershell/Modules/Microsoft.PowerShell.Utility/WebCmdlets.Tests.ps1 index 2b2968ea3da..a438f3b80d7 100644 --- a/test/powershell/Modules/Microsoft.PowerShell.Utility/WebCmdlets.Tests.ps1 +++ b/test/powershell/Modules/Microsoft.PowerShell.Utility/WebCmdlets.Tests.ps1 @@ -242,7 +242,7 @@ function ExecuteRequestWithCustomUserAgent { try { $Params = @{ Uri = $Uri - TimeoutSec = 5 + ConnectTimeoutSec = 5 UserAgent = $UserAgent SkipHeaderValidation = $SkipHeaderValidation.IsPresent } @@ -608,7 +608,15 @@ Describe "Invoke-WebRequest tests" -Tags "Feature", "RequireAdminOnWindows" { $Result.Output.Content | Should -Match '测试123' } - It "Invoke-WebRequest validate timeout option" { + It "Invoke-WebRequest validate ConnectTimeoutSec option" { + $uri = Get-WebListenerUrl -Test 'Delay' -TestValue '5' + $command = "Invoke-WebRequest -Uri '$uri' -ConnectTimeoutSec 2" + + $result = ExecuteWebCommand -command $command + $result.Error.FullyQualifiedErrorId | Should -Be "System.Threading.Tasks.TaskCanceledException,Microsoft.PowerShell.Commands.InvokeWebRequestCommand" + } + + It "Invoke-WebRequest validate TimeoutSec alias" { $uri = Get-WebListenerUrl -Test 'Delay' -TestValue '5' $command = "Invoke-WebRequest -Uri '$uri' -TimeoutSec 2" @@ -2596,7 +2604,15 @@ Describe "Invoke-RestMethod tests" -Tags "Feature", "RequireAdminOnWindows" { $Result.Output | Should -Match '测试123' } - It "Invoke-RestMethod validate timeout option" { + It "Invoke-RestMethod validate ConnectTimeoutSec option" { + $uri = Get-WebListenerUrl -Test 'Delay' -TestValue '5' + $command = "Invoke-RestMethod -Uri '$uri' -ConnectTimeoutSec 2" + + $result = ExecuteWebCommand -command $command + $result.Error.FullyQualifiedErrorId | Should -Be "System.Threading.Tasks.TaskCanceledException,Microsoft.PowerShell.Commands.InvokeRestMethodCommand" + } + + It "Invoke-RestMethod validate TimeoutSec alias" { $uri = Get-WebListenerUrl -Test 'Delay' -TestValue '5' $command = "Invoke-RestMethod -Uri '$uri' -TimeoutSec 2" @@ -4266,7 +4282,12 @@ Describe 'Invoke-WebRequest and Invoke-RestMethod support Cancellation through C RunWithCancellation -Uri $uri } - It 'Invoke-WebRequest: Defalate Compression CTRL-C Cancels request after request headers' { + It 'Invoke-WebRequest: Gzip Compression CTRL-C Cancels request after request headers with Content-Length' { + $uri = Get-WebListenerUrl -Test StallGzip -TestValue '30/application%2fjson?contentLength=true' + RunWithCancellation -Uri $uri + } + + It 'Invoke-WebRequest: Deflate Compression CTRL-C Cancels request after request headers' { $uri = Get-WebListenerUrl -Test StallDeflate -TestValue '30/application%2fjson' RunWithCancellation -Uri $uri } @@ -4336,3 +4357,80 @@ Describe 'Invoke-WebRequest and Invoke-RestMethod support Cancellation through C RunWithCancellation -Command 'Invoke-RestMethod' -Uri $uri } } + +Describe 'Invoke-WebRequest and Invoke-RestMethod support NetworkTimeoutSec' -Tags "CI", "RequireAdminOnWindows" { + BeforeAll { + $oldProgress = $ProgressPreference + $ProgressPreference = 'SilentlyContinue' + $WebListener = Start-WebListener + } + + AfterAll { + $ProgressPreference = $oldProgress + } + + function RunWithNetworkTimeout { + param( + [ValidateSet('Invoke-WebRequest', 'Invoke-RestMethod')] + [string]$Command = 'Invoke-WebRequest', + [string]$Arguments = '', + [uri]$Uri, + [int]$NetworkTimeoutSec, + [switch]$WillTimeout + ) + + $invoke = "$Command -Uri `"$Uri`" $Arguments" + if ($PSBoundParameters.ContainsKey('NetworkTimeoutSec')) { + $invoke = "$invoke -NetworkTimeoutSec $NetworkTimeoutSec" + } + + $result = ExecuteWebCommand -command $invoke + if ($WillTimeout) { + $result.Error | Should -Not -BeNullOrEmpty + $result.Output | Should -BeNullOrEmpty + } else { + $result.Error | Should -BeNullOrEmpty + $result.Output | Should -Not -BeNullOrEmpty + } + } + + It 'Invoke-WebRequest: NetworkTimeoutSec does not cancel if download takes longer than timeout' { + $uri = Get-WebListenerUrl -Test Stall -TestValue '2' -Query @{ chunks = 5 } + RunWithNetworkTimeout -Uri $uri -NetworkTimeoutSec 4 + } + + It 'Invoke-WebRequest: NetworkTimeoutSec cancels if stall lasts longer than NetworkTimeoutSec value' { + $uri = Get-WebListenerUrl -Test Stall -TestValue 30 + RunWithNetworkTimeout -Uri $uri -NetworkTimeoutSec 3 -WillTimeout + } + + It 'Invoke-WebRequest: NetworkTimeoutSec cancels if stall lasts longer than NetworkTimeoutSec value for HTTPS/gzip compression' { + $uri = Get-WebListenerUrl -Https -Test StallGzip -TestValue 30 + RunWithNetworkTimeout -Uri $uri -NetworkTimeoutSec 3 -WillTimeout -Arguments '-SkipCertificateCheck' + } + + It 'Invoke-RestMethod: NetworkTimeoutSec does not cancel if download takes longer than timeout' { + $uri = Get-WebListenerUrl -Test Stall -TestValue '2' -Query @{ chunks = 5 } + RunWithNetworkTimeout -Command Invoke-RestMethod -Uri $uri -NetworkTimeoutSec 4 + } + + It 'Invoke-RestMethod: NetworkTimeoutSec cancels if stall lasts longer than NetworkTimeoutSec value' { + $uri = Get-WebListenerUrl -Test Stall -TestValue 30 + RunWithNetworkTimeout -Command Invoke-RestMethod -Uri $uri -NetworkTimeoutSec 2 -WillTimeout + } + + It 'Invoke-RestMethod: NetworkTimeoutSec cancels when doing XML atom processing' { + $uri = Get-WebListenerUrl -Test Stall -TestValue '30/application%2fxml' + RunWithNetworkTimeout -Command Invoke-RestMethod -Uri $uri -NetworkTimeoutSec 2 -WillTimeout + } + + It 'Invoke-RestMethod: NetworkTimeoutSec cancels when doing JSON processing' { + $uri = Get-WebListenerUrl -Test Stall -TestValue '30/application%2fjson' + RunWithNetworkTimeout -Command Invoke-RestMethod -Uri $uri -NetworkTimeoutSec 2 -WillTimeout + } + + It 'Invoke-RestMethod: NetworkTimeoutSec cancels when doing XML atom processing for HTTPS/gzip compression' { + $uri = Get-WebListenerUrl -Https -Test StallGzip -TestValue 30/application%2fXML + RunWithNetworkTimeout -Command Invoke-RestMethod -Uri $uri -NetworkTimeoutSec 2 -WillTimeout -Arguments '-SkipCertificateCheck' + } +} diff --git a/test/tools/WebListener/Controllers/DelayController.cs b/test/tools/WebListener/Controllers/DelayController.cs index 5ab081d580d..5fdd9051d3d 100644 --- a/test/tools/WebListener/Controllers/DelayController.cs +++ b/test/tools/WebListener/Controllers/DelayController.cs @@ -54,33 +54,33 @@ public JsonResult Index(int seconds) return getController.Index(); } - public async Task Stall(int seconds, string contentType, CancellationToken cancellationToken) + public async Task Stall(int seconds, string contentType, int chunks, bool contentLength, CancellationToken cancellationToken) { - await WriteStallResponse(seconds, contentType, null, null, cancellationToken); + await WriteStallResponse(seconds, contentType, chunks, contentLength, null, null, cancellationToken); } - public async Task StallBrotli(int seconds, string contentType, CancellationToken cancellationToken) + public async Task StallBrotli(int seconds, string contentType, int chunks, bool contentLength, CancellationToken cancellationToken) { using var memStream = new MemoryStream(); using var compressedStream = new BrotliStream(memStream, CompressionLevel.Fastest); Response.Headers.ContentEncoding = "br"; - await WriteStallResponse(seconds, contentType, compressedStream, memStream, cancellationToken); + await WriteStallResponse(seconds, contentType, chunks, contentLength, compressedStream, memStream, cancellationToken); } - public async Task StallDeflate(int seconds, string contentType, CancellationToken cancellationToken) + public async Task StallDeflate(int seconds, string contentType, int chunks, bool contentLength, CancellationToken cancellationToken) { using var memStream = new MemoryStream(); using var compressedStream = new DeflateStream(memStream, CompressionLevel.Fastest); Response.Headers.ContentEncoding = "deflate"; - await WriteStallResponse(seconds, contentType, compressedStream, memStream, cancellationToken); + await WriteStallResponse(seconds, contentType, chunks, contentLength, compressedStream, memStream, cancellationToken); } - public async Task StallGZip(int seconds, string contentType, CancellationToken cancellationToken) + public async Task StallGZip(int seconds, string contentType, int chunks, bool contentLength, CancellationToken cancellationToken) { using var memStream = new MemoryStream(); using var compressedStream = new GZipStream(memStream, CompressionLevel.Fastest); Response.Headers.ContentEncoding = "gzip"; - await WriteStallResponse(seconds, contentType, compressedStream, memStream, cancellationToken); + await WriteStallResponse(seconds, contentType, chunks, contentLength, compressedStream, memStream, cancellationToken); } public IActionResult Error() @@ -88,7 +88,7 @@ public IActionResult Error() return View(new ErrorViewModel { RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier }); } - private async Task WriteStallResponse(int seconds, string contentType, Stream stream, MemoryStream memStream, CancellationToken cancellationToken) + private async Task WriteStallResponse(int seconds, string contentType, int chunks, bool contentLength, Stream stream, MemoryStream memStream, CancellationToken cancellationToken) { if (string.IsNullOrWhiteSpace(contentType)) { @@ -124,20 +124,41 @@ private async Task WriteStallResponse(int seconds, string contentType, Stream st stream.Close(); response = memStream.ToArray(); } - int midPoint = response.Length / 2; - - // Start writing approx half the content, including headers and then delay before writing the rest. - await Response.Body.WriteAsync(response, 0, midPoint, cancellationToken); - await Response.Body.FlushAsync(cancellationToken); + if (chunks < 2) + { + chunks = 2; + } + if (chunks > response.Length) + { + throw new InvalidDataException($"Response message is not big enough to break into {chunks} chunks. (Size {response.Length} bytes)."); + } - if (seconds > 0) + if (contentLength) { - int milliseconds = seconds * 1000; - await Task.Delay(milliseconds); + Response.ContentLength = response.Length; } + int chunkSize = response.Length / chunks; + int currentPos = 0; - await Response.Body.WriteAsync(response, midPoint, response.Length - midPoint, cancellationToken); - await Response.Body.FlushAsync(cancellationToken); + // Write each of the content chunks followed by a delay + // The last segment makes up the remainder of the content if + // it doesn't divide neatly into the required chunks + for (int i = 0; i < chunks; i++) + { + if (i == chunks - 1) + { + chunkSize = response.Length - currentPos; + seconds = 0; + } + await Response.Body.WriteAsync(response, currentPos, chunkSize, cancellationToken); + await Response.Body.FlushAsync(cancellationToken); + currentPos += chunkSize; + if (seconds > 0) + { + int milliseconds = seconds * 1000; + await Task.Delay(milliseconds); + } + } } } }