github.com/khulnasoft/cli@v0.0.0-20240402070845-01bcad7beefa/cli/compose/loader/merge.go (about)

     1  // FIXME(thaJeztah): remove once we are a module; the go:build directive prevents go from downgrading language version to go1.16:
     2  //go:build go1.19
     3  
     4  package loader
     5  
     6  import (
     7  	"reflect"
     8  	"sort"
     9  
    10  	"dario.cat/mergo"
    11  	"github.com/khulnasoft/cli/cli/compose/types"
    12  	"github.com/pkg/errors"
    13  )
    14  
    15  type specials struct {
    16  	m map[reflect.Type]func(dst, src reflect.Value) error
    17  }
    18  
    19  func (s *specials) Transformer(t reflect.Type) func(dst, src reflect.Value) error {
    20  	if fn, ok := s.m[t]; ok {
    21  		return fn
    22  	}
    23  	return nil
    24  }
    25  
    26  func merge(configs []*types.Config) (*types.Config, error) {
    27  	base := configs[0]
    28  	for _, override := range configs[1:] {
    29  		var err error
    30  		base.Services, err = mergeServices(base.Services, override.Services)
    31  		if err != nil {
    32  			return base, errors.Wrapf(err, "cannot merge services from %s", override.Filename)
    33  		}
    34  		base.Volumes, err = mergeVolumes(base.Volumes, override.Volumes)
    35  		if err != nil {
    36  			return base, errors.Wrapf(err, "cannot merge volumes from %s", override.Filename)
    37  		}
    38  		base.Networks, err = mergeNetworks(base.Networks, override.Networks)
    39  		if err != nil {
    40  			return base, errors.Wrapf(err, "cannot merge networks from %s", override.Filename)
    41  		}
    42  		base.Secrets, err = mergeSecrets(base.Secrets, override.Secrets)
    43  		if err != nil {
    44  			return base, errors.Wrapf(err, "cannot merge secrets from %s", override.Filename)
    45  		}
    46  		base.Configs, err = mergeConfigs(base.Configs, override.Configs)
    47  		if err != nil {
    48  			return base, errors.Wrapf(err, "cannot merge configs from %s", override.Filename)
    49  		}
    50  	}
    51  	return base, nil
    52  }
    53  
    54  func mergeServices(base, override []types.ServiceConfig) ([]types.ServiceConfig, error) {
    55  	baseServices := mapByName(base)
    56  	overrideServices := mapByName(override)
    57  	specials := &specials{
    58  		m: map[reflect.Type]func(dst, src reflect.Value) error{
    59  			reflect.TypeOf(&types.LoggingConfig{}):           safelyMerge(mergeLoggingConfig),
    60  			reflect.TypeOf([]types.ServicePortConfig{}):      mergeSlice(toServicePortConfigsMap, toServicePortConfigsSlice),
    61  			reflect.TypeOf([]types.ServiceSecretConfig{}):    mergeSlice(toServiceSecretConfigsMap, toServiceSecretConfigsSlice),
    62  			reflect.TypeOf([]types.ServiceConfigObjConfig{}): mergeSlice(toServiceConfigObjConfigsMap, toSServiceConfigObjConfigsSlice),
    63  			reflect.TypeOf(&types.UlimitsConfig{}):           mergeUlimitsConfig,
    64  			reflect.TypeOf([]types.ServiceVolumeConfig{}):    mergeSlice(toServiceVolumeConfigsMap, toServiceVolumeConfigsSlice),
    65  			reflect.TypeOf(types.ShellCommand{}):             mergeShellCommand,
    66  			reflect.TypeOf(&types.ServiceNetworkConfig{}):    mergeServiceNetworkConfig,
    67  			reflect.PointerTo(reflect.TypeOf(uint64(1))):     mergeUint64,
    68  		},
    69  	}
    70  	for name, overrideService := range overrideServices {
    71  		overrideService := overrideService
    72  		if baseService, ok := baseServices[name]; ok {
    73  			if err := mergo.Merge(&baseService, &overrideService, mergo.WithAppendSlice, mergo.WithOverride, mergo.WithTransformers(specials)); err != nil {
    74  				return base, errors.Wrapf(err, "cannot merge service %s", name)
    75  			}
    76  			baseServices[name] = baseService
    77  			continue
    78  		}
    79  		baseServices[name] = overrideService
    80  	}
    81  	services := []types.ServiceConfig{}
    82  	for _, baseService := range baseServices {
    83  		services = append(services, baseService)
    84  	}
    85  	sort.Slice(services, func(i, j int) bool { return services[i].Name < services[j].Name })
    86  	return services, nil
    87  }
    88  
    89  func toServiceSecretConfigsMap(s any) (map[any]any, error) {
    90  	secrets, ok := s.([]types.ServiceSecretConfig)
    91  	if !ok {
    92  		return nil, errors.Errorf("not a serviceSecretConfig: %v", s)
    93  	}
    94  	m := map[any]any{}
    95  	for _, secret := range secrets {
    96  		m[secret.Source] = secret
    97  	}
    98  	return m, nil
    99  }
   100  
   101  func toServiceConfigObjConfigsMap(s any) (map[any]any, error) {
   102  	secrets, ok := s.([]types.ServiceConfigObjConfig)
   103  	if !ok {
   104  		return nil, errors.Errorf("not a serviceSecretConfig: %v", s)
   105  	}
   106  	m := map[any]any{}
   107  	for _, secret := range secrets {
   108  		m[secret.Source] = secret
   109  	}
   110  	return m, nil
   111  }
   112  
   113  func toServicePortConfigsMap(s any) (map[any]any, error) {
   114  	ports, ok := s.([]types.ServicePortConfig)
   115  	if !ok {
   116  		return nil, errors.Errorf("not a servicePortConfig slice: %v", s)
   117  	}
   118  	m := map[any]any{}
   119  	for _, p := range ports {
   120  		m[p.Published] = p
   121  	}
   122  	return m, nil
   123  }
   124  
   125  func toServiceVolumeConfigsMap(s any) (map[any]any, error) {
   126  	volumes, ok := s.([]types.ServiceVolumeConfig)
   127  	if !ok {
   128  		return nil, errors.Errorf("not a serviceVolumeConfig slice: %v", s)
   129  	}
   130  	m := map[any]any{}
   131  	for _, v := range volumes {
   132  		m[v.Target] = v
   133  	}
   134  	return m, nil
   135  }
   136  
   137  func toServiceSecretConfigsSlice(dst reflect.Value, m map[any]any) error {
   138  	s := []types.ServiceSecretConfig{}
   139  	for _, v := range m {
   140  		s = append(s, v.(types.ServiceSecretConfig))
   141  	}
   142  	sort.Slice(s, func(i, j int) bool { return s[i].Source < s[j].Source })
   143  	dst.Set(reflect.ValueOf(s))
   144  	return nil
   145  }
   146  
   147  func toSServiceConfigObjConfigsSlice(dst reflect.Value, m map[any]any) error {
   148  	s := []types.ServiceConfigObjConfig{}
   149  	for _, v := range m {
   150  		s = append(s, v.(types.ServiceConfigObjConfig))
   151  	}
   152  	sort.Slice(s, func(i, j int) bool { return s[i].Source < s[j].Source })
   153  	dst.Set(reflect.ValueOf(s))
   154  	return nil
   155  }
   156  
   157  func toServicePortConfigsSlice(dst reflect.Value, m map[any]any) error {
   158  	s := []types.ServicePortConfig{}
   159  	for _, v := range m {
   160  		s = append(s, v.(types.ServicePortConfig))
   161  	}
   162  	sort.Slice(s, func(i, j int) bool { return s[i].Published < s[j].Published })
   163  	dst.Set(reflect.ValueOf(s))
   164  	return nil
   165  }
   166  
   167  func toServiceVolumeConfigsSlice(dst reflect.Value, m map[any]any) error {
   168  	s := []types.ServiceVolumeConfig{}
   169  	for _, v := range m {
   170  		s = append(s, v.(types.ServiceVolumeConfig))
   171  	}
   172  	sort.Slice(s, func(i, j int) bool { return s[i].Target < s[j].Target })
   173  	dst.Set(reflect.ValueOf(s))
   174  	return nil
   175  }
   176  
   177  type (
   178  	tomapFn             func(s any) (map[any]any, error)
   179  	writeValueFromMapFn func(reflect.Value, map[any]any) error
   180  )
   181  
   182  func safelyMerge(mergeFn func(dst, src reflect.Value) error) func(dst, src reflect.Value) error {
   183  	return func(dst, src reflect.Value) error {
   184  		if src.IsNil() {
   185  			return nil
   186  		}
   187  		if dst.IsNil() {
   188  			dst.Set(src)
   189  			return nil
   190  		}
   191  		return mergeFn(dst, src)
   192  	}
   193  }
   194  
   195  func mergeSlice(tomap tomapFn, writeValue writeValueFromMapFn) func(dst, src reflect.Value) error {
   196  	return func(dst, src reflect.Value) error {
   197  		dstMap, err := sliceToMap(tomap, dst)
   198  		if err != nil {
   199  			return err
   200  		}
   201  		srcMap, err := sliceToMap(tomap, src)
   202  		if err != nil {
   203  			return err
   204  		}
   205  		if err := mergo.Map(&dstMap, srcMap, mergo.WithOverride); err != nil {
   206  			return err
   207  		}
   208  		return writeValue(dst, dstMap)
   209  	}
   210  }
   211  
   212  func sliceToMap(tomap tomapFn, v reflect.Value) (map[any]any, error) {
   213  	// check if valid
   214  	if !v.IsValid() {
   215  		return nil, errors.Errorf("invalid value : %+v", v)
   216  	}
   217  	return tomap(v.Interface())
   218  }
   219  
   220  func mergeLoggingConfig(dst, src reflect.Value) error {
   221  	// Same driver, merging options
   222  	if getLoggingDriver(dst.Elem()) == getLoggingDriver(src.Elem()) ||
   223  		getLoggingDriver(dst.Elem()) == "" || getLoggingDriver(src.Elem()) == "" {
   224  		if getLoggingDriver(dst.Elem()) == "" {
   225  			dst.Elem().FieldByName("Driver").SetString(getLoggingDriver(src.Elem()))
   226  		}
   227  		dstOptions := dst.Elem().FieldByName("Options").Interface().(map[string]string)
   228  		srcOptions := src.Elem().FieldByName("Options").Interface().(map[string]string)
   229  		return mergo.Merge(&dstOptions, srcOptions, mergo.WithOverride)
   230  	}
   231  	// Different driver, override with src
   232  	dst.Set(src)
   233  	return nil
   234  }
   235  
   236  //nolint:unparam
   237  func mergeUlimitsConfig(dst, src reflect.Value) error {
   238  	if src.Interface() != reflect.Zero(reflect.TypeOf(src.Interface())).Interface() {
   239  		dst.Elem().Set(src.Elem())
   240  	}
   241  	return nil
   242  }
   243  
   244  //nolint:unparam
   245  func mergeShellCommand(dst, src reflect.Value) error {
   246  	if src.Len() != 0 {
   247  		dst.Set(src)
   248  	}
   249  	return nil
   250  }
   251  
   252  //nolint:unparam
   253  func mergeServiceNetworkConfig(dst, src reflect.Value) error {
   254  	if src.Interface() != reflect.Zero(reflect.TypeOf(src.Interface())).Interface() {
   255  		dst.Elem().FieldByName("Aliases").Set(src.Elem().FieldByName("Aliases"))
   256  		if ipv4 := src.Elem().FieldByName("Ipv4Address").Interface().(string); ipv4 != "" {
   257  			dst.Elem().FieldByName("Ipv4Address").SetString(ipv4)
   258  		}
   259  		if ipv6 := src.Elem().FieldByName("Ipv6Address").Interface().(string); ipv6 != "" {
   260  			dst.Elem().FieldByName("Ipv6Address").SetString(ipv6)
   261  		}
   262  	}
   263  	return nil
   264  }
   265  
   266  //nolint:unparam
   267  func mergeUint64(dst, src reflect.Value) error {
   268  	if !src.IsNil() {
   269  		dst.Elem().Set(src.Elem())
   270  	}
   271  	return nil
   272  }
   273  
   274  func getLoggingDriver(v reflect.Value) string {
   275  	return v.FieldByName("Driver").String()
   276  }
   277  
   278  func mapByName(services []types.ServiceConfig) map[string]types.ServiceConfig {
   279  	m := map[string]types.ServiceConfig{}
   280  	for _, service := range services {
   281  		m[service.Name] = service
   282  	}
   283  	return m
   284  }
   285  
   286  func mergeVolumes(base, override map[string]types.VolumeConfig) (map[string]types.VolumeConfig, error) {
   287  	err := mergo.Map(&base, &override, mergo.WithOverride)
   288  	return base, err
   289  }
   290  
   291  func mergeNetworks(base, override map[string]types.NetworkConfig) (map[string]types.NetworkConfig, error) {
   292  	err := mergo.Map(&base, &override, mergo.WithOverride)
   293  	return base, err
   294  }
   295  
   296  func mergeSecrets(base, override map[string]types.SecretConfig) (map[string]types.SecretConfig, error) {
   297  	err := mergo.Map(&base, &override, mergo.WithOverride)
   298  	return base, err
   299  }
   300  
   301  func mergeConfigs(base, override map[string]types.ConfigObjConfig) (map[string]types.ConfigObjConfig, error) {
   302  	err := mergo.Map(&base, &override, mergo.WithOverride)
   303  	return base, err
   304  }