github.com/sacloud/iaas-api-go@v1.12.0/mapconv/mapconv.go (about)

     1  // Copyright 2022-2023 The sacloud/iaas-api-go Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mapconv
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"reflect"
    21  	"strings"
    22  
    23  	"github.com/fatih/structs"
    24  	"github.com/mitchellh/mapstructure"
    25  	"github.com/sacloud/packages-go/objutil"
    26  )
    27  
    28  // DefaultMapConvTag デフォルトのmapconvタグ名
    29  const DefaultMapConvTag = "mapconv"
    30  
    31  // DecoderConfig mapconvでの変換の設定
    32  type DecoderConfig struct {
    33  	TagName     string
    34  	FilterFuncs map[string]FilterFunc
    35  }
    36  
    37  // FilterFunc mapconvでの変換時に適用するフィルタ
    38  type FilterFunc func(v interface{}) (interface{}, error)
    39  
    40  // TagInfo mapconvタグの情報
    41  type TagInfo struct {
    42  	Ignore       bool
    43  	SourceFields []string
    44  	Filters      []string
    45  	DefaultValue interface{}
    46  	OmitEmpty    bool
    47  	Recursive    bool
    48  	Squash       bool
    49  	IsSlice      bool
    50  }
    51  
    52  // Decoder mapconvでの変換
    53  type Decoder struct {
    54  	Config *DecoderConfig
    55  }
    56  
    57  func (d *Decoder) ConvertTo(source interface{}, dest interface{}) error {
    58  	s := structs.New(source)
    59  	mappedValues := Map(make(map[string]interface{}))
    60  
    61  	// recursiveの際に参照するためのdestのmap
    62  	destValues := Map(make(map[string]interface{}))
    63  	if structs.IsStruct(dest) {
    64  		destValues = Map(structs.Map(dest))
    65  	}
    66  
    67  	fields := s.Fields()
    68  	for _, f := range fields {
    69  		if !f.IsExported() {
    70  			continue
    71  		}
    72  
    73  		tags := d.ParseMapConvTag(f.Tag(d.Config.TagName))
    74  		if tags.Ignore {
    75  			continue
    76  		}
    77  		for _, key := range tags.SourceFields {
    78  			destKey := f.Name()
    79  			value := f.Value()
    80  
    81  			if key != "" {
    82  				destKey = key
    83  			}
    84  			if f.IsZero() {
    85  				if tags.OmitEmpty {
    86  					continue
    87  				}
    88  				if tags.DefaultValue != nil {
    89  					value = tags.DefaultValue
    90  				}
    91  			}
    92  
    93  			for _, filter := range tags.Filters {
    94  				filterFunc, ok := d.Config.FilterFuncs[filter]
    95  				if !ok {
    96  					return fmt.Errorf("filter %s not exists", filter)
    97  				}
    98  				filtered, err := filterFunc(value)
    99  				if err != nil {
   100  					return fmt.Errorf("failed to apply the filter: %s", err)
   101  				}
   102  				value = filtered
   103  			}
   104  
   105  			if tags.Squash {
   106  				dest := Map(make(map[string]interface{}))
   107  				err := d.ConvertTo(value, &dest)
   108  				if err != nil {
   109  					return err
   110  				}
   111  				for k, v := range dest {
   112  					mappedValues.Set(k, v)
   113  				}
   114  				continue
   115  			}
   116  
   117  			if tags.Recursive {
   118  				current, err := destValues.Get(destKey)
   119  				if err != nil {
   120  					return err
   121  				}
   122  
   123  				var dest []interface{}
   124  				values := valueToSlice(value)
   125  				currentValues := valueToSlice(current)
   126  				for i, v := range values {
   127  					if structs.IsStruct(v) {
   128  						var currentDest interface{}
   129  						if len(currentValues) > i {
   130  							currentDest = currentValues[i]
   131  						}
   132  						destMap := Map(make(map[string]interface{}))
   133  						if err := d.ConvertTo(v, &destMap); err != nil {
   134  							return err
   135  						}
   136  						// 宛先が存在しstructであれば(map[string]interface{}になっているはずなので)マージする
   137  						if currentDest != nil {
   138  							mv, ok := currentDest.(map[string]interface{})
   139  							// 元の値から空の値を除去する(structs:",omitempty"でも可)
   140  							for k, v := range mv {
   141  								if objutil.IsEmpty(v) {
   142  									delete(mv, k)
   143  								}
   144  							}
   145  							if ok {
   146  								for k, v := range destMap.Map() {
   147  									mv[k] = v
   148  								}
   149  								destMap = Map(mv)
   150  							}
   151  						}
   152  						dest = append(dest, destMap)
   153  					} else {
   154  						dest = append(dest, v)
   155  					}
   156  				}
   157  				if tags.IsSlice || dest == nil || len(dest) > 1 {
   158  					value = dest
   159  				} else {
   160  					value = dest[0]
   161  				}
   162  			}
   163  
   164  			mappedValues.Set(destKey, value)
   165  		}
   166  	}
   167  
   168  	config := &mapstructure.DecoderConfig{
   169  		WeaklyTypedInput: true,
   170  		Result:           dest,
   171  		ZeroFields:       true,
   172  	}
   173  	decoder, err := mapstructure.NewDecoder(config)
   174  	if err != nil {
   175  		return err
   176  	}
   177  	return decoder.Decode(mappedValues.Map())
   178  }
   179  
   180  func (d *Decoder) ConvertFrom(source interface{}, dest interface{}) error {
   181  	var sourceMap Map
   182  	if m, ok := source.(map[string]interface{}); ok {
   183  		sourceMap = Map(m)
   184  	} else {
   185  		sourceMap = Map(structs.New(source).Map())
   186  	}
   187  	destMap := Map(make(map[string]interface{}))
   188  
   189  	s := structs.New(dest)
   190  	fields := s.Fields()
   191  	for _, f := range fields {
   192  		if !f.IsExported() {
   193  			continue
   194  		}
   195  
   196  		tags := d.ParseMapConvTag(f.Tag(d.Config.TagName))
   197  		if tags.Ignore {
   198  			continue
   199  		}
   200  		if tags.Squash {
   201  			return errors.New("ConvertFrom is not allowed squash")
   202  		}
   203  		for _, key := range tags.SourceFields {
   204  			sourceKey := f.Name()
   205  			if key != "" {
   206  				sourceKey = key
   207  			}
   208  
   209  			value, err := sourceMap.Get(sourceKey)
   210  			if err != nil {
   211  				return err
   212  			}
   213  			if value == nil || reflect.ValueOf(value).IsZero() {
   214  				continue
   215  			}
   216  
   217  			for _, filter := range tags.Filters {
   218  				filterFunc, ok := d.Config.FilterFuncs[filter]
   219  				if !ok {
   220  					return fmt.Errorf("filter %s not exists", filter)
   221  				}
   222  				filtered, err := filterFunc(value)
   223  				if err != nil {
   224  					return fmt.Errorf("failed to apply the filter: %s", err)
   225  				}
   226  				value = filtered
   227  			}
   228  
   229  			if tags.Recursive {
   230  				t := reflect.TypeOf(f.Value())
   231  				if t.Kind() == reflect.Slice {
   232  					t = t.Elem().Elem()
   233  				} else {
   234  					t = t.Elem()
   235  				}
   236  
   237  				var dest []interface{}
   238  				values := valueToSlice(value)
   239  				for _, v := range values {
   240  					if v == nil {
   241  						dest = append(dest, v)
   242  						continue
   243  					}
   244  					dt := reflect.New(t).Interface()
   245  					if err := d.ConvertFrom(v, dt); err != nil {
   246  						return err
   247  					}
   248  					dest = append(dest, dt)
   249  				}
   250  
   251  				if dest != nil {
   252  					if tags.IsSlice || len(dest) > 1 {
   253  						value = dest
   254  					} else {
   255  						value = dest[0]
   256  					}
   257  				}
   258  			}
   259  
   260  			destMap.Set(f.Name(), value)
   261  		}
   262  	}
   263  	config := &mapstructure.DecoderConfig{
   264  		WeaklyTypedInput: true,
   265  		Result:           dest,
   266  		ZeroFields:       true,
   267  	}
   268  	decoder, err := mapstructure.NewDecoder(config)
   269  	if err != nil {
   270  		return err
   271  	}
   272  	return decoder.Decode(destMap.Map())
   273  }
   274  
   275  // ConvertTo converts struct which input by mapconv to plain models
   276  func ConvertTo(source interface{}, dest interface{}) error {
   277  	decoder := &Decoder{Config: &DecoderConfig{TagName: DefaultMapConvTag}}
   278  	return decoder.ConvertTo(source, dest)
   279  }
   280  
   281  // ConvertFrom converts struct which input by mapconv from plain models
   282  func ConvertFrom(source interface{}, dest interface{}) error {
   283  	decoder := &Decoder{Config: &DecoderConfig{TagName: DefaultMapConvTag}}
   284  	return decoder.ConvertFrom(source, dest)
   285  }
   286  
   287  // ParseMapConvTag mapconvタグを文字列で受け取りパースしてTagInfoを返す
   288  func (d *Decoder) ParseMapConvTag(tagBody string) TagInfo {
   289  	tokens := strings.Split(tagBody, ",")
   290  	key := strings.TrimSpace(tokens[0])
   291  
   292  	keys := strings.Split(key, "/")
   293  	var defaultValue interface{}
   294  	var filters []string
   295  	var ignore, omitEmpty, recursive, squash, isSlice bool
   296  
   297  	for _, k := range keys {
   298  		if k == "-" {
   299  			ignore = true
   300  			break
   301  		}
   302  		if strings.Contains(k, "[]") {
   303  			isSlice = true
   304  		}
   305  	}
   306  
   307  	for i, token := range tokens {
   308  		if i == 0 {
   309  			continue
   310  		}
   311  
   312  		token = strings.TrimSpace(token)
   313  
   314  		switch {
   315  		case strings.HasPrefix(token, "omitempty"):
   316  			omitEmpty = true
   317  		case strings.HasPrefix(token, "recursive"):
   318  			recursive = true
   319  		case strings.HasPrefix(token, "squash"):
   320  			squash = true
   321  		case strings.HasPrefix(token, "filters"):
   322  			keyValue := strings.Split(token, "=")
   323  			if len(keyValue) > 1 {
   324  				filters = strings.Split(strings.Join(keyValue[1:], ""), " ")
   325  			}
   326  		case strings.HasPrefix(token, "default"):
   327  			keyValue := strings.Split(token, "=")
   328  			if len(keyValue) > 1 {
   329  				defaultValue = strings.Join(keyValue[1:], "")
   330  			}
   331  		}
   332  	}
   333  	return TagInfo{
   334  		Ignore:       ignore,
   335  		SourceFields: keys,
   336  		DefaultValue: defaultValue,
   337  		OmitEmpty:    omitEmpty,
   338  		Recursive:    recursive,
   339  		Squash:       squash,
   340  		IsSlice:      isSlice,
   341  		Filters:      filters,
   342  	}
   343  }