github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/staticcheck/fakejson/encode.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // This file contains a modified copy of the encoding/json encoder.
     6  // All dynamic behavior has been removed, and reflecttion has been replaced with go/types.
     7  // This allows us to statically find unmarshable types
     8  // with the same rules for tags, shadowing and addressability as encoding/json.
     9  // This is used for SA1026.
    10  
    11  package fakejson
    12  
    13  import (
    14  	"go/types"
    15  	"sort"
    16  	"strings"
    17  	"unicode"
    18  
    19  	"github.com/amarpal/go-tools/go/types/typeutil"
    20  	"github.com/amarpal/go-tools/knowledge"
    21  	"github.com/amarpal/go-tools/staticcheck/fakereflect"
    22  	"golang.org/x/exp/typeparams"
    23  )
    24  
    25  // parseTag splits a struct field's json tag into its name and
    26  // comma-separated options.
    27  func parseTag(tag string) string {
    28  	if idx := strings.Index(tag, ","); idx != -1 {
    29  		return tag[:idx]
    30  	}
    31  	return tag
    32  }
    33  
    34  func Marshal(v types.Type) *UnsupportedTypeError {
    35  	enc := encoder{}
    36  	return enc.newTypeEncoder(fakereflect.TypeAndCanAddr{Type: v}, "x")
    37  }
    38  
    39  // An UnsupportedTypeError is returned by Marshal when attempting
    40  // to encode an unsupported value type.
    41  type UnsupportedTypeError struct {
    42  	Type types.Type
    43  	Path string
    44  }
    45  
    46  type encoder struct {
    47  	// TODO we track addressable and non-addressable instances separately out of an abundance of caution. We don't know
    48  	// if this is actually required for correctness.
    49  	seenCanAddr  typeutil.Map[struct{}]
    50  	seenCantAddr typeutil.Map[struct{}]
    51  }
    52  
    53  func (enc *encoder) newTypeEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError {
    54  	var m *typeutil.Map[struct{}]
    55  	if t.CanAddr() {
    56  		m = &enc.seenCanAddr
    57  	} else {
    58  		m = &enc.seenCantAddr
    59  	}
    60  	if _, ok := m.At(t.Type); ok {
    61  		return nil
    62  	}
    63  	m.Set(t.Type, struct{}{})
    64  
    65  	if t.Implements(knowledge.Interfaces["encoding/json.Marshaler"]) {
    66  		return nil
    67  	}
    68  	if !t.IsPtr() && t.CanAddr() && fakereflect.PtrTo(t).Implements(knowledge.Interfaces["encoding/json.Marshaler"]) {
    69  		return nil
    70  	}
    71  	if t.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
    72  		return nil
    73  	}
    74  	if !t.IsPtr() && t.CanAddr() && fakereflect.PtrTo(t).Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
    75  		return nil
    76  	}
    77  
    78  	switch t.Type.Underlying().(type) {
    79  	case *types.Basic, *types.Interface:
    80  		return nil
    81  	case *types.Struct:
    82  		return enc.typeFields(t, stack)
    83  	case *types.Map:
    84  		return enc.newMapEncoder(t, stack)
    85  	case *types.Slice:
    86  		return enc.newSliceEncoder(t, stack)
    87  	case *types.Array:
    88  		return enc.newArrayEncoder(t, stack)
    89  	case *types.Pointer:
    90  		// we don't have to express the pointer dereference in the path; x.f is syntactic sugar for (*x).f
    91  		return enc.newTypeEncoder(t.Elem(), stack)
    92  	default:
    93  		return &UnsupportedTypeError{t.Type, stack}
    94  	}
    95  }
    96  
    97  func (enc *encoder) newMapEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError {
    98  	if typeparams.IsTypeParam(t.Key().Type) {
    99  		// We don't know enough about the concrete instantiation to say much about the key. The only time we could make
   100  		// a definite "this key is bad" statement is if the type parameter is constrained by type terms, none of which
   101  		// are tilde terms, none of which are a basic type. In all other cases, the key might implement TextMarshaler.
   102  		// It doesn't seem worth checking for that one single case.
   103  		return enc.newTypeEncoder(t.Elem(), stack+"[k]")
   104  	}
   105  
   106  	switch t.Key().Type.Underlying().(type) {
   107  	case *types.Basic:
   108  	default:
   109  		if !t.Key().Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
   110  			return &UnsupportedTypeError{
   111  				Type: t.Type,
   112  				Path: stack,
   113  			}
   114  		}
   115  	}
   116  	return enc.newTypeEncoder(t.Elem(), stack+"[k]")
   117  }
   118  
   119  func (enc *encoder) newSliceEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError {
   120  	// Byte slices get special treatment; arrays don't.
   121  	basic, ok := t.Elem().Type.Underlying().(*types.Basic)
   122  	if ok && basic.Kind() == types.Uint8 {
   123  		p := fakereflect.PtrTo(t.Elem())
   124  		if !p.Implements(knowledge.Interfaces["encoding/json.Marshaler"]) && !p.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
   125  			return nil
   126  		}
   127  	}
   128  	return enc.newArrayEncoder(t, stack)
   129  }
   130  
   131  func (enc *encoder) newArrayEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError {
   132  	return enc.newTypeEncoder(t.Elem(), stack+"[0]")
   133  }
   134  
   135  func isValidTag(s string) bool {
   136  	if s == "" {
   137  		return false
   138  	}
   139  	for _, c := range s {
   140  		switch {
   141  		case strings.ContainsRune("!#$%&()*+-./:;<=>?@[]^_{|}~ ", c):
   142  			// Backslash and quote chars are reserved, but
   143  			// otherwise any punctuation chars are allowed
   144  			// in a tag name.
   145  		case !unicode.IsLetter(c) && !unicode.IsDigit(c):
   146  			return false
   147  		}
   148  	}
   149  	return true
   150  }
   151  
   152  func typeByIndex(t fakereflect.TypeAndCanAddr, index []int) fakereflect.TypeAndCanAddr {
   153  	for _, i := range index {
   154  		if t.IsPtr() {
   155  			t = t.Elem()
   156  		}
   157  		t = t.Field(i).Type
   158  	}
   159  	return t
   160  }
   161  
   162  func pathByIndex(t fakereflect.TypeAndCanAddr, index []int) string {
   163  	path := ""
   164  	for _, i := range index {
   165  		if t.IsPtr() {
   166  			t = t.Elem()
   167  		}
   168  		path += "." + t.Field(i).Name
   169  		t = t.Field(i).Type
   170  	}
   171  	return path
   172  }
   173  
   174  // A field represents a single field found in a struct.
   175  type field struct {
   176  	name string
   177  
   178  	tag   bool
   179  	index []int
   180  	typ   fakereflect.TypeAndCanAddr
   181  }
   182  
   183  // byIndex sorts field by index sequence.
   184  type byIndex []field
   185  
   186  func (x byIndex) Len() int { return len(x) }
   187  
   188  func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
   189  
   190  func (x byIndex) Less(i, j int) bool {
   191  	for k, xik := range x[i].index {
   192  		if k >= len(x[j].index) {
   193  			return false
   194  		}
   195  		if xik != x[j].index[k] {
   196  			return xik < x[j].index[k]
   197  		}
   198  	}
   199  	return len(x[i].index) < len(x[j].index)
   200  }
   201  
   202  // typeFields returns a list of fields that JSON should recognize for the given type.
   203  // The algorithm is breadth-first search over the set of structs to include - the top struct
   204  // and then any reachable anonymous structs.
   205  func (enc *encoder) typeFields(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError {
   206  	// Anonymous fields to explore at the current level and the next.
   207  	current := []field{}
   208  	next := []field{{typ: t}}
   209  
   210  	// Count of queued names for current level and the next.
   211  	var count, nextCount map[fakereflect.TypeAndCanAddr]int
   212  
   213  	// Types already visited at an earlier level.
   214  	visited := map[fakereflect.TypeAndCanAddr]bool{}
   215  
   216  	// Fields found.
   217  	var fields []field
   218  
   219  	for len(next) > 0 {
   220  		current, next = next, current[:0]
   221  		count, nextCount = nextCount, map[fakereflect.TypeAndCanAddr]int{}
   222  
   223  		for _, f := range current {
   224  			if visited[f.typ] {
   225  				continue
   226  			}
   227  			visited[f.typ] = true
   228  
   229  			// Scan f.typ for fields to include.
   230  			for i := 0; i < f.typ.NumField(); i++ {
   231  				sf := f.typ.Field(i)
   232  				if sf.Anonymous {
   233  					t := sf.Type
   234  					if t.IsPtr() {
   235  						t = t.Elem()
   236  					}
   237  					if !sf.IsExported() && !t.IsStruct() {
   238  						// Ignore embedded fields of unexported non-struct types.
   239  						continue
   240  					}
   241  					// Do not ignore embedded fields of unexported struct types
   242  					// since they may have exported fields.
   243  				} else if !sf.IsExported() {
   244  					// Ignore unexported non-embedded fields.
   245  					continue
   246  				}
   247  				tag := sf.Tag.Get("json")
   248  				if tag == "-" {
   249  					continue
   250  				}
   251  				name := parseTag(tag)
   252  				if !isValidTag(name) {
   253  					name = ""
   254  				}
   255  				index := make([]int, len(f.index)+1)
   256  				copy(index, f.index)
   257  				index[len(f.index)] = i
   258  
   259  				ft := sf.Type
   260  				if ft.Name() == "" && ft.IsPtr() {
   261  					// Follow pointer.
   262  					ft = ft.Elem()
   263  				}
   264  
   265  				// Record found field and index sequence.
   266  				if name != "" || !sf.Anonymous || !ft.IsStruct() {
   267  					tagged := name != ""
   268  					if name == "" {
   269  						name = sf.Name
   270  					}
   271  					field := field{
   272  						name:  name,
   273  						tag:   tagged,
   274  						index: index,
   275  						typ:   ft,
   276  					}
   277  
   278  					fields = append(fields, field)
   279  					if count[f.typ] > 1 {
   280  						// If there were multiple instances, add a second,
   281  						// so that the annihilation code will see a duplicate.
   282  						// It only cares about the distinction between 1 or 2,
   283  						// so don't bother generating any more copies.
   284  						fields = append(fields, fields[len(fields)-1])
   285  					}
   286  					continue
   287  				}
   288  
   289  				// Record new anonymous struct to explore in next round.
   290  				nextCount[ft]++
   291  				if nextCount[ft] == 1 {
   292  					next = append(next, field{name: ft.Name(), index: index, typ: ft})
   293  				}
   294  			}
   295  		}
   296  	}
   297  
   298  	sort.Slice(fields, func(i, j int) bool {
   299  		x := fields
   300  		// sort field by name, breaking ties with depth, then
   301  		// breaking ties with "name came from json tag", then
   302  		// breaking ties with index sequence.
   303  		if x[i].name != x[j].name {
   304  			return x[i].name < x[j].name
   305  		}
   306  		if len(x[i].index) != len(x[j].index) {
   307  			return len(x[i].index) < len(x[j].index)
   308  		}
   309  		if x[i].tag != x[j].tag {
   310  			return x[i].tag
   311  		}
   312  		return byIndex(x).Less(i, j)
   313  	})
   314  
   315  	// Delete all fields that are hidden by the Go rules for embedded fields,
   316  	// except that fields with JSON tags are promoted.
   317  
   318  	// The fields are sorted in primary order of name, secondary order
   319  	// of field index length. Loop over names; for each name, delete
   320  	// hidden fields by choosing the one dominant field that survives.
   321  	out := fields[:0]
   322  	for advance, i := 0, 0; i < len(fields); i += advance {
   323  		// One iteration per name.
   324  		// Find the sequence of fields with the name of this first field.
   325  		fi := fields[i]
   326  		name := fi.name
   327  		for advance = 1; i+advance < len(fields); advance++ {
   328  			fj := fields[i+advance]
   329  			if fj.name != name {
   330  				break
   331  			}
   332  		}
   333  		if advance == 1 { // Only one field with this name
   334  			out = append(out, fi)
   335  			continue
   336  		}
   337  		dominant, ok := dominantField(fields[i : i+advance])
   338  		if ok {
   339  			out = append(out, dominant)
   340  		}
   341  	}
   342  
   343  	fields = out
   344  	sort.Sort(byIndex(fields))
   345  
   346  	for i := range fields {
   347  		f := &fields[i]
   348  		err := enc.newTypeEncoder(typeByIndex(t, f.index), stack+pathByIndex(t, f.index))
   349  		if err != nil {
   350  			return err
   351  		}
   352  	}
   353  	return nil
   354  }
   355  
   356  // dominantField looks through the fields, all of which are known to
   357  // have the same name, to find the single field that dominates the
   358  // others using Go's embedding rules, modified by the presence of
   359  // JSON tags. If there are multiple top-level fields, the boolean
   360  // will be false: This condition is an error in Go and we skip all
   361  // the fields.
   362  func dominantField(fields []field) (field, bool) {
   363  	// The fields are sorted in increasing index-length order, then by presence of tag.
   364  	// That means that the first field is the dominant one. We need only check
   365  	// for error cases: two fields at top level, either both tagged or neither tagged.
   366  	if len(fields) > 1 && len(fields[0].index) == len(fields[1].index) && fields[0].tag == fields[1].tag {
   367  		return field{}, false
   368  	}
   369  	return fields[0], true
   370  }