github.com/boki/go-xmp@v1.0.1/xmp/native.go (about)

     1  // Copyright (c) 2017-2018 Alexander Eichhorn
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"): you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
    11  // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
    12  // License for the specific language governing permissions and limitations
    13  // under the License.
    14  
    15  package xmp
    16  
    17  import (
    18  	"encoding"
    19  	"encoding/xml"
    20  	"fmt"
    21  	"reflect"
    22  	"strings"
    23  )
    24  
    25  type Tag struct {
    26  	Key   string `xmp:"-" json:"key,omitempty"`
    27  	Value string `xmp:"-" json:"value,omitempty"`
    28  	Lang  string `xmp:"-" json:"lang,omitempty"`
    29  }
    30  
    31  type TagList []Tag
    32  
    33  func (x TagList) MarshalXMP(e *Encoder, node *Node, m Model) error {
    34  	if len(x) == 0 {
    35  		return nil
    36  	}
    37  	for _, v := range x {
    38  		name := xml.Name{Local: v.Key}
    39  		if v.Lang != "" {
    40  			name.Local += "-" + v.Lang
    41  		}
    42  		node.AddAttr(Attr{
    43  			Name:  name,
    44  			Value: v.Value,
    45  		})
    46  	}
    47  	return nil
    48  }
    49  
    50  func (x *TagList) UnmarshalXMP(d *Decoder, node *Node, m Model) error {
    51  	for _, v := range node.Attr {
    52  		tag := Tag{
    53  			Key:   v.Name.Local,
    54  			Value: v.Value,
    55  		}
    56  		if i := strings.Index(tag.Key, "-"); i > -1 {
    57  			// FIXME: parsing would be better here
    58  			tag.Key, tag.Lang = tag.Key[:i], tag.Key[i+1:]
    59  		}
    60  		*x = append(*x, tag)
    61  	}
    62  	return nil
    63  }
    64  
    65  func GetNativeField(v Model, name string) (string, error) {
    66  	nsName, err := getNamespaceName(v)
    67  	if err != nil {
    68  		return "", err
    69  	}
    70  
    71  	val := derefIndirect(v)
    72  	finfo, err := findField(val, name, nsName)
    73  	if err != nil {
    74  		return "", err
    75  	}
    76  
    77  	fv := finfo.value(val)
    78  	typ := fv.Type()
    79  
    80  	if !fv.IsValid() {
    81  		return "", nil
    82  	}
    83  
    84  	if (fv.Kind() == reflect.Interface || fv.Kind() == reflect.Ptr) && fv.IsNil() {
    85  		return "", nil
    86  	}
    87  
    88  	if finfo.flags&fEmpty == 0 && isEmptyValue(fv) {
    89  		return "", nil
    90  	}
    91  
    92  	// Drill into interfaces and pointers.
    93  	// This can turn into an infinite loop given a cyclic chain,
    94  	// but it matches the Go 1 behavior.
    95  	for fv.Kind() == reflect.Interface || fv.Kind() == reflect.Ptr {
    96  		fv = fv.Elem()
    97  	}
    98  
    99  	// Check for text marshaler and marshal as node value
   100  	if fv.CanAddr() {
   101  		pv := fv.Addr()
   102  		if pv.CanInterface() && (finfo != nil && finfo.flags&fTextMarshal > 0 || pv.Type().Implements(textMarshalerType)) {
   103  			b, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
   104  			if err != nil || b == nil {
   105  				return "", err
   106  			}
   107  			return string(b), nil
   108  		}
   109  	}
   110  
   111  	if fv.CanInterface() && (finfo != nil && finfo.flags&fTextMarshal > 0 || typ.Implements(textMarshalerType)) {
   112  		b, err := fv.Interface().(encoding.TextMarshaler).MarshalText()
   113  		if err != nil || b == nil {
   114  			return "", err
   115  		}
   116  		return string(b), nil
   117  	}
   118  
   119  	// simple values are just fine, but any other type (slice, array, struct)
   120  	// without textmarshaler will fail
   121  	if s, b, err := marshalSimple(typ, fv); err != nil {
   122  		return "", err
   123  	} else {
   124  		if b != nil {
   125  			s = string(b)
   126  		}
   127  		return s, nil
   128  	}
   129  }
   130  
   131  func SetNativeField(v Model, name, value string) error {
   132  	nsName, err := getNamespaceName(v)
   133  	if err != nil {
   134  		return err
   135  	}
   136  
   137  	val := derefIndirect(v)
   138  	finfo, err := findField(val, name, nsName)
   139  	if err != nil {
   140  		return err
   141  	}
   142  
   143  	f := finfo.value(val)
   144  
   145  	// allocate memory for pointer values in structs
   146  	if f.Type().Kind() == reflect.Ptr && f.IsNil() && f.CanSet() {
   147  		f.Set(reflect.New(f.Type().Elem()))
   148  	}
   149  
   150  	// load and potentially create value
   151  	f = derefValue(f)
   152  
   153  	// try unmarshalers
   154  	if f.CanAddr() {
   155  		pv := f.Addr()
   156  		if pv.CanInterface() && (finfo != nil && finfo.flags&fBinaryUnmarshal > 0 || pv.Type().Implements(binaryUnmarshalerType)) {
   157  			return pv.Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary([]byte(value))
   158  		}
   159  	}
   160  
   161  	if f.CanInterface() && (finfo != nil && finfo.flags&fBinaryUnmarshal > 0 || f.Type().Implements(binaryUnmarshalerType)) {
   162  		return f.Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary([]byte(value))
   163  	}
   164  
   165  	if f.CanAddr() {
   166  		pv := f.Addr()
   167  		if pv.CanInterface() && (finfo != nil && finfo.flags&fTextUnmarshal > 0 || pv.Type().Implements(textUnmarshalerType)) {
   168  			return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value))
   169  		}
   170  	}
   171  
   172  	if f.CanInterface() && (finfo != nil && finfo.flags&fTextUnmarshal > 0 || f.Type().Implements(textUnmarshalerType)) {
   173  		return f.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value))
   174  	}
   175  
   176  	// otherwise set simple field value directly or fail
   177  	return setValue(f, value)
   178  }
   179  
   180  func SetLocaleField(v Model, lang string, name, value string) error {
   181  	nsName, err := getNamespaceName(v)
   182  	if err != nil {
   183  		return err
   184  	}
   185  
   186  	val := derefIndirect(v)
   187  	finfo, err := findField(val, name, nsName)
   188  	if err != nil {
   189  		return err
   190  	}
   191  
   192  	f := finfo.value(val)
   193  
   194  	// allocate memory for pointer values in structs
   195  	if f.Type().Kind() == reflect.Ptr && f.IsNil() && f.CanSet() {
   196  		f.Set(reflect.New(f.Type().Elem()))
   197  	}
   198  
   199  	// load and potentially create value
   200  	f = derefValue(f)
   201  
   202  	if f.Kind() != reflect.Slice || f.Type().Elem() != reflect.TypeOf(AltItem{}) {
   203  		return fmt.Errorf("field '%s' must be of type xmp.AltString, found type '%s' kind '%s'", name, f.Type().String(), f.Kind())
   204  	}
   205  
   206  	// we need a pointer to AltString slices for appending
   207  	a, ok := f.Addr().Interface().(*AltString)
   208  	if !ok {
   209  		return fmt.Errorf("field '%s' must be of type xmp.AltString", name)
   210  	}
   211  
   212  	// use AltString interface to add value
   213  	a.Set(lang, value)
   214  	return nil
   215  }
   216  
   217  func GetLocaleField(v Model, lang string, name string) (string, error) {
   218  	nsName, err := getNamespaceName(v)
   219  	if err != nil {
   220  		return "", err
   221  	}
   222  
   223  	val := derefIndirect(v)
   224  	finfo, err := findField(val, name, nsName)
   225  	if err != nil {
   226  		return "", err
   227  	}
   228  
   229  	f := finfo.value(val)
   230  
   231  	if f.Kind() != reflect.Slice && f.Type().Elem() != reflect.TypeOf(AltItem{}) {
   232  		return "", fmt.Errorf("field '%s' must be of type AltString, found %s (%s)", name, f.Type().String(), f.Kind())
   233  	}
   234  
   235  	a, ok := f.Interface().(AltString)
   236  	if !ok {
   237  		return "", fmt.Errorf("field '%s' must be of type AltString", name)
   238  	}
   239  
   240  	// use AltString interface to get value
   241  	return a.Get(lang), nil
   242  }
   243  
   244  func ListNativeFields(v Model) (TagList, error) {
   245  	nsName, err := getNamespaceName(v)
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  
   250  	val := derefIndirect(v)
   251  	typ := val.Type()
   252  
   253  	tinfo, err := getTypeInfo(typ, nsName)
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  
   258  	tagList := make(TagList, 0)
   259  
   260  	// go through all fields
   261  	for _, finfo := range tinfo.fields {
   262  		fv := finfo.value(val)
   263  
   264  		if !fv.IsValid() {
   265  			continue
   266  		}
   267  
   268  		if (fv.Kind() == reflect.Interface || fv.Kind() == reflect.Ptr) && fv.IsNil() {
   269  			continue
   270  		}
   271  
   272  		if finfo.flags&fEmpty == 0 && isEmptyValue(fv) {
   273  			continue
   274  		}
   275  
   276  		// Drill into interfaces and pointers.
   277  		// This can turn into an infinite loop given a cyclic chain,
   278  		// but it matches the Go 1 behavior.
   279  		for fv.Kind() == reflect.Interface || fv.Kind() == reflect.Ptr {
   280  			fv = fv.Elem()
   281  		}
   282  
   283  		tag := Tag{
   284  			Key: finfo.name,
   285  		}
   286  
   287  		// Check for text marshaler and marshal as node value
   288  		if fv.CanInterface() && typ.Implements(textMarshalerType) {
   289  			b, err := fv.Interface().(encoding.TextMarshaler).MarshalText()
   290  			if err != nil {
   291  				return nil, err
   292  			}
   293  			if len(b) == 0 {
   294  				continue
   295  			}
   296  			tag.Value = string(b)
   297  			tagList = append(tagList, tag)
   298  			continue
   299  		}
   300  
   301  		if fv.CanAddr() {
   302  			pv := fv.Addr()
   303  			if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
   304  				b, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
   305  				if err != nil {
   306  					return nil, err
   307  				}
   308  				if len(b) == 0 {
   309  					continue
   310  				}
   311  				tag.Value = string(b)
   312  				tagList = append(tagList, tag)
   313  				continue
   314  			}
   315  		}
   316  
   317  		// handle multi-language arrays
   318  		if fv.Kind() == reflect.Slice && fv.Type().Elem() == reflect.TypeOf(AltItem{}) {
   319  			a, ok := fv.Interface().(AltString)
   320  			if !ok {
   321  				return nil, fmt.Errorf("field '%s' must be of type AltString", finfo.name)
   322  			}
   323  			for _, v := range a {
   324  				tagList = append(tagList, Tag{
   325  					Key:   finfo.name,
   326  					Lang:  v.Lang,
   327  					Value: v.Value,
   328  				})
   329  			}
   330  			continue
   331  		}
   332  
   333  		// simple values are just fine, but any other type (slice, array, struct)
   334  		// without textmarshaler will fail here
   335  		if s, b, err := marshalSimple(typ, fv); err != nil {
   336  			return nil, err
   337  		} else {
   338  			if b != nil {
   339  				s = string(b)
   340  			}
   341  			if len(s) == 0 {
   342  				continue
   343  			}
   344  			tag.Value = s
   345  			tagList = append(tagList, tag)
   346  		}
   347  	}
   348  
   349  	return tagList, nil
   350  }
   351  
   352  func findField(val reflect.Value, name, ns string) (*fieldInfo, error) {
   353  	typ := val.Type()
   354  	tinfo, err := getTypeInfo(typ, ns)
   355  	if err != nil {
   356  		return nil, err
   357  	}
   358  
   359  	// pick the correct field based on name, flags and version
   360  	var finfo *fieldInfo
   361  	any := -1
   362  	for i, v := range tinfo.fields {
   363  		// version must always match
   364  		// if !d.version.Between(v.minVersion, v.maxVersion) {
   365  		// 	continue
   366  		// }
   367  
   368  		// save `any` field in case
   369  		if v.flags&fAny > 0 {
   370  			any = i
   371  		}
   372  
   373  		// field name must match
   374  		if hasPrefix(name) {
   375  			// exact match when namespace is specified
   376  			if v.name != name {
   377  				continue
   378  			}
   379  		} else {
   380  			// suffix match without namespace
   381  			if stripPrefix(v.name) != name {
   382  				continue
   383  			}
   384  		}
   385  
   386  		finfo = &v
   387  		break
   388  	}
   389  
   390  	if finfo == nil && any >= 0 {
   391  		finfo = &tinfo.fields[any]
   392  	}
   393  
   394  	// nothing found
   395  	if finfo == nil {
   396  		return nil, fmt.Errorf("no field with tag '%s' in type '%s'", name, typ.String())
   397  	}
   398  
   399  	return finfo, nil
   400  }
   401  
   402  func getNamespaceName(v Model) (string, error) {
   403  	if n := v.Namespaces(); len(n) == 0 {
   404  		return "", fmt.Errorf("model '%s' must implement at least one namespace", reflect.TypeOf(v).String())
   405  	} else {
   406  		return n[0].GetName(), nil
   407  	}
   408  }