From b390a36e54f8f2e418b66fa21c4bddfe3b313675 Mon Sep 17 00:00:00 2001 From: "popov.dmitriy" <dp@sessia.dev> Date: Fri, 16 May 2025 19:15:02 +0300 Subject: [PATCH 1/2] Add validator --- request/validator.go | 230 ++++++++++++++++++++++++++++++++++++++ request/validator_test.go | 132 ++++++++++++++++++++++ 2 files changed, 362 insertions(+) create mode 100644 request/validator.go create mode 100644 request/validator_test.go diff --git a/request/validator.go b/request/validator.go new file mode 100644 index 0000000..b375b04 --- /dev/null +++ b/request/validator.go @@ -0,0 +1,230 @@ +package request + +import ( + "fmt" + + "github.com/go-playground/validator/v10" +) + +type FilterSetInterface interface { + GetFields() []string + Validate() error +} + +type OrderSetInterface interface { + GetFields() []string + Validate() error +} + +type GroupingSetInterface interface { + GetFields() []string + Validate() error +} + +const ( + errInvalidGroupingFields = "invalid grouping fields" + errInvalidGroupingMethods = "invalid grouping methods" + errInvalidFilterFields = "invalid filter fields" + errInvalidOrderFields = "invalid order fields" + errInvalidOrderMethods = "invalid order methods" +) + +type GroupingRules map[string][]string + +type DefaultRequestValidator struct { + validate *validator.Validate + groupingRules GroupingRules + filterFields map[string]struct{} + orderFields map[string]struct{} +} + +func NewDefaultRequestValidatorWithRules(rules GroupingRules, filterFields, orderFields []string) *DefaultRequestValidator { + v := validator.New() + + validator := &DefaultRequestValidator{ + validate: v, + groupingRules: rules, + filterFields: toSet(filterFields), + orderFields: toSet(orderFields), + } + + validator.registerValidations() + + return validator +} + +func (v *DefaultRequestValidator) registerValidations() { + v.registerGroupingAllowedFieldsValidation() + v.registerGroupingSetStructValidation() + v.registerFilterAllowedFieldsValidation() + v.registerOrderAllowedFieldsValidation() +} + +func (v *DefaultRequestValidator) registerGroupingAllowedFieldsValidation() { + allowed := v.groupingRulesKeysSet() + + v.validate.RegisterValidation("grouping_allowed_fields", func(fl validator.FieldLevel) bool { + return validateFieldsAllowed(fl, allowed) + }) +} + +func (v *DefaultRequestValidator) registerGroupingSetStructValidation() { + allowedMethods := buildAllowedMethodsMap(v.groupingRules) + + v.validate.RegisterStructValidation(func(sl validator.StructLevel) { + groupings, ok := sl.Current().Interface().(GroupingSet) + if !ok { + return + } + + for field, groupingMethod := range groupings.Groups { + methodsForField, fieldAllowed := allowedMethods[field] + if !fieldAllowed { + sl.ReportError(groupings.Groups, "Groups", field, "allowedfield", "") + continue + } + + if _, methodAllowed := methodsForField[groupingMethod.Method]; !methodAllowed { + sl.ReportError(groupingMethod.Method, "Method", field, "allowedmethod", "") + } + } + }, GroupingSet{}) +} + +func (v *DefaultRequestValidator) registerFilterAllowedFieldsValidation() { + v.validate.RegisterValidation("filter_allowed_fields", func(fl validator.FieldLevel) bool { + return validateFieldsAllowed(fl, v.filterFields) + }) +} + +func (v *DefaultRequestValidator) registerOrderAllowedFieldsValidation() { + v.validate.RegisterValidation("orders_allowed_fields", func(fl validator.FieldLevel) bool { + return validateFieldsAllowed(fl, v.orderFields) + }) + + v.validate.RegisterStructValidation(func(sl validator.StructLevel) { + orderSet, ok := sl.Current().Interface().(OrderSet) + if !ok { + return + } + + for field := range orderSet.Orders { + if _, exists := v.orderFields[field]; !exists { + sl.ReportError(orderSet.Orders, "Orders", field, "allowedfield", "") + } + } + }, OrderSet{}) +} + +func validateFieldsAllowed(fl validator.FieldLevel, allowedFields map[string]struct{}) bool { + fields, ok := fl.Field().Interface().([]string) + if !ok { + if str, ok := fl.Field().Interface().(string); ok { + _, exists := allowedFields[str] + return exists + } + return false + } + + for _, f := range fields { + if _, exists := allowedFields[f]; !exists { + return false + } + } + return true +} + +func toSet(fields []string) map[string]struct{} { + set := make(map[string]struct{}, len(fields)) + for _, f := range fields { + set[f] = struct{}{} + } + return set +} + +func (v *DefaultRequestValidator) groupingRulesKeysSet() map[string]struct{} { + set := make(map[string]struct{}, len(v.groupingRules)) + for k := range v.groupingRules { + set[k] = struct{}{} + } + return set +} + +func buildAllowedMethodsMap(rules GroupingRules) map[string]map[string]struct{} { + methodsMap := make(map[string]map[string]struct{}, len(rules)) + for field, methods := range rules { + mSet := make(map[string]struct{}, len(methods)) + for _, m := range methods { + mSet[m] = struct{}{} + } + methodsMap[field] = mSet + } + return methodsMap +} + +func (v *DefaultRequestValidator) ValidateFilters(filters FilterSetInterface) error { + if filters == nil { + return nil + } + + type wrapper struct { + Fields []string `validate:"filter_allowed_fields"` + } + + w := wrapper{Fields: filters.GetFields()} + if err := v.validate.Struct(w); err != nil { + return fmt.Errorf(errInvalidFilterFields) + } + + return filters.Validate() +} + +func (v *DefaultRequestValidator) ValidateGroupings(groupings GroupingSetInterface) error { + if groupings == nil { + return nil + } + + if err := groupings.Validate(); err != nil { + return err + } + + type wrapper struct { + Fields []string `validate:"grouping_allowed_fields"` + } + + w := wrapper{Fields: groupings.GetFields()} + if err := v.validate.Struct(w); err != nil { + return fmt.Errorf(errInvalidGroupingFields) + } + + if err := v.validate.Struct(groupings); err != nil { + return fmt.Errorf(errInvalidGroupingMethods) + } + + return nil +} + +func (v *DefaultRequestValidator) ValidateOrders(orders OrderSetInterface) error { + if orders == nil { + return nil + } + + if err := orders.Validate(); err != nil { + return err + } + + type wrapper struct { + Fields []string `validate:"orders_allowed_fields"` + } + + w := wrapper{Fields: orders.GetFields()} + if err := v.validate.Struct(w); err != nil { + return fmt.Errorf(errInvalidOrderFields) + } + + if err := v.validate.Struct(orders); err != nil { + return fmt.Errorf(errInvalidOrderMethods) + } + + return nil +} diff --git a/request/validator_test.go b/request/validator_test.go new file mode 100644 index 0000000..69328ee --- /dev/null +++ b/request/validator_test.go @@ -0,0 +1,132 @@ +package request + +import ( + "testing" +) + +type filterSetMock struct { + fields []string + err error +} + +func (f *filterSetMock) GetFields() []string { + return f.fields +} + +func (f *filterSetMock) Validate() error { + return f.err +} + +type groupingSetMock struct { + fields []string + err error +} + +func (g *groupingSetMock) GetFields() []string { + return g.fields +} + +func (g *groupingSetMock) Validate() error { + return g.err +} + +type orderSetMock struct { + fields []string + err error +} + +func (o *orderSetMock) GetFields() []string { + return o.fields +} + +func (o *orderSetMock) Validate() error { + return o.err +} + +func TestDefaultRequestValidator_ValidateFilters(t *testing.T) { + validator := NewDefaultRequestValidatorWithRules( + nil, + []string{"field1", "field2"}, + nil, + ) + + tests := []struct { + name string + filter *filterSetMock + wantError bool + }{ + {"valid fields", &filterSetMock{fields: []string{"field1"}}, false}, + {"invalid field", &filterSetMock{fields: []string{"invalidField"}}, true}, + {"validate error", &filterSetMock{fields: []string{"field1"}, err: errFake{}}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.ValidateFilters(tt.filter) + if (err != nil) != tt.wantError { + t.Errorf("ValidateFilters() error = %v, wantError %v", err, tt.wantError) + } + }) + } +} + +func TestDefaultRequestValidator_ValidateGroupings(t *testing.T) { + validator := NewDefaultRequestValidatorWithRules( + map[string][]string{"field1": {"method1", "method2"}}, + nil, + nil, + ) + + tests := []struct { + name string + grouping *groupingSetMock + wantError bool + }{ + {"valid fields", &groupingSetMock{fields: []string{"field1"}}, false}, + {"invalid field", &groupingSetMock{fields: []string{"invalidField"}}, true}, + {"validate error", &groupingSetMock{fields: []string{"field1"}, err: errFake{}}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.ValidateGroupings(tt.grouping) + if (err != nil) != tt.wantError { + t.Errorf("ValidateGroupings() error = %v, wantError %v", err, tt.wantError) + } + }) + } +} + +func TestDefaultRequestValidator_ValidateOrders(t *testing.T) { + validator := NewDefaultRequestValidatorWithRules( + nil, + nil, + []string{"field1", "field2"}, + ) + + tests := []struct { + name string + order *orderSetMock + wantError bool + }{ + {"valid fields", &orderSetMock{fields: []string{"field1"}}, false}, + {"invalid field", &orderSetMock{fields: []string{"invalidField"}}, true}, + {"validate error", &orderSetMock{fields: []string{"field1"}, err: errFake{}}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.ValidateOrders(tt.order) + if (err != nil) != tt.wantError { + t.Errorf("ValidateOrders() error = %v, wantError %v", err, tt.wantError) + } + }) + } +} + +// errFake Ð´Ð»Ñ Ð¸Ð¼Ð¸Ñ‚Ð°Ñ†Ð¸Ð¸ ошибки в Validate() +type errFake struct{} + +func (e errFake) Error() string { + return "fake error" +} -- GitLab From a401b462d29caed76b99804aed73a10637d18151 Mon Sep 17 00:00:00 2001 From: "popov.dmitriy" <dp@sessia.dev> Date: Fri, 16 May 2025 19:36:06 +0300 Subject: [PATCH 2/2] Fix linters --- request/validator.go | 60 +++++++++++++++++++++++++-------------- request/validator_test.go | 42 +++++++++++---------------- 2 files changed, 56 insertions(+), 46 deletions(-) diff --git a/request/validator.go b/request/validator.go index b375b04..c943c52 100644 --- a/request/validator.go +++ b/request/validator.go @@ -1,6 +1,7 @@ package request import ( + "errors" "fmt" "github.com/go-playground/validator/v10" @@ -38,7 +39,7 @@ type DefaultRequestValidator struct { orderFields map[string]struct{} } -func NewDefaultRequestValidatorWithRules(rules GroupingRules, filterFields, orderFields []string) *DefaultRequestValidator { +func NewDefaultRequestValidatorWithRules(rules GroupingRules, filterFields, orderFields []string) (*DefaultRequestValidator, error) { v := validator.New() validator := &DefaultRequestValidator{ @@ -48,27 +49,38 @@ func NewDefaultRequestValidatorWithRules(rules GroupingRules, filterFields, orde orderFields: toSet(orderFields), } - validator.registerValidations() + if err := validator.registerValidations(); err != nil { + return nil, err + } - return validator + return validator, nil } -func (v *DefaultRequestValidator) registerValidations() { - v.registerGroupingAllowedFieldsValidation() - v.registerGroupingSetStructValidation() - v.registerFilterAllowedFieldsValidation() - v.registerOrderAllowedFieldsValidation() +func (v *DefaultRequestValidator) registerValidations() error { + if err := v.registerGroupingAllowedFieldsValidation(); err != nil { + return fmt.Errorf("register grouping_allowed_fields validation: %w", err) + } + if err := v.registerGroupingSetStructValidation(); err != nil { + return fmt.Errorf("register grouping_set_struct validation: %w", err) + } + if err := v.registerFilterAllowedFieldsValidation(); err != nil { + return fmt.Errorf("register filter_allowed_fields validation: %w", err) + } + if err := v.registerOrderAllowedFieldsValidation(); err != nil { + return fmt.Errorf("register orders_allowed_fields validation: %w", err) + } + return nil } -func (v *DefaultRequestValidator) registerGroupingAllowedFieldsValidation() { +func (v *DefaultRequestValidator) registerGroupingAllowedFieldsValidation() error { allowed := v.groupingRulesKeysSet() - v.validate.RegisterValidation("grouping_allowed_fields", func(fl validator.FieldLevel) bool { + return v.validate.RegisterValidation("grouping_allowed_fields", func(fl validator.FieldLevel) bool { return validateFieldsAllowed(fl, allowed) }) } -func (v *DefaultRequestValidator) registerGroupingSetStructValidation() { +func (v *DefaultRequestValidator) registerGroupingSetStructValidation() error { allowedMethods := buildAllowedMethodsMap(v.groupingRules) v.validate.RegisterStructValidation(func(sl validator.StructLevel) { @@ -89,18 +101,22 @@ func (v *DefaultRequestValidator) registerGroupingSetStructValidation() { } } }, GroupingSet{}) + + return nil } -func (v *DefaultRequestValidator) registerFilterAllowedFieldsValidation() { - v.validate.RegisterValidation("filter_allowed_fields", func(fl validator.FieldLevel) bool { +func (v *DefaultRequestValidator) registerFilterAllowedFieldsValidation() error { + return v.validate.RegisterValidation("filter_allowed_fields", func(fl validator.FieldLevel) bool { return validateFieldsAllowed(fl, v.filterFields) }) } -func (v *DefaultRequestValidator) registerOrderAllowedFieldsValidation() { - v.validate.RegisterValidation("orders_allowed_fields", func(fl validator.FieldLevel) bool { +func (v *DefaultRequestValidator) registerOrderAllowedFieldsValidation() error { + if err := v.validate.RegisterValidation("orders_allowed_fields", func(fl validator.FieldLevel) bool { return validateFieldsAllowed(fl, v.orderFields) - }) + }); err != nil { + return err + } v.validate.RegisterStructValidation(func(sl validator.StructLevel) { orderSet, ok := sl.Current().Interface().(OrderSet) @@ -114,6 +130,8 @@ func (v *DefaultRequestValidator) registerOrderAllowedFieldsValidation() { } } }, OrderSet{}) + + return nil } func validateFieldsAllowed(fl validator.FieldLevel, allowedFields map[string]struct{}) bool { @@ -173,7 +191,7 @@ func (v *DefaultRequestValidator) ValidateFilters(filters FilterSetInterface) er w := wrapper{Fields: filters.GetFields()} if err := v.validate.Struct(w); err != nil { - return fmt.Errorf(errInvalidFilterFields) + return errors.New(errInvalidFilterFields) } return filters.Validate() @@ -194,11 +212,11 @@ func (v *DefaultRequestValidator) ValidateGroupings(groupings GroupingSetInterfa w := wrapper{Fields: groupings.GetFields()} if err := v.validate.Struct(w); err != nil { - return fmt.Errorf(errInvalidGroupingFields) + return errors.New(errInvalidGroupingFields) } if err := v.validate.Struct(groupings); err != nil { - return fmt.Errorf(errInvalidGroupingMethods) + return errors.New(errInvalidGroupingMethods) } return nil @@ -219,11 +237,11 @@ func (v *DefaultRequestValidator) ValidateOrders(orders OrderSetInterface) error w := wrapper{Fields: orders.GetFields()} if err := v.validate.Struct(w); err != nil { - return fmt.Errorf(errInvalidOrderFields) + return errors.New(errInvalidOrderFields) } if err := v.validate.Struct(orders); err != nil { - return fmt.Errorf(errInvalidOrderMethods) + return errors.New(errInvalidOrderMethods) } return nil diff --git a/request/validator_test.go b/request/validator_test.go index 69328ee..1620374 100644 --- a/request/validator_test.go +++ b/request/validator_test.go @@ -1,6 +1,7 @@ package request import ( + "errors" "testing" ) @@ -44,11 +45,10 @@ func (o *orderSetMock) Validate() error { } func TestDefaultRequestValidator_ValidateFilters(t *testing.T) { - validator := NewDefaultRequestValidatorWithRules( - nil, - []string{"field1", "field2"}, - nil, - ) + validator, err := NewDefaultRequestValidatorWithRules(nil, []string{"field1", "field2"}, nil) + if err != nil { + t.Fatalf("failed to create validator: %v", err) + } tests := []struct { name string @@ -57,7 +57,7 @@ func TestDefaultRequestValidator_ValidateFilters(t *testing.T) { }{ {"valid fields", &filterSetMock{fields: []string{"field1"}}, false}, {"invalid field", &filterSetMock{fields: []string{"invalidField"}}, true}, - {"validate error", &filterSetMock{fields: []string{"field1"}, err: errFake{}}, true}, + {"validate error", &filterSetMock{fields: []string{"field1"}, err: errors.New("fail")}, true}, } for _, tt := range tests { @@ -71,11 +71,11 @@ func TestDefaultRequestValidator_ValidateFilters(t *testing.T) { } func TestDefaultRequestValidator_ValidateGroupings(t *testing.T) { - validator := NewDefaultRequestValidatorWithRules( - map[string][]string{"field1": {"method1", "method2"}}, - nil, - nil, - ) + rules := map[string][]string{"field1": {"method1", "method2"}} + validator, err := NewDefaultRequestValidatorWithRules(rules, nil, nil) + if err != nil { + t.Fatalf("failed to create validator: %v", err) + } tests := []struct { name string @@ -84,7 +84,7 @@ func TestDefaultRequestValidator_ValidateGroupings(t *testing.T) { }{ {"valid fields", &groupingSetMock{fields: []string{"field1"}}, false}, {"invalid field", &groupingSetMock{fields: []string{"invalidField"}}, true}, - {"validate error", &groupingSetMock{fields: []string{"field1"}, err: errFake{}}, true}, + {"validate error", &groupingSetMock{fields: []string{"field1"}, err: errors.New("fail")}, true}, } for _, tt := range tests { @@ -98,11 +98,10 @@ func TestDefaultRequestValidator_ValidateGroupings(t *testing.T) { } func TestDefaultRequestValidator_ValidateOrders(t *testing.T) { - validator := NewDefaultRequestValidatorWithRules( - nil, - nil, - []string{"field1", "field2"}, - ) + validator, err := NewDefaultRequestValidatorWithRules(nil, nil, []string{"field1", "field2"}) + if err != nil { + t.Fatalf("failed to create validator: %v", err) + } tests := []struct { name string @@ -111,7 +110,7 @@ func TestDefaultRequestValidator_ValidateOrders(t *testing.T) { }{ {"valid fields", &orderSetMock{fields: []string{"field1"}}, false}, {"invalid field", &orderSetMock{fields: []string{"invalidField"}}, true}, - {"validate error", &orderSetMock{fields: []string{"field1"}, err: errFake{}}, true}, + {"validate error", &orderSetMock{fields: []string{"field1"}, err: errors.New("fail")}, true}, } for _, tt := range tests { @@ -123,10 +122,3 @@ func TestDefaultRequestValidator_ValidateOrders(t *testing.T) { }) } } - -// errFake Ð´Ð»Ñ Ð¸Ð¼Ð¸Ñ‚Ð°Ñ†Ð¸Ð¸ ошибки в Validate() -type errFake struct{} - -func (e errFake) Error() string { - return "fake error" -} -- GitLab