github.com/masterhung0112/hk_server/v5@v5.0.0-20220302090640-ec71aef15e1c/utils/merge.go (about)

     1  package utils
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  )
     7  
     8  // StructFieldFilter defines a callback function used to decide if a patch value should be applied.
     9  type StructFieldFilter func(structField reflect.StructField, base reflect.Value, patch reflect.Value) bool
    10  
    11  // MergeConfig allows for optional merge customizations.
    12  type MergeConfig struct {
    13  	StructFieldFilter StructFieldFilter
    14  }
    15  
    16  // Merge will return a new value of the same type as base and patch, recursively merging non-nil values from patch on top of base.
    17  //
    18  // Restrictions/guarantees:
    19  //   - base and patch must be the same type
    20  //   - base and patch will never be modified
    21  //   - values from patch are always selected when non-nil
    22  //   - structs are merged recursively
    23  //   - maps and slices are treated as pointers, and merged as a single value
    24  //
    25  // Note that callers need to cast the returned interface back into the original type:
    26  // func mergeTestStruct(base, patch *testStruct) (*testStruct, error) {
    27  //     ret, err := merge(base, patch)
    28  //     if err != nil {
    29  //         return nil, err
    30  //     }
    31  //
    32  //     retTS := ret.(testStruct)
    33  //     return &retTS, nil
    34  // }
    35  func Merge(base interface{}, patch interface{}, mergeConfig *MergeConfig) (interface{}, error) {
    36  	if reflect.TypeOf(base) != reflect.TypeOf(patch) {
    37  		return nil, fmt.Errorf(
    38  			"cannot merge different types. base type: %s, patch type: %s",
    39  			reflect.TypeOf(base),
    40  			reflect.TypeOf(patch),
    41  		)
    42  	}
    43  
    44  	commonType := reflect.TypeOf(base)
    45  	baseVal := reflect.ValueOf(base)
    46  	patchVal := reflect.ValueOf(patch)
    47  	if commonType.Kind() == reflect.Ptr {
    48  		commonType = commonType.Elem()
    49  		baseVal = baseVal.Elem()
    50  		patchVal = patchVal.Elem()
    51  	}
    52  
    53  	ret := reflect.New(commonType)
    54  
    55  	val, ok := merge(baseVal, patchVal, mergeConfig)
    56  	if ok {
    57  		ret.Elem().Set(val)
    58  	}
    59  	return ret.Elem().Interface(), nil
    60  }
    61  
    62  // merge recursively merges patch into base and returns the new struct, ptr, slice/map, or value
    63  func merge(base, patch reflect.Value, mergeConfig *MergeConfig) (reflect.Value, bool) {
    64  	commonType := base.Type()
    65  
    66  	switch commonType.Kind() {
    67  	case reflect.Struct:
    68  		merged := reflect.New(commonType).Elem()
    69  		for i := 0; i < base.NumField(); i++ {
    70  			if !merged.Field(i).CanSet() {
    71  				continue
    72  			}
    73  			if mergeConfig != nil && mergeConfig.StructFieldFilter != nil {
    74  				if !mergeConfig.StructFieldFilter(commonType.Field(i), base.Field(i), patch.Field(i)) {
    75  					merged.Field(i).Set(base.Field(i))
    76  					continue
    77  				}
    78  			}
    79  			val, ok := merge(base.Field(i), patch.Field(i), mergeConfig)
    80  			if ok {
    81  				merged.Field(i).Set(val)
    82  			}
    83  		}
    84  		return merged, true
    85  
    86  	case reflect.Ptr:
    87  		mergedPtr := reflect.New(commonType.Elem())
    88  		if base.IsNil() && patch.IsNil() {
    89  			return mergedPtr, false
    90  		}
    91  
    92  		// clone reference values (if any)
    93  		if base.IsNil() {
    94  			val, _ := merge(patch.Elem(), patch.Elem(), mergeConfig)
    95  			mergedPtr.Elem().Set(val)
    96  		} else if patch.IsNil() {
    97  			val, _ := merge(base.Elem(), base.Elem(), mergeConfig)
    98  			mergedPtr.Elem().Set(val)
    99  		} else {
   100  			val, _ := merge(base.Elem(), patch.Elem(), mergeConfig)
   101  			mergedPtr.Elem().Set(val)
   102  		}
   103  		return mergedPtr, true
   104  
   105  	case reflect.Slice:
   106  		if base.IsNil() && patch.IsNil() {
   107  			return reflect.Zero(commonType), false
   108  		}
   109  		if !patch.IsNil() {
   110  			// use patch
   111  			merged := reflect.MakeSlice(commonType, 0, patch.Len())
   112  			for i := 0; i < patch.Len(); i++ {
   113  				// recursively merge patch with itself. This will clone reference values.
   114  				val, _ := merge(patch.Index(i), patch.Index(i), mergeConfig)
   115  				merged = reflect.Append(merged, val)
   116  			}
   117  			return merged, true
   118  		}
   119  		// use base
   120  		merged := reflect.MakeSlice(commonType, 0, base.Len())
   121  		for i := 0; i < base.Len(); i++ {
   122  
   123  			// recursively merge base with itself. This will clone reference values.
   124  			val, _ := merge(base.Index(i), base.Index(i), mergeConfig)
   125  			merged = reflect.Append(merged, val)
   126  		}
   127  		return merged, true
   128  
   129  	case reflect.Map:
   130  		// maps are merged according to these rules:
   131  		// - if patch is not nil, replace the base map completely
   132  		// - otherwise, keep the base map
   133  		// - reference values (eg. slice/ptr/map) will be cloned
   134  		if base.IsNil() && patch.IsNil() {
   135  			return reflect.Zero(commonType), false
   136  		}
   137  		merged := reflect.MakeMap(commonType)
   138  		mapPtr := base
   139  		if !patch.IsNil() {
   140  			mapPtr = patch
   141  		}
   142  		for _, key := range mapPtr.MapKeys() {
   143  			// clone reference values
   144  			val, ok := merge(mapPtr.MapIndex(key), mapPtr.MapIndex(key), mergeConfig)
   145  			if !ok {
   146  				val = reflect.New(mapPtr.MapIndex(key).Type()).Elem()
   147  			}
   148  			merged.SetMapIndex(key, val)
   149  		}
   150  		return merged, true
   151  
   152  	case reflect.Interface:
   153  		var val reflect.Value
   154  		if base.IsNil() && patch.IsNil() {
   155  			return reflect.Zero(commonType), false
   156  		}
   157  
   158  		// clone reference values (if any)
   159  		if base.IsNil() {
   160  			val, _ = merge(patch.Elem(), patch.Elem(), mergeConfig)
   161  		} else if patch.IsNil() {
   162  			val, _ = merge(base.Elem(), base.Elem(), mergeConfig)
   163  		} else {
   164  			val, _ = merge(base.Elem(), patch.Elem(), mergeConfig)
   165  		}
   166  		return val, true
   167  
   168  	default:
   169  		return patch, true
   170  	}
   171  }