github.com/boki/go-xmp@v1.0.1/xmp/typeinfo.go (about)

     1  // Copyright (c) 2017-2018 Alexander Eichhorn
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"): you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
    11  // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
    12  // License for the specific language governing permissions and limitations
    13  // under the License.
    14  
    15  package xmp
    16  
    17  import (
    18  	"encoding"
    19  	"fmt"
    20  	"reflect"
    21  	"strings"
    22  	"sync"
    23  )
    24  
    25  // typeInfo holds details for the xml representation of a type.
    26  type typeInfo struct {
    27  	fields []fieldInfo
    28  }
    29  
    30  // fieldInfo holds details for the xmp representation of a single field.
    31  type fieldInfo struct {
    32  	idx        []int
    33  	name       string
    34  	minVersion Version
    35  	maxVersion Version
    36  	flags      fieldFlags
    37  }
    38  
    39  func (f fieldInfo) String() string {
    40  	s := []string{fmt.Sprintf("field %s (%v)", f.name, f.idx)}
    41  	if !f.minVersion.IsZero() {
    42  		s = append(s, "vmin", f.minVersion.String())
    43  	}
    44  	if !f.maxVersion.IsZero() {
    45  		s = append(s, "vmax", f.maxVersion.String())
    46  	}
    47  	if f.flags&fAttr > 0 {
    48  		s = append(s, "Attr")
    49  	}
    50  	if f.flags&fEmpty > 0 {
    51  		s = append(s, "Empty")
    52  	}
    53  	if f.flags&fOmit > 0 {
    54  		s = append(s, "Omit")
    55  	}
    56  	if f.flags&fAny > 0 {
    57  		s = append(s, "Any")
    58  	}
    59  	if f.flags&fFlat > 0 {
    60  		s = append(s, "Flat")
    61  	}
    62  	if f.flags&fArray > 0 {
    63  		s = append(s, "Array")
    64  	}
    65  	if f.flags&fBinaryMarshal > 0 {
    66  		s = append(s, "BinaryMarshal")
    67  	}
    68  	if f.flags&fBinaryUnmarshal > 0 {
    69  		s = append(s, "BinaryUnmarshal")
    70  	}
    71  	if f.flags&fTextMarshal > 0 {
    72  		s = append(s, "TextMarshal")
    73  	}
    74  	if f.flags&fTextUnmarshal > 0 {
    75  		s = append(s, "TextUnmarshal")
    76  	}
    77  	if f.flags&fMarshal > 0 {
    78  		s = append(s, "Marshal")
    79  	}
    80  	if f.flags&fUnmarshal > 0 {
    81  		s = append(s, "Unmarshal")
    82  	}
    83  	return strings.Join(s, " ")
    84  }
    85  
    86  type fieldFlags int
    87  
    88  const (
    89  	fElement fieldFlags = 1 << iota
    90  	fAttr
    91  	fEmpty
    92  	fOmit
    93  	fAny
    94  	fFlat
    95  	fArray
    96  	fBinaryMarshal
    97  	fBinaryUnmarshal
    98  	fTextMarshal
    99  	fTextUnmarshal
   100  	fMarshal
   101  	fUnmarshal
   102  	fMarshalAttr
   103  	fUnmarshalAttr
   104  	fMode = fElement | fAttr | fEmpty | fOmit | fAny | fFlat | fArray | fBinaryMarshal | fBinaryUnmarshal | fTextMarshal | fTextUnmarshal | fMarshal | fUnmarshal | fMarshalAttr | fUnmarshalAttr
   105  )
   106  
   107  type tinfoMap map[reflect.Type]*typeInfo
   108  
   109  var tinfoNsMap = make(map[string]tinfoMap)
   110  var tinfoLock sync.RWMutex
   111  
   112  var (
   113  	binaryUnmarshalerType = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem()
   114  	binaryMarshalerType   = reflect.TypeOf((*encoding.BinaryMarshaler)(nil)).Elem()
   115  	textUnmarshalerType   = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
   116  	textMarshalerType     = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
   117  	marshalerType         = reflect.TypeOf((*Marshaler)(nil)).Elem()
   118  	unmarshalerType       = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
   119  	attrMarshalerType     = reflect.TypeOf((*MarshalerAttr)(nil)).Elem()
   120  	attrUnmarshalerType   = reflect.TypeOf((*UnmarshalerAttr)(nil)).Elem()
   121  	arrayType             = reflect.TypeOf((*Array)(nil)).Elem()
   122  	zeroType              = reflect.TypeOf((*Zero)(nil)).Elem()
   123  	stringerType          = reflect.TypeOf((*fmt.Stringer)(nil)).Elem()
   124  )
   125  
   126  // getTypeInfo returns the typeInfo structure with details necessary
   127  // for marshaling and unmarshaling typ.
   128  func getTypeInfo(typ reflect.Type, ns string) (*typeInfo, error) {
   129  	if ns == "" {
   130  		ns = "xmp"
   131  	}
   132  	tinfoLock.RLock()
   133  	m, ok := tinfoNsMap[ns]
   134  	if !ok {
   135  		m = make(tinfoMap)
   136  		tinfoLock.RUnlock()
   137  		tinfoLock.Lock()
   138  		tinfoNsMap[ns] = m
   139  		tinfoLock.Unlock()
   140  		tinfoLock.RLock()
   141  	}
   142  	tinfo, ok := m[typ]
   143  	tinfoLock.RUnlock()
   144  	if ok {
   145  		return tinfo, nil
   146  	}
   147  	tinfo = &typeInfo{}
   148  	if typ.Kind() != reflect.Struct {
   149  		return nil, fmt.Errorf("xmp: type %s is not a struct", typ.String())
   150  	}
   151  	n := typ.NumField()
   152  	for i := 0; i < n; i++ {
   153  		f := typ.Field(i)
   154  		if (f.PkgPath != "" && !f.Anonymous) || f.Tag.Get(ns) == "-" {
   155  			continue // Private field
   156  		}
   157  
   158  		// For embedded structs, embed its fields.
   159  		if f.Anonymous {
   160  			t := f.Type
   161  			if t.Kind() == reflect.Ptr {
   162  				t = t.Elem()
   163  			}
   164  			if t.Kind() == reflect.Struct {
   165  				inner, err := getTypeInfo(t, ns)
   166  				if err != nil {
   167  					return nil, err
   168  				}
   169  				for _, finfo := range inner.fields {
   170  					finfo.idx = append([]int{i}, finfo.idx...)
   171  					if err := addFieldInfo(typ, tinfo, &finfo, ns); err != nil {
   172  						return nil, err
   173  					}
   174  				}
   175  				continue
   176  			}
   177  		}
   178  
   179  		finfo, err := structFieldInfo(typ, &f, ns)
   180  		if err != nil {
   181  			return nil, err
   182  		}
   183  
   184  		// Add the field if it doesn't conflict with other fields.
   185  		if err := addFieldInfo(typ, tinfo, finfo, ns); err != nil {
   186  			return nil, err
   187  		}
   188  	}
   189  	tinfoLock.Lock()
   190  	m[typ] = tinfo
   191  	tinfoLock.Unlock()
   192  	return tinfo, nil
   193  }
   194  
   195  // structFieldInfo builds and returns a fieldInfo for f.
   196  func structFieldInfo(typ reflect.Type, f *reflect.StructField, ns string) (*fieldInfo, error) {
   197  	finfo := &fieldInfo{idx: f.Index}
   198  	// Split the tag from the xml namespace if necessary.
   199  	tag := f.Tag.Get(ns)
   200  
   201  	// Parse flags.
   202  	tokens := strings.Split(tag, ",")
   203  	if len(tokens) == 1 {
   204  		finfo.flags = fElement
   205  	} else {
   206  		tag = tokens[0]
   207  		for _, flag := range tokens[1:] {
   208  			switch flag {
   209  			case "attr":
   210  				finfo.flags |= fAttr
   211  			case "empty":
   212  				finfo.flags |= fEmpty
   213  			case "omit":
   214  				finfo.flags |= fOmit
   215  			case "any":
   216  				finfo.flags |= fAny
   217  			case "flat":
   218  				finfo.flags |= fFlat
   219  			}
   220  
   221  			// dissect version(s)
   222  			//   v1.0     - only write in version v1.0
   223  			//   v1.0+    - starting at and after v1.0
   224  			//   v1.0-    - only write before and including v1.0
   225  			//   v1.0<1.2 - write from v1.0 until v1.2
   226  			if strings.HasPrefix(flag, "v") {
   227  				flag = flag[1:]
   228  				var op rune
   229  				tokens := strings.FieldsFunc(flag, func(r rune) bool {
   230  					switch r {
   231  					case '+', '-', '<':
   232  						op = r
   233  						return true
   234  					default:
   235  						return false
   236  					}
   237  				})
   238  				var err error
   239  				switch op {
   240  				case '+':
   241  					finfo.minVersion, err = ParseVersion(tokens[0])
   242  				case '-':
   243  					finfo.maxVersion, err = ParseVersion(tokens[0])
   244  				case '<':
   245  					finfo.minVersion, err = ParseVersion(tokens[0])
   246  					if err == nil {
   247  						finfo.maxVersion, err = ParseVersion(tokens[1])
   248  					}
   249  				default:
   250  					finfo.minVersion, err = ParseVersion(flag)
   251  					if err == nil {
   252  						finfo.maxVersion, err = ParseVersion(flag)
   253  					}
   254  				}
   255  
   256  				if err != nil {
   257  					return nil, fmt.Errorf("invalid %s version on field %s of type %s (%q): %v", ns, f.Name, typ, f.Tag.Get(ns), err)
   258  				}
   259  			}
   260  		}
   261  
   262  		// When any flag except `attr` is used it defaults to `element`
   263  		if finfo.flags&fAttr == 0 {
   264  			finfo.flags |= fElement
   265  		}
   266  	}
   267  
   268  	if tag != "" {
   269  		finfo.name = tag
   270  	} else {
   271  		// Use field name as default.
   272  		finfo.name = f.Name
   273  	}
   274  
   275  	// add static type info about interfaces the type implements
   276  	if f.Type.Implements(arrayType) {
   277  		finfo.flags |= fArray
   278  	}
   279  	if f.Type.Implements(binaryUnmarshalerType) {
   280  		finfo.flags |= fBinaryUnmarshal
   281  	}
   282  	if f.Type.Implements(binaryMarshalerType) {
   283  		finfo.flags |= fBinaryMarshal
   284  	}
   285  	if f.Type.Implements(textUnmarshalerType) {
   286  		finfo.flags |= fTextUnmarshal
   287  	}
   288  	if f.Type.Implements(textMarshalerType) {
   289  		finfo.flags |= fTextMarshal
   290  	}
   291  	if f.Type.Implements(unmarshalerType) {
   292  		finfo.flags |= fUnmarshal
   293  	}
   294  	if f.Type.Implements(marshalerType) {
   295  		finfo.flags |= fMarshal
   296  	}
   297  	if f.Type.Implements(attrUnmarshalerType) {
   298  		finfo.flags |= fUnmarshalAttr
   299  	}
   300  	if f.Type.Implements(attrMarshalerType) {
   301  		finfo.flags |= fMarshalAttr
   302  	}
   303  
   304  	return finfo, nil
   305  }
   306  
   307  func addFieldInfo(typ reflect.Type, tinfo *typeInfo, newf *fieldInfo, ns string) error {
   308  	var conflicts []int
   309  	// Find all conflicts.
   310  	for i := range tinfo.fields {
   311  		oldf := &tinfo.fields[i]
   312  
   313  		// Same name is a conflict unless versions don't overlap.
   314  		if newf.name == oldf.name {
   315  			if !newf.minVersion.Between(oldf.minVersion, oldf.maxVersion) {
   316  				continue
   317  			}
   318  			if !newf.maxVersion.Between(oldf.minVersion, oldf.maxVersion) {
   319  				continue
   320  			}
   321  			conflicts = append(conflicts, i)
   322  		}
   323  	}
   324  
   325  	// Return the first error.
   326  	for _, i := range conflicts {
   327  		oldf := &tinfo.fields[i]
   328  		f1 := typ.FieldByIndex(oldf.idx)
   329  		f2 := typ.FieldByIndex(newf.idx)
   330  		return fmt.Errorf("xmp: %s field %q with tag %q conflicts with field %q with tag %q", typ, f1.Name, f1.Tag.Get(ns), f2.Name, f2.Tag.Get(ns))
   331  	}
   332  
   333  	// Without conflicts, add the new field and return.
   334  	tinfo.fields = append(tinfo.fields, *newf)
   335  	return nil
   336  }
   337  
   338  // value returns v's field value corresponding to finfo.
   339  // It's equivalent to v.FieldByIndex(finfo.idx), but initializes
   340  // and dereferences pointers as necessary.
   341  func (finfo *fieldInfo) value(v reflect.Value) reflect.Value {
   342  	for i, x := range finfo.idx {
   343  		if i > 0 {
   344  			t := v.Type()
   345  			if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
   346  				if v.IsNil() {
   347  					v.Set(reflect.New(v.Type().Elem()))
   348  				}
   349  				v = v.Elem()
   350  			}
   351  		}
   352  		v = v.Field(x)
   353  	}
   354  
   355  	return v
   356  }
   357  
   358  // Load value from interface, but only if the result will be
   359  // usefully addressable.
   360  func derefIndirect(v interface{}) reflect.Value {
   361  	return derefValue(reflect.ValueOf(v))
   362  }
   363  
   364  func derefValue(val reflect.Value) reflect.Value {
   365  	if val.Kind() == reflect.Interface && !val.IsNil() {
   366  		e := val.Elem()
   367  		if e.Kind() == reflect.Ptr && !e.IsNil() {
   368  			val = e
   369  		}
   370  	}
   371  
   372  	if val.Kind() == reflect.Ptr {
   373  		if val.IsNil() {
   374  			val.Set(reflect.New(val.Type().Elem()))
   375  		}
   376  		val = val.Elem()
   377  	}
   378  	return val
   379  }