Skip to content

Commit 5154f86

Browse files
[WIP] first push
0 parents  commit 5154f86

File tree

93 files changed

+26016
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+26016
-0
lines changed

.gitignore

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# If you prefer the allow list template instead of the deny list, see community template:
2+
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
3+
#
4+
# Binaries for programs and plugins
5+
*.exe
6+
*.exe~
7+
*.dll
8+
*.so
9+
*.dylib
10+
11+
# Test binary, built with `go test -c`
12+
*.test
13+
14+
# Output of the go coverage tool, specifically when used with LiteIDE
15+
*.out
16+
17+
# Dependency directories (remove the comment below to include it)
18+
# vendor/
19+
20+
# Go workspace file
21+
go.work
22+
go.work.sum
23+
24+
# env file
25+
.env

cmd/mysql-schema-diff/apply_cmd.go

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"sort"
8+
"strings"
9+
"time"
10+
11+
"github.com/DoWithLogic/mysql-schema-diff/pkg/diff"
12+
_ "github.com/DoWithLogic/mysql-schema-diff/pkg/diff"
13+
"github.com/DoWithLogic/mysql-schema-diff/pkg/log"
14+
"github.com/go-sql-driver/mysql"
15+
"github.com/spf13/cobra"
16+
)
17+
18+
func buildApplyCmd() *cobra.Command {
19+
cmd := &cobra.Command{
20+
Use: "apply",
21+
Short: "Migrate your database to the match the inputted schema (apply the schema to the database)",
22+
}
23+
24+
connFlags := createConnFlags(cmd)
25+
planFlags := createPlanFlags(cmd)
26+
allowedHazardsTypesStrs := cmd.Flags().StringSlice("allow-hazards", nil,
27+
"Specify the hazards that are allowed. Order does not matter, and duplicates are ignored. If the"+
28+
" migration plan contains unwanted hazards (hazards not in this list), then the migration will fail to run"+
29+
" (example: --allowed-hazards DELETES_DATA,INDEX_BUILD)")
30+
skipConfirmPrompt := cmd.Flags().Bool("skip-confirm-prompt", false, "Skips prompt asking for user to confirm before applying")
31+
cmd.RunE = func(cmd *cobra.Command, args []string) error {
32+
logger := log.SimpleLogger()
33+
34+
connConfig, err := parseConnConfig(*connFlags, logger)
35+
if err != nil {
36+
return err
37+
}
38+
39+
planConfig, err := parsePlanConfig(*planFlags)
40+
if err != nil {
41+
return err
42+
}
43+
44+
cmd.SilenceUsage = true
45+
46+
plan, err := generatePlan(context.Background(), logger, connConfig, planConfig)
47+
if err != nil {
48+
return err
49+
} else if len(plan.Statements) == 0 {
50+
fmt.Println("Schema matches expected. No plan generated")
51+
return nil
52+
}
53+
54+
fmt.Println(header("Review plan"))
55+
fmt.Print(planToPrettyS(plan), "\n\n")
56+
57+
if err := failIfHazardsNotAllowed(plan, *allowedHazardsTypesStrs); err != nil {
58+
return err
59+
}
60+
61+
if !*skipConfirmPrompt {
62+
if err := mustContinuePrompt(
63+
fmt.Sprintf(
64+
"Apply migration with the following hazards: %s?",
65+
strings.Join(*allowedHazardsTypesStrs, ", "),
66+
),
67+
); err != nil {
68+
return err
69+
}
70+
}
71+
72+
if err := runPlan(context.Background(), connConfig, plan); err != nil {
73+
return err
74+
}
75+
fmt.Println("Schema applied successfully")
76+
return nil
77+
}
78+
79+
return cmd
80+
}
81+
82+
func failIfHazardsNotAllowed(plan diff.Plan, allowedHazardsTypesStrs []string) error {
83+
isAllowedByHazardType := make(map[diff.MigrationHazardType]bool)
84+
for _, val := range allowedHazardsTypesStrs {
85+
isAllowedByHazardType[strings.ToUpper(val)] = true
86+
}
87+
var disallowedHazardMsgs []string
88+
for i, stmt := range plan.Statements {
89+
var disallowedTypes []diff.MigrationHazardType
90+
for _, hzd := range stmt.Hazards {
91+
if !isAllowedByHazardType[hzd.Type] {
92+
disallowedTypes = append(disallowedTypes, hzd.Type)
93+
}
94+
}
95+
if len(disallowedTypes) > 0 {
96+
disallowedHazardMsgs = append(disallowedHazardMsgs,
97+
fmt.Sprintf("- Statement %d: %s", getDisplayableStmtIdx(i), strings.Join(disallowedTypes, ", ")),
98+
)
99+
}
100+
101+
}
102+
if len(disallowedHazardMsgs) > 0 {
103+
return errors.New(fmt.Sprintf(
104+
"Prohited hazards found\n"+
105+
"These hazards must be allowed via the allow-hazards flag, e.g., --allow-hazards %s\n"+
106+
"Prohibited hazards in the following statements:\n%s",
107+
strings.Join(getHazardTypes(plan), ","),
108+
strings.Join(disallowedHazardMsgs, "\n"),
109+
))
110+
}
111+
return nil
112+
}
113+
114+
func runPlan(ctx context.Context, connConfig *mysql.Config, plan diff.Plan) error {
115+
connPool, err := newMySQL(connConfig.FormatDSN())
116+
if err != nil {
117+
return err
118+
}
119+
defer connPool.Close()
120+
121+
conn, err := connPool.Conn(ctx)
122+
if err != nil {
123+
return err
124+
}
125+
defer conn.Close()
126+
127+
// Due to the way *sql.Db works, when a statement_timeout is set for the session, it will NOT reset
128+
// by default when it's returned to the pool.
129+
//
130+
// We can't set the timeout at the TRANSACTION-level (for each transaction) because `ADD INDEX CONCURRENTLY`
131+
// must be executed within its own transaction block. Postgres will error if you try to set a TRANSACTION-level
132+
// timeout for it. SESSION-level statement_timeouts are respected by `ADD INDEX CONCURRENTLY`
133+
for i, stmt := range plan.Statements {
134+
fmt.Println(header(fmt.Sprintf("Executing statement %d", getDisplayableStmtIdx(i))))
135+
fmt.Printf("%s\n\n", statementToPrettyS(stmt))
136+
start := time.Now()
137+
if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION statement_timeout = %d", stmt.Timeout.Milliseconds())); err != nil {
138+
return fmt.Errorf("setting statement timeout: %w", err)
139+
}
140+
if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION lock_timeout = %d", stmt.Timeout.Milliseconds())); err != nil {
141+
return fmt.Errorf("setting lock timeout: %w", err)
142+
}
143+
if _, err := conn.ExecContext(ctx, stmt.ToSQL()); err != nil {
144+
return fmt.Errorf("executing migration statement. the database maybe be in a dirty state: %s: %w", stmt, err)
145+
}
146+
fmt.Printf("Finished executing statement. Duration: %s\n", time.Since(start))
147+
}
148+
fmt.Println(header("Complete"))
149+
150+
return nil
151+
}
152+
153+
func getHazardTypes(plan diff.Plan) []diff.MigrationHazardType {
154+
seenHazardTypes := make(map[diff.MigrationHazardType]bool)
155+
var hazardTypes []diff.MigrationHazardType
156+
for _, stmt := range plan.Statements {
157+
for _, hazard := range stmt.Hazards {
158+
if !seenHazardTypes[hazard.Type] {
159+
seenHazardTypes[hazard.Type] = true
160+
hazardTypes = append(hazardTypes, hazard.Type)
161+
}
162+
}
163+
}
164+
sort.Slice(hazardTypes, func(i, j int) bool {
165+
return hazardTypes[i] < hazardTypes[j]
166+
})
167+
return hazardTypes
168+
}
169+
170+
func getDisplayableStmtIdx(i int) int {
171+
return i + 1
172+
}

cmd/mysql-schema-diff/cli.go

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"math"
6+
"strings"
7+
8+
"github.com/manifoldco/promptui"
9+
)
10+
11+
func header(header string) string {
12+
const headerTargetWidth = 80
13+
14+
if len(header) > headerTargetWidth {
15+
return header
16+
}
17+
18+
if len(header) > 0 {
19+
header = fmt.Sprintf(" %s ", header)
20+
}
21+
hashTagsOnSide := int(math.Ceil(float64(headerTargetWidth-len(header)) / 2))
22+
23+
rightHashTags := strings.Repeat("#", hashTagsOnSide)
24+
leftHashTags := rightHashTags
25+
if headerTargetWidth-len(header)-2*hashTagsOnSide > 0 {
26+
leftHashTags += "#"
27+
}
28+
return fmt.Sprintf("%s%s%s", leftHashTags, header, rightHashTags)
29+
}
30+
31+
// MustContinuePrompt prompts the user if they want to continue, and returns an error otherwise.
32+
// promptui requires the ContinueLabel to be one line
33+
func mustContinuePrompt(continueLabel string) error {
34+
if len(continueLabel) == 0 {
35+
continueLabel = "Continue?"
36+
}
37+
if _, result, err := (&promptui.Select{
38+
Label: continueLabel,
39+
Items: []string{"No", "Yes"},
40+
}).Run(); err != nil {
41+
return err
42+
} else if result == "No" {
43+
return fmt.Errorf("user aborted")
44+
}
45+
return nil
46+
}

cmd/mysql-schema-diff/datastructs.go

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"sort"
6+
)
7+
8+
func mustGetAndDeleteKey(m map[string]string, key string) (string, error) {
9+
val, ok := m[key]
10+
if !ok {
11+
return "", fmt.Errorf("could not find key %q", key)
12+
}
13+
delete(m, key)
14+
15+
return val, nil
16+
}
17+
18+
func keys(m map[string]string) []string {
19+
var vals []string
20+
for k := range m {
21+
vals = append(vals, k)
22+
}
23+
sort.Strings(vals)
24+
25+
return vals
26+
}
+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package main
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestKeys(t *testing.T) {
10+
for _, tt := range []struct {
11+
name string
12+
m map[string]string
13+
14+
want []string
15+
}{
16+
{
17+
name: "nil map",
18+
19+
want: nil,
20+
},
21+
{
22+
name: "empty map",
23+
24+
want: nil,
25+
},
26+
{
27+
name: "filled map",
28+
m: map[string]string{
29+
// Use an arbitrary order
30+
"key2": "value2",
31+
"key3": "value3",
32+
"key1": "value1",
33+
},
34+
35+
want: []string{"key1", "key2", "key3"},
36+
},
37+
} {
38+
t.Run(tt.name, func(t *testing.T) {
39+
assert.Equal(t, tt.want, keys(tt.m))
40+
})
41+
}
42+
}

cmd/mysql-schema-diff/flags.go

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
"github.com/DoWithLogic/mysql-schema-diff/pkg/log"
8+
"github.com/go-logfmt/logfmt"
9+
"github.com/go-sql-driver/mysql"
10+
"github.com/spf13/cobra"
11+
)
12+
13+
type connFlags struct {
14+
dsn string
15+
}
16+
17+
type Config struct {
18+
Host string
19+
Port uint16
20+
Database string
21+
User string
22+
Password string
23+
}
24+
25+
func createConnFlags(cmd *cobra.Command) *connFlags {
26+
flags := new(connFlags)
27+
28+
// Don't mark dsn as a required flag.
29+
// Allow users to user the MYSQLHOST etc environment variables like `mysql`
30+
cmd.Flags().StringVar(&flags.dsn, "dsn", "", "Connection string for the database (DB password can be specified through MYSQLPASSWORD environment variable)")
31+
32+
return flags
33+
}
34+
35+
func parseConnConfig(c connFlags, logger log.Logger) (*mysql.Config, error) {
36+
if c.dsn == "" {
37+
logger.Warnf("DSN flag not set. Using libpq environment variables and default values.")
38+
}
39+
40+
cfg, err := mysql.ParseDSN(c.dsn)
41+
if err != nil {
42+
return nil, err
43+
}
44+
45+
return cfg, nil
46+
}
47+
48+
func logFmtToMap(logFmt string) (map[string]string, error) {
49+
logMap := make(map[string]string)
50+
decoder := logfmt.NewDecoder(strings.NewReader(logFmt))
51+
for decoder.ScanRecord() {
52+
if _, ok := logMap[string(decoder.Key())]; ok {
53+
return nil, fmt.Errorf("duplicate key %q in logfmt", string(decoder.Key()))
54+
}
55+
logMap[string(decoder.Key())] = string(decoder.Value())
56+
}
57+
58+
if decoder.Err() != nil {
59+
return nil, decoder.Err()
60+
}
61+
62+
return logMap, nil
63+
}

0 commit comments

Comments
 (0)