Skip to content

[Blazor] Add ability to filter persistent component state callbacks based on persistence reason #62394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/Components/Components/src/IPersistenceReason.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.AspNetCore.Components;

/// <summary>
/// Represents a reason for persisting component state.
/// </summary>
public interface IPersistenceReason
{
/// <summary>
/// Gets a value indicating whether state should be persisted by default for this reason.
/// </summary>
bool PersistByDefault { get; }
}
17 changes: 17 additions & 0 deletions src/Components/Components/src/IPersistenceReasonFilter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.AspNetCore.Components;

/// <summary>
/// Filters component state persistence based on the reason for persistence.
/// </summary>
public interface IPersistenceReasonFilter
{
/// <summary>
/// Determines whether state should be persisted for the given reason.
/// </summary>
/// <param name="reason">The reason for persistence.</param>
/// <returns><c>true</c> to persist state, <c>false</c> to skip persistence, or <c>null</c> to defer to other filters or default behavior.</returns>
bool? ShouldPersist(IPersistenceReason reason);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ namespace Microsoft.AspNetCore.Components;

internal readonly struct PersistComponentStateRegistration(
Func<Task> callback,
IComponentRenderMode? renderMode)
IComponentRenderMode? renderMode,
IReadOnlyList<IPersistenceReasonFilter> reasonFilters)
{
public Func<Task> Callback { get; } = callback;

public IComponentRenderMode? RenderMode { get; } = renderMode;

public IReadOnlyList<IPersistenceReasonFilter> ReasonFilters { get; } = reasonFilters ?? Array.Empty<IPersistenceReasonFilter>();
}
34 changes: 34 additions & 0 deletions src/Components/Components/src/PersistReasonFilter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.AspNetCore.Components;

/// <summary>
/// Base class for filtering component state persistence based on specific persistence reasons.
/// </summary>
/// <typeparam name="TReason">The type of persistence reason this filter handles.</typeparam>
public abstract class PersistReasonFilter<TReason> : Attribute, IPersistenceReasonFilter
where TReason : IPersistenceReason
{
private readonly bool _persist;

/// <summary>
/// Initializes a new instance of the <see cref="PersistReasonFilter{TReason}"/> class.
/// </summary>
/// <param name="persist">Whether to persist state for the specified reason type.</param>
protected PersistReasonFilter(bool persist)
{
_persist = persist;
}

/// <inheritdoc />
public bool? ShouldPersist(IPersistenceReason reason)
{
if (reason is TReason)
{
return _persist;
}

return null;
}
}
28 changes: 26 additions & 2 deletions src/Components/Components/src/PersistentComponentState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,31 @@ internal void InitializeExistingState(IDictionary<string, byte[]> existingState)
/// <param name="callback">The callback to invoke when the application is being paused.</param>
/// <returns>A subscription that can be used to unregister the callback when disposed.</returns>
public PersistingComponentStateSubscription RegisterOnPersisting(Func<Task> callback)
=> RegisterOnPersisting(callback, null);
=> RegisterOnPersisting(callback, null, Array.Empty<IPersistenceReasonFilter>());

/// <summary>
/// Register a callback to persist the component state when the application is about to be paused.
/// Registered callbacks can use this opportunity to persist their state so that it can be retrieved when the application resumes.
/// </summary>
/// <param name="callback">The callback to invoke when the application is being paused.</param>
/// <param name="renderMode"></param>
/// <param name="reasonFilters">Filters to control when the callback should be invoked based on the persistence reason.</param>
/// <returns>A subscription that can be used to unregister the callback when disposed.</returns>
public PersistingComponentStateSubscription RegisterOnPersisting(Func<Task> callback, IComponentRenderMode? renderMode, IReadOnlyList<IPersistenceReasonFilter> reasonFilters)
{
ArgumentNullException.ThrowIfNull(callback);

if (PersistingState)
{
throw new InvalidOperationException("Registering a callback while persisting state is not allowed.");
}

var persistenceCallback = new PersistComponentStateRegistration(callback, renderMode, reasonFilters);

_registeredCallbacks.Add(persistenceCallback);

return new PersistingComponentStateSubscription(_registeredCallbacks, persistenceCallback);
}

/// <summary>
/// Register a callback to persist the component state when the application is about to be paused.
Expand All @@ -61,7 +85,7 @@ public PersistingComponentStateSubscription RegisterOnPersisting(Func<Task> call
throw new InvalidOperationException("Registering a callback while persisting state is not allowed.");
}

var persistenceCallback = new PersistComponentStateRegistration(callback, renderMode);
var persistenceCallback = new PersistComponentStateRegistration(callback, renderMode, Array.Empty<IPersistenceReasonFilter>());

_registeredCallbacks.Add(persistenceCallback);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ public async Task RestoreStateAsync(IPersistentComponentStateStore store)
/// </summary>
/// <param name="store">The <see cref="IPersistentComponentStateStore"/> to restore the application state from.</param>
/// <param name="renderer">The <see cref="Renderer"/> that components are being rendered.</param>
/// <param name="persistenceReason">The reason for persisting the state.</param>
/// <returns>A <see cref="Task"/> that will complete when the state has been restored.</returns>
public Task PersistStateAsync(IPersistentComponentStateStore store, Renderer renderer)
public Task PersistStateAsync(IPersistentComponentStateStore store, Renderer renderer, IPersistenceReason? persistenceReason = null)
{
if (_stateIsPersisted)
{
Expand Down Expand Up @@ -113,7 +114,7 @@ async Task PauseAndPersistState()

async Task<bool> TryPersistState(IPersistentComponentStateStore store)
{
if (!await TryPauseAsync(store))
if (!await TryPauseAsync(store, persistenceReason))
{
_currentState.Clear();
return false;
Expand Down Expand Up @@ -159,7 +160,7 @@ private void InferRenderModes(Renderer renderer)
var componentRenderMode = renderer.GetComponentRenderMode(component);
if (componentRenderMode != null)
{
_registeredCallbacks[i] = new PersistComponentStateRegistration(registration.Callback, componentRenderMode);
_registeredCallbacks[i] = new PersistComponentStateRegistration(registration.Callback, componentRenderMode, registration.ReasonFilters);
}
else
{
Expand All @@ -176,7 +177,7 @@ private void InferRenderModes(Renderer renderer)
}
}

internal Task<bool> TryPauseAsync(IPersistentComponentStateStore store)
internal Task<bool> TryPauseAsync(IPersistentComponentStateStore store, IPersistenceReason? persistenceReason = null)
{
List<Task<bool>>? pendingCallbackTasks = null;

Expand All @@ -199,6 +200,27 @@ internal Task<bool> TryPauseAsync(IPersistentComponentStateStore store)
continue;
}

// Evaluate reason filters to determine if the callback should be executed for this persistence reason
if (registration.ReasonFilters.Count > 0)
{
var shouldPersist = EvaluateReasonFilters(registration.ReasonFilters, persistenceReason);
if (shouldPersist.HasValue && !shouldPersist.Value)
{
// Filters explicitly indicate not to persist for this reason
continue;
}
else if (!shouldPersist.HasValue && !(persistenceReason?.PersistByDefault ?? true))
{
// No filter matched and default is not to persist
continue;
}
}
else if (!(persistenceReason?.PersistByDefault ?? true))
{
// No filters defined and default is not to persist
continue;
}

var result = TryExecuteCallback(registration.Callback, _logger);
if (!result.IsCompletedSuccessfully)
{
Expand Down Expand Up @@ -271,4 +293,25 @@ static async Task<bool> AnyTaskFailed(List<Task<bool>> pendingCallbackTasks)
return true;
}
}

private static bool? EvaluateReasonFilters(IReadOnlyList<IPersistenceReasonFilter> reasonFilters, IPersistenceReason? persistenceReason)
{
if (persistenceReason is null)
{
// No reason provided, can't evaluate filters
return null;
}

foreach (var reasonFilter in reasonFilters)
{
var shouldPersist = reasonFilter.ShouldPersist(persistenceReason);
if (shouldPersist.HasValue)
{
return shouldPersist.Value;
}
}

// No filter matched
return null;
}
}
2 changes: 1 addition & 1 deletion src/Components/Components/src/PublicAPI.Shipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ Microsoft.AspNetCore.Components.IHandleEvent
Microsoft.AspNetCore.Components.IHandleEvent.HandleEventAsync(Microsoft.AspNetCore.Components.EventCallbackWorkItem item, object? arg) -> System.Threading.Tasks.Task!
Microsoft.AspNetCore.Components.Infrastructure.ComponentStatePersistenceManager
Microsoft.AspNetCore.Components.Infrastructure.ComponentStatePersistenceManager.ComponentStatePersistenceManager(Microsoft.Extensions.Logging.ILogger<Microsoft.AspNetCore.Components.Infrastructure.ComponentStatePersistenceManager!>! logger) -> void
Microsoft.AspNetCore.Components.Infrastructure.ComponentStatePersistenceManager.PersistStateAsync(Microsoft.AspNetCore.Components.IPersistentComponentStateStore! store, Microsoft.AspNetCore.Components.RenderTree.Renderer! renderer) -> System.Threading.Tasks.Task!
Microsoft.AspNetCore.Components.Infrastructure.ComponentStatePersistenceManager.PersistStateAsync(Microsoft.AspNetCore.Components.IPersistentComponentStateStore! store, Microsoft.AspNetCore.Components.RenderTree.Renderer! renderer, Microsoft.AspNetCore.Components.IPersistenceReason? persistenceReason = null) -> System.Threading.Tasks.Task!
Microsoft.AspNetCore.Components.Infrastructure.ComponentStatePersistenceManager.RestoreStateAsync(Microsoft.AspNetCore.Components.IPersistentComponentStateStore! store) -> System.Threading.Tasks.Task!
Microsoft.AspNetCore.Components.Infrastructure.ComponentStatePersistenceManager.State.get -> Microsoft.AspNetCore.Components.PersistentComponentState!
Microsoft.AspNetCore.Components.InjectAttribute
Expand Down
8 changes: 8 additions & 0 deletions src/Components/Components/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ Microsoft.AspNetCore.Components.Infrastructure.ComponentStatePersistenceManager.
Microsoft.AspNetCore.Components.Infrastructure.RegisterPersistentComponentStateServiceCollectionExtensions
Microsoft.AspNetCore.Components.SupplyParameterFromPersistentComponentStateAttribute
Microsoft.AspNetCore.Components.SupplyParameterFromPersistentComponentStateAttribute.SupplyParameterFromPersistentComponentStateAttribute() -> void
Microsoft.AspNetCore.Components.IPersistenceReason
Microsoft.AspNetCore.Components.IPersistenceReason.PersistByDefault.get -> bool
Microsoft.AspNetCore.Components.IPersistenceReasonFilter
Microsoft.AspNetCore.Components.IPersistenceReasonFilter.ShouldPersist(Microsoft.AspNetCore.Components.IPersistenceReason! reason) -> bool?
Microsoft.AspNetCore.Components.PersistReasonFilter<TReason>
Microsoft.AspNetCore.Components.PersistReasonFilter<TReason>.PersistReasonFilter(bool persist) -> void
Microsoft.AspNetCore.Components.PersistReasonFilter<TReason>.ShouldPersist(Microsoft.AspNetCore.Components.IPersistenceReason! reason) -> bool?
Microsoft.AspNetCore.Components.PersistentComponentState.RegisterOnPersisting(System.Func<System.Threading.Tasks.Task!>! callback, Microsoft.AspNetCore.Components.IComponentRenderMode? renderMode, System.Collections.Generic.IReadOnlyList<Microsoft.AspNetCore.Components.IPersistenceReasonFilter!>! reasonFilters) -> Microsoft.AspNetCore.Components.PersistingComponentStateSubscription
Microsoft.Extensions.DependencyInjection.SupplyParameterFromPersistentComponentStateProviderServiceCollectionExtensions
static Microsoft.AspNetCore.Components.Infrastructure.RegisterPersistentComponentStateServiceCollectionExtensions.AddPersistentServiceRegistration<TService>(Microsoft.Extensions.DependencyInjection.IServiceCollection! services, Microsoft.AspNetCore.Components.IComponentRenderMode! componentRenderMode) -> Microsoft.Extensions.DependencyInjection.IServiceCollection!
static Microsoft.AspNetCore.Components.Infrastructure.ComponentsMetricsServiceCollectionExtensions.AddComponentsMetrics(Microsoft.Extensions.DependencyInjection.IServiceCollection! services) -> Microsoft.Extensions.DependencyInjection.IServiceCollection!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,118 @@ IEnumerator IEnumerable.GetEnumerator()
}
}

[Fact]
public void PersistenceReasons_HaveCorrectDefaults()
{
// Arrange & Act
var prerenderingReason = new TestPersistOnPrerendering();
var enhancedNavReason = new TestPersistOnEnhancedNavigation();
var circuitPauseReason = new TestPersistOnCircuitPause();

// Assert
Assert.True(prerenderingReason.PersistByDefault);
Assert.False(enhancedNavReason.PersistByDefault);
Assert.True(circuitPauseReason.PersistByDefault);
}

[Fact]
public async Task PersistStateAsync_RespectsReasonFilters()
{
// Arrange
var logger = NullLogger<ComponentStatePersistenceManager>.Instance;
var manager = new ComponentStatePersistenceManager(logger);
var renderer = new TestRenderer();
var store = new TestStore([]);
var callbackExecuted = false;

// Register callback with filter that blocks enhanced navigation
var filters = new List<IPersistenceReasonFilter>
{
new TestPersistenceReasonFilter<TestPersistOnEnhancedNavigation>(false)
};

manager.State.RegisterOnPersisting(() =>
{
callbackExecuted = true;
return Task.CompletedTask;
}, new TestRenderMode(), filters);

// Act - persist with enhanced navigation reason
await manager.PersistStateAsync(store, renderer, new TestPersistOnEnhancedNavigation());

// Assert - callback should not be executed
Assert.False(callbackExecuted);
}

[Fact]
public async Task PersistStateAsync_AllowsWhenFilterMatches()
{
// Arrange
var logger = NullLogger<ComponentStatePersistenceManager>.Instance;
var manager = new ComponentStatePersistenceManager(logger);
var renderer = new TestRenderer();
var store = new TestStore([]);
var callbackExecuted = false;

// Register callback with filter that allows prerendering
var filters = new List<IPersistenceReasonFilter>
{
new TestPersistenceReasonFilter<TestPersistOnPrerendering>(true)
};

manager.State.RegisterOnPersisting(() =>
{
callbackExecuted = true;
return Task.CompletedTask;
}, new TestRenderMode(), filters);

// Act - persist with prerendering reason
await manager.PersistStateAsync(store, renderer, new TestPersistOnPrerendering());

// Assert - callback should be executed
Assert.True(callbackExecuted);
}

private class TestPersistenceReasonFilter<TReason> : IPersistenceReasonFilter
where TReason : IPersistenceReason
{
private readonly bool _allow;

public TestPersistenceReasonFilter(bool allow)
{
_allow = allow;
}

public bool? ShouldPersist(IPersistenceReason reason)
{
if (reason is TReason)
{
return _allow;
}
return null;
}
}

private class TestRenderMode : IComponentRenderMode
{
}

// Test implementations of persistence reasons
private class TestPersistOnPrerendering : IPersistenceReason
{
public bool PersistByDefault => true;
}

private class TestPersistOnEnhancedNavigation : IPersistenceReason
{
public bool PersistByDefault => false;
}

private class TestPersistOnCircuitPause : IPersistenceReason
{
public bool PersistByDefault => true;
}

private class PersistentService : IPersistentServiceRegistration
{
public string Assembly { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ public async ValueTask<IHtmlContent> PrerenderPersistedStateAsync(HttpContext ht

if (store != null)
{
await manager.PersistStateAsync(store, this);
IPersistenceReason persistenceReason = IsProgressivelyEnhancedNavigation(httpContext.Request)
? PersistOnEnhancedNavigation.Instance
: PersistOnPrerendering.Instance;
await manager.PersistStateAsync(store, this, persistenceReason);
return store switch
{
ProtectedPrerenderComponentApplicationStore protectedStore => new ComponentStateHtmlContent(protectedStore, null),
Expand Down Expand Up @@ -80,7 +83,10 @@ public async ValueTask<IHtmlContent> PrerenderPersistedStateAsync(HttpContext ht
var webAssembly = new CopyOnlyStore<InteractiveWebAssemblyRenderMode>();
store = new CompositeStore(server, auto, webAssembly);

await manager.PersistStateAsync(store, this);
IPersistenceReason persistenceReason = IsProgressivelyEnhancedNavigation(httpContext.Request)
? PersistOnEnhancedNavigation.Instance
: PersistOnPrerendering.Instance;
await manager.PersistStateAsync(store, this, persistenceReason);

foreach (var kvp in auto.Saved)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public async Task PauseCircuitAsync(CircuitHost circuit, bool saveStateToClient
collector.PersistRootComponents,
RenderMode.InteractiveServer);

await persistenceManager.PersistStateAsync(collector, renderer);
await persistenceManager.PersistStateAsync(collector, renderer, PersistOnCircuitPause.Instance);

if (saveStateToClient)
{
Expand Down
Loading
Loading