github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/staticcheck/fakexml/marshal.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // This file contains a modified copy of the encoding/xml encoder.
     6  // All dynamic behavior has been removed, and reflecttion has been replaced with go/types.
     7  // This allows us to statically find unmarshable types
     8  // with the same rules for tags, shadowing and addressability as encoding/xml.
     9  // This is used for SA1026 and SA5008.
    10  
    11  // NOTE(dh): we do not check CanInterface in various places, which means we'll accept more marshaler implementations than encoding/xml does. This will lead to a small amount of false negatives.
    12  
    13  package fakexml
    14  
    15  import (
    16  	"fmt"
    17  	"go/types"
    18  
    19  	"github.com/amarpal/go-tools/go/types/typeutil"
    20  	"github.com/amarpal/go-tools/knowledge"
    21  	"github.com/amarpal/go-tools/staticcheck/fakereflect"
    22  )
    23  
    24  func Marshal(v types.Type) error {
    25  	return NewEncoder().Encode(v)
    26  }
    27  
    28  type Encoder struct {
    29  	// TODO we track addressable and non-addressable instances separately out of an abundance of caution. We don't know
    30  	// if this is actually required for correctness.
    31  	seenCanAddr  typeutil.Map[struct{}]
    32  	seenCantAddr typeutil.Map[struct{}]
    33  }
    34  
    35  func NewEncoder() *Encoder {
    36  	e := &Encoder{}
    37  	return e
    38  }
    39  
    40  func (enc *Encoder) Encode(v types.Type) error {
    41  	rv := fakereflect.TypeAndCanAddr{Type: v}
    42  	return enc.marshalValue(rv, nil, nil, "x")
    43  }
    44  
    45  func implementsMarshaler(v fakereflect.TypeAndCanAddr) bool {
    46  	t := v.Type
    47  	obj, _, _ := types.LookupFieldOrMethod(t, false, nil, "MarshalXML")
    48  	if obj == nil {
    49  		return false
    50  	}
    51  	fn, ok := obj.(*types.Func)
    52  	if !ok {
    53  		return false
    54  	}
    55  	params := fn.Type().(*types.Signature).Params()
    56  	if params.Len() != 2 {
    57  		return false
    58  	}
    59  	if !typeutil.IsType(params.At(0).Type(), "*encoding/xml.Encoder") {
    60  		return false
    61  	}
    62  	if !typeutil.IsType(params.At(1).Type(), "encoding/xml.StartElement") {
    63  		return false
    64  	}
    65  	rets := fn.Type().(*types.Signature).Results()
    66  	if rets.Len() != 1 {
    67  		return false
    68  	}
    69  	if !typeutil.IsType(rets.At(0).Type(), "error") {
    70  		return false
    71  	}
    72  	return true
    73  }
    74  
    75  func implementsMarshalerAttr(v fakereflect.TypeAndCanAddr) bool {
    76  	t := v.Type
    77  	obj, _, _ := types.LookupFieldOrMethod(t, false, nil, "MarshalXMLAttr")
    78  	if obj == nil {
    79  		return false
    80  	}
    81  	fn, ok := obj.(*types.Func)
    82  	if !ok {
    83  		return false
    84  	}
    85  	params := fn.Type().(*types.Signature).Params()
    86  	if params.Len() != 1 {
    87  		return false
    88  	}
    89  	if !typeutil.IsType(params.At(0).Type(), "encoding/xml.Name") {
    90  		return false
    91  	}
    92  	rets := fn.Type().(*types.Signature).Results()
    93  	if rets.Len() != 2 {
    94  		return false
    95  	}
    96  	if !typeutil.IsType(rets.At(0).Type(), "encoding/xml.Attr") {
    97  		return false
    98  	}
    99  	if !typeutil.IsType(rets.At(1).Type(), "error") {
   100  		return false
   101  	}
   102  	return true
   103  }
   104  
   105  type CyclicTypeError struct {
   106  	Type types.Type
   107  	Path string
   108  }
   109  
   110  func (err *CyclicTypeError) Error() string {
   111  	return "cyclic type"
   112  }
   113  
   114  // marshalValue writes one or more XML elements representing val.
   115  // If val was obtained from a struct field, finfo must have its details.
   116  func (e *Encoder) marshalValue(val fakereflect.TypeAndCanAddr, finfo *fieldInfo, startTemplate *StartElement, stack string) error {
   117  	var m *typeutil.Map[struct{}]
   118  	if val.CanAddr() {
   119  		m = &e.seenCanAddr
   120  	} else {
   121  		m = &e.seenCantAddr
   122  	}
   123  	if _, ok := m.At(val.Type); ok {
   124  		return nil
   125  	}
   126  	m.Set(val.Type, struct{}{})
   127  
   128  	// Drill into interfaces and pointers.
   129  	seen := map[fakereflect.TypeAndCanAddr]struct{}{}
   130  	for val.IsInterface() || val.IsPtr() {
   131  		if val.IsInterface() {
   132  			return nil
   133  		}
   134  		val = val.Elem()
   135  		if _, ok := seen[val]; ok {
   136  			// Loop in type graph, e.g. 'type P *P'
   137  			return &CyclicTypeError{val.Type, stack}
   138  		}
   139  		seen[val] = struct{}{}
   140  	}
   141  
   142  	// Check for marshaler.
   143  	if implementsMarshaler(val) {
   144  		return nil
   145  	}
   146  	if val.CanAddr() {
   147  		pv := fakereflect.PtrTo(val)
   148  		if implementsMarshaler(pv) {
   149  			return nil
   150  		}
   151  	}
   152  
   153  	// Check for text marshaler.
   154  	if val.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
   155  		return nil
   156  	}
   157  	if val.CanAddr() {
   158  		pv := fakereflect.PtrTo(val)
   159  		if pv.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
   160  			return nil
   161  		}
   162  	}
   163  
   164  	// Slices and arrays iterate over the elements. They do not have an enclosing tag.
   165  	if (val.IsSlice() || val.IsArray()) && !isByteArray(val) && !isByteSlice(val) {
   166  		if err := e.marshalValue(val.Elem(), finfo, startTemplate, stack+"[0]"); err != nil {
   167  			return err
   168  		}
   169  		return nil
   170  	}
   171  
   172  	tinfo, err := getTypeInfo(val)
   173  	if err != nil {
   174  		return err
   175  	}
   176  
   177  	// Create start element.
   178  	// Precedence for the XML element name is:
   179  	// 0. startTemplate
   180  	// 1. XMLName field in underlying struct;
   181  	// 2. field name/tag in the struct field; and
   182  	// 3. type name
   183  	var start StartElement
   184  
   185  	if startTemplate != nil {
   186  		start.Name = startTemplate.Name
   187  		start.Attr = append(start.Attr, startTemplate.Attr...)
   188  	} else if tinfo.xmlname != nil {
   189  		xmlname := tinfo.xmlname
   190  		if xmlname.name != "" {
   191  			start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name
   192  		}
   193  	}
   194  
   195  	// Attributes
   196  	for i := range tinfo.fields {
   197  		finfo := &tinfo.fields[i]
   198  		if finfo.flags&fAttr == 0 {
   199  			continue
   200  		}
   201  		fv := finfo.value(val)
   202  
   203  		name := Name{Space: finfo.xmlns, Local: finfo.name}
   204  		if err := e.marshalAttr(&start, name, fv, stack+pathByIndex(val, finfo.idx)); err != nil {
   205  			return err
   206  		}
   207  	}
   208  
   209  	if val.IsStruct() {
   210  		return e.marshalStruct(tinfo, val, stack)
   211  	} else {
   212  		return e.marshalSimple(val, stack)
   213  	}
   214  }
   215  
   216  func isSlice(v fakereflect.TypeAndCanAddr) bool {
   217  	_, ok := v.Type.Underlying().(*types.Slice)
   218  	return ok
   219  }
   220  
   221  func isByteSlice(v fakereflect.TypeAndCanAddr) bool {
   222  	slice, ok := v.Type.Underlying().(*types.Slice)
   223  	if !ok {
   224  		return false
   225  	}
   226  	basic, ok := slice.Elem().Underlying().(*types.Basic)
   227  	if !ok {
   228  		return false
   229  	}
   230  	return basic.Kind() == types.Uint8
   231  }
   232  
   233  func isByteArray(v fakereflect.TypeAndCanAddr) bool {
   234  	slice, ok := v.Type.Underlying().(*types.Array)
   235  	if !ok {
   236  		return false
   237  	}
   238  	basic, ok := slice.Elem().Underlying().(*types.Basic)
   239  	if !ok {
   240  		return false
   241  	}
   242  	return basic.Kind() == types.Uint8
   243  }
   244  
   245  // marshalAttr marshals an attribute with the given name and value, adding to start.Attr.
   246  func (e *Encoder) marshalAttr(start *StartElement, name Name, val fakereflect.TypeAndCanAddr, stack string) error {
   247  	if implementsMarshalerAttr(val) {
   248  		return nil
   249  	}
   250  
   251  	if val.CanAddr() {
   252  		pv := fakereflect.PtrTo(val)
   253  		if implementsMarshalerAttr(pv) {
   254  			return nil
   255  		}
   256  	}
   257  
   258  	if val.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
   259  		return nil
   260  	}
   261  
   262  	if val.CanAddr() {
   263  		pv := fakereflect.PtrTo(val)
   264  		if pv.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
   265  			return nil
   266  		}
   267  	}
   268  
   269  	// Dereference or skip nil pointer
   270  	if val.IsPtr() {
   271  		val = val.Elem()
   272  	}
   273  
   274  	// Walk slices.
   275  	if isSlice(val) && !isByteSlice(val) {
   276  		if err := e.marshalAttr(start, name, val.Elem(), stack+"[0]"); err != nil {
   277  			return err
   278  		}
   279  		return nil
   280  	}
   281  
   282  	if typeutil.IsType(val.Type, "encoding/xml.Attr") {
   283  		return nil
   284  	}
   285  
   286  	return e.marshalSimple(val, stack)
   287  }
   288  
   289  func (e *Encoder) marshalSimple(val fakereflect.TypeAndCanAddr, stack string) error {
   290  	switch val.Type.Underlying().(type) {
   291  	case *types.Basic, *types.Interface:
   292  		return nil
   293  	case *types.Slice, *types.Array:
   294  		basic, ok := val.Elem().Type.Underlying().(*types.Basic)
   295  		if !ok || basic.Kind() != types.Uint8 {
   296  			return &UnsupportedTypeError{val.Type, stack}
   297  		}
   298  		return nil
   299  	default:
   300  		return &UnsupportedTypeError{val.Type, stack}
   301  	}
   302  }
   303  
   304  func indirect(vf fakereflect.TypeAndCanAddr) fakereflect.TypeAndCanAddr {
   305  	for vf.IsPtr() {
   306  		vf = vf.Elem()
   307  	}
   308  	return vf
   309  }
   310  
   311  func pathByIndex(t fakereflect.TypeAndCanAddr, index []int) string {
   312  	path := ""
   313  	for _, i := range index {
   314  		if t.IsPtr() {
   315  			t = t.Elem()
   316  		}
   317  		path += "." + t.Field(i).Name
   318  		t = t.Field(i).Type
   319  	}
   320  	return path
   321  }
   322  
   323  func (e *Encoder) marshalStruct(tinfo *typeInfo, val fakereflect.TypeAndCanAddr, stack string) error {
   324  	for i := range tinfo.fields {
   325  		finfo := &tinfo.fields[i]
   326  		if finfo.flags&fAttr != 0 {
   327  			continue
   328  		}
   329  		vf := finfo.value(val)
   330  
   331  		switch finfo.flags & fMode {
   332  		case fCDATA, fCharData:
   333  			if vf.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
   334  				continue
   335  			}
   336  			if vf.CanAddr() {
   337  				pv := fakereflect.PtrTo(vf)
   338  				if pv.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
   339  					continue
   340  				}
   341  			}
   342  			continue
   343  
   344  		case fComment:
   345  			vf = indirect(vf)
   346  			if !(isByteSlice(vf) || isByteArray(vf)) {
   347  				return fmt.Errorf("xml: bad type for comment field of %s", val)
   348  			}
   349  			continue
   350  
   351  		case fInnerXML:
   352  			vf = indirect(vf)
   353  			if typeutil.IsType(vf.Type, "[]byte") || typeutil.IsType(vf.Type, "string") {
   354  				continue
   355  			}
   356  
   357  		case fElement, fElement | fAny:
   358  		}
   359  		if err := e.marshalValue(vf, finfo, nil, stack+pathByIndex(val, finfo.idx)); err != nil {
   360  			return err
   361  		}
   362  	}
   363  	return nil
   364  }
   365  
   366  // UnsupportedTypeError is returned when Marshal encounters a type
   367  // that cannot be converted into XML.
   368  type UnsupportedTypeError struct {
   369  	Type types.Type
   370  	Path string
   371  }
   372  
   373  func (e *UnsupportedTypeError) Error() string {
   374  	return fmt.Sprintf("xml: unsupported type %s, via %s ", e.Type, e.Path)
   375  }