github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/staticcheck/fakexml/marshal.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // This file contains a modified copy of the encoding/xml encoder. 6 // All dynamic behavior has been removed, and reflecttion has been replaced with go/types. 7 // This allows us to statically find unmarshable types 8 // with the same rules for tags, shadowing and addressability as encoding/xml. 9 // This is used for SA1026 and SA5008. 10 11 // NOTE(dh): we do not check CanInterface in various places, which means we'll accept more marshaler implementations than encoding/xml does. This will lead to a small amount of false negatives. 12 13 package fakexml 14 15 import ( 16 "fmt" 17 "go/types" 18 19 "github.com/amarpal/go-tools/go/types/typeutil" 20 "github.com/amarpal/go-tools/knowledge" 21 "github.com/amarpal/go-tools/staticcheck/fakereflect" 22 ) 23 24 func Marshal(v types.Type) error { 25 return NewEncoder().Encode(v) 26 } 27 28 type Encoder struct { 29 // TODO we track addressable and non-addressable instances separately out of an abundance of caution. We don't know 30 // if this is actually required for correctness. 31 seenCanAddr typeutil.Map[struct{}] 32 seenCantAddr typeutil.Map[struct{}] 33 } 34 35 func NewEncoder() *Encoder { 36 e := &Encoder{} 37 return e 38 } 39 40 func (enc *Encoder) Encode(v types.Type) error { 41 rv := fakereflect.TypeAndCanAddr{Type: v} 42 return enc.marshalValue(rv, nil, nil, "x") 43 } 44 45 func implementsMarshaler(v fakereflect.TypeAndCanAddr) bool { 46 t := v.Type 47 obj, _, _ := types.LookupFieldOrMethod(t, false, nil, "MarshalXML") 48 if obj == nil { 49 return false 50 } 51 fn, ok := obj.(*types.Func) 52 if !ok { 53 return false 54 } 55 params := fn.Type().(*types.Signature).Params() 56 if params.Len() != 2 { 57 return false 58 } 59 if !typeutil.IsType(params.At(0).Type(), "*encoding/xml.Encoder") { 60 return false 61 } 62 if !typeutil.IsType(params.At(1).Type(), "encoding/xml.StartElement") { 63 return false 64 } 65 rets := fn.Type().(*types.Signature).Results() 66 if rets.Len() != 1 { 67 return false 68 } 69 if !typeutil.IsType(rets.At(0).Type(), "error") { 70 return false 71 } 72 return true 73 } 74 75 func implementsMarshalerAttr(v fakereflect.TypeAndCanAddr) bool { 76 t := v.Type 77 obj, _, _ := types.LookupFieldOrMethod(t, false, nil, "MarshalXMLAttr") 78 if obj == nil { 79 return false 80 } 81 fn, ok := obj.(*types.Func) 82 if !ok { 83 return false 84 } 85 params := fn.Type().(*types.Signature).Params() 86 if params.Len() != 1 { 87 return false 88 } 89 if !typeutil.IsType(params.At(0).Type(), "encoding/xml.Name") { 90 return false 91 } 92 rets := fn.Type().(*types.Signature).Results() 93 if rets.Len() != 2 { 94 return false 95 } 96 if !typeutil.IsType(rets.At(0).Type(), "encoding/xml.Attr") { 97 return false 98 } 99 if !typeutil.IsType(rets.At(1).Type(), "error") { 100 return false 101 } 102 return true 103 } 104 105 type CyclicTypeError struct { 106 Type types.Type 107 Path string 108 } 109 110 func (err *CyclicTypeError) Error() string { 111 return "cyclic type" 112 } 113 114 // marshalValue writes one or more XML elements representing val. 115 // If val was obtained from a struct field, finfo must have its details. 116 func (e *Encoder) marshalValue(val fakereflect.TypeAndCanAddr, finfo *fieldInfo, startTemplate *StartElement, stack string) error { 117 var m *typeutil.Map[struct{}] 118 if val.CanAddr() { 119 m = &e.seenCanAddr 120 } else { 121 m = &e.seenCantAddr 122 } 123 if _, ok := m.At(val.Type); ok { 124 return nil 125 } 126 m.Set(val.Type, struct{}{}) 127 128 // Drill into interfaces and pointers. 129 seen := map[fakereflect.TypeAndCanAddr]struct{}{} 130 for val.IsInterface() || val.IsPtr() { 131 if val.IsInterface() { 132 return nil 133 } 134 val = val.Elem() 135 if _, ok := seen[val]; ok { 136 // Loop in type graph, e.g. 'type P *P' 137 return &CyclicTypeError{val.Type, stack} 138 } 139 seen[val] = struct{}{} 140 } 141 142 // Check for marshaler. 143 if implementsMarshaler(val) { 144 return nil 145 } 146 if val.CanAddr() { 147 pv := fakereflect.PtrTo(val) 148 if implementsMarshaler(pv) { 149 return nil 150 } 151 } 152 153 // Check for text marshaler. 154 if val.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) { 155 return nil 156 } 157 if val.CanAddr() { 158 pv := fakereflect.PtrTo(val) 159 if pv.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) { 160 return nil 161 } 162 } 163 164 // Slices and arrays iterate over the elements. They do not have an enclosing tag. 165 if (val.IsSlice() || val.IsArray()) && !isByteArray(val) && !isByteSlice(val) { 166 if err := e.marshalValue(val.Elem(), finfo, startTemplate, stack+"[0]"); err != nil { 167 return err 168 } 169 return nil 170 } 171 172 tinfo, err := getTypeInfo(val) 173 if err != nil { 174 return err 175 } 176 177 // Create start element. 178 // Precedence for the XML element name is: 179 // 0. startTemplate 180 // 1. XMLName field in underlying struct; 181 // 2. field name/tag in the struct field; and 182 // 3. type name 183 var start StartElement 184 185 if startTemplate != nil { 186 start.Name = startTemplate.Name 187 start.Attr = append(start.Attr, startTemplate.Attr...) 188 } else if tinfo.xmlname != nil { 189 xmlname := tinfo.xmlname 190 if xmlname.name != "" { 191 start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name 192 } 193 } 194 195 // Attributes 196 for i := range tinfo.fields { 197 finfo := &tinfo.fields[i] 198 if finfo.flags&fAttr == 0 { 199 continue 200 } 201 fv := finfo.value(val) 202 203 name := Name{Space: finfo.xmlns, Local: finfo.name} 204 if err := e.marshalAttr(&start, name, fv, stack+pathByIndex(val, finfo.idx)); err != nil { 205 return err 206 } 207 } 208 209 if val.IsStruct() { 210 return e.marshalStruct(tinfo, val, stack) 211 } else { 212 return e.marshalSimple(val, stack) 213 } 214 } 215 216 func isSlice(v fakereflect.TypeAndCanAddr) bool { 217 _, ok := v.Type.Underlying().(*types.Slice) 218 return ok 219 } 220 221 func isByteSlice(v fakereflect.TypeAndCanAddr) bool { 222 slice, ok := v.Type.Underlying().(*types.Slice) 223 if !ok { 224 return false 225 } 226 basic, ok := slice.Elem().Underlying().(*types.Basic) 227 if !ok { 228 return false 229 } 230 return basic.Kind() == types.Uint8 231 } 232 233 func isByteArray(v fakereflect.TypeAndCanAddr) bool { 234 slice, ok := v.Type.Underlying().(*types.Array) 235 if !ok { 236 return false 237 } 238 basic, ok := slice.Elem().Underlying().(*types.Basic) 239 if !ok { 240 return false 241 } 242 return basic.Kind() == types.Uint8 243 } 244 245 // marshalAttr marshals an attribute with the given name and value, adding to start.Attr. 246 func (e *Encoder) marshalAttr(start *StartElement, name Name, val fakereflect.TypeAndCanAddr, stack string) error { 247 if implementsMarshalerAttr(val) { 248 return nil 249 } 250 251 if val.CanAddr() { 252 pv := fakereflect.PtrTo(val) 253 if implementsMarshalerAttr(pv) { 254 return nil 255 } 256 } 257 258 if val.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) { 259 return nil 260 } 261 262 if val.CanAddr() { 263 pv := fakereflect.PtrTo(val) 264 if pv.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) { 265 return nil 266 } 267 } 268 269 // Dereference or skip nil pointer 270 if val.IsPtr() { 271 val = val.Elem() 272 } 273 274 // Walk slices. 275 if isSlice(val) && !isByteSlice(val) { 276 if err := e.marshalAttr(start, name, val.Elem(), stack+"[0]"); err != nil { 277 return err 278 } 279 return nil 280 } 281 282 if typeutil.IsType(val.Type, "encoding/xml.Attr") { 283 return nil 284 } 285 286 return e.marshalSimple(val, stack) 287 } 288 289 func (e *Encoder) marshalSimple(val fakereflect.TypeAndCanAddr, stack string) error { 290 switch val.Type.Underlying().(type) { 291 case *types.Basic, *types.Interface: 292 return nil 293 case *types.Slice, *types.Array: 294 basic, ok := val.Elem().Type.Underlying().(*types.Basic) 295 if !ok || basic.Kind() != types.Uint8 { 296 return &UnsupportedTypeError{val.Type, stack} 297 } 298 return nil 299 default: 300 return &UnsupportedTypeError{val.Type, stack} 301 } 302 } 303 304 func indirect(vf fakereflect.TypeAndCanAddr) fakereflect.TypeAndCanAddr { 305 for vf.IsPtr() { 306 vf = vf.Elem() 307 } 308 return vf 309 } 310 311 func pathByIndex(t fakereflect.TypeAndCanAddr, index []int) string { 312 path := "" 313 for _, i := range index { 314 if t.IsPtr() { 315 t = t.Elem() 316 } 317 path += "." + t.Field(i).Name 318 t = t.Field(i).Type 319 } 320 return path 321 } 322 323 func (e *Encoder) marshalStruct(tinfo *typeInfo, val fakereflect.TypeAndCanAddr, stack string) error { 324 for i := range tinfo.fields { 325 finfo := &tinfo.fields[i] 326 if finfo.flags&fAttr != 0 { 327 continue 328 } 329 vf := finfo.value(val) 330 331 switch finfo.flags & fMode { 332 case fCDATA, fCharData: 333 if vf.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) { 334 continue 335 } 336 if vf.CanAddr() { 337 pv := fakereflect.PtrTo(vf) 338 if pv.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) { 339 continue 340 } 341 } 342 continue 343 344 case fComment: 345 vf = indirect(vf) 346 if !(isByteSlice(vf) || isByteArray(vf)) { 347 return fmt.Errorf("xml: bad type for comment field of %s", val) 348 } 349 continue 350 351 case fInnerXML: 352 vf = indirect(vf) 353 if typeutil.IsType(vf.Type, "[]byte") || typeutil.IsType(vf.Type, "string") { 354 continue 355 } 356 357 case fElement, fElement | fAny: 358 } 359 if err := e.marshalValue(vf, finfo, nil, stack+pathByIndex(val, finfo.idx)); err != nil { 360 return err 361 } 362 } 363 return nil 364 } 365 366 // UnsupportedTypeError is returned when Marshal encounters a type 367 // that cannot be converted into XML. 368 type UnsupportedTypeError struct { 369 Type types.Type 370 Path string 371 } 372 373 func (e *UnsupportedTypeError) Error() string { 374 return fmt.Sprintf("xml: unsupported type %s, via %s ", e.Type, e.Path) 375 }