Skip to content

Commit 502612c

Browse files
author
Rex Scaria
committed
DRAFT: give clear error messages on gateway policy drift
1 parent c0a9e89 commit 502612c

File tree

1 file changed

+245
-0
lines changed

1 file changed

+245
-0
lines changed

internal/services/zero_trust_gateway_policy/resource.go

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@ package zero_trust_gateway_policy
44

55
import (
66
"context"
7+
"encoding/json"
78
"fmt"
89
"io"
910
"net/http"
11+
"strings"
1012

1113
"github.com/cloudflare/cloudflare-go/v6"
1214
"github.com/cloudflare/cloudflare-go/v6/option"
1315
"github.com/cloudflare/cloudflare-go/v6/zero_trust"
1416
"github.com/cloudflare/terraform-provider-cloudflare/internal/apijson"
1517
"github.com/cloudflare/terraform-provider-cloudflare/internal/importpath"
1618
"github.com/cloudflare/terraform-provider-cloudflare/internal/logging"
19+
"github.com/hashicorp/terraform-plugin-framework/diag"
1720
"github.com/hashicorp/terraform-plugin-framework/resource"
1821
"github.com/hashicorp/terraform-plugin-framework/types"
1922
)
@@ -112,6 +115,19 @@ func (r *ZeroTrustGatewayPolicyResource) Update(ctx context.Context, req resourc
112115
return
113116
}
114117

118+
// First, get the current API state to detect drift before applying changes
119+
currentAPIState, err := r.getCurrentAPIState(ctx, data.ID.ValueString(), data.AccountID.ValueString())
120+
if err != nil {
121+
resp.Diagnostics.AddError("failed to retrieve current API state for drift detection", err.Error())
122+
return
123+
}
124+
125+
// Detect and report drift between the current API state and planned configuration
126+
if currentAPIState != nil {
127+
driftDiags := r.detectDriftOnUpdate(ctx, currentAPIState, data)
128+
resp.Diagnostics.Append(driftDiags...)
129+
}
130+
115131
dataBytes, err := data.MarshalJSONForUpdate(*state)
116132
if err != nil {
117133
resp.Diagnostics.AddError("failed to serialize http request", err.Error())
@@ -153,6 +169,9 @@ func (r *ZeroTrustGatewayPolicyResource) Read(ctx context.Context, req resource.
153169
return
154170
}
155171

172+
// Store the current Terraform state for drift comparison
173+
currentTerraformState := *data
174+
156175
res := new(http.Response)
157176
env := ZeroTrustGatewayPolicyResultEnvelope{*data}
158177
_, err := r.client.ZeroTrust.Gateway.Rules.Get(
@@ -181,6 +200,11 @@ func (r *ZeroTrustGatewayPolicyResource) Read(ctx context.Context, req resource.
181200
}
182201
data = &env.Result
183202

203+
// Detect drift between the current API state and Terraform state
204+
// Only compare user-configurable fields (exclude computed-only fields)
205+
driftDiags := r.detectConfigurationDriftOnRead(ctx, data, &currentTerraformState)
206+
resp.Diagnostics.Append(driftDiags...)
207+
184208
resp.Diagnostics.Append(resp.State.Set(ctx, &data)...)
185209
}
186210

@@ -257,3 +281,224 @@ func (r *ZeroTrustGatewayPolicyResource) ImportState(ctx context.Context, req re
257281
func (r *ZeroTrustGatewayPolicyResource) ModifyPlan(_ context.Context, _ resource.ModifyPlanRequest, _ *resource.ModifyPlanResponse) {
258282

259283
}
284+
285+
// detectDriftOnUpdate compares the current API state with the planned configuration and returns diagnostic messages
286+
// showing the differences between what's configured vs what exists in the API
287+
func (r *ZeroTrustGatewayPolicyResource) detectDriftOnUpdate(ctx context.Context, apiState, plannedConfig *ZeroTrustGatewayPolicyModel) diag.Diagnostics {
288+
var diags diag.Diagnostics
289+
290+
if apiState == nil || plannedConfig == nil {
291+
return diags
292+
}
293+
294+
differences := r.compareModels(apiState, plannedConfig)
295+
296+
if len(differences) > 0 {
297+
var driftDetails strings.Builder
298+
driftDetails.WriteString("Configuration drift detected between API state and Terraform configuration.\n\n")
299+
// We expect a single consolidated difference (policy JSON)
300+
diff := differences[0]
301+
driftDetails.WriteString("detectDriftOnUpdate@UPDATE")
302+
driftDetails.WriteString("Side-by-side (Terraform | API):\n")
303+
driftDetails.WriteString(sideBySideDiff(diff.ConfigValue, diff.APIValue, 60))
304+
305+
driftDetails.WriteString("\n\nTo fix the drift, update your terraform declaration to match the current API state.")
306+
307+
diags.AddWarning(
308+
"Configuration Drift Detected",
309+
driftDetails.String(),
310+
)
311+
}
312+
313+
return diags
314+
}
315+
316+
// DriftDifference represents a single field difference between API state and configuration
317+
type DriftDifference struct {
318+
Field string
319+
APIValue string
320+
ConfigValue string
321+
}
322+
323+
// compareModels performs a detailed comparison between API state and planned configuration
324+
func (r *ZeroTrustGatewayPolicyResource) compareModels(apiState, plannedConfig *ZeroTrustGatewayPolicyModel) []DriftDifference {
325+
// Marshal using model's JSON to respect tags/omissions
326+
apiBytes, errA := apiState.MarshalJSON()
327+
cfgBytes, errC := plannedConfig.MarshalJSON()
328+
329+
if errA != nil || errC != nil {
330+
return []DriftDifference{
331+
{
332+
Field: "policy",
333+
APIValue: fmt.Sprintf("error marshalling api state: %v", errA),
334+
ConfigValue: fmt.Sprintf("error marshalling config: %v", errC),
335+
},
336+
}
337+
}
338+
339+
// Normalize JSON for stable comparison
340+
var apiObj any
341+
var cfgObj any
342+
if err := json.Unmarshal(apiBytes, &apiObj); err != nil {
343+
return []DriftDifference{{Field: "policy", APIValue: string(apiBytes), ConfigValue: string(cfgBytes)}}
344+
}
345+
if err := json.Unmarshal(cfgBytes, &cfgObj); err != nil {
346+
return []DriftDifference{{Field: "policy", APIValue: string(apiBytes), ConfigValue: string(cfgBytes)}}
347+
}
348+
normAPI, _ := json.MarshalIndent(apiObj, "", " ")
349+
normCfg, _ := json.MarshalIndent(cfgObj, "", " ")
350+
351+
if string(normAPI) == string(normCfg) {
352+
return nil
353+
}
354+
355+
return []DriftDifference{{Field: "policy", APIValue: string(normAPI), ConfigValue: string(normCfg)}}
356+
}
357+
358+
// getCurrentAPIState retrieves the current state of the resource from the API
359+
func (r *ZeroTrustGatewayPolicyResource) getCurrentAPIState(ctx context.Context, ruleID, accountID string) (*ZeroTrustGatewayPolicyModel, error) {
360+
res := new(http.Response)
361+
var data ZeroTrustGatewayPolicyModel
362+
env := ZeroTrustGatewayPolicyResultEnvelope{data}
363+
364+
_, err := r.client.ZeroTrust.Gateway.Rules.Get(
365+
ctx,
366+
ruleID,
367+
zero_trust.GatewayRuleGetParams{
368+
AccountID: cloudflare.F(accountID),
369+
},
370+
option.WithResponseBodyInto(&res),
371+
option.WithMiddleware(logging.Middleware(ctx)),
372+
)
373+
374+
if err != nil {
375+
return nil, fmt.Errorf("failed to retrieve current API state: %w", err)
376+
}
377+
378+
if res.StatusCode != 200 {
379+
return nil, fmt.Errorf("failed to retrieve current API state: %w", err)
380+
}
381+
382+
bytes, err := io.ReadAll(res.Body)
383+
if err != nil {
384+
return nil, fmt.Errorf("failed to read response body: %w", err)
385+
}
386+
387+
err = apijson.Unmarshal(bytes, &env)
388+
if err != nil {
389+
return nil, fmt.Errorf("failed to deserialize API response: %w", err)
390+
}
391+
392+
return &env.Result, nil
393+
}
394+
395+
// This is used during Read operations to detect drift in the configuration
396+
func (r *ZeroTrustGatewayPolicyResource) detectConfigurationDriftOnRead(ctx context.Context, apiState, terraformState *ZeroTrustGatewayPolicyModel) diag.Diagnostics {
397+
var diags diag.Diagnostics
398+
399+
if apiState == nil || terraformState == nil {
400+
return diags
401+
}
402+
403+
// Serialize both objects using model-aware JSON marshaling
404+
apiBytes, errA := apiState.MarshalJSON()
405+
cfgBytes, errC := terraformState.MarshalJSON()
406+
407+
if errA != nil || errC != nil {
408+
diags.AddWarning(
409+
"Configuration Drift Detected",
410+
fmt.Sprintf("error marshalling for drift detection (api=%v, config=%v)", errA, errC),
411+
)
412+
return diags
413+
}
414+
415+
// Normalize: drop computed-only fields, then pretty-print
416+
var apiMap map[string]any
417+
var cfgMap map[string]any
418+
if err := json.Unmarshal(apiBytes, &apiMap); err != nil {
419+
diags.AddWarning("Configuration Drift Detected", "failed to parse API state for drift detection")
420+
return diags
421+
}
422+
if err := json.Unmarshal(cfgBytes, &cfgMap); err != nil {
423+
diags.AddWarning("Configuration Drift Detected", "failed to parse Terraform state for drift detection")
424+
return diags
425+
}
426+
427+
// Remove computed-only keys to focus on user-configurable fields
428+
remove := func(m map[string]any) {
429+
delete(m, "id")
430+
delete(m, "account_id")
431+
delete(m, "created_at")
432+
delete(m, "deleted_at")
433+
delete(m, "read_only")
434+
delete(m, "sharable")
435+
delete(m, "source_account")
436+
delete(m, "updated_at")
437+
delete(m, "version")
438+
delete(m, "warning_status")
439+
}
440+
remove(apiMap)
441+
remove(cfgMap)
442+
443+
normAPI, _ := json.MarshalIndent(apiMap, "", " ")
444+
normCfg, _ := json.MarshalIndent(cfgMap, "", " ")
445+
446+
if string(normAPI) == string(normCfg) {
447+
return diags
448+
}
449+
450+
var msg strings.Builder
451+
452+
msg.WriteString("detectConfigurationDriftOnRead@READ")
453+
454+
msg.WriteString("Configuration drift detected! The actual API state differs from your Terraform configuration.\n\n")
455+
msg.WriteString("Side-by-side (Terraform | API) for user-configurable fields:\n")
456+
msg.WriteString(sideBySideDiff(string(normCfg), string(normAPI), 60))
457+
458+
msg.WriteString("\n\nTo fix the drift, update your terraform declaration to match the current API state.")
459+
460+
diags.AddWarning("Configuration Drift Detected", msg.String())
461+
return diags
462+
}
463+
464+
// sideBySideDiff renders a simple side-by-side view of two multi-line strings.
465+
// leftWidth controls the width of the left column (API). Differences are marked with '≠'.
466+
func sideBySideDiff(right, left string, leftWidth int) string {
467+
// Intentionally interpret first arg as Terraform (left), second as API (right)
468+
lnsL := strings.Split(right, "\n") // Terraform
469+
lnsR := strings.Split(left, "\n") // API
470+
n := len(lnsL)
471+
if len(lnsR) > n {
472+
n = len(lnsR)
473+
}
474+
rightWidth := leftWidth
475+
var b strings.Builder
476+
477+
// Header (Terraform left, API right)
478+
b.WriteString(fmt.Sprintf("%4s %-*s %4s %s\n", "#", leftWidth, "Terraform", "#", "API"))
479+
b.WriteString(fmt.Sprintf("%4s %-*s %4s %s\n", strings.Repeat("-", 4), leftWidth, strings.Repeat("-", leftWidth), strings.Repeat("-", 4), strings.Repeat("-", rightWidth)))
480+
481+
// Rows
482+
for i := 0; i < n; i++ {
483+
var L, R string
484+
if i < len(lnsL) {
485+
L = lnsL[i]
486+
}
487+
if i < len(lnsR) {
488+
R = lnsR[i]
489+
}
490+
marker := " "
491+
if L != R {
492+
marker = "≠"
493+
}
494+
// Truncate to fit columns
495+
if leftWidth > 1 && len(L) > leftWidth {
496+
L = L[:leftWidth-1] + "…"
497+
}
498+
if rightWidth > 1 && len(R) > rightWidth {
499+
R = R[:rightWidth-1] + "…"
500+
}
501+
b.WriteString(fmt.Sprintf("%4d %-*s %s %4d %s\n", i+1, leftWidth, L, marker, i+1, R))
502+
}
503+
return b.String()
504+
}

0 commit comments

Comments
 (0)