diff --git a/.changeset/funny-peas-compare.md b/.changeset/funny-peas-compare.md new file mode 100644 index 000000000..8075a2b54 --- /dev/null +++ b/.changeset/funny-peas-compare.md @@ -0,0 +1,5 @@ +--- +"@ensembleui/react-runtime": patch +--- + +Fix: memoize conditional branch widgets diff --git a/packages/runtime/src/widgets/Conditional.tsx b/packages/runtime/src/widgets/Conditional.tsx index c98416351..03c57c475 100644 --- a/packages/runtime/src/widgets/Conditional.tsx +++ b/packages/runtime/src/widgets/Conditional.tsx @@ -1,7 +1,8 @@ import type { Expression } from "@ensembleui/react-framework"; import { unwrapWidget, useRegisterBindings } from "@ensembleui/react-framework"; import { cloneDeep, head, isEmpty, last } from "lodash-es"; -import { useMemo } from "react"; +import type { ReactNode } from "react"; +import { useMemo, useRef } from "react"; import { WidgetRegistry } from "../registry"; import { EnsembleRuntime } from "../runtime"; import type { EnsembleWidgetProps } from "../shared/types"; @@ -26,6 +27,7 @@ export const Conditional: React.FC = ({ conditions, ...props }) => { + const matched = useRef<{ [key: string]: ReactNode[] }>({}); const [isValid, errorMessage] = hasProperStructure(conditions); if (!isValid) throw Error(errorMessage); @@ -55,18 +57,27 @@ export const Conditional: React.FC = ({ if (trueIndex === undefined || trueIndex < 0) { return null; } + const key = conditionStatements[trueIndex]?.toString(); + + if (key && matched.current[key]) { + return matched.current[key]; + } + const extractedWidget = extractWidget(conditions[trueIndex]); - return { - ...extractedWidget, - key: conditionStatements[trueIndex]?.toString(), - }; + + const renderWidget = EnsembleRuntime.render([{ ...extractedWidget, key }]); + + if (key && !matched.current[key]) { + matched.current[key] = renderWidget; + } + return renderWidget; }, [conditionStatements, conditions, trueIndex]); if (!widget) { return null; } - return <>{EnsembleRuntime.render([widget])}; + return <>{widget}; }; WidgetRegistry.register(widgetName, Conditional); diff --git a/packages/runtime/src/widgets/__tests__/Conditional.test.tsx b/packages/runtime/src/widgets/__tests__/Conditional.test.tsx index a28df0f80..4a8623650 100644 --- a/packages/runtime/src/widgets/__tests__/Conditional.test.tsx +++ b/packages/runtime/src/widgets/__tests__/Conditional.test.tsx @@ -1,4 +1,10 @@ -import { render, screen } from "@testing-library/react"; +/* eslint import/first: 0 */ +// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment +const framework = jest.requireActual("@ensembleui/react-framework"); +// eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/no-unsafe-member-access +const unwrapWidgetSpy = jest.fn().mockImplementation(framework.unwrapWidget); +import { fireEvent, render, screen } from "@testing-library/react"; +import { BrowserRouter } from "react-router-dom"; import type { ConditionalProps, ConditionalElement } from "../Conditional"; import { Conditional, @@ -7,9 +13,20 @@ import { extractCondition, } from "../Conditional"; import "../index"; +import { EnsembleScreen } from "../../runtime/screen"; jest.mock("react-markdown", jest.fn()); +// eslint-disable-next-line @typescript-eslint/no-unsafe-return +jest.mock("@ensembleui/react-framework", () => ({ + ...framework, + unwrapWidget: unwrapWidgetSpy, +})); + +afterEach(() => { + jest.clearAllMocks(); +}); + describe("Conditional Component", () => { test('renders the widget when "if" condition is met', () => { const conditionalProps: ConditionalProps = { @@ -233,3 +250,91 @@ describe("extractCondition Function", () => { expect(extractedCondition).toBe("1 === 1"); }); }); + +describe("conditional widget memoization", () => { + it("should memoize branch widgets and prevent unnecessary re-renders", () => { + render( + 0}`, + Text: { + text: "Greater than 0", + }, + }, + ], + }, + }, + { + name: "Button", + properties: { + label: "Increase", + onTap: { + executeCode: + "ensemble.storage.set('number', ensemble.storage.get('number') + 1)", + }, + }, + }, + { + name: "Button", + properties: { + label: "Decrease", + onTap: { + executeCode: + "ensemble.storage.set('number', ensemble.storage.get('number') - 1)", + }, + }, + }, + ], + }, + }, + onLoad: { executeCode: 'ensemble.storage.set("number", -1)' }, + }} + />, + { + wrapper: BrowserRouter, + }, + ); + + expect(unwrapWidgetSpy).toHaveBeenCalledTimes(1); + expect(screen.getByText("Less than 0")).not.toBeNull(); + + fireEvent.click(screen.getByText("Increase")); + expect(unwrapWidgetSpy).toHaveBeenCalledTimes(2); + expect(screen.getByText("Equals to 0")).not.toBeNull(); + + fireEvent.click(screen.getByText("Increase")); + expect(unwrapWidgetSpy).toHaveBeenCalledTimes(3); + expect(screen.getByText("Greater than 0")).not.toBeNull(); + + fireEvent.click(screen.getByText("Decrease")); + expect(unwrapWidgetSpy).toHaveBeenCalledTimes(3); + expect(screen.getByText("Equals to 0")).not.toBeNull(); + + fireEvent.click(screen.getByText("Decrease")); + expect(unwrapWidgetSpy).toHaveBeenCalledTimes(3); + expect(screen.getByText("Less than 0")).not.toBeNull(); + }); +});