Skip to content

Commit 80c1e2e

Browse files
authored
Allow InProcessRuntime subscriptions to be processed concurrently (#179)
* Allow InProcessRuntime subscriptions to be processed concurrently * Fix tests
1 parent 18299bc commit 80c1e2e

2 files changed

Lines changed: 11 additions & 20 deletions

File tree

dotnet/src/Microsoft.Extensions.AI.Agents.Runtime.Abstractions/InProcessRuntime.cs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -262,21 +262,22 @@ async Task WaitAndRemoveAsync(long taskId, ValueTask processTask)
262262
{
263263
Debug.Assert(message.Topic.HasValue);
264264

265-
List<Exception>? exceptions = null;
265+
List<Task>? tasks = null;
266266
TopicId topic = message.Topic!.Value;
267267
foreach (KeyValuePair<string, ISubscriptionDefinition> subscription in message.Runtime._subscriptions)
268268
{
269-
if (!subscription.Value.Matches(topic))
269+
if (subscription.Value.Matches(topic))
270270
{
271-
continue;
271+
(tasks ??= []).Add(ProcessSubscriptionAsync(message, subscription.Value, topic, cancellationToken));
272272
}
273273

274-
try
274+
static async Task ProcessSubscriptionAsync(
275+
MessageToProcess message, ISubscriptionDefinition subscription, TopicId topic, CancellationToken cancellationToken)
275276
{
276277
using CancellationTokenSource combinedSource = CancellationTokenSource.CreateLinkedTokenSource(message.Cancellation, cancellationToken);
277278
combinedSource.Token.ThrowIfCancellationRequested();
278279

279-
ActorId actorId = subscription.Value.MapToActor(topic);
280+
ActorId actorId = subscription.MapToActor(topic);
280281
ActorId? sender = message.Sender;
281282
if (sender is null || sender != actorId)
282283
{
@@ -289,15 +290,11 @@ async Task WaitAndRemoveAsync(long taskId, ValueTask processTask)
289290
}, combinedSource.Token).ConfigureAwait(false);
290291
}
291292
}
292-
catch (Exception ex)
293-
{
294-
(exceptions ??= []).Add(ex);
295-
}
296293
}
297294

298-
if (exceptions is not null)
295+
if (tasks is not null)
299296
{
300-
throw new AggregateException("One or more exceptions occurred while processing the message.", exceptions);
297+
await Task.WhenAll(tasks).ConfigureAwait(false);
301298
}
302299

303300
// This method is effectively void, with the result never being used. But it's typed the same as SendMessageServicerAsync

dotnet/tests/Microsoft.Extensions.AI.Agents.Runtime.Abstractions.UnitTests/PublishMessageTests.cs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3-
using System;
43
using System.Threading.Tasks;
54

65
namespace Microsoft.Extensions.AI.Agents.Runtime.InProcess.Tests;
@@ -35,8 +34,7 @@ public async Task Test_PublishMessage_SingleFailureAsync()
3534
await fixture.RegisterErrorAgentAsync(topicTypes: "TestTopic");
3635

3736
// Test that we wrap single errors appropriately
38-
var e = await Assert.ThrowsAsync<AggregateException>(async () => await fixture.RunPublishTestAsync(new TopicId("TestTopic"), new BasicMessage { Content = "1" }));
39-
Assert.IsType<TestException>(Assert.Single(e.InnerExceptions));
37+
await Assert.ThrowsAsync<TestException>(async () => await fixture.RunPublishTestAsync(new TopicId("TestTopic"), new BasicMessage { Content = "1" }));
4038

4139
var values = fixture.GetAgentInstances<ReceiverAgent>().Values;
4240
}
@@ -50,9 +48,7 @@ public async Task Test_PublishMessage_MultipleFailuresAsync()
5048
await fixture.RegisterErrorAgentAsync("2", topicTypes: "TestTopic");
5149

5250
// What we are really testing here is that a single exception does not prevent sending to the remaining agents
53-
var e = await Assert.ThrowsAsync<AggregateException>(async () => await fixture.RunPublishTestAsync(new TopicId("TestTopic"), new BasicMessage { Content = "1" }));
54-
Assert.Equal(2, e.InnerExceptions.Count);
55-
Assert.All(e.InnerExceptions, innerException => Assert.IsType<TestException>(innerException));
51+
await Assert.ThrowsAsync<TestException>(async () => await fixture.RunPublishTestAsync(new TopicId("TestTopic"), new BasicMessage { Content = "1" }));
5652

5753
var values = fixture.GetAgentInstances<ErrorAgent>().Values;
5854
Assert.Equal(2, values.Count);
@@ -70,9 +66,7 @@ public async Task Test_PublishMessage_MixedSuccessFailureAsync()
7066
await fixture.RegisterErrorAgentAsync("2", topicTypes: "TestTopic");
7167

7268
// What we are really testing here is that raising exceptions does not prevent sending to the remaining agents
73-
var e = await Assert.ThrowsAsync<AggregateException>(async () => await fixture.RunPublishTestAsync(new TopicId("TestTopic"), new BasicMessage { Content = "1" }));
74-
Assert.Equal(2, e.InnerExceptions.Count);
75-
Assert.All(e.InnerExceptions, innerException => Assert.IsType<TestException>(innerException));
69+
await Assert.ThrowsAsync<TestException>(async () => await fixture.RunPublishTestAsync(new TopicId("TestTopic"), new BasicMessage { Content = "1" }));
7670

7771
var agents = fixture.GetAgentInstances<ReceiverAgent>().Values;
7872
Assert.Equal(2, agents.Count);

0 commit comments

Comments
 (0)