github.com/panekj/cli@v0.0.0-20230304125325-467dd2f3797e/cli/compose/loader/merge.go (about)

     1  package loader
     2  
     3  import (
     4  	"reflect"
     5  	"sort"
     6  
     7  	"github.com/docker/cli/cli/compose/types"
     8  	"github.com/imdario/mergo"
     9  	"github.com/pkg/errors"
    10  )
    11  
    12  type specials struct {
    13  	m map[reflect.Type]func(dst, src reflect.Value) error
    14  }
    15  
    16  func (s *specials) Transformer(t reflect.Type) func(dst, src reflect.Value) error {
    17  	if fn, ok := s.m[t]; ok {
    18  		return fn
    19  	}
    20  	return nil
    21  }
    22  
    23  func merge(configs []*types.Config) (*types.Config, error) {
    24  	base := configs[0]
    25  	for _, override := range configs[1:] {
    26  		var err error
    27  		base.Services, err = mergeServices(base.Services, override.Services)
    28  		if err != nil {
    29  			return base, errors.Wrapf(err, "cannot merge services from %s", override.Filename)
    30  		}
    31  		base.Volumes, err = mergeVolumes(base.Volumes, override.Volumes)
    32  		if err != nil {
    33  			return base, errors.Wrapf(err, "cannot merge volumes from %s", override.Filename)
    34  		}
    35  		base.Networks, err = mergeNetworks(base.Networks, override.Networks)
    36  		if err != nil {
    37  			return base, errors.Wrapf(err, "cannot merge networks from %s", override.Filename)
    38  		}
    39  		base.Secrets, err = mergeSecrets(base.Secrets, override.Secrets)
    40  		if err != nil {
    41  			return base, errors.Wrapf(err, "cannot merge secrets from %s", override.Filename)
    42  		}
    43  		base.Configs, err = mergeConfigs(base.Configs, override.Configs)
    44  		if err != nil {
    45  			return base, errors.Wrapf(err, "cannot merge configs from %s", override.Filename)
    46  		}
    47  	}
    48  	return base, nil
    49  }
    50  
    51  func mergeServices(base, override []types.ServiceConfig) ([]types.ServiceConfig, error) {
    52  	baseServices := mapByName(base)
    53  	overrideServices := mapByName(override)
    54  	specials := &specials{
    55  		m: map[reflect.Type]func(dst, src reflect.Value) error{
    56  			reflect.TypeOf(&types.LoggingConfig{}):           safelyMerge(mergeLoggingConfig),
    57  			reflect.TypeOf([]types.ServicePortConfig{}):      mergeSlice(toServicePortConfigsMap, toServicePortConfigsSlice),
    58  			reflect.TypeOf([]types.ServiceSecretConfig{}):    mergeSlice(toServiceSecretConfigsMap, toServiceSecretConfigsSlice),
    59  			reflect.TypeOf([]types.ServiceConfigObjConfig{}): mergeSlice(toServiceConfigObjConfigsMap, toSServiceConfigObjConfigsSlice),
    60  			reflect.TypeOf(&types.UlimitsConfig{}):           mergeUlimitsConfig,
    61  			reflect.TypeOf([]types.ServiceVolumeConfig{}):    mergeSlice(toServiceVolumeConfigsMap, toServiceVolumeConfigsSlice),
    62  			reflect.TypeOf(types.ShellCommand{}):             mergeShellCommand,
    63  			reflect.TypeOf(&types.ServiceNetworkConfig{}):    mergeServiceNetworkConfig,
    64  			reflect.PointerTo(reflect.TypeOf(uint64(1))):     mergeUint64,
    65  		},
    66  	}
    67  	for name, overrideService := range overrideServices {
    68  		overrideService := overrideService
    69  		if baseService, ok := baseServices[name]; ok {
    70  			if err := mergo.Merge(&baseService, &overrideService, mergo.WithAppendSlice, mergo.WithOverride, mergo.WithTransformers(specials)); err != nil {
    71  				return base, errors.Wrapf(err, "cannot merge service %s", name)
    72  			}
    73  			baseServices[name] = baseService
    74  			continue
    75  		}
    76  		baseServices[name] = overrideService
    77  	}
    78  	services := []types.ServiceConfig{}
    79  	for _, baseService := range baseServices {
    80  		services = append(services, baseService)
    81  	}
    82  	sort.Slice(services, func(i, j int) bool { return services[i].Name < services[j].Name })
    83  	return services, nil
    84  }
    85  
    86  func toServiceSecretConfigsMap(s interface{}) (map[interface{}]interface{}, error) {
    87  	secrets, ok := s.([]types.ServiceSecretConfig)
    88  	if !ok {
    89  		return nil, errors.Errorf("not a serviceSecretConfig: %v", s)
    90  	}
    91  	m := map[interface{}]interface{}{}
    92  	for _, secret := range secrets {
    93  		m[secret.Source] = secret
    94  	}
    95  	return m, nil
    96  }
    97  
    98  func toServiceConfigObjConfigsMap(s interface{}) (map[interface{}]interface{}, error) {
    99  	secrets, ok := s.([]types.ServiceConfigObjConfig)
   100  	if !ok {
   101  		return nil, errors.Errorf("not a serviceSecretConfig: %v", s)
   102  	}
   103  	m := map[interface{}]interface{}{}
   104  	for _, secret := range secrets {
   105  		m[secret.Source] = secret
   106  	}
   107  	return m, nil
   108  }
   109  
   110  func toServicePortConfigsMap(s interface{}) (map[interface{}]interface{}, error) {
   111  	ports, ok := s.([]types.ServicePortConfig)
   112  	if !ok {
   113  		return nil, errors.Errorf("not a servicePortConfig slice: %v", s)
   114  	}
   115  	m := map[interface{}]interface{}{}
   116  	for _, p := range ports {
   117  		m[p.Published] = p
   118  	}
   119  	return m, nil
   120  }
   121  
   122  func toServiceVolumeConfigsMap(s interface{}) (map[interface{}]interface{}, error) {
   123  	volumes, ok := s.([]types.ServiceVolumeConfig)
   124  	if !ok {
   125  		return nil, errors.Errorf("not a serviceVolumeConfig slice: %v", s)
   126  	}
   127  	m := map[interface{}]interface{}{}
   128  	for _, v := range volumes {
   129  		m[v.Target] = v
   130  	}
   131  	return m, nil
   132  }
   133  
   134  func toServiceSecretConfigsSlice(dst reflect.Value, m map[interface{}]interface{}) error {
   135  	s := []types.ServiceSecretConfig{}
   136  	for _, v := range m {
   137  		s = append(s, v.(types.ServiceSecretConfig))
   138  	}
   139  	sort.Slice(s, func(i, j int) bool { return s[i].Source < s[j].Source })
   140  	dst.Set(reflect.ValueOf(s))
   141  	return nil
   142  }
   143  
   144  func toSServiceConfigObjConfigsSlice(dst reflect.Value, m map[interface{}]interface{}) error {
   145  	s := []types.ServiceConfigObjConfig{}
   146  	for _, v := range m {
   147  		s = append(s, v.(types.ServiceConfigObjConfig))
   148  	}
   149  	sort.Slice(s, func(i, j int) bool { return s[i].Source < s[j].Source })
   150  	dst.Set(reflect.ValueOf(s))
   151  	return nil
   152  }
   153  
   154  func toServicePortConfigsSlice(dst reflect.Value, m map[interface{}]interface{}) error {
   155  	s := []types.ServicePortConfig{}
   156  	for _, v := range m {
   157  		s = append(s, v.(types.ServicePortConfig))
   158  	}
   159  	sort.Slice(s, func(i, j int) bool { return s[i].Published < s[j].Published })
   160  	dst.Set(reflect.ValueOf(s))
   161  	return nil
   162  }
   163  
   164  func toServiceVolumeConfigsSlice(dst reflect.Value, m map[interface{}]interface{}) error {
   165  	s := []types.ServiceVolumeConfig{}
   166  	for _, v := range m {
   167  		s = append(s, v.(types.ServiceVolumeConfig))
   168  	}
   169  	sort.Slice(s, func(i, j int) bool { return s[i].Target < s[j].Target })
   170  	dst.Set(reflect.ValueOf(s))
   171  	return nil
   172  }
   173  
   174  type (
   175  	tomapFn             func(s interface{}) (map[interface{}]interface{}, error)
   176  	writeValueFromMapFn func(reflect.Value, map[interface{}]interface{}) error
   177  )
   178  
   179  func safelyMerge(mergeFn func(dst, src reflect.Value) error) func(dst, src reflect.Value) error {
   180  	return func(dst, src reflect.Value) error {
   181  		if src.IsNil() {
   182  			return nil
   183  		}
   184  		if dst.IsNil() {
   185  			dst.Set(src)
   186  			return nil
   187  		}
   188  		return mergeFn(dst, src)
   189  	}
   190  }
   191  
   192  func mergeSlice(tomap tomapFn, writeValue writeValueFromMapFn) func(dst, src reflect.Value) error {
   193  	return func(dst, src reflect.Value) error {
   194  		dstMap, err := sliceToMap(tomap, dst)
   195  		if err != nil {
   196  			return err
   197  		}
   198  		srcMap, err := sliceToMap(tomap, src)
   199  		if err != nil {
   200  			return err
   201  		}
   202  		if err := mergo.Map(&dstMap, srcMap, mergo.WithOverride); err != nil {
   203  			return err
   204  		}
   205  		return writeValue(dst, dstMap)
   206  	}
   207  }
   208  
   209  func sliceToMap(tomap tomapFn, v reflect.Value) (map[interface{}]interface{}, error) {
   210  	// check if valid
   211  	if !v.IsValid() {
   212  		return nil, errors.Errorf("invalid value : %+v", v)
   213  	}
   214  	return tomap(v.Interface())
   215  }
   216  
   217  func mergeLoggingConfig(dst, src reflect.Value) error {
   218  	// Same driver, merging options
   219  	if getLoggingDriver(dst.Elem()) == getLoggingDriver(src.Elem()) ||
   220  		getLoggingDriver(dst.Elem()) == "" || getLoggingDriver(src.Elem()) == "" {
   221  		if getLoggingDriver(dst.Elem()) == "" {
   222  			dst.Elem().FieldByName("Driver").SetString(getLoggingDriver(src.Elem()))
   223  		}
   224  		dstOptions := dst.Elem().FieldByName("Options").Interface().(map[string]string)
   225  		srcOptions := src.Elem().FieldByName("Options").Interface().(map[string]string)
   226  		return mergo.Merge(&dstOptions, srcOptions, mergo.WithOverride)
   227  	}
   228  	// Different driver, override with src
   229  	dst.Set(src)
   230  	return nil
   231  }
   232  
   233  //nolint:unparam
   234  func mergeUlimitsConfig(dst, src reflect.Value) error {
   235  	if src.Interface() != reflect.Zero(reflect.TypeOf(src.Interface())).Interface() {
   236  		dst.Elem().Set(src.Elem())
   237  	}
   238  	return nil
   239  }
   240  
   241  //nolint:unparam
   242  func mergeShellCommand(dst, src reflect.Value) error {
   243  	if src.Len() != 0 {
   244  		dst.Set(src)
   245  	}
   246  	return nil
   247  }
   248  
   249  //nolint:unparam
   250  func mergeServiceNetworkConfig(dst, src reflect.Value) error {
   251  	if src.Interface() != reflect.Zero(reflect.TypeOf(src.Interface())).Interface() {
   252  		dst.Elem().FieldByName("Aliases").Set(src.Elem().FieldByName("Aliases"))
   253  		if ipv4 := src.Elem().FieldByName("Ipv4Address").Interface().(string); ipv4 != "" {
   254  			dst.Elem().FieldByName("Ipv4Address").SetString(ipv4)
   255  		}
   256  		if ipv6 := src.Elem().FieldByName("Ipv6Address").Interface().(string); ipv6 != "" {
   257  			dst.Elem().FieldByName("Ipv6Address").SetString(ipv6)
   258  		}
   259  	}
   260  	return nil
   261  }
   262  
   263  //nolint:unparam
   264  func mergeUint64(dst, src reflect.Value) error {
   265  	if !src.IsNil() {
   266  		dst.Elem().Set(src.Elem())
   267  	}
   268  	return nil
   269  }
   270  
   271  func getLoggingDriver(v reflect.Value) string {
   272  	return v.FieldByName("Driver").String()
   273  }
   274  
   275  func mapByName(services []types.ServiceConfig) map[string]types.ServiceConfig {
   276  	m := map[string]types.ServiceConfig{}
   277  	for _, service := range services {
   278  		m[service.Name] = service
   279  	}
   280  	return m
   281  }
   282  
   283  func mergeVolumes(base, override map[string]types.VolumeConfig) (map[string]types.VolumeConfig, error) {
   284  	err := mergo.Map(&base, &override, mergo.WithOverride)
   285  	return base, err
   286  }
   287  
   288  func mergeNetworks(base, override map[string]types.NetworkConfig) (map[string]types.NetworkConfig, error) {
   289  	err := mergo.Map(&base, &override, mergo.WithOverride)
   290  	return base, err
   291  }
   292  
   293  func mergeSecrets(base, override map[string]types.SecretConfig) (map[string]types.SecretConfig, error) {
   294  	err := mergo.Map(&base, &override, mergo.WithOverride)
   295  	return base, err
   296  }
   297  
   298  func mergeConfigs(base, override map[string]types.ConfigObjConfig) (map[string]types.ConfigObjConfig, error) {
   299  	err := mergo.Map(&base, &override, mergo.WithOverride)
   300  	return base, err
   301  }