github.com/sacloud/libsacloud/v2@v2.32.3/pkg/mapconv/mapconv.go (about)

     1  // Copyright 2016-2022 The Libsacloud Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mapconv
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"reflect"
    21  	"strings"
    22  
    23  	"github.com/sacloud/libsacloud/v2/pkg/util"
    24  
    25  	"github.com/fatih/structs"
    26  	"github.com/mitchellh/mapstructure"
    27  )
    28  
    29  // DefaultMapConvTag デフォルトのmapconvタグ名
    30  const DefaultMapConvTag = "mapconv"
    31  
    32  // DecoderConfig mapconvでの変換の設定
    33  type DecoderConfig struct {
    34  	TagName     string
    35  	FilterFuncs map[string]FilterFunc
    36  }
    37  
    38  // FilterFunc mapconvでの変換時に適用するフィルタ
    39  type FilterFunc func(v interface{}) (interface{}, error)
    40  
    41  // TagInfo mapconvタグの情報
    42  type TagInfo struct {
    43  	Ignore       bool
    44  	SourceFields []string
    45  	Filters      []string
    46  	DefaultValue interface{}
    47  	OmitEmpty    bool
    48  	Recursive    bool
    49  	Squash       bool
    50  	IsSlice      bool
    51  }
    52  
    53  // Decoder mapconvでの変換
    54  type Decoder struct {
    55  	Config *DecoderConfig
    56  }
    57  
    58  func (d *Decoder) ConvertTo(source interface{}, dest interface{}) error {
    59  	s := structs.New(source)
    60  	mappedValues := Map(make(map[string]interface{}))
    61  
    62  	// recursiveの際に参照するためのdestのmap
    63  	destValues := Map(make(map[string]interface{}))
    64  	if structs.IsStruct(dest) {
    65  		destValues = Map(structs.Map(dest))
    66  	}
    67  
    68  	fields := s.Fields()
    69  	for _, f := range fields {
    70  		if !f.IsExported() {
    71  			continue
    72  		}
    73  
    74  		tags := d.ParseMapConvTag(f.Tag(d.Config.TagName))
    75  		if tags.Ignore {
    76  			continue
    77  		}
    78  		for _, key := range tags.SourceFields {
    79  			destKey := f.Name()
    80  			value := f.Value()
    81  
    82  			if key != "" {
    83  				destKey = key
    84  			}
    85  			if f.IsZero() {
    86  				if tags.OmitEmpty {
    87  					continue
    88  				}
    89  				if tags.DefaultValue != nil {
    90  					value = tags.DefaultValue
    91  				}
    92  			}
    93  
    94  			for _, filter := range tags.Filters {
    95  				filterFunc, ok := d.Config.FilterFuncs[filter]
    96  				if !ok {
    97  					return fmt.Errorf("filter %s not exists", filter)
    98  				}
    99  				filtered, err := filterFunc(value)
   100  				if err != nil {
   101  					return fmt.Errorf("failed to apply the filter: %s", err)
   102  				}
   103  				value = filtered
   104  			}
   105  
   106  			if tags.Squash {
   107  				dest := Map(make(map[string]interface{}))
   108  				err := d.ConvertTo(value, &dest)
   109  				if err != nil {
   110  					return err
   111  				}
   112  				for k, v := range dest {
   113  					mappedValues.Set(k, v)
   114  				}
   115  				continue
   116  			}
   117  
   118  			if tags.Recursive {
   119  				current, err := destValues.Get(destKey)
   120  				if err != nil {
   121  					return err
   122  				}
   123  
   124  				var dest []interface{}
   125  				values := valueToSlice(value)
   126  				currentValues := valueToSlice(current)
   127  				for i, v := range values {
   128  					if structs.IsStruct(v) {
   129  						var currentDest interface{}
   130  						if len(currentValues) > i {
   131  							currentDest = currentValues[i]
   132  						}
   133  						destMap := Map(make(map[string]interface{}))
   134  						if err := d.ConvertTo(v, &destMap); err != nil {
   135  							return err
   136  						}
   137  						// 宛先が存在しstructであれば(map[string]interface{}になっているはずなので)マージする
   138  						if currentDest != nil {
   139  							mv, ok := currentDest.(map[string]interface{})
   140  							// 元の値から空の値を除去する(structs:",omitempty"でも可)
   141  							for k, v := range mv {
   142  								if util.IsEmpty(v) {
   143  									delete(mv, k)
   144  								}
   145  							}
   146  							if ok {
   147  								for k, v := range destMap.Map() {
   148  									mv[k] = v
   149  								}
   150  								destMap = Map(mv)
   151  							}
   152  						}
   153  						dest = append(dest, destMap)
   154  					} else {
   155  						dest = append(dest, v)
   156  					}
   157  				}
   158  				if tags.IsSlice || dest == nil || len(dest) > 1 {
   159  					value = dest
   160  				} else {
   161  					value = dest[0]
   162  				}
   163  			}
   164  
   165  			mappedValues.Set(destKey, value)
   166  		}
   167  	}
   168  
   169  	config := &mapstructure.DecoderConfig{
   170  		WeaklyTypedInput: true,
   171  		Result:           dest,
   172  		ZeroFields:       true,
   173  	}
   174  	decoder, err := mapstructure.NewDecoder(config)
   175  	if err != nil {
   176  		return err
   177  	}
   178  	return decoder.Decode(mappedValues.Map())
   179  }
   180  
   181  func (d *Decoder) ConvertFrom(source interface{}, dest interface{}) error {
   182  	var sourceMap Map
   183  	if m, ok := source.(map[string]interface{}); ok {
   184  		sourceMap = Map(m)
   185  	} else {
   186  		sourceMap = Map(structs.New(source).Map())
   187  	}
   188  	destMap := Map(make(map[string]interface{}))
   189  
   190  	s := structs.New(dest)
   191  	fields := s.Fields()
   192  	for _, f := range fields {
   193  		if !f.IsExported() {
   194  			continue
   195  		}
   196  
   197  		tags := d.ParseMapConvTag(f.Tag(d.Config.TagName))
   198  		if tags.Ignore {
   199  			continue
   200  		}
   201  		if tags.Squash {
   202  			return errors.New("ConvertFrom is not allowed squash")
   203  		}
   204  		for _, key := range tags.SourceFields {
   205  			sourceKey := f.Name()
   206  			if key != "" {
   207  				sourceKey = key
   208  			}
   209  
   210  			value, err := sourceMap.Get(sourceKey)
   211  			if err != nil {
   212  				return err
   213  			}
   214  			if value == nil || reflect.ValueOf(value).IsZero() {
   215  				continue
   216  			}
   217  
   218  			for _, filter := range tags.Filters {
   219  				filterFunc, ok := d.Config.FilterFuncs[filter]
   220  				if !ok {
   221  					return fmt.Errorf("filter %s not exists", filter)
   222  				}
   223  				filtered, err := filterFunc(value)
   224  				if err != nil {
   225  					return fmt.Errorf("failed to apply the filter: %s", err)
   226  				}
   227  				value = filtered
   228  			}
   229  
   230  			if tags.Recursive {
   231  				t := reflect.TypeOf(f.Value())
   232  				if t.Kind() == reflect.Slice {
   233  					t = t.Elem().Elem()
   234  				} else {
   235  					t = t.Elem()
   236  				}
   237  
   238  				var dest []interface{}
   239  				values := valueToSlice(value)
   240  				for _, v := range values {
   241  					if v == nil {
   242  						dest = append(dest, v)
   243  						continue
   244  					}
   245  					dt := reflect.New(t).Interface()
   246  					if err := d.ConvertFrom(v, dt); err != nil {
   247  						return err
   248  					}
   249  					dest = append(dest, dt)
   250  				}
   251  
   252  				if dest != nil {
   253  					if tags.IsSlice || len(dest) > 1 {
   254  						value = dest
   255  					} else {
   256  						value = dest[0]
   257  					}
   258  				}
   259  			}
   260  
   261  			destMap.Set(f.Name(), value)
   262  		}
   263  	}
   264  	config := &mapstructure.DecoderConfig{
   265  		WeaklyTypedInput: true,
   266  		Result:           dest,
   267  		ZeroFields:       true,
   268  	}
   269  	decoder, err := mapstructure.NewDecoder(config)
   270  	if err != nil {
   271  		return err
   272  	}
   273  	return decoder.Decode(destMap.Map())
   274  }
   275  
   276  // ConvertTo converts struct which input by mapconv to plain models
   277  func ConvertTo(source interface{}, dest interface{}) error {
   278  	decoder := &Decoder{Config: &DecoderConfig{TagName: DefaultMapConvTag}}
   279  	return decoder.ConvertTo(source, dest)
   280  }
   281  
   282  // ConvertFrom converts struct which input by mapconv from plain models
   283  func ConvertFrom(source interface{}, dest interface{}) error {
   284  	decoder := &Decoder{Config: &DecoderConfig{TagName: DefaultMapConvTag}}
   285  	return decoder.ConvertFrom(source, dest)
   286  }
   287  
   288  // ParseMapConvTag mapconvタグを文字列で受け取りパースしてTagInfoを返す
   289  func (d *Decoder) ParseMapConvTag(tagBody string) TagInfo {
   290  	tokens := strings.Split(tagBody, ",")
   291  	key := strings.TrimSpace(tokens[0])
   292  
   293  	keys := strings.Split(key, "/")
   294  	var defaultValue interface{}
   295  	var filters []string
   296  	var ignore, omitEmpty, recursive, squash, isSlice bool
   297  
   298  	for _, k := range keys {
   299  		if k == "-" {
   300  			ignore = true
   301  			break
   302  		}
   303  		if strings.Contains(k, "[]") {
   304  			isSlice = true
   305  		}
   306  	}
   307  
   308  	for i, token := range tokens {
   309  		if i == 0 {
   310  			continue
   311  		}
   312  
   313  		token = strings.TrimSpace(token)
   314  
   315  		switch {
   316  		case strings.HasPrefix(token, "omitempty"):
   317  			omitEmpty = true
   318  		case strings.HasPrefix(token, "recursive"):
   319  			recursive = true
   320  		case strings.HasPrefix(token, "squash"):
   321  			squash = true
   322  		case strings.HasPrefix(token, "filters"):
   323  			keyValue := strings.Split(token, "=")
   324  			if len(keyValue) > 1 {
   325  				filters = strings.Split(strings.Join(keyValue[1:], ""), " ")
   326  			}
   327  		case strings.HasPrefix(token, "default"):
   328  			keyValue := strings.Split(token, "=")
   329  			if len(keyValue) > 1 {
   330  				defaultValue = strings.Join(keyValue[1:], "")
   331  			}
   332  		}
   333  	}
   334  	return TagInfo{
   335  		Ignore:       ignore,
   336  		SourceFields: keys,
   337  		DefaultValue: defaultValue,
   338  		OmitEmpty:    omitEmpty,
   339  		Recursive:    recursive,
   340  		Squash:       squash,
   341  		IsSlice:      isSlice,
   342  		Filters:      filters,
   343  	}
   344  }