github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/staticcheck/fakejson/encode.go (about) 1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // This file contains a modified copy of the encoding/json encoder. 6 // All dynamic behavior has been removed, and reflecttion has been replaced with go/types. 7 // This allows us to statically find unmarshable types 8 // with the same rules for tags, shadowing and addressability as encoding/json. 9 // This is used for SA1026. 10 11 package fakejson 12 13 import ( 14 "go/types" 15 "sort" 16 "strings" 17 "unicode" 18 19 "github.com/amarpal/go-tools/go/types/typeutil" 20 "github.com/amarpal/go-tools/knowledge" 21 "github.com/amarpal/go-tools/staticcheck/fakereflect" 22 "golang.org/x/exp/typeparams" 23 ) 24 25 // parseTag splits a struct field's json tag into its name and 26 // comma-separated options. 27 func parseTag(tag string) string { 28 if idx := strings.Index(tag, ","); idx != -1 { 29 return tag[:idx] 30 } 31 return tag 32 } 33 34 func Marshal(v types.Type) *UnsupportedTypeError { 35 enc := encoder{} 36 return enc.newTypeEncoder(fakereflect.TypeAndCanAddr{Type: v}, "x") 37 } 38 39 // An UnsupportedTypeError is returned by Marshal when attempting 40 // to encode an unsupported value type. 41 type UnsupportedTypeError struct { 42 Type types.Type 43 Path string 44 } 45 46 type encoder struct { 47 // TODO we track addressable and non-addressable instances separately out of an abundance of caution. We don't know 48 // if this is actually required for correctness. 49 seenCanAddr typeutil.Map[struct{}] 50 seenCantAddr typeutil.Map[struct{}] 51 } 52 53 func (enc *encoder) newTypeEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError { 54 var m *typeutil.Map[struct{}] 55 if t.CanAddr() { 56 m = &enc.seenCanAddr 57 } else { 58 m = &enc.seenCantAddr 59 } 60 if _, ok := m.At(t.Type); ok { 61 return nil 62 } 63 m.Set(t.Type, struct{}{}) 64 65 if t.Implements(knowledge.Interfaces["encoding/json.Marshaler"]) { 66 return nil 67 } 68 if !t.IsPtr() && t.CanAddr() && fakereflect.PtrTo(t).Implements(knowledge.Interfaces["encoding/json.Marshaler"]) { 69 return nil 70 } 71 if t.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) { 72 return nil 73 } 74 if !t.IsPtr() && t.CanAddr() && fakereflect.PtrTo(t).Implements(knowledge.Interfaces["encoding.TextMarshaler"]) { 75 return nil 76 } 77 78 switch t.Type.Underlying().(type) { 79 case *types.Basic, *types.Interface: 80 return nil 81 case *types.Struct: 82 return enc.typeFields(t, stack) 83 case *types.Map: 84 return enc.newMapEncoder(t, stack) 85 case *types.Slice: 86 return enc.newSliceEncoder(t, stack) 87 case *types.Array: 88 return enc.newArrayEncoder(t, stack) 89 case *types.Pointer: 90 // we don't have to express the pointer dereference in the path; x.f is syntactic sugar for (*x).f 91 return enc.newTypeEncoder(t.Elem(), stack) 92 default: 93 return &UnsupportedTypeError{t.Type, stack} 94 } 95 } 96 97 func (enc *encoder) newMapEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError { 98 if typeparams.IsTypeParam(t.Key().Type) { 99 // We don't know enough about the concrete instantiation to say much about the key. The only time we could make 100 // a definite "this key is bad" statement is if the type parameter is constrained by type terms, none of which 101 // are tilde terms, none of which are a basic type. In all other cases, the key might implement TextMarshaler. 102 // It doesn't seem worth checking for that one single case. 103 return enc.newTypeEncoder(t.Elem(), stack+"[k]") 104 } 105 106 switch t.Key().Type.Underlying().(type) { 107 case *types.Basic: 108 default: 109 if !t.Key().Implements(knowledge.Interfaces["encoding.TextMarshaler"]) { 110 return &UnsupportedTypeError{ 111 Type: t.Type, 112 Path: stack, 113 } 114 } 115 } 116 return enc.newTypeEncoder(t.Elem(), stack+"[k]") 117 } 118 119 func (enc *encoder) newSliceEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError { 120 // Byte slices get special treatment; arrays don't. 121 basic, ok := t.Elem().Type.Underlying().(*types.Basic) 122 if ok && basic.Kind() == types.Uint8 { 123 p := fakereflect.PtrTo(t.Elem()) 124 if !p.Implements(knowledge.Interfaces["encoding/json.Marshaler"]) && !p.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) { 125 return nil 126 } 127 } 128 return enc.newArrayEncoder(t, stack) 129 } 130 131 func (enc *encoder) newArrayEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError { 132 return enc.newTypeEncoder(t.Elem(), stack+"[0]") 133 } 134 135 func isValidTag(s string) bool { 136 if s == "" { 137 return false 138 } 139 for _, c := range s { 140 switch { 141 case strings.ContainsRune("!#$%&()*+-./:;<=>?@[]^_{|}~ ", c): 142 // Backslash and quote chars are reserved, but 143 // otherwise any punctuation chars are allowed 144 // in a tag name. 145 case !unicode.IsLetter(c) && !unicode.IsDigit(c): 146 return false 147 } 148 } 149 return true 150 } 151 152 func typeByIndex(t fakereflect.TypeAndCanAddr, index []int) fakereflect.TypeAndCanAddr { 153 for _, i := range index { 154 if t.IsPtr() { 155 t = t.Elem() 156 } 157 t = t.Field(i).Type 158 } 159 return t 160 } 161 162 func pathByIndex(t fakereflect.TypeAndCanAddr, index []int) string { 163 path := "" 164 for _, i := range index { 165 if t.IsPtr() { 166 t = t.Elem() 167 } 168 path += "." + t.Field(i).Name 169 t = t.Field(i).Type 170 } 171 return path 172 } 173 174 // A field represents a single field found in a struct. 175 type field struct { 176 name string 177 178 tag bool 179 index []int 180 typ fakereflect.TypeAndCanAddr 181 } 182 183 // byIndex sorts field by index sequence. 184 type byIndex []field 185 186 func (x byIndex) Len() int { return len(x) } 187 188 func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] } 189 190 func (x byIndex) Less(i, j int) bool { 191 for k, xik := range x[i].index { 192 if k >= len(x[j].index) { 193 return false 194 } 195 if xik != x[j].index[k] { 196 return xik < x[j].index[k] 197 } 198 } 199 return len(x[i].index) < len(x[j].index) 200 } 201 202 // typeFields returns a list of fields that JSON should recognize for the given type. 203 // The algorithm is breadth-first search over the set of structs to include - the top struct 204 // and then any reachable anonymous structs. 205 func (enc *encoder) typeFields(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError { 206 // Anonymous fields to explore at the current level and the next. 207 current := []field{} 208 next := []field{{typ: t}} 209 210 // Count of queued names for current level and the next. 211 var count, nextCount map[fakereflect.TypeAndCanAddr]int 212 213 // Types already visited at an earlier level. 214 visited := map[fakereflect.TypeAndCanAddr]bool{} 215 216 // Fields found. 217 var fields []field 218 219 for len(next) > 0 { 220 current, next = next, current[:0] 221 count, nextCount = nextCount, map[fakereflect.TypeAndCanAddr]int{} 222 223 for _, f := range current { 224 if visited[f.typ] { 225 continue 226 } 227 visited[f.typ] = true 228 229 // Scan f.typ for fields to include. 230 for i := 0; i < f.typ.NumField(); i++ { 231 sf := f.typ.Field(i) 232 if sf.Anonymous { 233 t := sf.Type 234 if t.IsPtr() { 235 t = t.Elem() 236 } 237 if !sf.IsExported() && !t.IsStruct() { 238 // Ignore embedded fields of unexported non-struct types. 239 continue 240 } 241 // Do not ignore embedded fields of unexported struct types 242 // since they may have exported fields. 243 } else if !sf.IsExported() { 244 // Ignore unexported non-embedded fields. 245 continue 246 } 247 tag := sf.Tag.Get("json") 248 if tag == "-" { 249 continue 250 } 251 name := parseTag(tag) 252 if !isValidTag(name) { 253 name = "" 254 } 255 index := make([]int, len(f.index)+1) 256 copy(index, f.index) 257 index[len(f.index)] = i 258 259 ft := sf.Type 260 if ft.Name() == "" && ft.IsPtr() { 261 // Follow pointer. 262 ft = ft.Elem() 263 } 264 265 // Record found field and index sequence. 266 if name != "" || !sf.Anonymous || !ft.IsStruct() { 267 tagged := name != "" 268 if name == "" { 269 name = sf.Name 270 } 271 field := field{ 272 name: name, 273 tag: tagged, 274 index: index, 275 typ: ft, 276 } 277 278 fields = append(fields, field) 279 if count[f.typ] > 1 { 280 // If there were multiple instances, add a second, 281 // so that the annihilation code will see a duplicate. 282 // It only cares about the distinction between 1 or 2, 283 // so don't bother generating any more copies. 284 fields = append(fields, fields[len(fields)-1]) 285 } 286 continue 287 } 288 289 // Record new anonymous struct to explore in next round. 290 nextCount[ft]++ 291 if nextCount[ft] == 1 { 292 next = append(next, field{name: ft.Name(), index: index, typ: ft}) 293 } 294 } 295 } 296 } 297 298 sort.Slice(fields, func(i, j int) bool { 299 x := fields 300 // sort field by name, breaking ties with depth, then 301 // breaking ties with "name came from json tag", then 302 // breaking ties with index sequence. 303 if x[i].name != x[j].name { 304 return x[i].name < x[j].name 305 } 306 if len(x[i].index) != len(x[j].index) { 307 return len(x[i].index) < len(x[j].index) 308 } 309 if x[i].tag != x[j].tag { 310 return x[i].tag 311 } 312 return byIndex(x).Less(i, j) 313 }) 314 315 // Delete all fields that are hidden by the Go rules for embedded fields, 316 // except that fields with JSON tags are promoted. 317 318 // The fields are sorted in primary order of name, secondary order 319 // of field index length. Loop over names; for each name, delete 320 // hidden fields by choosing the one dominant field that survives. 321 out := fields[:0] 322 for advance, i := 0, 0; i < len(fields); i += advance { 323 // One iteration per name. 324 // Find the sequence of fields with the name of this first field. 325 fi := fields[i] 326 name := fi.name 327 for advance = 1; i+advance < len(fields); advance++ { 328 fj := fields[i+advance] 329 if fj.name != name { 330 break 331 } 332 } 333 if advance == 1 { // Only one field with this name 334 out = append(out, fi) 335 continue 336 } 337 dominant, ok := dominantField(fields[i : i+advance]) 338 if ok { 339 out = append(out, dominant) 340 } 341 } 342 343 fields = out 344 sort.Sort(byIndex(fields)) 345 346 for i := range fields { 347 f := &fields[i] 348 err := enc.newTypeEncoder(typeByIndex(t, f.index), stack+pathByIndex(t, f.index)) 349 if err != nil { 350 return err 351 } 352 } 353 return nil 354 } 355 356 // dominantField looks through the fields, all of which are known to 357 // have the same name, to find the single field that dominates the 358 // others using Go's embedding rules, modified by the presence of 359 // JSON tags. If there are multiple top-level fields, the boolean 360 // will be false: This condition is an error in Go and we skip all 361 // the fields. 362 func dominantField(fields []field) (field, bool) { 363 // The fields are sorted in increasing index-length order, then by presence of tag. 364 // That means that the first field is the dominant one. We need only check 365 // for error cases: two fields at top level, either both tagged or neither tagged. 366 if len(fields) > 1 && len(fields[0].index) == len(fields[1].index) && fields[0].tag == fields[1].tag { 367 return field{}, false 368 } 369 return fields[0], true 370 }