Skip to content

Commit 2865407

Browse files
committed
fix(server): Improve synchronization of flag definitions
Previously the synchronization prevented flag definitions from being overwritten. This now ensures the poller and the client don't load flag definitions at the same time.
1 parent 4853718 commit 2865407

File tree

2 files changed

+275
-9
lines changed

2 files changed

+275
-9
lines changed

posthog-server/src/main/java/com/posthog/server/internal/PostHogFeatureFlags.kt

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ internal class PostHogFeatureFlags(
4242

4343
private var definitionsLoaded = false
4444

45+
@Volatile
46+
private var isLoading = false
47+
48+
private val loadLock = Object()
49+
4550
init {
4651
startPoller()
4752
if (!localEvaluation) {
@@ -301,17 +306,34 @@ internal class PostHogFeatureFlags(
301306
return
302307
}
303308

304-
synchronized(this) {
305-
if (definitionsLoaded) {
306-
config.logger.log("Definitions already loaded, skipping")
309+
var wasWaitingForLoad = false
310+
311+
synchronized(loadLock) {
312+
while (isLoading) {
313+
wasWaitingForLoad = true
314+
try {
315+
loadLock.wait()
316+
} catch (e: InterruptedException) {
317+
Thread.currentThread().interrupt()
318+
config.logger.log("Interrupted while waiting for flag definitions to load")
319+
return
320+
}
321+
}
322+
323+
if (wasWaitingForLoad && flagDefinitions != null) {
324+
config.logger.log("Definitions loaded by another thread, skipping duplicate request")
307325
return
308326
}
309327

310-
try {
311-
config.logger.log("Loading feature flags for local evaluation")
312-
val apiResponse = api.localEvaluation(personalApiKey)
328+
isLoading = true
329+
}
313330

314-
if (apiResponse != null) {
331+
try {
332+
config.logger.log("Loading feature flags for local evaluation")
333+
val apiResponse = api.localEvaluation(personalApiKey)
334+
335+
if (apiResponse != null) {
336+
synchronized(loadLock) {
315337
// apiResponse is now LocalEvaluationResponse with properly typed models
316338
featureFlags = apiResponse.flags
317339
flagDefinitions = apiResponse.flags?.associateBy { it.key }
@@ -327,8 +349,13 @@ internal class PostHogFeatureFlags(
327349
config.logger.log("Error in onFeatureFlags callback: ${e.message}")
328350
}
329351
}
330-
} catch (e: Throwable) {
331-
config.logger.log("Failed to load feature flags for local evaluation: ${e.message}")
352+
}
353+
} catch (e: Throwable) {
354+
config.logger.log("Failed to load feature flags for local evaluation: ${e.message}")
355+
} finally {
356+
synchronized(loadLock) {
357+
isLoading = false
358+
loadLock.notifyAll()
332359
}
333360
}
334361
}

posthog-server/src/test/java/com/posthog/server/internal/PostHogFeatureFlagsTest.kt

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,13 @@ import com.posthog.server.errorResponse
1111
import com.posthog.server.jsonResponse
1212
import kotlin.test.Test
1313
import kotlin.test.assertEquals
14+
import kotlin.test.assertFalse
1415
import kotlin.test.assertNull
1516
import kotlin.test.assertTrue
17+
import okhttp3.mockwebserver.Dispatcher
18+
import okhttp3.mockwebserver.MockResponse
19+
import okhttp3.mockwebserver.MockWebServer
20+
import okhttp3.mockwebserver.RecordedRequest
1621

1722
internal class PostHogFeatureFlagsTest {
1823
@Test
@@ -639,4 +644,238 @@ internal class PostHogFeatureFlagsTest {
639644
remoteConfig.shutDown()
640645
mockServer.shutdown()
641646
}
647+
648+
@Test
649+
fun `loadFeatureFlagDefinitions overwrites existing definitions on reload`() {
650+
val logger = TestLogger()
651+
652+
// Create first response with "flag-v1"
653+
val firstResponse =
654+
createLocalEvaluationResponse(
655+
flagKey = "flag-v1",
656+
aggregationGroupTypeIndex = null,
657+
)
658+
659+
// Create second response with "flag-v2" only (no flag-v1)
660+
val secondResponse =
661+
createLocalEvaluationResponse(
662+
flagKey = "flag-v2",
663+
aggregationGroupTypeIndex = null,
664+
)
665+
666+
val mockServer =
667+
createMockHttp(
668+
jsonResponse(firstResponse),
669+
jsonResponse(secondResponse),
670+
jsonResponse(secondResponse), // Extra response for potential poller activity
671+
)
672+
val url = mockServer.url("/")
673+
674+
val config = createTestConfig(logger, url.toString())
675+
val api = PostHogApi(config)
676+
val featureFlags =
677+
PostHogFeatureFlags(
678+
config,
679+
api,
680+
60000,
681+
100,
682+
localEvaluation = true,
683+
personalApiKey = "test-personal-key",
684+
)
685+
686+
// Wait for initial poller load to complete (loads flag-v1)
687+
Thread.sleep(1000)
688+
689+
// Verify first flag is available (loaded by poller)
690+
val firstResult =
691+
featureFlags.getFeatureFlag(
692+
key = "flag-v1",
693+
defaultValue = false,
694+
distinctId = "test-user",
695+
)
696+
assertEquals(true, firstResult)
697+
698+
// Load second set of definitions (should overwrite first with flag-v2)
699+
featureFlags.loadFeatureFlagDefinitions()
700+
701+
// Verify second flag is now available
702+
val secondResult =
703+
featureFlags.getFeatureFlag(
704+
key = "flag-v2",
705+
defaultValue = false,
706+
distinctId = "test-user",
707+
)
708+
assertEquals(true, secondResult)
709+
710+
// Verify first flag is no longer available (was overwritten)
711+
val firstResultAfterReload =
712+
featureFlags.getFeatureFlag(
713+
key = "flag-v1",
714+
defaultValue = false,
715+
distinctId = "test-user",
716+
)
717+
assertEquals(false, firstResultAfterReload)
718+
719+
// Verify we made at least 2 API calls (poller's initial load + our manual loads)
720+
assertTrue(
721+
mockServer.requestCount >= 2,
722+
"Expected at least 2 requests, got ${mockServer.requestCount}",
723+
)
724+
assertTrue(logger.containsLog("Loading feature flags for local evaluation"))
725+
assertTrue(logger.containsLog("Loaded 1 feature flags for local evaluation"))
726+
727+
featureFlags.shutDown()
728+
mockServer.shutdown()
729+
}
730+
731+
@Test
732+
fun `concurrent initial loads only make one API request`() {
733+
val logger = TestLogger()
734+
val localEvalResponse =
735+
createLocalEvaluationResponse(
736+
flagKey = "test-flag",
737+
aggregationGroupTypeIndex = null,
738+
)
739+
740+
// Provide multiple responses in case duplicate requests happen (we want to verify they don't)
741+
val mockServer =
742+
createMockHttp(
743+
jsonResponse(localEvalResponse),
744+
jsonResponse(localEvalResponse),
745+
jsonResponse(localEvalResponse),
746+
)
747+
val url = mockServer.url("/")
748+
749+
val config = createTestConfig(logger, url.toString())
750+
val api = PostHogApi(config)
751+
752+
// Create instance and immediately try to use it
753+
// This simulates the race condition where poller (starts immediately at delay=0)
754+
// and first flag evaluation both try to load definitions concurrently
755+
val featureFlags =
756+
PostHogFeatureFlags(
757+
config,
758+
api,
759+
60000,
760+
100,
761+
localEvaluation = true,
762+
personalApiKey = "test-personal-key",
763+
)
764+
765+
// Immediately trigger flag evaluation (which checks definitions and loads if needed)
766+
// This happens concurrently with poller's initial load
767+
val result =
768+
featureFlags.getFeatureFlag(
769+
key = "test-flag",
770+
defaultValue = false,
771+
distinctId = "test-user",
772+
)
773+
774+
// Wait a bit to ensure both potential loads have time to complete
775+
Thread.sleep(1000)
776+
777+
// Verify the flag works (definitions were loaded successfully)
778+
assertEquals(true, result)
779+
780+
// Critical assertion: only 1 API request should have been made
781+
// The second thread should have waited for the first to complete
782+
assertEquals(
783+
1,
784+
mockServer.requestCount,
785+
"Expected exactly 1 API request due to concurrent load deduplication, got ${mockServer.requestCount}",
786+
)
787+
788+
// Verify we logged the skip message
789+
assertTrue(
790+
logger.containsLog("Definitions loaded by another thread, skipping duplicate request") ||
791+
mockServer.requestCount == 1,
792+
"Should either log skip message or only make 1 request",
793+
)
794+
795+
featureFlags.shutDown()
796+
mockServer.shutdown()
797+
}
798+
799+
@Test
800+
fun `multiple concurrent loadFeatureFlagDefinitions calls make only one API request`() {
801+
val logger = TestLogger()
802+
val localEvalResponse =
803+
createLocalEvaluationResponse(
804+
flagKey = "test-flag",
805+
aggregationGroupTypeIndex = null,
806+
)
807+
808+
// Create mock server with DELAYED response (1 second) to ensure all threads enter wait state
809+
val dispatcher =
810+
object : Dispatcher() {
811+
override fun dispatch(request: RecordedRequest): MockResponse {
812+
Thread.sleep(1000) // Simulate slow API
813+
return MockResponse()
814+
.setResponseCode(200)
815+
.setBody(localEvalResponse)
816+
}
817+
}
818+
val mockServer = MockWebServer()
819+
mockServer.dispatcher = dispatcher
820+
mockServer.start()
821+
val url = mockServer.url("/")
822+
823+
val config = createTestConfig(logger, url.toString())
824+
val api = PostHogApi(config)
825+
val featureFlags =
826+
PostHogFeatureFlags(
827+
config,
828+
api,
829+
60000,
830+
100,
831+
localEvaluation = true,
832+
personalApiKey = "test-personal-key",
833+
)
834+
835+
// Shut down poller to control loading manually
836+
featureFlags.shutDown()
837+
838+
// Launch 5 concurrent threads all calling loadFeatureFlagDefinitions
839+
val threadCount = 5
840+
val threads =
841+
List(threadCount) {
842+
Thread {
843+
featureFlags.loadFeatureFlagDefinitions()
844+
}
845+
}
846+
847+
// Start all threads simultaneously
848+
threads.forEach { it.start() }
849+
850+
// Wait for all to complete
851+
threads.forEach { it.join(5000) } // 5 sec timeout
852+
853+
// All threads should have completed successfully
854+
threads.forEach { thread ->
855+
assertFalse(
856+
thread.isAlive,
857+
"Thread should have completed",
858+
)
859+
}
860+
861+
// Critical assertion: only 1 API request despite 5 concurrent calls
862+
assertEquals(
863+
1,
864+
mockServer.requestCount,
865+
"Expected exactly 1 API request from $threadCount concurrent calls, got ${mockServer.requestCount}",
866+
)
867+
868+
// Verify definitions were loaded
869+
val result = featureFlags.getFeatureFlag("test-flag", false, "test-user")
870+
assertEquals(true, result)
871+
872+
// Verify logging shows threads waited
873+
val skipCount = logger.logs.count { it.contains("Definitions loaded by another thread") }
874+
assertTrue(
875+
skipCount >= threadCount - 1,
876+
"Expected at least ${threadCount - 1} threads to skip duplicate request, but only $skipCount did",
877+
)
878+
879+
mockServer.shutdown()
880+
}
642881
}

0 commit comments

Comments
 (0)