github.com/boki/go-xmp@v1.0.1/xmp/unmarshal.go (about)

     1  // Copyright (c) 2017-2018 Alexander Eichhorn
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"): you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
    11  // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
    12  // License for the specific language governing permissions and limitations
    13  // under the License.
    14  
    15  package xmp
    16  
    17  import (
    18  	"bytes"
    19  	"encoding"
    20  	"encoding/xml"
    21  	"fmt"
    22  	"io"
    23  	"reflect"
    24  	"strconv"
    25  	"strings"
    26  )
    27  
    28  type Unmarshaler interface {
    29  	UnmarshalXMP(d *Decoder, n *Node, model Model) error
    30  }
    31  
    32  type UnmarshalerAttr interface {
    33  	UnmarshalXMPAttr(d *Decoder, a Attr) error
    34  }
    35  
    36  type Decoder struct {
    37  	d        *xml.Decoder
    38  	toolkit  string
    39  	about    string
    40  	nodes    NodeList
    41  	intNsMap map[string]*Namespace
    42  	extNsMap map[string]*Namespace
    43  	version  Version
    44  }
    45  
    46  func NewDecoder(r io.Reader) *Decoder {
    47  	return &Decoder{
    48  		d:        xml.NewDecoder(r),
    49  		nodes:    make(NodeList, 0),
    50  		intNsMap: make(map[string]*Namespace),
    51  		extNsMap: make(map[string]*Namespace),
    52  	}
    53  }
    54  
    55  func (d *Decoder) SetVersion(v Version) {
    56  	d.version = v
    57  }
    58  
    59  func Unmarshal(data []byte, d *Document) error {
    60  	return NewDecoder(bytes.NewReader(data)).Decode(d)
    61  }
    62  
    63  func (d *Decoder) DecodeElement(v interface{}, src *Node) error {
    64  	val := reflect.ValueOf(v)
    65  	if val.Kind() != reflect.Ptr {
    66  		return fmt.Errorf("xmp: model for '%s' is not a pointer", reflect.TypeOf(v))
    67  	}
    68  	return d.unmarshal(val, nil, src)
    69  }
    70  
    71  func (d *Decoder) Decode(x *Document) error {
    72  
    73  	if x == nil {
    74  		return nil
    75  	}
    76  
    77  	// 1  parse node tree from XML
    78  	root := NewNode(emptyName)
    79  	gc := root
    80  	defer gc.Close()
    81  	if err := d.d.Decode(root); err != nil {
    82  		return fmt.Errorf("xmp: parsing xml failed: %v", err)
    83  	}
    84  
    85  	// 2  skip top-level `x:xmpmeta` (optional) and `rdf:RDF` nodes
    86  	if root.FullName() == "x:xmpmeta" {
    87  		if a := root.GetAttr(nsX.GetURI(), "xmptk"); len(a) > 0 {
    88  			x.toolkit = strings.TrimSpace(a[0].Value)
    89  		}
    90  		if len(root.Nodes) == 0 {
    91  			return fmt.Errorf("xmp: invalid XML format: missing rdf:RDF node")
    92  		}
    93  		if len(root.Nodes) > 1 {
    94  			return fmt.Errorf("xmp: invalid XML format: too many child nodes in x:xmpmeta")
    95  		}
    96  		root = root.Nodes[0]
    97  	}
    98  
    99  	if root.FullName() != "rdf:RDF" {
   100  		return fmt.Errorf("xmp: invalid XML format: missing rdf:RDF node, found %s:%s", root.Namespace(), root.Name())
   101  	}
   102  
   103  	// 3  extract document namespaces
   104  	for _, n := range root.Nodes {
   105  		for _, v := range n.GetAttr("xmlns", "") {
   106  			d.addNamespace(v.Name.Local, v.Value)
   107  		}
   108  	}
   109  
   110  	// 4  walk node tree and create model instances
   111  	for _, n := range root.Nodes {
   112  		// we expect outer nodes
   113  		if n.FullName() != "rdf:Description" {
   114  			return fmt.Errorf("xmp: invalid XML format: expected rdf:Description node, found %s:%s", n.Namespace(), n.Name())
   115  		}
   116  
   117  		// process attributes
   118  		for _, v := range n.Attr {
   119  			if v.Name.Space == nsRDF.GetURI() {
   120  				if v.Name.Local == "about" {
   121  					d.about = v.Value
   122  				}
   123  				continue
   124  			}
   125  
   126  			if err := d.decodeAttribute(&d.nodes, v); err != nil {
   127  				return err
   128  			}
   129  		}
   130  
   131  		// process child nodes
   132  		for _, v := range n.Nodes {
   133  			if err := d.decodeNode(&d.nodes, v); err != nil {
   134  				return err
   135  			}
   136  		}
   137  	}
   138  
   139  	// copy decoded values to document
   140  	x.toolkit = d.toolkit
   141  	x.about = d.about
   142  	x.nodes = d.nodes
   143  	x.intNsMap = d.intNsMap
   144  	x.extNsMap = d.extNsMap
   145  	return x.syncFromXMP()
   146  }
   147  
   148  func (d *Decoder) decodeNode(ctx *NodeList, src *Node) error {
   149  
   150  	node, err := d.lookupNode(ctx, src.XMLName)
   151  	if err != nil {
   152  		return err
   153  	}
   154  
   155  	d.translate(&src.XMLName)
   156  	name := src.FullName()
   157  
   158  	// process the node value
   159  	var storeNode bool
   160  	if node.Model != nil {
   161  		finfo, field := d.findStructField(derefIndirect(node.Model), name)
   162  		if field.IsValid() {
   163  			return d.unmarshal(field, finfo, src)
   164  		} else {
   165  			storeNode = finfo == nil || finfo.flags&fOmit == 0
   166  		}
   167  	} else {
   168  		storeNode = true
   169  	}
   170  
   171  	// capture the node and its children into the selected node
   172  	if storeNode {
   173  		src.translate(d)
   174  		if !src.IsZero() {
   175  			node.AddNode(copyNode(src))
   176  			Log.Debugf("xmp: missing struct field for %s, saving as external node in %s model", name, node.FullName())
   177  		}
   178  	}
   179  	return nil
   180  }
   181  
   182  func (d *Decoder) decodeAttribute(ctx *NodeList, src Attr) error {
   183  
   184  	d.translate(&src.Name)
   185  	if skipField(src.Name) {
   186  		return nil
   187  	}
   188  
   189  	node, err := d.lookupNode(ctx, src.Name)
   190  	if err != nil {
   191  		return err
   192  	}
   193  
   194  	// process the attribute value
   195  	var storeAttr bool
   196  	if node.Model != nil {
   197  		finfo, field := d.findStructField(derefIndirect(node.Model), src.Name.Local)
   198  		if field.IsValid() {
   199  			if err := d.unmarshalAttr(field, finfo, src); err != nil {
   200  				return err
   201  			}
   202  		} else {
   203  			// capture the field as external attribute
   204  			storeAttr = finfo == nil || finfo.flags&fOmit == 0
   205  		}
   206  
   207  	} else {
   208  		storeAttr = true
   209  	}
   210  
   211  	if storeAttr {
   212  		if src.Value != "" {
   213  			Log.Debugf("xmp: missing struct field for %s, saving as unknown attr in %s model", src.Name.Local, node.FullName())
   214  			node.AddAttr(src)
   215  		}
   216  	}
   217  
   218  	return nil
   219  }
   220  
   221  func (d *Decoder) unmarshal(val reflect.Value, finfo *fieldInfo, src *Node) error {
   222  	// Load value from interface, but only if the result will be
   223  	// usefully addressable.
   224  	val = derefValue(val)
   225  
   226  	// XMP
   227  	if val.CanAddr() {
   228  		pv := val.Addr()
   229  		if pv.CanInterface() && (finfo != nil && finfo.flags&fUnmarshal > 0 || pv.Type().Implements(unmarshalerType)) {
   230  			return pv.Interface().(Unmarshaler).UnmarshalXMP(d, src, nil)
   231  		}
   232  	}
   233  
   234  	// Text
   235  	if val.CanAddr() {
   236  		pv := val.Addr()
   237  		if pv.CanInterface() && (finfo != nil && finfo.flags&fTextUnmarshal > 0 || pv.Type().Implements(textUnmarshalerType)) {
   238  			return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(src.Value))
   239  		}
   240  	}
   241  
   242  	// structs
   243  	if val.Kind() == reflect.Struct {
   244  		// process attributes first
   245  		for _, a := range src.Attr {
   246  			d.translate(&a.Name)
   247  			if skipField(a.Name) {
   248  				continue
   249  			}
   250  			if finfo, field := d.findStructField(val, a.Name.Local); field.IsValid() {
   251  				if err := d.unmarshalAttr(field, finfo, a); err != nil {
   252  					return err
   253  				}
   254  			} else {
   255  				return fmt.Errorf("xmp: unmarshal model %s: field for attr %s not found in type %v", src.FullName(), a.Name.Local, val.Type())
   256  			}
   257  		}
   258  
   259  		// recurse into child nodes
   260  		for _, n := range src.Nodes {
   261  			d.translate(&n.XMLName)
   262  			name := n.FullName()
   263  			switch name {
   264  			case "rdf:Description":
   265  				if err := d.unmarshal(val, nil, n); err != nil {
   266  					return err
   267  				}
   268  			default:
   269  				if skipField(n.XMLName) {
   270  					break
   271  				}
   272  				if finfo, field := d.findStructField(val, name); field.IsValid() {
   273  					if err := d.unmarshal(field, finfo, n); err != nil {
   274  						return err
   275  					}
   276  				} else {
   277  					return fmt.Errorf("xmp: unmarshal model %s: struct field %s not found (not stored)", src.FullName(), name)
   278  				}
   279  			}
   280  		}
   281  	} else {
   282  		// otherwise set simple value directly
   283  		if err := setValue(val, src.Value); err != nil {
   284  			return fmt.Errorf("xmp: unmarshal %s: %v", finfo.String(), err)
   285  		}
   286  	}
   287  
   288  	return nil
   289  }
   290  
   291  func (d *Decoder) unmarshalAttr(val reflect.Value, finfo *fieldInfo, src Attr) error {
   292  	// Load value from interface
   293  	val = derefValue(val)
   294  
   295  	// attribute unmarshaler
   296  	if val.CanAddr() {
   297  		pv := val.Addr()
   298  		if pv.CanInterface() && (finfo != nil && finfo.flags&fUnmarshalAttr > 0 || pv.Type().Implements(attrUnmarshalerType)) {
   299  			return pv.Interface().(UnmarshalerAttr).UnmarshalXMPAttr(d, src)
   300  		}
   301  	}
   302  
   303  	// text unmarshaler
   304  	if val.CanAddr() {
   305  		pv := val.Addr()
   306  		if pv.CanInterface() && (finfo != nil && finfo.flags&fTextUnmarshal > 0 || pv.Type().Implements(textUnmarshalerType)) {
   307  			return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(src.Value))
   308  		}
   309  	}
   310  
   311  	// Slice of element values.
   312  	if val.Type().Kind() == reflect.Slice && val.Type().Elem().Kind() != reflect.Uint8 {
   313  		// Grow slice.
   314  		n := val.Len()
   315  		val.Set(reflect.Append(val, reflect.Zero(val.Type().Elem())))
   316  
   317  		// Recur to read element into slice.
   318  		if err := d.unmarshalAttr(val.Index(n), nil, src); err != nil {
   319  			val.SetLen(n)
   320  			return fmt.Errorf("xmp: unmarshal %s: %v", finfo.String(), err)
   321  		}
   322  		return nil
   323  	}
   324  
   325  	// otherwise set value directly
   326  	return setValue(val, src.Value)
   327  }
   328  
   329  // Translate an xml name's namespace to XMP format.
   330  // Since rdf and xml namespaces are not registered in a document,
   331  // we look up those namespaces in our registry. This is necessary
   332  // to transform rdf-related attributes like rdf:about, xml:lang,
   333  // rdf:parseType, etc.
   334  func (d Decoder) translate(n *xml.Name) {
   335  	if len(n.Space) == 0 || n.Space == "xmlns" {
   336  		return
   337  	}
   338  	ns := d.findNs(*n)
   339  	if ns == nil {
   340  		ns, _ = NsRegistry.GetNamespace(NsRegistry.GetPrefix(n.Space))
   341  	}
   342  	if ns != nil {
   343  		n.Space = ""
   344  		n.Local = ns.Expand(n.Local)
   345  	}
   346  }
   347  
   348  // Keep track of all used namespaces (registered and unknown).
   349  // In XML a local name's prefix may differ from the XMP standard
   350  // prefix. Even Adobe used to use `xap` instead of `xmp` as prefix
   351  // in documents before the standard was finished.
   352  func (d *Decoder) addNamespace(prefix, uri string) {
   353  	// register known namespaces using their standard prefix
   354  	if ns := NsRegistry.GetPrefix(uri); len(ns) > 0 {
   355  		d.intNsMap[uri], _ = NsRegistry.GetNamespace(ns)
   356  		return
   357  	}
   358  
   359  	// keep track of unknown namespaces using their in-document prefix
   360  	if _, ok := d.extNsMap[uri]; !ok {
   361  		d.extNsMap[uri] = &Namespace{prefix, uri, emptyFactory}
   362  	}
   363  }
   364  
   365  func (d Decoder) _findNsByURI(uri string) *Namespace {
   366  	if v, ok := d.intNsMap[uri]; ok {
   367  		return v
   368  	}
   369  	if v, ok := d.extNsMap[uri]; ok {
   370  		return v
   371  	}
   372  	return nil
   373  }
   374  
   375  func (d Decoder) _findNsByPrefix(pre string) *Namespace {
   376  	for _, v := range d.intNsMap {
   377  		if v.GetName() == pre {
   378  			return v
   379  		}
   380  	}
   381  	for _, v := range d.extNsMap {
   382  		if v.GetName() == pre {
   383  			return v
   384  		}
   385  	}
   386  	return nil
   387  }
   388  
   389  func (d Decoder) findNs(n xml.Name) *Namespace {
   390  	var ns *Namespace
   391  	if len(n.Space) > 0 {
   392  		ns = d._findNsByURI(n.Space)
   393  	}
   394  	if ns == nil {
   395  		ns = d._findNsByPrefix(getPrefix(n.Local))
   396  	}
   397  	return ns
   398  }
   399  
   400  func (d *Decoder) findStructField(val reflect.Value, name string) (*fieldInfo, reflect.Value) {
   401  	typ := val.Type()
   402  	tinfo, err := getTypeInfo(typ, "xmp")
   403  	if err != nil {
   404  		return nil, reflect.Value{}
   405  	}
   406  
   407  	var finfo *fieldInfo
   408  	any := -1
   409  	// pick the correct field based on name, flags and version
   410  	for i, v := range tinfo.fields {
   411  		// version must always match
   412  		if !d.version.Between(v.minVersion, v.maxVersion) {
   413  			continue
   414  		}
   415  
   416  		// save `any` field in case
   417  		if v.flags&fAny > 0 {
   418  			any = i
   419  		}
   420  
   421  		// field name must match
   422  		if v.name != name {
   423  			continue
   424  		}
   425  
   426  		finfo = &v
   427  		break
   428  	}
   429  
   430  	if finfo == nil && any >= 0 {
   431  		finfo = &tinfo.fields[any]
   432  	}
   433  
   434  	// nothing found
   435  	if finfo == nil {
   436  		return nil, reflect.Value{}
   437  	}
   438  
   439  	// allocate memory for pointer values in structs
   440  	v := finfo.value(val)
   441  	if v.Type().Kind() == reflect.Ptr && v.IsNil() && v.CanSet() {
   442  		v.Set(reflect.New(v.Type().Elem()))
   443  	}
   444  
   445  	return finfo, v
   446  }
   447  
   448  func (d *Decoder) lookupNode(ctx *NodeList, name xml.Name) (*Node, error) {
   449  	// check namespace has been registered (i.e. exists in document)
   450  	ns := d.findNs(name)
   451  	if ns == nil {
   452  		return nil, &UnknownNamespaceError{name}
   453  	}
   454  
   455  	// pick or create the XMP model for the current namespace
   456  	node := ctx.FindNode(ns)
   457  	if node == nil {
   458  		model := ns.NewModel()
   459  		if model != nil {
   460  			modelNs := model.Namespaces()
   461  			if len(modelNs) == 0 {
   462  				return nil, fmt.Errorf("xmp: model '%v' must implement at least one namespace", reflect.TypeOf(model))
   463  			}
   464  			node = NewNode(modelNs[0].XMLName(""))
   465  			node.Model = model
   466  		} else {
   467  			node = NewNode(ns.XMLName(""))
   468  		}
   469  		*ctx = append(*ctx, node)
   470  	}
   471  	return node, nil
   472  }
   473  
   474  func setValue(dst reflect.Value, src string) error {
   475  	dst0 := dst
   476  	if dst.Kind() == reflect.Ptr {
   477  		if dst.IsNil() {
   478  			dst.Set(reflect.New(dst.Type().Elem()))
   479  		}
   480  		dst = dst.Elem()
   481  	}
   482  
   483  	switch dst.Kind() {
   484  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   485  		i, err := strconv.ParseInt(src, 10, dst.Type().Bits())
   486  		if err != nil {
   487  			return err
   488  		}
   489  		dst.SetInt(i)
   490  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   491  		i, err := strconv.ParseUint(src, 10, dst.Type().Bits())
   492  		if err != nil {
   493  			return err
   494  		}
   495  		dst.SetUint(i)
   496  	case reflect.Float32, reflect.Float64:
   497  		i, err := strconv.ParseFloat(src, dst.Type().Bits())
   498  		if err != nil {
   499  			return err
   500  		}
   501  		dst.SetFloat(i)
   502  	case reflect.Bool:
   503  		i, err := strconv.ParseBool(strings.TrimSpace(src))
   504  		if err != nil {
   505  			return err
   506  		}
   507  		dst.SetBool(i)
   508  	case reflect.String:
   509  		dst.SetString(strings.TrimSpace(src))
   510  	case reflect.Slice:
   511  		// make sure it's a byte slice
   512  		if dst.Type().Elem().Kind() == reflect.Uint8 {
   513  			dst.SetBytes([]byte(src))
   514  		}
   515  	default:
   516  		return fmt.Errorf("xmp: no method for unmarshalling type %s", dst0.Type().String())
   517  	}
   518  	return nil
   519  }
   520  
   521  func skipField(n xml.Name) bool {
   522  	if n.Space == "xmlns" {
   523  		return true
   524  	}
   525  
   526  	switch n.Local {
   527  	case "rdf:parseType", "rdf:type", "xml:lang":
   528  		return true
   529  	default:
   530  		return false
   531  	}
   532  }