Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions src/Renci.SshNet/SftpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ public void DownloadFile(string path, Stream output, Action<ulong>? downloadCall

if (downloadCallback != null)
{
downloadProgress = new Progress<DownloadFileProgressReport>(r => downloadCallback(r.TotalBytesDownloaded));
downloadProgress = new ThreadPoolProgress<DownloadFileProgressReport>(r => downloadCallback(r.TotalBytesDownloaded));
}

InternalDownloadFile(
Expand Down Expand Up @@ -934,7 +934,7 @@ public Task DownloadFileAsync(string path, Stream output, IProgress<DownloadFile
path,
output,
asyncResult: null,
downloadProgress: downloadProgress,
downloadProgress,
isAsync: true,
cancellationToken);
}
Expand Down Expand Up @@ -1011,7 +1011,11 @@ public IAsyncResult BeginDownloadFile(string path, Stream output, AsyncCallback?

if (downloadCallback != null)
{
downloadProgress = new Progress<DownloadFileProgressReport>(r => downloadCallback(r.TotalBytesDownloaded));
// The System.Progress<T> ctor captures the current synchronization context
// and posts the progress reports to it. For back-compat with previous
// versions which always posted the callback to the threadpool regardless of
// sync context, we use a custom IProgress<T> impl.
downloadProgress = new ThreadPoolProgress<DownloadFileProgressReport>(r => downloadCallback(r.TotalBytesDownloaded));
}

var asyncResult = new SftpDownloadAsyncResult(asyncCallback, state);
Expand Down Expand Up @@ -1089,7 +1093,7 @@ public void UploadFile(Stream input, string path, bool canOverride, Action<ulong

if (uploadCallback != null)
{
uploadProgress = new Progress<UploadFileProgressReport>(r => uploadCallback(r.TotalBytesUploaded));
uploadProgress = new ThreadPoolProgress<UploadFileProgressReport>(r => uploadCallback(r.TotalBytesUploaded));
}
Comment thread
Rob-Hague marked this conversation as resolved.

InternalUploadFile(
Expand Down Expand Up @@ -1273,7 +1277,11 @@ public IAsyncResult BeginUploadFile(Stream input, string path, bool canOverride,

if (uploadCallback != null)
{
uploadProgress = new Progress<UploadFileProgressReport>(r => uploadCallback(r.TotalBytesUploaded));
// The System.Progress<T> ctor captures the current synchronization context
// and posts the progress reports to it. For back-compat with previous
// versions which always posted the callback to the threadpool regardless of
// sync context, we use a custom IProgress<T> impl.
uploadProgress = new ThreadPoolProgress<UploadFileProgressReport>(r => uploadCallback(r.TotalBytesUploaded));
}

var asyncResult = new SftpUploadAsyncResult(asyncCallback, state);
Expand Down Expand Up @@ -2417,16 +2425,10 @@ private async Task InternalDownloadFile(

asyncResult?.Update(totalBytesRead);

if (downloadProgress is not null)
downloadProgress?.Report(new DownloadFileProgressReport()
{
// Copy offset to ensure it's not modified between now and execution of callback
var report = new DownloadFileProgressReport()
{
TotalBytesDownloaded = totalBytesRead,
};

downloadProgress.Report(report);
}
TotalBytesDownloaded = totalBytesRead
});
}
}
finally
Expand Down Expand Up @@ -2536,16 +2538,10 @@ private async Task InternalUploadFile(

asyncResult?.Update(writtenBytes);

// Call callback to report number of bytes written
if (uploadProgress is not null)
uploadProgress?.Report(new UploadFileProgressReport()
{
UploadFileProgressReport report = new()
{
TotalBytesUploaded = writtenBytes,
};

uploadProgress.Report(report);
}
TotalBytesUploaded = writtenBytes
});
}
finally
{
Expand Down Expand Up @@ -2652,5 +2648,29 @@ private ISftpSession CreateAndConnectToSftpSession()
throw;
}
}

/// <summary>
/// An <see cref="IProgress{T}"/> implementation that posts callbacks to the threadpool.
/// </summary>
private sealed class ThreadPoolProgress<T> : IProgress<T>
{
private readonly Action<T> _handler;

public ThreadPoolProgress(Action<T> handler)
{
Debug.Assert(handler != null);
_handler = handler!;
}
Comment thread
Rob-Hague marked this conversation as resolved.

void IProgress<T>.Report(T value)
{
_ = ThreadPool.QueueUserWorkItem(static state =>
{
var (handler, value) = ((Action<T>, T))state!;
handler(value);
},
(_handler, value));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -324,77 +324,104 @@ public void Test_Sftp_Ensure_Async_Delegates_Called_For_BeginFileUpload_BeginFil

var remoteFileName = Path.GetRandomFileName();
var localFileName = Path.GetRandomFileName();
var uploadDelegateCalled = false;
var downloadDelegateCalled = false;
var listDirectoryDelegateCalled = false;
using var uploadDelegateEvent = new ManualResetEventSlim();
using var downloadDelegateEvent = new ManualResetEventSlim();
using var listDirectoryDelegateEvent = new ManualResetEventSlim();
using var uploadCallbackEvent = new ManualResetEventSlim();
using var downloadCallbackEvent = new ManualResetEventSlim();
using var listDirectoryCallbackEvent = new ManualResetEventSlim();
IAsyncResult asyncResult;

// Test for BeginUploadFile.

CreateTestFile(localFileName, 1);

using (var fileStream = File.OpenRead(localFileName))
var originalContext = SynchronizationContext.Current;
try
{
asyncResult = sftp.BeginUploadFile(fileStream,
remoteFileName,
delegate (IAsyncResult ar)
{
sftp.EndUploadFile(ar);
uploadDelegateCalled = true;
},
null);

while (!asyncResult.IsCompleted)
// Set a throwing context to verify it's not captured by the callback
SynchronizationContext.SetSynchronizationContext(new ThrowingSynchronizationContext());

using (var fileStream = File.OpenRead(localFileName))
{
Thread.Sleep(500);
asyncResult = sftp.BeginUploadFile(fileStream,
remoteFileName,
delegate (IAsyncResult ar)
{
uploadDelegateEvent.Set();
},
state: null,
uploadCallback: _ => uploadCallbackEvent.Set());

sftp.EndUploadFile(asyncResult);
}
}
finally
{
SynchronizationContext.SetSynchronizationContext(originalContext);
}

File.Delete(localFileName);

Assert.IsTrue(uploadDelegateCalled, "BeginUploadFile");
Assert.IsTrue(uploadDelegateEvent.Wait(1000));
Assert.IsTrue(uploadCallbackEvent.Wait(1000));
Comment thread
Rob-Hague marked this conversation as resolved.

// Test for BeginDownloadFile.

asyncResult = null;
using (var fileStream = File.OpenWrite(localFileName))
try
{
asyncResult = sftp.BeginDownloadFile(remoteFileName,
fileStream,
delegate (IAsyncResult ar)
{
sftp.EndDownloadFile(ar);
downloadDelegateCalled = true;
},
null);
// Set a throwing context to verify it's not captured by the callback
SynchronizationContext.SetSynchronizationContext(new ThrowingSynchronizationContext());

while (!asyncResult.IsCompleted)
using (var fileStream = File.OpenWrite(localFileName))
{
Thread.Sleep(500);
asyncResult = sftp.BeginDownloadFile(remoteFileName,
fileStream,
delegate (IAsyncResult ar)
{
downloadDelegateEvent.Set();
},
state: null,
downloadCallback: _ => downloadCallbackEvent.Set());

sftp.EndDownloadFile(asyncResult);
}
}
finally
{
SynchronizationContext.SetSynchronizationContext(originalContext);
}

File.Delete(localFileName);

Assert.IsTrue(downloadDelegateCalled, "BeginDownloadFile");
Assert.IsTrue(downloadDelegateEvent.Wait(1000));
Assert.IsTrue(downloadCallbackEvent.Wait(1000));
Comment thread
Rob-Hague marked this conversation as resolved.

// Test for BeginListDirectory.

asyncResult = null;
asyncResult = sftp.BeginListDirectory(sftp.WorkingDirectory,
delegate (IAsyncResult ar)
{
_ = sftp.EndListDirectory(ar);
listDirectoryDelegateCalled = true;
},
null);

while (!asyncResult.IsCompleted)
try
{
Thread.Sleep(500);
// Set a throwing context to verify it's not captured by the callback
SynchronizationContext.SetSynchronizationContext(new ThrowingSynchronizationContext());

asyncResult = sftp.BeginListDirectory(sftp.WorkingDirectory,
delegate (IAsyncResult ar)
{
listDirectoryDelegateEvent.Set();
},
state: null,
listCallback: _ => listDirectoryCallbackEvent.Set());

_ = sftp.EndListDirectory(asyncResult);
}
finally
{
SynchronizationContext.SetSynchronizationContext(originalContext);
}

Assert.IsTrue(listDirectoryDelegateCalled, "BeginListDirectory");
Assert.IsTrue(listDirectoryDelegateEvent.Wait(1000));
Assert.IsTrue(listDirectoryCallbackEvent.Wait(1000));
Comment thread
Rob-Hague marked this conversation as resolved.
}
}

Expand Down Expand Up @@ -482,5 +509,18 @@ public async Task Test_Sftp_UploadFileAsync_UploadProgress()
Assert.IsTrue(callbackCalled);
}
}

private sealed class ThrowingSynchronizationContext : SynchronizationContext
{
public override void Post(SendOrPostCallback d, object state)
{
throw new InvalidOperationException();
}
Comment thread
Rob-Hague marked this conversation as resolved.

public override void Send(SendOrPostCallback d, object state)
{
throw new InvalidOperationException();
}
Comment thread
Rob-Hague marked this conversation as resolved.
}
}
}