golang.org/x/tools/gopls@v0.15.3/internal/util/frob/frob.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package frob is a fast restricted object encoder/decoder in the
     6  // spirit of encoding/gob.
     7  //
     8  // As with gob, types that recursively contain functions, channels,
     9  // and unsafe.Pointers cannot be encoded, but frob has these
    10  // additional restrictions:
    11  //
    12  //   - Interface values are not supported; this avoids the need for
    13  //     the encoding to describe types.
    14  //
    15  //   - Types that recursively contain private struct fields are not
    16  //     permitted.
    17  //
    18  //   - The encoding is unspecified and subject to change, so the encoder
    19  //     and decoder must exactly agree on their implementation and on the
    20  //     definitions of the target types.
    21  //
    22  //   - Lengths (of arrays, slices, and maps) are currently assumed to
    23  //     fit in 32 bits.
    24  //
    25  //   - There is no error handling. All errors are reported by panicking.
    26  //
    27  //   - Values are serialized as trees, not graphs, so shared subgraphs
    28  //     are encoded repeatedly.
    29  //
    30  //   - No attempt is made to detect cyclic data structures.
    31  package frob
    32  
    33  import (
    34  	"encoding/binary"
    35  	"fmt"
    36  	"math"
    37  	"reflect"
    38  	"sync"
    39  )
    40  
    41  // A Codec[T] is an immutable encoder and decoder for values of type T.
    42  type Codec[T any] struct{ frob *frob }
    43  
    44  // CodecFor[T] returns a codec for values of type T.
    45  // It panics if type T is unsuitable.
    46  func CodecFor[T any]() Codec[T] {
    47  	frobsMu.Lock()
    48  	defer frobsMu.Unlock()
    49  	return Codec[T]{frobFor(reflect.TypeOf((*T)(nil)).Elem())}
    50  }
    51  
    52  func (codec Codec[T]) Encode(v T) []byte          { return codec.frob.Encode(v) }
    53  func (codec Codec[T]) Decode(data []byte, ptr *T) { codec.frob.Decode(data, ptr) }
    54  
    55  var (
    56  	frobsMu sync.Mutex
    57  	frobs   = make(map[reflect.Type]*frob)
    58  )
    59  
    60  // A frob is an encoder/decoder for a specific type.
    61  type frob struct {
    62  	t     reflect.Type
    63  	kind  reflect.Kind
    64  	elems []*frob // elem (array/slice/ptr), key+value (map), fields (struct)
    65  }
    66  
    67  // frobFor returns the frob for a particular type.
    68  // Precondition: caller holds frobsMu.
    69  func frobFor(t reflect.Type) *frob {
    70  	fr, ok := frobs[t]
    71  	if !ok {
    72  		fr = &frob{t: t, kind: t.Kind()}
    73  		frobs[t] = fr
    74  
    75  		switch fr.kind {
    76  		case reflect.Bool,
    77  			reflect.Int,
    78  			reflect.Int8,
    79  			reflect.Int16,
    80  			reflect.Int32,
    81  			reflect.Int64,
    82  			reflect.Uint,
    83  			reflect.Uint8,
    84  			reflect.Uint16,
    85  			reflect.Uint32,
    86  			reflect.Uint64,
    87  			reflect.Uintptr,
    88  			reflect.Float32,
    89  			reflect.Float64,
    90  			reflect.Complex64,
    91  			reflect.Complex128,
    92  			reflect.String:
    93  
    94  		case reflect.Array,
    95  			reflect.Slice,
    96  			reflect.Ptr: // TODO(adonovan): after go1.18, use Pointer
    97  			fr.addElem(fr.t.Elem())
    98  
    99  		case reflect.Map:
   100  			fr.addElem(fr.t.Key())
   101  			fr.addElem(fr.t.Elem())
   102  
   103  		case reflect.Struct:
   104  			for i := 0; i < fr.t.NumField(); i++ {
   105  				field := fr.t.Field(i)
   106  				if field.PkgPath != "" {
   107  					panic(fmt.Sprintf("unexported field %v", field))
   108  				}
   109  				fr.addElem(field.Type)
   110  			}
   111  
   112  		default:
   113  			// chan, func, interface, unsafe.Pointer
   114  			panic(fmt.Sprintf("type %v is not supported by frob", fr.t))
   115  		}
   116  	}
   117  	return fr
   118  }
   119  
   120  func (fr *frob) addElem(t reflect.Type) {
   121  	fr.elems = append(fr.elems, frobFor(t))
   122  }
   123  
   124  const magic = "frob"
   125  
   126  func (fr *frob) Encode(v any) []byte {
   127  	rv := reflect.ValueOf(v)
   128  	if rv.Type() != fr.t {
   129  		panic(fmt.Sprintf("got %v, want %v", rv.Type(), fr.t))
   130  	}
   131  	w := &writer{}
   132  	w.bytes([]byte(magic))
   133  	fr.encode(w, rv)
   134  	if uint64(len(w.data))>>32 != 0 {
   135  		panic("too large") // includes all cases where len doesn't fit in 32 bits
   136  	}
   137  	return w.data
   138  }
   139  
   140  // encode appends the encoding of value v, whose type must be fr.t.
   141  func (fr *frob) encode(out *writer, v reflect.Value) {
   142  	switch fr.kind {
   143  	case reflect.Bool:
   144  		var b byte
   145  		if v.Bool() {
   146  			b = 1
   147  		}
   148  		out.uint8(b)
   149  	case reflect.Int:
   150  		out.uint64(uint64(v.Int()))
   151  	case reflect.Int8:
   152  		out.uint8(uint8(v.Int()))
   153  	case reflect.Int16:
   154  		out.uint16(uint16(v.Int()))
   155  	case reflect.Int32:
   156  		out.uint32(uint32(v.Int()))
   157  	case reflect.Int64:
   158  		out.uint64(uint64(v.Int()))
   159  	case reflect.Uint:
   160  		out.uint64(v.Uint())
   161  	case reflect.Uint8:
   162  		out.uint8(uint8(v.Uint()))
   163  	case reflect.Uint16:
   164  		out.uint16(uint16(v.Uint()))
   165  	case reflect.Uint32:
   166  		out.uint32(uint32(v.Uint()))
   167  	case reflect.Uint64:
   168  		out.uint64(v.Uint())
   169  	case reflect.Uintptr:
   170  		out.uint64(v.Uint())
   171  	case reflect.Float32:
   172  		out.uint32(math.Float32bits(float32(v.Float())))
   173  	case reflect.Float64:
   174  		out.uint64(math.Float64bits(v.Float()))
   175  	case reflect.Complex64:
   176  		z := complex64(v.Complex())
   177  		out.uint32(math.Float32bits(real(z)))
   178  		out.uint32(math.Float32bits(imag(z)))
   179  	case reflect.Complex128:
   180  		z := v.Complex()
   181  		out.uint64(math.Float64bits(real(z)))
   182  		out.uint64(math.Float64bits(imag(z)))
   183  
   184  	case reflect.Array:
   185  		len := v.Type().Len()
   186  		elem := fr.elems[0]
   187  		for i := 0; i < len; i++ {
   188  			elem.encode(out, v.Index(i))
   189  		}
   190  
   191  	case reflect.Slice:
   192  		len := v.Len()
   193  		out.uint32(uint32(len))
   194  		if len > 0 {
   195  			elem := fr.elems[0]
   196  			if elem.kind == reflect.Uint8 {
   197  				// []byte fast path
   198  				out.bytes(v.Bytes())
   199  			} else {
   200  				for i := 0; i < len; i++ {
   201  					elem.encode(out, v.Index(i))
   202  				}
   203  			}
   204  		}
   205  
   206  	case reflect.Map:
   207  		len := v.Len()
   208  		out.uint32(uint32(len))
   209  		if len > 0 {
   210  			kfrob, vfrob := fr.elems[0], fr.elems[1]
   211  			for iter := v.MapRange(); iter.Next(); {
   212  				kfrob.encode(out, iter.Key())
   213  				vfrob.encode(out, iter.Value())
   214  			}
   215  		}
   216  
   217  	case reflect.Ptr: // TODO(adonovan): after go1.18, use Pointer
   218  		if v.IsNil() {
   219  			out.uint8(0)
   220  		} else {
   221  			out.uint8(1)
   222  			fr.elems[0].encode(out, v.Elem())
   223  		}
   224  
   225  	case reflect.String:
   226  		len := v.Len()
   227  		out.uint32(uint32(len))
   228  		if len > 0 {
   229  			out.data = append(out.data, v.String()...)
   230  		}
   231  
   232  	case reflect.Struct:
   233  		for i, elem := range fr.elems {
   234  			elem.encode(out, v.Field(i))
   235  		}
   236  
   237  	default:
   238  		panic(fr.t)
   239  	}
   240  }
   241  
   242  func (fr *frob) Decode(data []byte, ptr any) {
   243  	rv := reflect.ValueOf(ptr).Elem()
   244  	if rv.Type() != fr.t {
   245  		panic(fmt.Sprintf("got %v, want %v", rv.Type(), fr.t))
   246  	}
   247  	rd := &reader{data}
   248  	if string(rd.bytes(4)) != magic {
   249  		panic("not a frob-encoded message")
   250  	}
   251  	fr.decode(rd, rv)
   252  	if len(rd.data) > 0 {
   253  		panic("surplus bytes")
   254  	}
   255  }
   256  
   257  // decode reads from in, decodes a value, and sets addr to it.
   258  // addr must be a zero-initialized addressable variable of type fr.t.
   259  func (fr *frob) decode(in *reader, addr reflect.Value) {
   260  	switch fr.kind {
   261  	case reflect.Bool:
   262  		addr.SetBool(in.uint8() != 0)
   263  	case reflect.Int:
   264  		addr.SetInt(int64(in.uint64()))
   265  	case reflect.Int8:
   266  		addr.SetInt(int64(in.uint8()))
   267  	case reflect.Int16:
   268  		addr.SetInt(int64(in.uint16()))
   269  	case reflect.Int32:
   270  		addr.SetInt(int64(in.uint32()))
   271  	case reflect.Int64:
   272  		addr.SetInt(int64(in.uint64()))
   273  	case reflect.Uint:
   274  		addr.SetUint(in.uint64())
   275  	case reflect.Uint8:
   276  		addr.SetUint(uint64(in.uint8()))
   277  	case reflect.Uint16:
   278  		addr.SetUint(uint64(in.uint16()))
   279  	case reflect.Uint32:
   280  		addr.SetUint(uint64(in.uint32()))
   281  	case reflect.Uint64:
   282  		addr.SetUint(in.uint64())
   283  	case reflect.Uintptr:
   284  		addr.SetUint(in.uint64())
   285  	case reflect.Float32:
   286  		addr.SetFloat(float64(math.Float32frombits(in.uint32())))
   287  	case reflect.Float64:
   288  		addr.SetFloat(math.Float64frombits(in.uint64()))
   289  	case reflect.Complex64:
   290  		addr.SetComplex(complex128(complex(
   291  			math.Float32frombits(in.uint32()),
   292  			math.Float32frombits(in.uint32()),
   293  		)))
   294  	case reflect.Complex128:
   295  		addr.SetComplex(complex(
   296  			math.Float64frombits(in.uint64()),
   297  			math.Float64frombits(in.uint64()),
   298  		))
   299  
   300  	case reflect.Array:
   301  		len := fr.t.Len()
   302  		for i := 0; i < len; i++ {
   303  			fr.elems[0].decode(in, addr.Index(i))
   304  		}
   305  
   306  	case reflect.Slice:
   307  		len := int(in.uint32())
   308  		if len > 0 {
   309  			elem := fr.elems[0]
   310  			if elem.kind == reflect.Uint8 {
   311  				// []byte fast path
   312  				// (Not addr.SetBytes: we must make a copy.)
   313  				addr.Set(reflect.AppendSlice(addr, reflect.ValueOf(in.bytes(len))))
   314  			} else {
   315  				addr.Set(reflect.MakeSlice(fr.t, len, len))
   316  				for i := 0; i < len; i++ {
   317  					elem.decode(in, addr.Index(i))
   318  				}
   319  			}
   320  		}
   321  
   322  	case reflect.Map:
   323  		len := int(in.uint32())
   324  		if len > 0 {
   325  			m := reflect.MakeMapWithSize(fr.t, len)
   326  			addr.Set(m)
   327  			kfrob, vfrob := fr.elems[0], fr.elems[1]
   328  			k := reflect.New(kfrob.t).Elem()
   329  			v := reflect.New(vfrob.t).Elem()
   330  			kzero := reflect.Zero(kfrob.t)
   331  			vzero := reflect.Zero(vfrob.t)
   332  			for i := 0; i < len; i++ {
   333  				// TODO(adonovan): use SetZero from go1.20.
   334  				// k.SetZero()
   335  				// v.SetZero()
   336  				k.Set(kzero)
   337  				v.Set(vzero)
   338  				kfrob.decode(in, k)
   339  				vfrob.decode(in, v)
   340  				m.SetMapIndex(k, v)
   341  			}
   342  		}
   343  
   344  	case reflect.Ptr: // TODO(adonovan): after go1.18, use Pointer
   345  		isNil := in.uint8() == 0
   346  		if !isNil {
   347  			ptr := reflect.New(fr.elems[0].t)
   348  			addr.Set(ptr)
   349  			fr.elems[0].decode(in, ptr.Elem())
   350  		}
   351  
   352  	case reflect.String:
   353  		len := int(in.uint32())
   354  		if len > 0 {
   355  			addr.SetString(string(in.bytes(len)))
   356  		}
   357  
   358  	case reflect.Struct:
   359  		for i, elem := range fr.elems {
   360  			elem.decode(in, addr.Field(i))
   361  		}
   362  
   363  	default:
   364  		panic(fr.t)
   365  	}
   366  }
   367  
   368  var le = binary.LittleEndian
   369  
   370  type reader struct{ data []byte }
   371  
   372  func (r *reader) uint8() uint8 {
   373  	v := r.data[0]
   374  	r.data = r.data[1:]
   375  	return v
   376  }
   377  
   378  func (r *reader) uint16() uint16 {
   379  	v := le.Uint16(r.data)
   380  	r.data = r.data[2:]
   381  	return v
   382  }
   383  
   384  func (r *reader) uint32() uint32 {
   385  	v := le.Uint32(r.data)
   386  	r.data = r.data[4:]
   387  	return v
   388  }
   389  
   390  func (r *reader) uint64() uint64 {
   391  	v := le.Uint64(r.data)
   392  	r.data = r.data[8:]
   393  	return v
   394  }
   395  
   396  func (r *reader) bytes(n int) []byte {
   397  	v := r.data[:n]
   398  	r.data = r.data[n:]
   399  	return v
   400  }
   401  
   402  type writer struct{ data []byte }
   403  
   404  func (w *writer) uint8(v uint8)   { w.data = append(w.data, v) }
   405  func (w *writer) uint16(v uint16) { w.data = appendUint16(w.data, v) }
   406  func (w *writer) uint32(v uint32) { w.data = appendUint32(w.data, v) }
   407  func (w *writer) uint64(v uint64) { w.data = appendUint64(w.data, v) }
   408  func (w *writer) bytes(v []byte)  { w.data = append(w.data, v...) }
   409  
   410  // TODO(adonovan): delete these as in go1.19 they are methods on LittleEndian:
   411  
   412  func appendUint16(b []byte, v uint16) []byte {
   413  	return append(b,
   414  		byte(v),
   415  		byte(v>>8),
   416  	)
   417  }
   418  
   419  func appendUint32(b []byte, v uint32) []byte {
   420  	return append(b,
   421  		byte(v),
   422  		byte(v>>8),
   423  		byte(v>>16),
   424  		byte(v>>24),
   425  	)
   426  }
   427  
   428  func appendUint64(b []byte, v uint64) []byte {
   429  	return append(b,
   430  		byte(v),
   431  		byte(v>>8),
   432  		byte(v>>16),
   433  		byte(v>>24),
   434  		byte(v>>32),
   435  		byte(v>>40),
   436  		byte(v>>48),
   437  		byte(v>>56),
   438  	)
   439  }