diff --git a/src/Renci.SshNet/SftpClient.cs b/src/Renci.SshNet/SftpClient.cs index 960c6261f..fc4cf2e51 100644 --- a/src/Renci.SshNet/SftpClient.cs +++ b/src/Renci.SshNet/SftpClient.cs @@ -905,7 +905,7 @@ public void DownloadFile(string path, Stream output, Action? downloadCall if (downloadCallback != null) { - downloadProgress = new Progress(r => downloadCallback(r.TotalBytesDownloaded)); + downloadProgress = new ThreadPoolProgress(r => downloadCallback(r.TotalBytesDownloaded)); } InternalDownloadFile( @@ -934,7 +934,7 @@ public Task DownloadFileAsync(string path, Stream output, IProgress(r => downloadCallback(r.TotalBytesDownloaded)); + // The System.Progress ctor captures the current synchronization context + // and posts the progress reports to it. For back-compat with previous + // versions which always posted the callback to the threadpool regardless of + // sync context, we use a custom IProgress impl. + downloadProgress = new ThreadPoolProgress(r => downloadCallback(r.TotalBytesDownloaded)); } var asyncResult = new SftpDownloadAsyncResult(asyncCallback, state); @@ -1089,7 +1093,7 @@ public void UploadFile(Stream input, string path, bool canOverride, Action(r => uploadCallback(r.TotalBytesUploaded)); + uploadProgress = new ThreadPoolProgress(r => uploadCallback(r.TotalBytesUploaded)); } InternalUploadFile( @@ -1273,7 +1277,11 @@ public IAsyncResult BeginUploadFile(Stream input, string path, bool canOverride, if (uploadCallback != null) { - uploadProgress = new Progress(r => uploadCallback(r.TotalBytesUploaded)); + // The System.Progress ctor captures the current synchronization context + // and posts the progress reports to it. For back-compat with previous + // versions which always posted the callback to the threadpool regardless of + // sync context, we use a custom IProgress impl. + uploadProgress = new ThreadPoolProgress(r => uploadCallback(r.TotalBytesUploaded)); } var asyncResult = new SftpUploadAsyncResult(asyncCallback, state); @@ -2417,16 +2425,10 @@ private async Task InternalDownloadFile( asyncResult?.Update(totalBytesRead); - if (downloadProgress is not null) + downloadProgress?.Report(new DownloadFileProgressReport() { - // Copy offset to ensure it's not modified between now and execution of callback - var report = new DownloadFileProgressReport() - { - TotalBytesDownloaded = totalBytesRead, - }; - - downloadProgress.Report(report); - } + TotalBytesDownloaded = totalBytesRead + }); } } finally @@ -2536,16 +2538,10 @@ private async Task InternalUploadFile( asyncResult?.Update(writtenBytes); - // Call callback to report number of bytes written - if (uploadProgress is not null) + uploadProgress?.Report(new UploadFileProgressReport() { - UploadFileProgressReport report = new() - { - TotalBytesUploaded = writtenBytes, - }; - - uploadProgress.Report(report); - } + TotalBytesUploaded = writtenBytes + }); } finally { @@ -2652,5 +2648,29 @@ private ISftpSession CreateAndConnectToSftpSession() throw; } } + + /// + /// An implementation that posts callbacks to the threadpool. + /// + private sealed class ThreadPoolProgress : IProgress + { + private readonly Action _handler; + + public ThreadPoolProgress(Action handler) + { + Debug.Assert(handler != null); + _handler = handler!; + } + + void IProgress.Report(T value) + { + _ = ThreadPool.QueueUserWorkItem(static state => + { + var (handler, value) = ((Action, T))state!; + handler(value); + }, + (_handler, value)); + } + } } } diff --git a/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SftpClientTest.Upload.cs b/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SftpClientTest.Upload.cs index 40c296c7f..064487465 100644 --- a/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SftpClientTest.Upload.cs +++ b/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SftpClientTest.Upload.cs @@ -324,77 +324,104 @@ public void Test_Sftp_Ensure_Async_Delegates_Called_For_BeginFileUpload_BeginFil var remoteFileName = Path.GetRandomFileName(); var localFileName = Path.GetRandomFileName(); - var uploadDelegateCalled = false; - var downloadDelegateCalled = false; - var listDirectoryDelegateCalled = false; + using var uploadDelegateEvent = new ManualResetEventSlim(); + using var downloadDelegateEvent = new ManualResetEventSlim(); + using var listDirectoryDelegateEvent = new ManualResetEventSlim(); + using var uploadCallbackEvent = new ManualResetEventSlim(); + using var downloadCallbackEvent = new ManualResetEventSlim(); + using var listDirectoryCallbackEvent = new ManualResetEventSlim(); IAsyncResult asyncResult; // Test for BeginUploadFile. CreateTestFile(localFileName, 1); - using (var fileStream = File.OpenRead(localFileName)) + var originalContext = SynchronizationContext.Current; + try { - asyncResult = sftp.BeginUploadFile(fileStream, - remoteFileName, - delegate (IAsyncResult ar) - { - sftp.EndUploadFile(ar); - uploadDelegateCalled = true; - }, - null); - - while (!asyncResult.IsCompleted) + // Set a throwing context to verify it's not captured by the callback + SynchronizationContext.SetSynchronizationContext(new ThrowingSynchronizationContext()); + + using (var fileStream = File.OpenRead(localFileName)) { - Thread.Sleep(500); + asyncResult = sftp.BeginUploadFile(fileStream, + remoteFileName, + delegate (IAsyncResult ar) + { + uploadDelegateEvent.Set(); + }, + state: null, + uploadCallback: _ => uploadCallbackEvent.Set()); + + sftp.EndUploadFile(asyncResult); } } + finally + { + SynchronizationContext.SetSynchronizationContext(originalContext); + } File.Delete(localFileName); - Assert.IsTrue(uploadDelegateCalled, "BeginUploadFile"); + Assert.IsTrue(uploadDelegateEvent.Wait(1000)); + Assert.IsTrue(uploadCallbackEvent.Wait(1000)); // Test for BeginDownloadFile. asyncResult = null; - using (var fileStream = File.OpenWrite(localFileName)) + try { - asyncResult = sftp.BeginDownloadFile(remoteFileName, - fileStream, - delegate (IAsyncResult ar) - { - sftp.EndDownloadFile(ar); - downloadDelegateCalled = true; - }, - null); + // Set a throwing context to verify it's not captured by the callback + SynchronizationContext.SetSynchronizationContext(new ThrowingSynchronizationContext()); - while (!asyncResult.IsCompleted) + using (var fileStream = File.OpenWrite(localFileName)) { - Thread.Sleep(500); + asyncResult = sftp.BeginDownloadFile(remoteFileName, + fileStream, + delegate (IAsyncResult ar) + { + downloadDelegateEvent.Set(); + }, + state: null, + downloadCallback: _ => downloadCallbackEvent.Set()); + + sftp.EndDownloadFile(asyncResult); } } + finally + { + SynchronizationContext.SetSynchronizationContext(originalContext); + } File.Delete(localFileName); - Assert.IsTrue(downloadDelegateCalled, "BeginDownloadFile"); + Assert.IsTrue(downloadDelegateEvent.Wait(1000)); + Assert.IsTrue(downloadCallbackEvent.Wait(1000)); // Test for BeginListDirectory. - asyncResult = null; - asyncResult = sftp.BeginListDirectory(sftp.WorkingDirectory, - delegate (IAsyncResult ar) - { - _ = sftp.EndListDirectory(ar); - listDirectoryDelegateCalled = true; - }, - null); - - while (!asyncResult.IsCompleted) + try { - Thread.Sleep(500); + // Set a throwing context to verify it's not captured by the callback + SynchronizationContext.SetSynchronizationContext(new ThrowingSynchronizationContext()); + + asyncResult = sftp.BeginListDirectory(sftp.WorkingDirectory, + delegate (IAsyncResult ar) + { + listDirectoryDelegateEvent.Set(); + }, + state: null, + listCallback: _ => listDirectoryCallbackEvent.Set()); + + _ = sftp.EndListDirectory(asyncResult); + } + finally + { + SynchronizationContext.SetSynchronizationContext(originalContext); } - Assert.IsTrue(listDirectoryDelegateCalled, "BeginListDirectory"); + Assert.IsTrue(listDirectoryDelegateEvent.Wait(1000)); + Assert.IsTrue(listDirectoryCallbackEvent.Wait(1000)); } } @@ -482,5 +509,18 @@ public async Task Test_Sftp_UploadFileAsync_UploadProgress() Assert.IsTrue(callbackCalled); } } + + private sealed class ThrowingSynchronizationContext : SynchronizationContext + { + public override void Post(SendOrPostCallback d, object state) + { + throw new InvalidOperationException(); + } + + public override void Send(SendOrPostCallback d, object state) + { + throw new InvalidOperationException(); + } + } } }