diff --git a/server/channels/utils/merge.go b/server/channels/utils/merge.go index cdeaa6e908..8071ec59ec 100644 --- a/server/channels/utils/merge.go +++ b/server/channels/utils/merge.go @@ -36,31 +36,27 @@ type MergeConfig struct { // retTS := ret.(testStruct) // return &retTS, nil // } -func Merge(base any, patch any, mergeConfig *MergeConfig) (any, error) { - if reflect.TypeOf(base) != reflect.TypeOf(patch) { - return nil, fmt.Errorf( - "cannot merge different types. base type: %s, patch type: %s", - reflect.TypeOf(base), - reflect.TypeOf(patch), - ) - } - +func Merge[T any](base T, patch T, mergeConfig *MergeConfig) (T, error) { commonType := reflect.TypeOf(base) baseVal := reflect.ValueOf(base) patchVal := reflect.ValueOf(patch) - if commonType.Kind() == reflect.Ptr { - commonType = commonType.Elem() - baseVal = baseVal.Elem() - patchVal = patchVal.Elem() - } - ret := reflect.New(commonType) val, ok := merge(baseVal, patchVal, mergeConfig) if ok { ret.Elem().Set(val) } - return ret.Elem().Interface(), nil + + r, ok := ret.Elem().Interface().(T) + if !ok { + return r, fmt.Errorf( + "Unexpected type of return element, expected %s, is %s", + commonType, + reflect.TypeOf(r), + ) + } + + return r, nil } // merge recursively merges patch into base and returns the new struct, ptr, slice/map, or value diff --git a/server/channels/utils/merge_test.go b/server/channels/utils/merge_test.go index d900658bd6..17b6ebe0e9 100644 --- a/server/channels/utils/merge_test.go +++ b/server/channels/utils/merge_test.go @@ -1549,8 +1549,8 @@ func mergeSimple(base, patch simple) (*simple, error) { if err != nil { return nil, err } - retS := ret.(simple) - return &retS, nil + + return &ret, nil } func mergeEvenSimpler(base, patch evenSimpler) (*evenSimpler, error) { @@ -1558,8 +1558,8 @@ func mergeEvenSimpler(base, patch evenSimpler) (*evenSimpler, error) { if err != nil { return nil, err } - retTS := ret.(evenSimpler) - return &retTS, nil + + return &ret, nil } func mergeEvenSimplerWithConfig(base, patch evenSimpler, mergeConfig *utils.MergeConfig) (*evenSimpler, error) { @@ -1567,8 +1567,8 @@ func mergeEvenSimplerWithConfig(base, patch evenSimpler, mergeConfig *utils.Merg if err != nil { return nil, err } - retTS := ret.(evenSimpler) - return &retTS, nil + + return &ret, nil } func mergeSliceStruct(base, patch sliceStruct) (*sliceStruct, error) { @@ -1576,8 +1576,8 @@ func mergeSliceStruct(base, patch sliceStruct) (*sliceStruct, error) { if err != nil { return nil, err } - retTS := ret.(sliceStruct) - return &retTS, nil + + return &ret, nil } func mergeMapPtr(base, patch mapPtr) (*mapPtr, error) { @@ -1585,8 +1585,8 @@ func mergeMapPtr(base, patch mapPtr) (*mapPtr, error) { if err != nil { return nil, err } - retTS := ret.(mapPtr) - return &retTS, nil + + return &ret, nil } func mergeMapPtrState(base, patch mapPtrState) (*mapPtrState, error) { @@ -1594,8 +1594,8 @@ func mergeMapPtrState(base, patch mapPtrState) (*mapPtrState, error) { if err != nil { return nil, err } - retTS := ret.(mapPtrState) - return &retTS, nil + + return &ret, nil } func mergeMapPtrState2(base, patch mapPtrState2) (*mapPtrState2, error) { @@ -1603,8 +1603,8 @@ func mergeMapPtrState2(base, patch mapPtrState2) (*mapPtrState2, error) { if err != nil { return nil, err } - retTS := ret.(mapPtrState2) - return &retTS, nil + + return &ret, nil } func mergeTestStructs(base, patch testStruct) (*testStruct, error) { @@ -1612,71 +1612,36 @@ func mergeTestStructs(base, patch testStruct) (*testStruct, error) { if err != nil { return nil, err } - retTS := ret.(testStruct) - return &retTS, nil + + return &ret, nil } func mergeStringIntMap(base, patch map[string]int) (map[string]int, error) { - ret, err := utils.Merge(base, patch, nil) - if err != nil { - return nil, err - } - retTS := ret.(map[string]int) - return retTS, nil + return utils.Merge(base, patch, nil) } func mergeStringPtrIntMap(base, patch map[string]*int) (map[string]*int, error) { - ret, err := utils.Merge(base, patch, nil) - if err != nil { - return nil, err - } - retTS := ret.(map[string]*int) - return retTS, nil + return utils.Merge(base, patch, nil) } func mergeStringSliceIntMap(base, patch map[string][]int) (map[string][]int, error) { - ret, err := utils.Merge(base, patch, nil) - if err != nil { - return nil, err - } - retTS := ret.(map[string][]int) - return retTS, nil + return utils.Merge(base, patch, nil) } func mergeMapOfMap(base, patch map[string]map[string]*int) (map[string]map[string]*int, error) { - ret, err := utils.Merge(base, patch, nil) - if err != nil { - return nil, err - } - retTS := ret.(map[string]map[string]*int) - return retTS, nil + return utils.Merge(base, patch, nil) } func mergeInterfaceMap(base, patch map[string]any) (map[string]any, error) { - ret, err := utils.Merge(base, patch, nil) - if err != nil { - return nil, err - } - retTS := ret.(map[string]any) - return retTS, nil + return utils.Merge(base, patch, nil) } func mergeStringSlices(base, patch []string) ([]string, error) { - ret, err := utils.Merge(base, patch, nil) - if err != nil { - return nil, err - } - retTS := ret.([]string) - return retTS, nil + return utils.Merge(base, patch, nil) } func mergeTestStructsPtrs(base, patch *testStruct) (*testStruct, error) { - ret, err := utils.Merge(base, patch, nil) - if err != nil { - return nil, err - } - retTS := ret.(testStruct) - return &retTS, nil + return utils.Merge(base, patch, nil) } func newBool(b bool) *bool { return &b } diff --git a/server/config/utils.go b/server/config/utils.go index ba8de8de95..53a74c9f19 100644 --- a/server/config/utils.go +++ b/server/config/utils.go @@ -164,13 +164,7 @@ func FixInvalidLocales(cfg *model.Config) bool { // Merge merges two configs together. The receiver's values are overwritten with the patch's // values except when the patch's values are nil. func Merge(cfg *model.Config, patch *model.Config, mergeConfig *utils.MergeConfig) (*model.Config, error) { - ret, err := utils.Merge(cfg, patch, mergeConfig) - if err != nil { - return nil, err - } - - retCfg := ret.(model.Config) - return &retCfg, nil + return utils.Merge(cfg, patch, mergeConfig) } func IsDatabaseDSN(dsn string) bool {