Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call EnsureLegalAccess from EntityFeature in dotnet-isolated #2633

Merged
merged 4 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions release_notes.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Release Notes

## Microsoft.Azure.Functions.Worker.Extensions.DurableTask v1.1.0-preview.1
jviau marked this conversation as resolved.
Show resolved Hide resolved

### New Features

- Updates to take advantage of new core-entity support
Expand All @@ -8,6 +10,7 @@
### Bug Fixes

- Address input issues when using .NET isolated (#2581)[https://github.com/Azure/azure-functions-durable-extension/issues/2581]
- No longer fail orchestrations which return before accessing the `TaskOrchestrationContext`.

### Breaking Changes

Expand Down
29 changes: 9 additions & 20 deletions src/Worker.Extensions.DurableTask/FunctionsOrchestrationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ public FunctionsOrchestrationContext(TaskOrchestrationContext innerContext, Func
this.LoggerFactory = functionContext.InstanceServices.GetRequiredService<ILoggerFactory>();
}

public bool IsAccessed { get; private set; }

public override TaskName Name => this.innerContext.Name;

public override string InstanceId => this.innerContext.InstanceId;
Expand All @@ -53,7 +51,7 @@ public FunctionsOrchestrationContext(TaskOrchestrationContext innerContext, Func

public override T GetInput<T>()
{
this.EnsureLegalAccess();
this.ThrowIfIllegalAccess();

object? input = this.innerContext.GetInput<object>();
if (input is T typed)
Expand All @@ -71,60 +69,51 @@ public override T GetInput<T>()

public override Guid NewGuid()
{
this.EnsureLegalAccess();
this.ThrowIfIllegalAccess();
return this.innerContext.NewGuid();
}

public override Task<T> CallActivityAsync<T>(TaskName name, object? input = null, TaskOptions? options = null)
{
this.EnsureLegalAccess();
this.ThrowIfIllegalAccess();
return this.innerContext.CallActivityAsync<T>(name, input, options);
}

public override Task<TResult> CallSubOrchestratorAsync<TResult>(
TaskName orchestratorName, object? input = null, TaskOptions? options = null)
{
this.EnsureLegalAccess();
this.ThrowIfIllegalAccess();
return this.innerContext.CallSubOrchestratorAsync<TResult>(orchestratorName, input, options);
}

public override void ContinueAsNew(object? newInput = null, bool preserveUnprocessedEvents = true)
{
this.EnsureLegalAccess();
this.ThrowIfIllegalAccess();
this.innerContext.ContinueAsNew(newInput, preserveUnprocessedEvents);
}

public override Task CreateTimer(DateTime fireAt, CancellationToken cancellationToken)
{
this.EnsureLegalAccess();
this.ThrowIfIllegalAccess();
return this.innerContext.CreateTimer(fireAt, cancellationToken);
}

public override void SetCustomStatus(object? customStatus)
{
this.EnsureLegalAccess();
this.ThrowIfIllegalAccess();
this.innerContext.SetCustomStatus(customStatus);
}

public override void SendEvent(string instanceId, string eventName, object payload)
{
this.EnsureLegalAccess();
this.ThrowIfIllegalAccess();
this.innerContext.SendEvent(instanceId, eventName, payload);
}

public override Task<T> WaitForExternalEvent<T>(string eventName, CancellationToken cancellationToken = default)
{
this.EnsureLegalAccess();
return this.innerContext.WaitForExternalEvent<T>(eventName, cancellationToken);
}

/// <summary>
/// Throws if accessed by a non-orchestrator thread or marks the current object as accessed successfully.
/// </summary>
private void EnsureLegalAccess()
{
this.ThrowIfIllegalAccess();
this.IsAccessed = true;
return this.innerContext.WaitForExternalEvent<T>(eventName, cancellationToken);
jviau marked this conversation as resolved.
Show resolved Hide resolved
}

internal void ThrowIfIllegalAccess()
Expand Down
8 changes: 0 additions & 8 deletions src/Worker.Extensions.DurableTask/FunctionsOrchestrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ private static async Task EnsureSynchronousExecution(
FunctionsOrchestrationContext orchestrationContext)
{
Task orchestratorTask = next(functionContext);
if (!orchestratorTask.IsCompleted && !orchestrationContext.IsAccessed)
jviau marked this conversation as resolved.
Show resolved Hide resolved
{
// If the middleware returns before the orchestrator function's context object was accessed and before
// it completes its execution, then we know that either some middleware component went async or that the
// orchestrator function did some illegal await as its very first action.
throw new InvalidOperationException(Constants.IllegalAwaitErrorMessage);
}

await orchestratorTask;

// This will throw if either the orchestrator performed an illegal await or if some middleware ahead of this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public ValueTask<ConversionResult> ConvertAsync(ConverterContext context)
// 3. The TargetType matches our cached type.
// If these are met, then we assume this parameter is the orchestration input.
if (context.Source is null
&& context.FunctionContext.Items.TryGetValue(OrchestrationInputKey, out object value)
&& context.FunctionContext.Items.TryGetValue(OrchestrationInputKey, out object? value)
&& context.TargetType == value?.GetType())
{
// Remove this from the items so we bind this only once.
Expand Down
2 changes: 1 addition & 1 deletion src/Worker.Extensions.DurableTask/TaskEntityDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ private class StateEntity<T> : TaskEntity<T>

private class DelegateEntity : ITaskEntity
{
readonly Func<TaskEntityOperation, ValueTask<object?>> handler;
private readonly Func<TaskEntityOperation, ValueTask<object?>> handler;

public DelegateEntity(Func<TaskEntityOperation, ValueTask<object?>> handler)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

<!-- Version information -->
<VersionPrefix>1.1.0</VersionPrefix>
<VersionSuffix>entities-preview.2</VersionSuffix>
<VersionSuffix>entities-preview.3</VersionSuffix>
<AssemblyVersion>$(VersionPrefix).0</AssemblyVersion>
<!-- FileVersionRevision is expected to be set by the CI. -->
<FileVersion Condition="'$(FileVersionRevision)' != ''">$(VersionPrefix).$(FileVersionRevision)</FileVersion>
Expand Down
11 changes: 9 additions & 2 deletions test/SmokeTests/OOProcSmokeTests/DotNetIsolated/Counter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,22 @@ public static async Task<HttpResponseData> ReadCounter(

logger.LogInformation($"Reading state of {entityId}...");
var response = await client.Entities.GetEntityAsync(entityId, includeState: true);
logger.LogInformation($"Read state of {entityId}: {response?.SerializedState ?? "(null: entity does not exist)"}");
if (response?.IncludesState ?? false)
{
logger.LogInformation("Entity does not exist.");
}
else
{
logger.LogInformation("Entity state is: {State}", response!.State.Value);
}

if (response == null)
{
return request.CreateResponse(System.Net.HttpStatusCode.NotFound);
}
else
{
int currentValue = response.ReadStateAs<Counter>()!.CurrentValue;
int currentValue = response.State.ReadAs<Counter>()!.CurrentValue;
var httpResponse = request.CreateResponse(System.Net.HttpStatusCode.OK);
httpResponse.Headers.Add("Content-Type", "text/plain; charset=utf-8");
httpResponse.WriteString($"{currentValue}\n");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<AzureFunctionsVersion>v4</AzureFunctionsVersion>
Expand All @@ -11,7 +11,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.Azure.Functions.Worker" Version="1.19.0" />
<PackageReference Include="Microsoft.Azure.Functions.Worker.Extensions.Http" Version="3.0.13" />
<PackageReference Include="Microsoft.Azure.Functions.Worker.Sdk" Version="1.14.1" OutputItemType="Analyzer" />
<PackageReference Include="Microsoft.Azure.Functions.Worker.Sdk" Version="1.16.0-preview2" OutputItemType="Analyzer" />
<PackageReference Include="Microsoft.DurableTask.Generators" Version="1.0.0-preview.1" OutputItemType="Analyzer" />
</ItemGroup>

Expand Down