Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 60 additions & 17 deletions dotnet/src/Microsoft.Agents.AI.Workflows/HandoffWorkflowBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;

using ExecutorFactoryFunc = System.Func<Microsoft.Agents.AI.Workflows.ExecutorConfig<Microsoft.Agents.AI.Workflows.ExecutorOptions>,
string,
System.Threading.Tasks.ValueTask<Microsoft.Agents.AI.Workflows.Specialized.HandoffAgentExecutor>>;

namespace Microsoft.Agents.AI.Workflows;

internal static class DiagnosticConstants
Expand Down Expand Up @@ -233,24 +237,70 @@ public TBuilder WithHandoff(AIAgent from, AIAgent to, string? handoffReason = nu
return (TBuilder)this;
}

private Dictionary<string, ExecutorBinding> CreateExecutorBindings(WorkflowBuilder builder)
{
HandoffAgentExecutorOptions options = new(this.HandoffInstructions,
this._emitAgentResponseEvents,
this._emitAgentResponseUpdateEvents,
this._toolCallFilteringBehavior);

// There are two types of ids being used in this method, and it is critical that we are clear about
// which one we are using, and where.
// AgentId...: comes from AIAgent.Id, is often an unreadable machine identifier (e.g. a Guid), and is used to address
// the handoffs
// ExecutorId: uses AIAgent.GetDescriptiveId() to use a friendlier name in telemetry, and is used for ExecutorBinding,
// which are subsequently used in building the workflow

// The outgoing dictionary maps from AgentId => ExecutorBinding
return this._allAgents.ToDictionary(keySelector: a => a.Id, elementSelector: CreateFactoryBinding);

ExecutorBinding CreateFactoryBinding(AIAgent agent)
{
if (!this._targets.TryGetValue(agent, out HashSet<HandoffTarget>? handoffs))
{
handoffs = new();
}

// Use the ExecutorId as the placeholder id for a (possibly) future-bound factory
builder.AddSwitch(HandoffAgentExecutor.IdFor(agent), (SwitchBuilder sb) =>
{
foreach (HandoffTarget handoff in handoffs)
{
sb.AddCase<HandoffState>(state => state?.RequestedHandoffTargetAgentId == handoff.Target.Id, // Use AgentId for target matching
HandoffAgentExecutor.IdFor(handoff.Target)); // Use ExecutorId in for routing at the workflow level
}

sb.WithDefault(HandoffEndExecutor.ExecutorId);
});

ExecutorFactoryFunc factory =
(config, sessionId) => new(
new HandoffAgentExecutor(agent,
handoffs,
options));

// Make sure to use ExecutorId when binding the executor, not AgentId
ExecutorBinding binding = factory.BindExecutor(HandoffAgentExecutor.IdFor(agent));

builder.BindExecutor(binding);

return binding;
}
}

/// <summary>
/// Builds a <see cref="Workflow"/> composed of agents that operate via handoffs, with the next
/// agent to process messages selected by the current agent.
/// </summary>
/// <returns>The workflow built based on the handoffs in the builder.</returns>
public Workflow Build()
{
HandoffsStartExecutor start = new(this._returnToPrevious);
HandoffsEndExecutor end = new(this._returnToPrevious);
HandoffStartExecutor start = new(this._returnToPrevious);
HandoffEndExecutor end = new(this._returnToPrevious);
WorkflowBuilder builder = new(start);

HandoffAgentExecutorOptions options = new(this.HandoffInstructions,
this._emitAgentResponseEvents,
this._emitAgentResponseUpdateEvents,
this._toolCallFilteringBehavior);

// Create an AgentExecutor for each agent.
Dictionary<string, HandoffAgentExecutor> executors = this._allAgents.ToDictionary(a => a.Id, a => new HandoffAgentExecutor(a, options));
// Create an factory-based ExecutorBinding for each agent.
Dictionary<string, ExecutorBinding> executors = this.CreateExecutorBindings(builder);

// Connect the start executor to the initial agent (or use dynamic routing when ReturnToPrevious is enabled).
if (this._returnToPrevious)
Expand All @@ -263,7 +313,7 @@ public Workflow Build()
if (agent.Id != initialAgentId)
{
string agentId = agent.Id;
sb.AddCase<HandoffState>(state => state?.CurrentAgentId == agentId, executors[agentId]);
sb.AddCase<HandoffState>(state => state?.PreviousAgentId == agentId, executors[agentId]);
}
}

Expand All @@ -275,13 +325,6 @@ public Workflow Build()
builder.AddEdge(start, executors[this._initialAgent.Id]);
}

// Initialize each executor with its handoff targets to the other executors.
foreach (var agent in this._allAgents)
{
executors[agent.Id].Initialize(builder, end, executors,
this._targets.TryGetValue(agent, out HashSet<HandoffTarget>? targets) ? targets : []);
}

// Build the workflow.
return builder.WithOutputFrom(end).Build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ public static bool ShouldEmitStreamingEvents(this TurnToken token, bool? agentSe

public static bool ShouldEmitStreamingEvents(bool? turnTokenSetting, bool? agentSetting)
=> turnTokenSetting ?? agentSetting ?? false;

public static bool ShouldEmitStreamingEvents(this HandoffState handoffState, bool? agentSetting)
=> handoffState.TurnToken.ShouldEmitStreamingEvents(agentSetting);
}

internal sealed class AIAgentHostExecutor : ChatProtocolExecutor
Expand Down Expand Up @@ -81,7 +84,11 @@ private ValueTask HandleUserInputResponseAsync(
// resumes can be processed in one invocation.
return this.ProcessTurnMessagesAsync(async (pendingMessages, ctx, ct) =>
{
pendingMessages.Add(new ChatMessage(ChatRole.User, [response]));
pendingMessages.Add(new ChatMessage(ChatRole.User, [response])
{
CreatedAt = DateTimeOffset.UtcNow,
MessageId = Guid.NewGuid().ToString("N"),
});

await this.ContinueTurnAsync(pendingMessages, ctx, this._currentTurnEmitEvents ?? false, ct).ConfigureAwait(false);

Expand All @@ -104,7 +111,12 @@ private ValueTask HandleFunctionResultAsync(
// resumes can be processed in one invocation.
return this.ProcessTurnMessagesAsync(async (pendingMessages, ctx, ct) =>
{
pendingMessages.Add(new ChatMessage(ChatRole.Tool, [result]));
pendingMessages.Add(new ChatMessage(ChatRole.Tool, [result])
{
AuthorName = this._agent.Name ?? this._agent.Id,
CreatedAt = DateTimeOffset.UtcNow,
MessageId = Guid.NewGuid().ToString("N"),
});

await this.ContinueTurnAsync(pendingMessages, ctx, this._currentTurnEmitEvents ?? false, ct).ConfigureAwait(false);

Expand Down Expand Up @@ -186,16 +198,13 @@ protected override ValueTask TakeTurnAsync(List<ChatMessage> messages, IWorkflow
TurnExtensions.ShouldEmitStreamingEvents(turnTokenSetting: emitEvents, this._options.EmitAgentUpdateEvents),
cancellationToken);

private async ValueTask<AgentResponse> InvokeAgentAsync(IEnumerable<ChatMessage> messages, IWorkflowContext context, bool emitEvents, CancellationToken cancellationToken = default)
private async ValueTask<AgentResponse> InvokeAgentAsync(IEnumerable<ChatMessage> messages, IWorkflowContext context, bool emitUpdateEvents, CancellationToken cancellationToken = default)
{
#pragma warning disable MEAI001
Dictionary<string, ToolApprovalRequestContent> userInputRequests = new();
Dictionary<string, FunctionCallContent> functionCalls = new();
AgentResponse response;
AIAgentUnservicedRequestsCollector collector = new(this._userInputHandler, this._functionCallHandler);

if (emitEvents)
if (emitUpdateEvents)
{
#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
// Run the agent in streaming mode only when agent run update events are to be emitted.
IAsyncEnumerable<AgentResponseUpdate> agentStream = this._agent.RunStreamingAsync(
messages,
Expand All @@ -206,7 +215,7 @@ await this.EnsureSessionAsync(context, cancellationToken).ConfigureAwait(false),
await foreach (AgentResponseUpdate update in agentStream.ConfigureAwait(false))
{
await context.YieldOutputAsync(update, cancellationToken).ConfigureAwait(false);
ExtractUnservicedRequests(update.Contents);
collector.ProcessAgentResponseUpdate(update);
updates.Add(update);
}

Expand All @@ -220,53 +229,16 @@ await this.EnsureSessionAsync(context, cancellationToken).ConfigureAwait(false),
cancellationToken: cancellationToken)
.ConfigureAwait(false);

ExtractUnservicedRequests(response.Messages.SelectMany(message => message.Contents));
collector.ProcessAgentResponse(response);
}

if (this._options.EmitAgentResponseEvents)
{
await context.YieldOutputAsync(response, cancellationToken).ConfigureAwait(false);
}

if (userInputRequests.Count > 0 || functionCalls.Count > 0)
{
Task userInputTask = this._userInputHandler?.ProcessRequestContentsAsync(userInputRequests, context, cancellationToken) ?? Task.CompletedTask;
Task functionCallTask = this._functionCallHandler?.ProcessRequestContentsAsync(functionCalls, context, cancellationToken) ?? Task.CompletedTask;

await Task.WhenAll(userInputTask, functionCallTask)
.ConfigureAwait(false);
}
await collector.SubmitAsync(context, cancellationToken).ConfigureAwait(false);

return response;

void ExtractUnservicedRequests(IEnumerable<AIContent> contents)
{
foreach (AIContent content in contents)
{
if (content is ToolApprovalRequestContent userInputRequest)
{
// It is an error to simultaneously have multiple outstanding user input requests with the same ID.
userInputRequests.Add(userInputRequest.RequestId, userInputRequest);
}
else if (content is ToolApprovalResponseContent userInputResponse)
{
// If the set of messages somehow already has a corresponding user input response, remove it.
_ = userInputRequests.Remove(userInputResponse.RequestId);
}
else if (content is FunctionCallContent functionCall)
{
// For function calls, we emit an event to notify the workflow.
//
// possibility 1: this will be handled inline by the agent abstraction
// possibility 2: this will not be handled inline by the agent abstraction
functionCalls.Add(functionCall.CallId, functionCall);
}
else if (content is FunctionResultContent functionResult)
{
_ = functionCalls.Remove(functionResult.CallId);
}
}
}
#pragma warning restore MEAI001
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;

namespace Microsoft.Agents.AI.Workflows.Specialized;

internal sealed class AIAgentUnservicedRequestsCollector(AIContentExternalHandler<ToolApprovalRequestContent, ToolApprovalResponseContent>? userInputHandler,
AIContentExternalHandler<FunctionCallContent, FunctionResultContent>? functionCallHandler)
{
private readonly Dictionary<string, ToolApprovalRequestContent> _userInputRequests = [];
private readonly Dictionary<string, FunctionCallContent> _functionCalls = [];

public Task SubmitAsync(IWorkflowContext context, CancellationToken cancellationToken)
{
Task userInputTask = userInputHandler != null && this._userInputRequests.Count > 0
? userInputHandler.ProcessRequestContentsAsync(this._userInputRequests, context, cancellationToken)
: Task.CompletedTask;

Task functionCallTask = functionCallHandler != null && this._functionCalls.Count > 0
? functionCallHandler.ProcessRequestContentsAsync(this._functionCalls, context, cancellationToken)
: Task.CompletedTask;

return Task.WhenAll(userInputTask, functionCallTask);
}

public void ProcessAgentResponseUpdate(AgentResponseUpdate update, Func<FunctionCallContent, bool>? functionCallFilter = null)
=> this.ProcessAIContents(update.Contents, functionCallFilter);

public void ProcessAgentResponse(AgentResponse response)
=> this.ProcessAIContents(response.Messages.SelectMany(message => message.Contents));

public void ProcessAIContents(IEnumerable<AIContent> contents, Func<FunctionCallContent, bool>? functionCallFilter = null)
{
foreach (AIContent content in contents)
{
if (content is ToolApprovalRequestContent userInputRequest)
{
if (this._userInputRequests.ContainsKey(userInputRequest.RequestId))
{
throw new InvalidOperationException($"ToolApprovalRequestContent with duplicate RequestId: ${userInputRequest.RequestId}");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
throw new InvalidOperationException($"ToolApprovalRequestContent with duplicate RequestId: ${userInputRequest.RequestId}");
throw new InvalidOperationException($"ToolApprovalRequestContent with duplicate RequestId: {userInputRequest.RequestId}");

}

// It is an error to simultaneously have multiple outstanding user input requests with the same ID.
this._userInputRequests.Add(userInputRequest.RequestId, userInputRequest);
}
else if (content is ToolApprovalResponseContent userInputResponse)
{
// If the set of messages somehow already has a corresponding user input response, remove it.
_ = this._userInputRequests.Remove(userInputResponse.RequestId);
}
else if (content is FunctionCallContent functionCall)
{
// For function calls, we emit an event to notify the workflow.
//
// possibility 1: this will be handled inline by the agent abstraction
// possibility 2: this will not be handled inline by the agent abstraction
if (functionCallFilter == null || functionCallFilter(functionCall))
{
if (this._functionCalls.ContainsKey(functionCall.CallId))
{
throw new InvalidOperationException($"FunctionCallContent with duplicate CallId: ${functionCall.CallId}");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
throw new InvalidOperationException($"FunctionCallContent with duplicate CallId: ${functionCall.CallId}");
throw new InvalidOperationException($"FunctionCallContent with duplicate CallId: {functionCall.CallId}");

}

this._functionCalls.Add(functionCall.CallId, functionCall);
}
}
else if (content is FunctionResultContent functionResult)
{
_ = this._functionCalls.Remove(functionResult.CallId);
}
}
}
}
Loading
Loading