diff --git a/go.mod b/go.mod index 29db43d..0d93d71 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/a-h/templ v0.2.707 github.com/gorilla/sessions v1.3.0 github.com/stretchr/testify v1.9.0 + golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 ) require ( diff --git a/go.sum b/go.sum index 1f527fe..7b9ebdd 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/validate/README.md b/validate/README.md deleted file mode 100644 index 842260c..0000000 --- a/validate/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# Validate -Schema based validation with superpowers for Golang. - -# TODO diff --git a/validate/boolean.go b/validate/boolean.go new file mode 100644 index 0000000..b9f375b --- /dev/null +++ b/validate/boolean.go @@ -0,0 +1,37 @@ +package validate + +import ( + "github.com/anthdm/superkit/validate/primitives" +) + +type boolValidator struct { + Rules []primitives.Rule + IsOptional bool +} + +func Bool() *boolValidator { + return &boolValidator{ + Rules: []primitives.Rule{ + primitives.IsType[bool]("is not a valid boolean"), + }, + } +} + +func (v *boolValidator) Validate(fieldValue any) ([]string, bool) { + return primitives.GenericValidator(fieldValue, v.Rules, v.IsOptional) +} + +func (v *boolValidator) Optional() *boolValidator { + v.IsOptional = true + return v +} + +func (v *boolValidator) True() *boolValidator { + v.Rules = append(v.Rules, primitives.EQ[bool](true, "should be true")) + return v +} + +func (v *boolValidator) False() *boolValidator { + v.Rules = append(v.Rules, primitives.EQ[bool](false, "should be false")) + return v +} diff --git a/validate/env.go b/validate/env.go new file mode 100644 index 0000000..5423b5d --- /dev/null +++ b/validate/env.go @@ -0,0 +1,69 @@ +package validate + +import ( + "fmt" + "os" + "strconv" + + p "github.com/anthdm/superkit/validate/primitives" +) + +// takes a key and a validator and returns the validated and converted environment variable +func Env[T supportedEnvTypes](key string, v fieldValidator, defailtValue ...T) T { + str := os.Getenv(key) + + val, err := coerceString[T](str) + + if err != nil || p.IsZeroValue(val) { + if len(defailtValue) > 0 { + return defailtValue[0] + } else { + panic(fmt.Errorf("failed to parse env %s: %v", key, err)) + } + } + + errs, ok := v.Validate(val) + if !ok { + panic(fmt.Errorf("failed to validate env %s: %v", key, errs)) + } + + return val +} + +type supportedEnvTypes interface { + ~int | ~float64 | ~bool | ~string +} + +func coerceString[T supportedEnvTypes](val string) (T, error) { + var result T + + switch any(result).(type) { + case int: + var tmp int + tmp, err := strconv.Atoi(val) + if err != nil { + return result, err + } + + result = any(tmp).(T) + case float64: + tmp, err := strconv.ParseFloat(val, 64) + if err != nil { + return result, err + } + result = any(tmp).(T) + + case bool: + tmp, err := strconv.ParseBool(val) + if err != nil { + return result, err + } + result = any(tmp).(T) + case string: + result = any(val).(T) + default: + return result, fmt.Errorf("unsupported type: %T", result) + } + + return result, nil +} diff --git a/validate/env_test.go b/validate/env_test.go new file mode 100644 index 0000000..cb8a5c4 --- /dev/null +++ b/validate/env_test.go @@ -0,0 +1,85 @@ +package validate + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +// fmt.Printf("VALUE: %v | err: %v", val, err) +func TestCoerceToString(t *testing.T) { + val, err := coerceString[string]("123") + assert.Nil(t, err) + assert.Equal(t, "123", val) +} + +func TestCoerceToInt(t *testing.T) { + val, err := coerceString[int]("123") + assert.Nil(t, err) + assert.Equal(t, 123, val) +} + +func TestCoerceToFloat(t *testing.T) { + val, err := coerceString[float64]("123.25") + assert.Nil(t, err) + assert.Equal(t, 123.25, val) +} + +func TestCoerceToBool(t *testing.T) { + val, err := coerceString[bool]("true") + assert.Nil(t, err) + assert.Equal(t, true, val) +} + +func TestEmptyEnv(t *testing.T) { + assert.Panics(t, func() { + Env[string]("Test", String().Required()) + }) + + assert.Panics(t, func() { + os.Setenv("TEST", "") + Env[string]("Test", String().Required()) + }) +} + +func TestReturnsValue(t *testing.T) { + + os.Setenv("TEST", "value") + + val := Env[string]("TEST", String().Required()) + + assert.Equal(t, val, "value") +} + +func TestDefault(t *testing.T) { + val := Env[string]("TEST2", String().Required(), "hello") + assert.Equal(t, "hello", val) + + os.Setenv("TEST2", "world") + val = Env[string]("TEST2", String().Required(), "hello") + assert.Equal(t, "world", val) + + assert.Panics(t, func() { + os.Setenv("TEST2", "1") + _ = Env[string]("TEST2", String().Min(4)) + }) +} + +func TestInt(t *testing.T) { + os.Setenv("TEST", "1") + val := Env[int]("TEST", Int().LT(2)) + assert.Equal(t, 1, val) +} + +func TestBool(t *testing.T) { + os.Setenv("TEST", "true") + val := Env[bool]("TEST", Bool().True()) + assert.Equal(t, true, val) +} + +func TestFloat(t *testing.T) { + os.Setenv("TEST", "1.1") + val := Env[float64]("TEST", Float().GT(1.0)) + assert.Equal(t, 1.1, val) +} diff --git a/validate/numbers.go b/validate/numbers.go new file mode 100644 index 0000000..90476e7 --- /dev/null +++ b/validate/numbers.go @@ -0,0 +1,88 @@ +package validate + +import ( + "fmt" + + p "github.com/anthdm/superkit/validate/primitives" +) + +type Numeric interface { + ~int | ~float64 +} + +type numberValidator[T Numeric] struct { + Rules []p.Rule + IsOptional bool +} + +func Float() *numberValidator[float64] { + return &numberValidator[float64]{ + Rules: []p.Rule{ + p.IsType[float64]("should be a decimal number"), + }, + } +} + +func Int() *numberValidator[int] { + return &numberValidator[int]{ + Rules: []p.Rule{ + p.IsType[int]("should be an whole number"), + }, + } +} + +// GLOBAL METHODS + +func (v *numberValidator[T]) Refine(ruleName string, errorMsg string, validateFunc p.RuleValidateFunc) *numberValidator[T] { + v.Rules = append(v.Rules, + p.Rule{ + Name: ruleName, + ErrorMessage: errorMsg, + ValidateFunc: validateFunc, + }, + ) + + return v +} + +// is equal to one of the values +func (v *numberValidator[T]) In(values []T) *numberValidator[T] { + v.Rules = append(v.Rules, p.In(values, fmt.Sprintf("should be in %v", values))) + return v +} + +func (v *numberValidator[Numeric]) Optional() *numberValidator[Numeric] { + v.IsOptional = true + return v +} + +func (v *numberValidator[Numeric]) Validate(fieldValue any) ([]string, bool) { + return p.GenericValidator(fieldValue, v.Rules, v.IsOptional) +} + +// UNIQUE METHODS + +func (v *numberValidator[Numeric]) EQ(n Numeric) *numberValidator[Numeric] { + v.Rules = append(v.Rules, p.EQ(n, fmt.Sprintf("should be equal to %v", n))) + return v +} + +func (v *numberValidator[Numeric]) LTE(n Numeric) *numberValidator[Numeric] { + v.Rules = append(v.Rules, p.LTE(n, fmt.Sprintf("should be lesser or equal than %v", n))) + return v +} + +func (v *numberValidator[Numeric]) GTE(n Numeric) *numberValidator[Numeric] { + v.Rules = append(v.Rules, p.GTE(n, fmt.Sprintf("should be greater or equal to %v", n))) + return v +} + +func (v *numberValidator[Numeric]) LT(n Numeric) *numberValidator[Numeric] { + v.Rules = append(v.Rules, p.LT(n, fmt.Sprintf("should be less than %v", n))) + return v +} + +func (v *numberValidator[Numeric]) GT(n Numeric) *numberValidator[Numeric] { + v.Rules = append(v.Rules, p.GT(n, fmt.Sprintf("should be greater than %v", n))) + return v +} diff --git a/validate/numbers_test.go b/validate/numbers_test.go new file mode 100644 index 0000000..188cad6 --- /dev/null +++ b/validate/numbers_test.go @@ -0,0 +1,145 @@ +package validate + +import "testing" + +func TestEq(t *testing.T) { + validator := Int().EQ(5) + errs, ok := validator.Validate(5) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator.Validate(4) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } + + validator2 := Float().EQ(5.0) + errs, ok = validator2.Validate(5.0) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator2.Validate(4.0) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } +} + +func TestGt(t *testing.T) { + validator := Int().GT(5) + errs, ok := validator.Validate(6) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator.Validate(5) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } + errs, ok = validator.Validate(4) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } + + validator2 := Float().GT(5.0) + errs, ok = validator2.Validate(6.0) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator2.Validate(5.0) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } + errs, ok = validator2.Validate(4.0) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } +} + +func TestGte(t *testing.T) { + validator := Int().GTE(5) + errs, ok := validator.Validate(6) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator.Validate(5) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator.Validate(4) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } + + validator2 := Float().GTE(5.0) + errs, ok = validator2.Validate(6.0) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator2.Validate(5.0) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator2.Validate(4.0) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } +} + +func TestLt(t *testing.T) { + validator := Int().LT(5) + errs, ok := validator.Validate(4) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator.Validate(5) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } + errs, ok = validator.Validate(6) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } + + validator2 := Float().LT(5.0) + errs, ok = validator2.Validate(4.0) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator2.Validate(5.0) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } + errs, ok = validator2.Validate(6.0) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } +} + +func TestLte(t *testing.T) { + validator := Int().LTE(5) + errs, ok := validator.Validate(4) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator.Validate(5) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator.Validate(6) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } + + validator2 := Float().LTE(5.0) + errs, ok = validator2.Validate(4.0) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator2.Validate(5.0) + if !ok || len(errs) > 0 { + t.Errorf("Expected no errors, got %v", errs) + } + errs, ok = validator2.Validate(6.0) + if ok || len(errs) == 0 { + t.Errorf("Expected errors, got none") + } +} diff --git a/validate/primitives/Rule.go b/validate/primitives/Rule.go new file mode 100644 index 0000000..1be640e --- /dev/null +++ b/validate/primitives/Rule.go @@ -0,0 +1,23 @@ +package primitives + +type RuleValidateFunc func(Rule) bool + +type Rule struct { + Name string + RuleValue any + FieldValue any // TODO I think I can remove this + FieldName any // TODO I think I can remove this + ErrorMessage string + ValidateFunc RuleValidateFunc +} + +// ORIGINAL +// type RuleSet struct { +// Name string +// RuleValue any +// FieldValue any +// FieldName any +// ErrorMessage string +// MessageFunc func(RuleSet) string +// ValidateFunc func(RuleSet) bool +// } diff --git a/validate/primitives/genericValidators.go b/validate/primitives/genericValidators.go new file mode 100644 index 0000000..271a5f6 --- /dev/null +++ b/validate/primitives/genericValidators.go @@ -0,0 +1,159 @@ +package primitives + +import ( + "reflect" + + "golang.org/x/exp/constraints" +) + +type LengthCapable[K any] interface { + ~[]any | ~[]K | string | map[any]any | ~chan any +} + +func IsType[T any](msg string) Rule { + return Rule{ + Name: "isType", + ErrorMessage: msg, + ValidateFunc: func(set Rule) bool { + _, ok := set.FieldValue.(T) + return ok + }, + } +} + +func LenMin[T LengthCapable[any]](n int, msg string) Rule { + return Rule{ + Name: "min", + RuleValue: n, + ValidateFunc: func(set Rule) bool { + val, ok := set.FieldValue.(T) + if !ok { + return false + } + return len(val) >= n + }, + ErrorMessage: msg, + } +} + +func LenMax[T LengthCapable[any]](n int, msg string) Rule { + return Rule{ + Name: "max", + RuleValue: n, + ValidateFunc: func(set Rule) bool { + val, ok := set.FieldValue.(T) + if !ok { + return false + } + return len(val) <= n + }, + ErrorMessage: msg, + } +} + +func Length[T LengthCapable[any]](n int, msg string) Rule { + return Rule{ + Name: "length", + RuleValue: n, + ValidateFunc: func(set Rule) bool { + val, ok := set.FieldValue.(T) + if !ok { + return false + } + return len(val) == n + }, + ErrorMessage: msg, + } +} + +func In[T any](values []T, msg string) Rule { + return Rule{ + Name: "in", + RuleValue: values, + ValidateFunc: func(set Rule) bool { + for _, value := range values { + v := set.FieldValue.(T) + if reflect.DeepEqual(v, value) { + return true + } + } + return false + }, + ErrorMessage: msg, + } +} + +func EQ[T comparable](n T, msg string) Rule { + return Rule{ + Name: "eq", + RuleValue: n, + ValidateFunc: func(set Rule) bool { + v, ok := set.FieldValue.(T) + if !ok { + return false + } + return v == n + }, + ErrorMessage: msg, + } +} + +func LTE[T constraints.Ordered](n T, msg string) Rule { + return Rule{ + Name: "lte", + RuleValue: n, + ValidateFunc: func(set Rule) bool { + v, ok := set.FieldValue.(T) + if !ok { + return false + } + return v <= n + }, + ErrorMessage: msg, + } +} + +func GTE[T constraints.Ordered](n T, msg string) Rule { + return Rule{ + Name: "gte", + RuleValue: n, + ValidateFunc: func(set Rule) bool { + v, ok := set.FieldValue.(T) + if !ok { + return false + } + return v >= n + }, + ErrorMessage: msg, + } +} + +func LT[T constraints.Ordered](n T, msg string) Rule { + return Rule{ + Name: "lt", + RuleValue: n, + ValidateFunc: func(set Rule) bool { + v, ok := set.FieldValue.(T) + if !ok { + return false + } + return v < n + }, + ErrorMessage: msg, + } +} + +func GT[T constraints.Ordered](n T, msg string) Rule { + return Rule{ + Name: "gt", + RuleValue: n, + ValidateFunc: func(set Rule) bool { + v, ok := set.FieldValue.(T) + if !ok { + return false + } + return v > n + }, + ErrorMessage: msg, + } +} diff --git a/validate/primitives/utils.go b/validate/primitives/utils.go new file mode 100644 index 0000000..210f2ee --- /dev/null +++ b/validate/primitives/utils.go @@ -0,0 +1,44 @@ +package primitives + +import "reflect" + +func IsZeroValue(x any) bool { + if x == nil { + return true + } + + v := reflect.ValueOf(x) + if !v.IsValid() { + return true + } + + // Check if the value is the zero value for its type + zeroValue := reflect.Zero(v.Type()) + return reflect.DeepEqual(v.Interface(), zeroValue.Interface()) +} + +func GenericValidator(fieldValue any, rules []Rule, isOptional bool) ([]string, bool) { + + var errors []string = nil + ok := true + + // if its optional and the value is zero we can skip the validation + if isOptional && IsZeroValue(fieldValue) { + return errors, ok + } + + for _, set := range rules { + + set.FieldValue = fieldValue + if !set.ValidateFunc(set) { + ok = false + msg := set.ErrorMessage + if errors == nil { + errors = []string{} + } + errors = append(errors, msg) + } + } + + return errors, ok +} diff --git a/validate/rules.go b/validate/rules.go deleted file mode 100644 index 2053d16..0000000 --- a/validate/rules.go +++ /dev/null @@ -1,306 +0,0 @@ -package validate - -import ( - "fmt" - "reflect" - "regexp" - "time" - "unicode" -) - -var ( - emailRegex = regexp.MustCompile(`^[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,4}$`) - urlRegex = regexp.MustCompile(`^(https?:\/\/)?(www\.)?([a-zA-Z0-9\-]+\.)+[a-zA-Z]{2,}(\/[a-zA-Z0-9\-._~:\/?#\[\]@!$&'()*+,;=]*)?$`) -) - -// RuleSet holds the state of a single rule. -type RuleSet struct { - Name string - RuleValue any - FieldValue any - FieldName any - ErrorMessage string - MessageFunc func(RuleSet) string - ValidateFunc func(RuleSet) bool -} - -// Message overrides the default message of a RuleSet -func (set RuleSet) Message(msg string) RuleSet { - set.ErrorMessage = msg - return set -} - -type Numeric interface { - int | float64 -} - -func In[T any](values []T) RuleSet { - return RuleSet{ - Name: "in", - RuleValue: values, - ValidateFunc: func(set RuleSet) bool { - for _, value := range values { - v := set.FieldValue.(T) - if reflect.DeepEqual(v, value) { - return true - } - } - return false - }, - MessageFunc: func(set RuleSet) string { - return fmt.Sprintf("should be in %v", values) - }, - } -} - -var ContainsUpper = RuleSet{ - Name: "containsUpper", - ValidateFunc: func(rule RuleSet) bool { - str, ok := rule.FieldValue.(string) - if !ok { - return false - } - for _, ch := range str { - if unicode.IsUpper(rune(ch)) { - return true - } - } - return false - }, - MessageFunc: func(set RuleSet) string { - return "must contain at least 1 uppercase character" - }, -} - -var ContainsDigit = RuleSet{ - Name: "containsDigit", - ValidateFunc: func(rule RuleSet) bool { - str, ok := rule.FieldValue.(string) - if !ok { - return false - } - return hasDigit(str) - }, - MessageFunc: func(set RuleSet) string { - return "must contain at least 1 numeric character" - }, -} - -var ContainsSpecial = RuleSet{ - Name: "containsSpecial", - ValidateFunc: func(rule RuleSet) bool { - str, ok := rule.FieldValue.(string) - if !ok { - return false - } - return hasSpecialChar(str) - }, - MessageFunc: func(set RuleSet) string { - return "must contain at least 1 special character" - }, -} - -var Required = RuleSet{ - Name: "required", - MessageFunc: func(set RuleSet) string { - return "is a required field" - }, - ValidateFunc: func(rule RuleSet) bool { - str, ok := rule.FieldValue.(string) - if !ok { - return false - } - return len(str) > 0 - }, -} - -var URL = RuleSet{ - Name: "url", - MessageFunc: func(set RuleSet) string { - return "is not a valid url" - }, - ValidateFunc: func(set RuleSet) bool { - u, ok := set.FieldValue.(string) - if !ok { - return false - } - return urlRegex.MatchString(u) - }, -} - -var Email = RuleSet{ - Name: "email", - MessageFunc: func(set RuleSet) string { - return "is not a valid email address" - }, - ValidateFunc: func(set RuleSet) bool { - email, ok := set.FieldValue.(string) - if !ok { - return false - } - return emailRegex.MatchString(email) - }, -} - -var Time = RuleSet{ - Name: "time", - ValidateFunc: func(set RuleSet) bool { - t, ok := set.FieldValue.(time.Time) - if !ok { - return false - } - return t.After(time.Time{}) - }, - MessageFunc: func(set RuleSet) string { - return "is not a valid time" - }, -} - -func TimeAfter(t time.Time) RuleSet { - return RuleSet{ - Name: "timeAfter", - ValidateFunc: func(set RuleSet) bool { - t, ok := set.FieldValue.(time.Time) - if !ok { - return false - } - return t.After(t) - }, - MessageFunc: func(set RuleSet) string { - return fmt.Sprintf("is not after %v", set.FieldValue) - }, - } -} - -func TimeBefore(t time.Time) RuleSet { - return RuleSet{ - Name: "timeBefore", - ValidateFunc: func(set RuleSet) bool { - t, ok := set.FieldValue.(time.Time) - if !ok { - return false - } - return t.Before(t) - }, - MessageFunc: func(set RuleSet) string { - return fmt.Sprintf("is not before %v", set.FieldValue) - }, - } -} - -func EQ[T comparable](v T) RuleSet { - return RuleSet{ - Name: "eq", - RuleValue: v, - ValidateFunc: func(set RuleSet) bool { - return set.FieldValue.(T) == v - }, - MessageFunc: func(set RuleSet) string { - return fmt.Sprintf("should be equal to %v", v) - }, - } -} - -func LTE[T Numeric](n T) RuleSet { - return RuleSet{ - Name: "lte", - RuleValue: n, - ValidateFunc: func(set RuleSet) bool { - return set.FieldValue.(T) <= n - }, - MessageFunc: func(set RuleSet) string { - return fmt.Sprintf("should be lesser or equal than %v", n) - }, - } -} - -func GTE[T Numeric](n T) RuleSet { - return RuleSet{ - Name: "gte", - RuleValue: n, - ValidateFunc: func(set RuleSet) bool { - return set.FieldValue.(T) >= n - }, - MessageFunc: func(set RuleSet) string { - return fmt.Sprintf("should be greater or equal than %v", n) - }, - } -} - -func LT[T Numeric](n T) RuleSet { - return RuleSet{ - Name: "lt", - RuleValue: n, - ValidateFunc: func(set RuleSet) bool { - return set.FieldValue.(T) < n - }, - MessageFunc: func(set RuleSet) string { - return fmt.Sprintf("should be lesser than %v", n) - }, - } -} - -func GT[T Numeric](n T) RuleSet { - return RuleSet{ - Name: "gt", - RuleValue: n, - ValidateFunc: func(set RuleSet) bool { - return set.FieldValue.(T) > n - }, - MessageFunc: func(set RuleSet) string { - return fmt.Sprintf("should be greater than %v", n) - }, - } -} - -func Max(n int) RuleSet { - return RuleSet{ - Name: "max", - RuleValue: n, - ValidateFunc: func(set RuleSet) bool { - str, ok := set.FieldValue.(string) - if !ok { - return false - } - return len(str) <= n - }, - MessageFunc: func(set RuleSet) string { - return fmt.Sprintf("should be maximum %d characters long", n) - }, - } -} - -func Min(n int) RuleSet { - return RuleSet{ - Name: "min", - RuleValue: n, - ValidateFunc: func(set RuleSet) bool { - str, ok := set.FieldValue.(string) - if !ok { - return false - } - return len(str) >= n - }, - MessageFunc: func(set RuleSet) string { - return fmt.Sprintf("should be at least %d characters long", n) - }, - } -} - -func hasDigit(s string) bool { - for _, char := range s { - if unicode.IsDigit(char) { - return true - } - } - return false -} - -func hasSpecialChar(s string) bool { - for _, char := range s { - if !unicode.IsLetter(char) && !unicode.IsDigit(char) { - return true - } - } - return false -} diff --git a/validate/slices.go b/validate/slices.go new file mode 100644 index 0000000..d4c3726 --- /dev/null +++ b/validate/slices.go @@ -0,0 +1,171 @@ +package validate + +import ( + "fmt" + "reflect" + + p "github.com/anthdm/superkit/validate/primitives" +) + +type sliceValidator struct { + Rules []p.Rule + IsOptional bool +} + +func Slice(schema fieldValidator) *sliceValidator { + return &sliceValidator{ + Rules: []p.Rule{ + { + Name: "sliceItemsMatchSchema", + RuleValue: schema, + ErrorMessage: "all items should match the schema", + ValidateFunc: func(set p.Rule) bool { + rv := reflect.ValueOf(set.FieldValue) + if rv.Kind() != reflect.Slice { + return false + } + s, ok := set.RuleValue.(fieldValidator) + if !ok { + return false + } + for idx := 0; idx < rv.Len(); idx++ { + v := rv.Index(idx).Interface() + _, ok := s.Validate(v) + if !ok { + return false + } + } + return true + }, + }, + }, + } +} + +// GLOBAL METHODS + +func (v *sliceValidator) Refine(ruleName string, errorMsg string, validateFunc p.RuleValidateFunc) *sliceValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: ruleName, + ErrorMessage: errorMsg, + ValidateFunc: validateFunc, + }, + ) + + return v +} + +func (v *sliceValidator) Optional() *sliceValidator { + v.IsOptional = true + return v +} + +func (v *sliceValidator) Validate(fieldValue any) ([]string, bool) { + return p.GenericValidator(fieldValue, v.Rules, v.IsOptional) +} + +// UNIQUE METHODS + +// TODO +// some & every -> pass a validator + +func (v *sliceValidator) NotEmpty() *sliceValidator { + v.Rules = append(v.Rules, + sliceMin(1, "should not be empty"), + ) + return v +} + +// Minimum number of items +func (v *sliceValidator) Min(n int) *sliceValidator { + v.Rules = append(v.Rules, + sliceMin(n, fmt.Sprintf("should be at least %d items long", n)), + ) + return v +} + +// Maximum number of items +func (v *sliceValidator) Max(n int) *sliceValidator { + v.Rules = append(v.Rules, + sliceMax(n, fmt.Sprintf("should be at maximum %d items long", n)), + ) + return v +} + +// Exact number of items +func (v *sliceValidator) Len(n int) *sliceValidator { + v.Rules = append(v.Rules, + sliceLength(n, fmt.Sprintf("should be exactly %d items long", n)), + ) + return v +} + +func (v *sliceValidator) Contains(val any) *sliceValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "contains", + RuleValue: val, + ErrorMessage: fmt.Sprintf("should contain %v", val), + ValidateFunc: func(set p.Rule) bool { + rv := reflect.ValueOf(set.FieldValue) + if rv.Kind() != reflect.Slice { + return false + } + for idx := 0; idx < rv.Len(); idx++ { + v := rv.Index(idx).Interface() + + if reflect.DeepEqual(v, val) { + return true + } + } + + return false + }, + }, + ) + return v +} + +func sliceMin(n int, errMsg string) p.Rule { + return p.Rule{ + Name: "sliceMin", + RuleValue: n, + ErrorMessage: errMsg, + ValidateFunc: func(set p.Rule) bool { + rv := reflect.ValueOf(set.FieldValue) + if rv.Kind() != reflect.Slice { + return false + } + return rv.Len() >= n + }, + } +} +func sliceMax(n int, errMsg string) p.Rule { + return p.Rule{ + Name: "sliceMax", + RuleValue: n, + ErrorMessage: errMsg, + ValidateFunc: func(set p.Rule) bool { + rv := reflect.ValueOf(set.FieldValue) + if rv.Kind() != reflect.Slice { + return false + } + return rv.Len() <= n + }, + } +} +func sliceLength(n int, errMsg string) p.Rule { + return p.Rule{ + Name: "sliceLength", + RuleValue: n, + ErrorMessage: errMsg, + ValidateFunc: func(set p.Rule) bool { + rv := reflect.ValueOf(set.FieldValue) + if rv.Kind() != reflect.Slice { + return false + } + return rv.Len() == n + }, + } +} diff --git a/validate/slices_test.go b/validate/slices_test.go new file mode 100644 index 0000000..c2e2009 --- /dev/null +++ b/validate/slices_test.go @@ -0,0 +1,98 @@ +package validate + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSlicePassSchema(t *testing.T) { + type TestStruct struct { + Items []any + } + + s := TestStruct{ + Items: []any{"a", "b", "c"}, + } + + errs, ok := Validate(s, Schema{"items": Slice(String().Len(1))}) + assert.True(t, ok) + assert.Empty(t, errs) + + s.Items = []any{"a", "b", "c", "d", 1} + errs, ok = Validate(s, Schema{"items": Slice(String().Len(1))}) + assert.False(t, ok) + assert.Len(t, errs, 1) +} + +func TestSliceNotEmpty(t *testing.T) { + type TestStruct struct { + Items []any + } + s := TestStruct{ + Items: []any{}, + } + + errs, ok := Validate(s, Schema{"items": Slice(String()).NotEmpty()}) + assert.False(t, ok) + assert.NotEmpty(t, errs) + + s.Items = []any{"a", "b", "c"} + errs, ok = Validate(s, Schema{"items": Slice(String()).NotEmpty()}) + assert.True(t, ok) + assert.Empty(t, errs) +} + +func TestSliceLength(t *testing.T) { + type TestStruct struct { + Items []any + } + + s := TestStruct{ + Items: []any{"a", "b", "c"}, + } + + errs, ok := Validate(s, Schema{"items": Slice(String()).Len(3)}) + assert.True(t, ok) + assert.Empty(t, errs) + + errs, ok = Validate(s, Schema{"items": Slice(String()).Len(2)}) + assert.False(t, ok) + assert.Len(t, errs, 1) + + // min & max + errs, ok = Validate(s, Schema{"items": Slice(String()).Min(2)}) + assert.True(t, ok) + assert.Empty(t, errs) + + errs, ok = Validate(s, Schema{"items": Slice(String()).Min(4)}) + assert.False(t, ok) + assert.Len(t, errs, 1) + + errs, ok = Validate(s, Schema{"items": Slice(String()).Max(3)}) + assert.True(t, ok) + assert.Empty(t, errs) + + errs, ok = Validate(s, Schema{"items": Slice(String()).Max(1)}) + assert.False(t, ok) + assert.Len(t, errs, 1) +} + +func TestSliceContains(t *testing.T) { + + type TestStruct struct { + Items []any + } + + s := TestStruct{ + Items: []any{"a", "b", "c"}, + } + + errs, ok := Validate(s, Schema{"items": Slice(String()).Contains("a")}) + assert.True(t, ok) + assert.Empty(t, errs) + + errs, ok = Validate(s, Schema{"items": Slice(String()).Contains("d")}) + assert.False(t, ok) + assert.Len(t, errs, 1) +} diff --git a/validate/string.go b/validate/string.go new file mode 100644 index 0000000..7dc6dd3 --- /dev/null +++ b/validate/string.go @@ -0,0 +1,262 @@ +package validate + +import ( + "fmt" + "regexp" + "strings" + + p "github.com/anthdm/superkit/validate/primitives" +) + +var ( + emailRegex = regexp.MustCompile(`^[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,4}$`) + // TODO improve this regex? + urlRegex = regexp.MustCompile(`^(http(s)?://)?([\da-z\.-]+)\.([a-z\.]{2,6})([/\w \.-]*)*/?$`) +) + +type StringValidator struct { + Rules []p.Rule + IsOptional bool +} + +func String() *StringValidator { + return &StringValidator{ + Rules: []p.Rule{ + p.IsType[string]("should be a string"), + }, + } +} + +// GLOBAL METHODS + +// is equal to one of the values +func (v *StringValidator) In(values []string) *StringValidator { + v.Rules = append(v.Rules, p.In(values, fmt.Sprintf("should be in %v", values))) + return v +} + +func (v *StringValidator) Validate(fieldValue any) ([]string, bool) { + return p.GenericValidator(fieldValue, v.Rules, v.IsOptional) +} + +func (v *StringValidator) Optional() *StringValidator { + v.IsOptional = true + return v +} + +func (v *StringValidator) Refine(ruleName string, errorMsg string, validateFunc p.RuleValidateFunc) *StringValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: ruleName, + ErrorMessage: errorMsg, + ValidateFunc: validateFunc, + }, + ) + + return v +} + +// METHODS + +func (v *StringValidator) Min(n int) *StringValidator { + v.Rules = append(v.Rules, + p.LenMin[string](n, fmt.Sprintf("should be at least %d characters long", n))) + return v +} + +func (v *StringValidator) Max(n int) *StringValidator { + v.Rules = append(v.Rules, + p.LenMax[string](n, fmt.Sprintf("should be at most %d characters long", n))) + return v +} +func (v *StringValidator) Len(n int) *StringValidator { + v.Rules = append(v.Rules, + p.Length[string](n, fmt.Sprintf("should be exactly %d characters long", n)), + ) + return v +} + +// THIS IS ONLY HERE FOR CREATING ERROR MSGS FOR FORMS. DOESN'T ACTUALLY PROVIDE ANY VALUE +func (v *StringValidator) Required() *StringValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "required", + ValidateFunc: func(rule p.Rule) bool { + str, ok := rule.FieldValue.(string) + if !ok { + return false + } + return str != "" + }, + ErrorMessage: "is a required field", + }, + ) + return v +} + +func (v *StringValidator) Email() *StringValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "email", + ErrorMessage: "is not a valid email address", + ValidateFunc: func(set p.Rule) bool { + email, ok := set.FieldValue.(string) + if !ok { + return false + } + return emailRegex.MatchString(email) + }, + }, + ) + return v +} + +func (v *StringValidator) URL() *StringValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "url", + ErrorMessage: "is not a valid url", + ValidateFunc: func(set p.Rule) bool { + u, ok := set.FieldValue.(string) + if !ok { + return false + } + isOk := urlRegex.MatchString(u) + return isOk + }, + }, + ) + return v +} + +// Should use the go method name for this? HasPrefix & HasSuffix??? +// TODO ??? +func (v *StringValidator) StartsWith(s string) *StringValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "startsWith", + RuleValue: s, + ValidateFunc: func(set p.Rule) bool { + val, ok := set.FieldValue.(string) + if !ok { + return false + } + return strings.HasPrefix(val, s) + }, + ErrorMessage: fmt.Sprintf("should start with %s", s), + }, + ) + return v +} + +func (v *StringValidator) EndsWith(s string) *StringValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "startsWith", + RuleValue: s, + ValidateFunc: func(set p.Rule) bool { + val, ok := set.FieldValue.(string) + if !ok { + return false + } + return strings.HasSuffix(val, s) + }, + ErrorMessage: fmt.Sprintf("should end with %s", s), + }, + ) + return v +} + +func (v *StringValidator) Contains(sub string) *StringValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "contains", + RuleValue: sub, + ValidateFunc: func(set p.Rule) bool { + val, ok := set.FieldValue.(string) + if !ok { + return false + } + return strings.Contains(val, sub) + }, + ErrorMessage: fmt.Sprintf("should contain %s", sub), + }, + ) + return v +} + +func (v *StringValidator) ContainsUpper() *StringValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "containsUpper", + ValidateFunc: func(set p.Rule) bool { + val, ok := set.FieldValue.(string) + if !ok { + return false + } + for _, r := range val { + if r >= 'A' && r <= 'Z' { + return true + } + } + return false + }, + ErrorMessage: "should contain at least one uppercase letter", + }, + ) + return v +} + +func (v *StringValidator) ContainsDigit() *StringValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "containsDigit", + ValidateFunc: func(set p.Rule) bool { + val, ok := set.FieldValue.(string) + if !ok { + return false + } + for _, r := range val { + if r >= '0' && r <= '9' { + return true + } + } + return false + }, + ErrorMessage: "should contain at least one digit", + }, + ) + return v +} + +func (v *StringValidator) ContainsSpecial() *StringValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "containsSpecial", + ValidateFunc: func(set p.Rule) bool { + val, ok := set.FieldValue.(string) + if !ok { + return false + } + for _, r := range val { + if (r >= '!' && r <= '/') || + (r >= ':' && r <= '@') || + (r >= '[' && r <= '`') || + (r >= '{' && r <= '~') { + return true + } + } + return false + }, + ErrorMessage: "should contain at least one special character", + }, + ) + return v +} + +// TODO +// IP +// date +// datetime +// time +// emoji diff --git a/validate/time.go b/validate/time.go new file mode 100644 index 0000000..b63892e --- /dev/null +++ b/validate/time.go @@ -0,0 +1,103 @@ +package validate + +import ( + "fmt" + "time" + + p "github.com/anthdm/superkit/validate/primitives" +) + +type timeValidator struct { + Rules []p.Rule + IsOptional bool +} + +func Time() *timeValidator { + return &timeValidator{ + Rules: []p.Rule{ + p.IsType[time.Time]("is not a a valid time"), + }, + } +} + +// GLOBAL METHODS + +func (v *timeValidator) Refine(ruleName string, errorMsg string, validateFunc p.RuleValidateFunc) *timeValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: ruleName, + ErrorMessage: errorMsg, + ValidateFunc: validateFunc, + }, + ) + + return v +} + +func (v *timeValidator) In(values []time.Time) *timeValidator { + v.Rules = append(v.Rules, p.In(values, fmt.Sprintf("should be in %v", values))) + return v +} + +func (v *timeValidator) Optional() *timeValidator { + v.IsOptional = true + return v +} + +func (v *timeValidator) Validate(fieldValue any) ([]string, bool) { + return p.GenericValidator(fieldValue, v.Rules, v.IsOptional) +} + +// UNIQUE METHODS + +func (v *timeValidator) After(t time.Time) *timeValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "timeAfter", + ErrorMessage: fmt.Sprintf("is not after %v", t), + ValidateFunc: func(set p.Rule) bool { + val, ok := set.FieldValue.(time.Time) + if !ok { + return false + } + return val.After(t) + }, + }, + ) + return v +} + +func (v *timeValidator) Before(t time.Time) *timeValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "timeBefore", + ErrorMessage: fmt.Sprintf("is not before %v", t), + ValidateFunc: func(set p.Rule) bool { + val, ok := set.FieldValue.(time.Time) + if !ok { + return false + } + return val.Before(t) + }, + }, + ) + return v +} + +func (v *timeValidator) Is(t time.Time) *timeValidator { + v.Rules = append(v.Rules, + p.Rule{ + Name: "timeIs", + ErrorMessage: fmt.Sprintf("is not %v", t), + ValidateFunc: func(set p.Rule) bool { + val, ok := set.FieldValue.(time.Time) + if !ok { + return false + } + return val.Equal(t) + }, + }, + ) + + return v +} diff --git a/validate/time_test.go b/validate/time_test.go new file mode 100644 index 0000000..a31de08 --- /dev/null +++ b/validate/time_test.go @@ -0,0 +1,68 @@ +package validate + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTimeAfter(t *testing.T) { + type Foo struct { + CreatedAt time.Time + } + now := time.Now() + foo := Foo{ + CreatedAt: now, + } + schema := Schema{ + "createdAt": Time().After(now), + } + errors, ok := Validate(foo, schema) + assert.False(t, ok) + assert.Len(t, errors["createdAt"], 1) + + foo.CreatedAt = now.Add(time.Second * 10000) + _, ok = Validate(foo, schema) + assert.True(t, ok) +} + +func TestTimeBefore(t *testing.T) { + type Foo struct { + CreatedAt time.Time + } + now := time.Now() + foo := Foo{ + CreatedAt: now, + } + schema := Schema{ + "createdAt": Time().Before(now), + } + errors, ok := Validate(foo, schema) + assert.False(t, ok) + assert.Len(t, errors["createdAt"], 1) + + foo.CreatedAt = now.Add(time.Second * -10000) + _, ok = Validate(foo, schema) + assert.True(t, ok) +} + +func TestTimeIs(t *testing.T) { + type Foo struct { + CreatedAt time.Time + } + now := time.Now() + foo := Foo{ + CreatedAt: now.Add(time.Second * 10000), + } + schema := Schema{ + "createdAt": Time().Is(now), + } + errors, ok := Validate(foo, schema) + assert.False(t, ok) + assert.Len(t, errors["createdAt"], 1) + + foo.CreatedAt = now + _, ok = Validate(foo, schema) + assert.True(t, ok) +} diff --git a/validate/validate.go b/validate/validate.go index 71c3131..1a0b1ab 100644 --- a/validate/validate.go +++ b/validate/validate.go @@ -31,37 +31,58 @@ func (e Errors) Get(field string) []string { return e[field] } -// Has returns true whether the given field has any errors. -func (e Errors) Has(field string) bool { - return len(e[field]) > 0 +// the interface of any set of rules for a field. Eg: String().Min(5) -> this is a fieldValidator +type fieldValidator interface { + Validate(val any) (errors []string, ok bool) } // Schema represents a validation schema. -type Schema map[string][]RuleSet +type Schema map[string]fieldValidator -// Merge merges the two given schemas, returning a new Schema. -func Merge(schema, other Schema) Schema { - newSchema := Schema{} - maps.Copy(newSchema, schema) - maps.Copy(newSchema, other) - return newSchema +// Validate validates data based on the given Schema. +func (s Schema) Validate(data any) (Errors, bool) { + errors := Errors{} + return validateSchema(data, s, errors) } +func Validate(data any, fields Schema) (Errors, bool) { + errors := Errors{} + return validateSchema(data, fields, errors) +} + +func validateSchema(data any, schema Schema, errors Errors) (Errors, bool) { + globalOk := true + ok := true + var fieldErrs []string + + for fieldName, validator := range schema { + fieldName = string(unicode.ToUpper(rune(fieldName[0]))) + fieldName[1:] + fieldValue := getFieldValueByName(data, fieldName) + fieldName = string(unicode.ToLower([]rune(fieldName)[0])) + fieldName[1:] + fieldErrs, ok = validator.Validate(fieldValue) + if !ok { + errors[fieldName] = fieldErrs + globalOk = false + } -// Rules is a function that takes any amount of RuleSets -func Rules(rules ...RuleSet) []RuleSet { - ruleSets := make([]RuleSet, len(rules)) - for i := 0; i < len(ruleSets); i++ { - ruleSets[i] = rules[i] } - return ruleSets + + return errors, globalOk } -// Validate validates data based on the given Schema. -func Validate(data any, fields Schema) (Errors, bool) { - errors := Errors{} - return validate(data, fields, errors) +// Merge merges the two given schemas returning a new Schema. In case of clashing second will take priority +func Merge(first, second Schema, rest ...Schema) Schema { + newSchema := Schema{} + maps.Copy(newSchema, first) + maps.Copy(newSchema, second) + + for _, s := range rest { + maps.Copy(newSchema, s) + } + + return newSchema } +// ! PARSE REQUESTS // Request parses an http.Request into data and validates it based // on the given schema. func Request(r *http.Request, data any, schema Schema) (Errors, bool) { @@ -69,57 +90,55 @@ func Request(r *http.Request, data any, schema Schema) (Errors, bool) { if err := parseRequest(r, data); err != nil { errors["_error"] = []string{err.Error()} } - return validate(data, schema, errors) + return validateSchema(data, schema, errors) } -func validate(data any, schema Schema, errors Errors) (Errors, bool) { - ok := true - for fieldName, ruleSets := range schema { - // Uppercase the field name so we never check un-exported fields. - // But we need to watch out for member fields that are uppercased by - // the user. For example (URL, ID, ...) - if !isUppercase(fieldName) { - fieldName = string(unicode.ToUpper(rune(fieldName[0]))) + fieldName[1:] +// TODO -> Parse requestQueryParams +func RequestParams(r *http.Request, data any, schema Schema) (Errors, bool) { + errors := Errors{} + if err := parseRequestParams(r, data); err != nil { + errors["_error"] = []string{err.Error()} + } + return validateSchema(data, schema, errors) +} + +func parseRequestParams(r *http.Request, v any) error { + + params := r.URL.Query() + val := reflect.ValueOf(v).Elem() + for i := 0; i < val.NumField(); i++ { + field := val.Type().Field(i) + paramTag := field.Tag.Get("param") + param := params[paramTag] + + if len(param) == 0 { + continue } - fieldValue := getFieldAndTagByName(data, fieldName) - for _, set := range ruleSets { - set.FieldValue = fieldValue - set.FieldName = fieldName - fieldName = string(unicode.ToLower([]rune(fieldName)[0])) + fieldName[1:] - if !set.ValidateFunc(set) { - ok = false - msg := set.MessageFunc(set) - if len(set.ErrorMessage) > 0 { - msg = set.ErrorMessage - } - if _, ok := errors[fieldName]; !ok { - errors[fieldName] = []string{} + fieldVal := val.Field(i) + t := fieldVal.Kind() + switch t { + case reflect.Slice: + for idx, v := range param { + if idx < fieldVal.Len() { + fieldVal.Index(idx).Set(reflect.ValueOf(v)) + } else { + newElem := reflect.Append(fieldVal, reflect.ValueOf(v)) + fieldVal.Set(newElem) } - errors[fieldName] = append(errors[fieldName], msg) + } + default: + if err := parsePrimitive(&t, &fieldVal, param[0]); err != nil { + return err } } } - return errors, ok -} - -func getFieldAndTagByName(v any, name string) any { - val := reflect.ValueOf(v) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - if val.Kind() != reflect.Struct { - return nil - } - fieldVal := val.FieldByName(name) - if !fieldVal.IsValid() { - return nil - } - return fieldVal.Interface() + return nil } func parseRequest(r *http.Request, v any) error { contentType := r.Header.Get("Content-Type") + // TODO support more content types if contentType == "application/x-www-form-urlencoded" { if err := r.ParseForm(); err != nil { return fmt.Errorf("failed to parse form: %v", err) @@ -135,44 +154,9 @@ func parseRequest(r *http.Request, v any) error { } fieldVal := val.Field(i) - switch fieldVal.Kind() { - case reflect.Bool: - // There are cases where frontend libraries use "on" as the bool value - // think about toggles. Hence, let's try this first. - if formValue == "on" { - fieldVal.SetBool(true) - } else if formValue == "off" { - fieldVal.SetBool(false) - return nil - } else { - boolVal, err := strconv.ParseBool(formValue) - if err != nil { - return fmt.Errorf("failed to parse bool: %v", err) - } - fieldVal.SetBool(boolVal) - } - case reflect.String: - fieldVal.SetString(formValue) - case reflect.Int, reflect.Int32, reflect.Int64: - intVal, err := strconv.Atoi(formValue) - if err != nil { - return fmt.Errorf("failed to parse int: %v", err) - } - fieldVal.SetInt(int64(intVal)) - case reflect.Uint, reflect.Uint32, reflect.Uint64: - intVal, err := strconv.Atoi(formValue) - if err != nil { - return fmt.Errorf("failed to parse int: %v", err) - } - fieldVal.SetUint(uint64(intVal)) - case reflect.Float64: - floatVal, err := strconv.ParseFloat(formValue, 64) - if err != nil { - return fmt.Errorf("failed to parse float: %v", err) - } - fieldVal.SetFloat(floatVal) - default: - return fmt.Errorf("unsupported kind %s", fieldVal.Kind()) + typ := fieldVal.Kind() + if err := parsePrimitive(&typ, &fieldVal, formValue); err != nil { + return err } } @@ -180,11 +164,56 @@ func parseRequest(r *http.Request, v any) error { return nil } -func isUppercase(s string) bool { - for _, ch := range s { - if !unicode.IsUpper(rune(ch)) { - return false +func parsePrimitive(typ *reflect.Kind, refObj *reflect.Value, value string) error { + switch *typ { + case reflect.Bool: + // There are cases where frontend libraries use "on" as the bool value + // think about toggles. Hence, let's try this first. + if value == "on" { + refObj.SetBool(true) + } else if value == "off" { + refObj.SetBool(false) + return nil + } else { + boolVal, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("failed to parse bool: %v", err) + } + refObj.SetBool(boolVal) + } + + case reflect.String: + refObj.SetString(value) + case reflect.Int: + intVal, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("failed to parse int: %v", err) } + refObj.SetInt(int64(intVal)) + case reflect.Float64: + floatVal, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("failed to parse float: %v", err) + } + refObj.SetFloat(floatVal) + default: + return fmt.Errorf("unsupported kind %s", refObj.Kind()) + } + + return nil +} + +func getFieldValueByName(v any, name string) any { + val := reflect.ValueOf(v) + if val.Kind() == reflect.Ptr { + val = val.Elem() } - return true + if val.Kind() != reflect.Struct { + return nil + } + fieldVal := val.FieldByName(name) + if !fieldVal.IsValid() { + return nil + } + return fieldVal.Interface() } diff --git a/validate/validate_test.go b/validate/validate_test.go index b29e12a..770dec1 100644 --- a/validate/validate_test.go +++ b/validate/validate_test.go @@ -1,207 +1,129 @@ package validate import ( - "fmt" "net/http" - "net/url" "strings" "testing" - "time" + p "github.com/anthdm/superkit/validate/primitives" "github.com/stretchr/testify/assert" ) -var createdAt = time.Now() - -var testSchema = Schema{ - "createdAt": Rules(Time), - "startedAt": Rules(TimeBefore(time.Now())), - "deletedAt": Rules(TimeAfter(createdAt)), - "email": Rules(Email), - "url": Rules(URL), - "password": Rules( - ContainsSpecial, - ContainsUpper, - ContainsDigit, - Min(7), - Max(50), - ), - "age": Rules(GTE(18)), - "bet": Rules(GT(0), LTE(10)), - "username": Rules(Required), -} +func TestRequest(t *testing.T) { + formData := "name=JohnDoe&email=john@doe.com&age=30&isMarried=true&lights=on&cash=10.5&swagger=doweird" -func TestValidateRequest(t *testing.T) { - var ( - email = "foo@bar.com" - password = "superHunter123@" - firstName = "Anthony" - website = "http://foo.com" - randomNumber = 123 - randomFloat = 9.999 - ) - formValues := url.Values{} - formValues.Set("email", email) - formValues.Set("password", password) - formValues.Set("firstName", firstName) - formValues.Set("url", website) - formValues.Set("brandom", fmt.Sprint(randomNumber)) - formValues.Set("arandom", fmt.Sprint(randomFloat)) - encodedValues := formValues.Encode() - - req, err := http.NewRequest("POST", "http://foo.com", strings.NewReader(encodedValues)) - assert.Nil(t, err) + // Create a fake HTTP request with form data + req, err := http.NewRequest("POST", "/submit?foo=bar&bar=foo&foo=baz", strings.NewReader(formData)) + if err != nil { + t.Fatalf("Error creating request: %v", err) + } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - type SignupData struct { - Email string `form:"email"` - Password string `form:"password"` - FirstName string `form:"firstName"` - URL string `form:"url"` - ARandomRenamedNumber int `form:"brandom"` - ARandomRenamedFloat float64 `form:"arandom"` + type User struct { + Email string `form:"email"` + Name string `form:"name"` + Age int `form:"age"` + IsMarried bool `form:"isMarried"` + Lights bool `form:"lights"` + Cash float64 `form:"cash"` + Swagger string `form:"swagger"` } - schema := Schema{ - "Email": Rules(Email), - "Password": Rules( - Required, - ContainsDigit, - ContainsUpper, - ContainsSpecial, - Min(7), - ), - "FirstName": Rules(Min(3), Max(50)), - "URL": Rules(URL), - "ARandomRenamedNumber": Rules(GT(100), LT(124)), - "ARandomRenamedFloat": Rules(GT(9.0), LT(10.1)), - } - - var data SignupData - errors, ok := Request(req, &data, schema) + "email": String().Email(), + "name": String().Min(3).Max(10), + "age": Int().GT(18), + "isMarried": Bool().True(), + "lights": Bool().True(), + "cash": Float().GT(10.0), + "swagger": String().Refine("swagger", "should be doweird", func(rule p.Rule) bool { + return rule.FieldValue.(string) == "doweird" + }), + } + u := User{} + + errs, ok := Request(req, &u, schema) + + assert.Equal(t, "john@doe.com", u.Email) + assert.Equal(t, "JohnDoe", u.Name) + assert.Equal(t, 30, u.Age) + assert.True(t, u.IsMarried) + assert.True(t, u.Lights) + assert.Equal(t, 10.5, u.Cash) + assert.Equal(t, u.Swagger, "doweird") + assert.Empty(t, errs) assert.True(t, ok) - assert.Empty(t, errors) - - assert.Equal(t, data.Email, email) - assert.Equal(t, data.Password, password) - assert.Equal(t, data.FirstName, firstName) - assert.Equal(t, data.URL, website) - assert.Equal(t, data.ARandomRenamedNumber, randomNumber) - assert.Equal(t, data.ARandomRenamedFloat, randomFloat) } -func TestTime(t *testing.T) { - type Foo struct { - CreatedAt time.Time - } - foo := Foo{ - CreatedAt: time.Now(), - } - schema := Schema{ - "createdAt": Rules(Time), - } - _, ok := Validate(foo, schema) - assert.True(t, ok) +func TestRequestParams(t *testing.T) { + formData := "name=JohnDoe&email=john@doe.com&age=30&age=20&isMarried=true&lights=on&cash=10.5&swagger=doweird&swagger=swagger" - foo.CreatedAt = time.Time{} - _, ok = Validate(foo, schema) - assert.False(t, ok) -} - -func TestURL(t *testing.T) { - type Foo struct { - URL string `v:"URL"` + // Create a fake HTTP request with form data + req, err := http.NewRequest("POST", "/submit?"+formData, nil) + if err != nil { + t.Fatalf("Error creating request: %v", err) } - foo := Foo{ - URL: "not an url", - } - schema := Schema{ - "URL": Rules(URL), - } - errors, ok := Validate(foo, schema) - assert.False(t, ok) - assert.NotEmpty(t, errors) - validURLS := []string{ - "http://google.com", - "http://www.google.com", - "https://www.google.com", - "https://www.google.com", - "www.google.com", - "https://book.com/sales", - "app.book.com", - "app.book.com/signup", - } - - for _, url := range validURLS { - foo.URL = url - errors, ok = Validate(foo, schema) - assert.True(t, ok) - assert.Empty(t, errors) + type User struct { + Email string `param:"email"` + Name string `param:"name"` + Age int `param:"age"` + IsMarried bool `param:"isMarried"` + Lights bool `param:"lights"` + Cash float64 `param:"cash"` + Swagger []string `param:"swagger"` } -} -func TestContainsUpper(t *testing.T) { - type Foo struct { - Password string - } - foo := Foo{"hunter"} schema := Schema{ - "Password": Rules(ContainsUpper), - } - errors, ok := Validate(foo, schema) - assert.False(t, ok) - assert.NotEmpty(t, errors) - - foo.Password = "Hunter" - errors, ok = Validate(foo, schema) + "email": String().Email(), + "name": String().Min(3).Max(10), + "age": Int().GT(18), + "isMarried": Bool().True(), + "lights": Bool().True(), + "cash": Float().GT(10.0), + "swagger": Slice( + String().Min(1)).Min(2), + } + u := User{} + + errs, ok := RequestParams(req, &u, schema) + + assert.Equal(t, "john@doe.com", u.Email) + assert.Equal(t, "JohnDoe", u.Name) + assert.Equal(t, 30, u.Age) + assert.True(t, u.IsMarried) + assert.True(t, u.Lights) + assert.Equal(t, 10.5, u.Cash) + assert.Equal(t, u.Swagger, []string{"doweird", "swagger"}) + assert.Empty(t, errs) assert.True(t, ok) - assert.Empty(t, errors) } -func TestContainsDigit(t *testing.T) { +func TestStringURL(t *testing.T) { type Foo struct { - Password string - } - foo := Foo{"hunter"} - schema := Schema{ - "Password": Rules(ContainsDigit), + Url string } - errors, ok := Validate(foo, schema) - assert.False(t, ok) - assert.NotEmpty(t, errors) - - foo.Password = "Hunter1" - errors, ok = Validate(foo, schema) - assert.True(t, ok) - assert.Empty(t, errors) -} - -func TestContainsSpecial(t *testing.T) { - type Foo struct { - Password string + foo := Foo{ + Url: "not an url", } - foo := Foo{"hunter"} schema := Schema{ - "Password": Rules(ContainsSpecial), + "url": String().URL(), } errors, ok := Validate(foo, schema) assert.False(t, ok) - assert.NotEmpty(t, errors) + assert.Len(t, errors["url"], 1) - foo.Password = "Hunter@" + foo.Url = "https://www.user.com" errors, ok = Validate(foo, schema) assert.True(t, ok) assert.Empty(t, errors) } -func TestRuleIn(t *testing.T) { +func TestStringIn(t *testing.T) { type Foo struct { Currency string } foo := Foo{"eur"} schema := Schema{ - "currency": Rules(In([]string{"eur", "usd", "chz"})), + "currency": String().In([]string{"eur", "usd", "chz"}), } errors, ok := Validate(foo, schema) assert.True(t, ok) @@ -218,9 +140,9 @@ func TestValidate(t *testing.T) { Username string } schema := Schema{ - "email": Rules(Email), + "email": String().Email(), // Test both lower and uppercase - "Username": Rules(Min(3), Max(10)), + "username": String().Min(3).Max(10), } user := User{ Email: "foo@bar.com", @@ -229,23 +151,44 @@ func TestValidate(t *testing.T) { errors, ok := Validate(user, schema) assert.True(t, ok) assert.Empty(t, errors) + assert.Empty(t, errors) } -func TestMergeSchemas(t *testing.T) { - expected := Schema{ - "Name": Rules(), - "Email": Rules(), - "FirstName": Rules(), - "LastName": Rules(), +func TestOptional(t *testing.T) { + type User struct { + Email string + Username string } - a := Schema{ - "Name": Rules(), - "Email": Rules(), + schema := Schema{ + "email": String().Email(), + "username": String().Min(3).Max(10).Optional(), } - b := Schema{ - "FirstName": Rules(), - "LastName": Rules(), + user := User{ + Email: "pedro@gmail.com", } - c := Merge(a, b) - assert.Equal(t, expected, c) + + errors, ok := Validate(user, schema) + assert.True(t, ok) + assert.Empty(t, errors) + assert.Empty(t, errors) +} + +func TestEmpty(t *testing.T) { + type User struct { + Email string + Username string + } + schema := Schema{ + "email": String(), + "username": String(), + } + user := User{ + Email: "", + Username: "", + } + + errors, ok := Validate(user, schema) + assert.True(t, ok) + assert.Empty(t, errors) + assert.Empty(t, errors) }