git.lukeshu.com/go/lowmemjson@v0.3.9-0.20230723050957-72f6d13f6fb2/encode.go (about)

     1  // Copyright (C) 2022-2023  Luke Shumaker <lukeshu@lukeshu.com>
     2  //
     3  // SPDX-License-Identifier: GPL-2.0-or-later
     4  
     5  package lowmemjson
     6  
     7  import (
     8  	"bytes"
     9  	"encoding"
    10  	"encoding/base64"
    11  	"fmt"
    12  	"io"
    13  	"reflect"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  	"unsafe"
    18  
    19  	"git.lukeshu.com/go/lowmemjson/internal/jsonstring"
    20  	"git.lukeshu.com/go/lowmemjson/internal/jsonstruct"
    21  )
    22  
    23  // Encodable is the interface implemented by types that can encode
    24  // themselves to JSON.  Encodable is a low-memory-overhead replacement
    25  // for the json.Marshaler interface.
    26  //
    27  // The io.Writer passed to EncodeJSON returns an error if invalid JSON
    28  // is written to it.
    29  type Encodable interface {
    30  	EncodeJSON(w io.Writer) error
    31  }
    32  
    33  // An Encoder encodes and writes values to a stream of JSON elements.
    34  //
    35  // Encoder is analogous to, and has a similar API to the standar
    36  // library's encoding/json.Encoder.  Differences are that rather than
    37  // having .SetEscapeHTML and .SetIndent methods, the io.Writer passed
    38  // to it may be a *ReEncoder that has these settings (and more).  If
    39  // something more similar to a json.Encoder is desired,
    40  // lowmemjson/compat/json.Encoder offers those .SetEscapeHTML and
    41  // .SetIndent methods.
    42  type Encoder struct {
    43  	w      *ReEncoder
    44  	isRoot bool
    45  }
    46  
    47  // NewEncoder returns a new Encoder that writes to w.
    48  //
    49  // If w is an *ReEncoder, then the inner backslash-escaping of
    50  // double-encoded ",string" tagged string values obeys the
    51  // *ReEncoder's BackslashEscape policy.
    52  //
    53  // An Encoder tends to make many small writes; if w.Write calls are
    54  // syscalls, then you may want to wrap w in a bufio.Writer.
    55  func NewEncoder(w io.Writer) *Encoder {
    56  	re, ok := w.(*ReEncoder)
    57  	if !ok {
    58  		re = NewReEncoder(w, ReEncoderConfig{
    59  			AllowMultipleValues: true,
    60  		})
    61  	}
    62  	return &Encoder{
    63  		w:      re,
    64  		isRoot: re.par.StackIsEmpty(),
    65  	}
    66  }
    67  
    68  // Encode encodes obj to JSON and writes that JSON to the Encoder's
    69  // output stream.
    70  //
    71  // See the [documentation for encoding/json.Marshal] for details about
    72  // the conversion Go values to JSON; Encode behaves identically to
    73  // that, with the exception that in addition to the json.Marshaler
    74  // interface it also checks for the Encodable interface.
    75  //
    76  // Unlike encoding/json.Encoder.Encode, lowmemjson.Encoder.Encode does
    77  // not buffer its output; if a encode-error is encountered, lowmemjson
    78  // may write partial output, whereas encodin/json would not have
    79  // written anything.
    80  //
    81  // [documentation for encoding/json.Marshal]: https://pkg.go.dev/encoding/json@go1.20#Marshal
    82  func (enc *Encoder) Encode(obj any) (err error) {
    83  	if enc.isRoot {
    84  		enc.w.par.Reset()
    85  	}
    86  	escaper := enc.w.esc
    87  	if escaper == nil {
    88  		escaper = EscapeDefault
    89  	}
    90  	if err := encode(enc.w, reflect.ValueOf(obj), escaper, enc.w.utf, false, 0, map[any]struct{}{}); err != nil {
    91  		if rwe, ok := err.(*ReEncodeWriteError); ok {
    92  			err = &EncodeWriteError{
    93  				Err:    rwe.Err,
    94  				Offset: rwe.Offset,
    95  			}
    96  		}
    97  		return err
    98  	}
    99  	if enc.isRoot {
   100  		return enc.w.Close()
   101  	}
   102  	return nil
   103  }
   104  
   105  func discardInt(_ int, err error) error {
   106  	return err
   107  }
   108  
   109  const startDetectingCyclesAfter = 1000
   110  
   111  func encode(w *ReEncoder, val reflect.Value, escaper BackslashEscaper, utf InvalidUTF8Mode, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) error {
   112  	if !val.IsValid() {
   113  		return discardInt(w.WriteString("null"))
   114  	}
   115  	switch {
   116  
   117  	case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(encodableType):
   118  		val = val.Addr()
   119  		fallthrough
   120  	case val.Type().Implements(encodableType):
   121  		if val.Kind() == reflect.Pointer && val.IsNil() {
   122  			return discardInt(w.WriteString("null"))
   123  		}
   124  		obj, ok := val.Interface().(Encodable)
   125  		if !ok {
   126  			return discardInt(w.WriteString("null"))
   127  		}
   128  		w.pushWriteBarrier()
   129  		if err := obj.EncodeJSON(w); err != nil {
   130  			return &EncodeMethodError{
   131  				Type:       val.Type(),
   132  				SourceFunc: "EncodeJSON",
   133  				Err:        err,
   134  			}
   135  		}
   136  		if err := w.Close(); err != nil {
   137  			return &EncodeMethodError{
   138  				Type:       val.Type(),
   139  				SourceFunc: "EncodeJSON",
   140  				Err:        err,
   141  			}
   142  		}
   143  		w.popWriteBarrier()
   144  
   145  	case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(jsonMarshalerType):
   146  		val = val.Addr()
   147  		fallthrough
   148  	case val.Type().Implements(jsonMarshalerType):
   149  		if val.Kind() == reflect.Pointer && val.IsNil() {
   150  			return discardInt(w.WriteString("null"))
   151  		}
   152  		obj, ok := val.Interface().(jsonMarshaler)
   153  		if !ok {
   154  			return discardInt(w.WriteString("null"))
   155  		}
   156  		dat, err := obj.MarshalJSON()
   157  		if err != nil {
   158  			return &EncodeMethodError{
   159  				Type:       val.Type(),
   160  				SourceFunc: "MarshalJSON",
   161  				Err:        err,
   162  			}
   163  		}
   164  		w.pushWriteBarrier()
   165  		if _, err := w.Write(dat); err != nil {
   166  			return &EncodeMethodError{
   167  				Type:       val.Type(),
   168  				SourceFunc: "MarshalJSON",
   169  				Err:        err,
   170  			}
   171  		}
   172  		if err := w.Close(); err != nil {
   173  			return &EncodeMethodError{
   174  				Type:       val.Type(),
   175  				SourceFunc: "MarshalJSON",
   176  				Err:        err,
   177  			}
   178  		}
   179  		w.popWriteBarrier()
   180  
   181  	case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(textMarshalerType):
   182  		val = val.Addr()
   183  		fallthrough
   184  	case val.Type().Implements(textMarshalerType):
   185  		if val.Kind() == reflect.Pointer && val.IsNil() {
   186  			return discardInt(w.WriteString("null"))
   187  		}
   188  		obj, ok := val.Interface().(encoding.TextMarshaler)
   189  		if !ok {
   190  			return discardInt(w.WriteString("null"))
   191  		}
   192  		text, err := obj.MarshalText()
   193  		if err != nil {
   194  			return &EncodeMethodError{
   195  				Type:       val.Type(),
   196  				SourceFunc: "MarshalText",
   197  				Err:        err,
   198  			}
   199  		}
   200  		if err := jsonstring.EncodeStringFromBytes(w, escaper, utf, val, text); err != nil {
   201  			return err
   202  		}
   203  	default:
   204  		switch val.Kind() {
   205  		case reflect.Bool:
   206  			if quote {
   207  				if err := w.WriteByte('"'); err != nil {
   208  					return err
   209  				}
   210  			}
   211  			if val.Bool() {
   212  				if _, err := w.WriteString("true"); err != nil {
   213  					return err
   214  				}
   215  			} else {
   216  				if _, err := w.WriteString("false"); err != nil {
   217  					return err
   218  				}
   219  			}
   220  			if quote {
   221  				if err := w.WriteByte('"'); err != nil {
   222  					return err
   223  				}
   224  			}
   225  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   226  			if quote {
   227  				if err := w.WriteByte('"'); err != nil {
   228  					return err
   229  				}
   230  			}
   231  			// MaxInt64  =  9223372036854775807
   232  			// MinInt64  =  -9223372036854775808
   233  			//              0        1         2
   234  			//              12345678901234567890
   235  			var buf [20]byte
   236  			if _, err := w.Write(strconv.AppendInt(buf[:0], val.Int(), 10)); err != nil {
   237  				return err
   238  			}
   239  			if quote {
   240  				if err := w.WriteByte('"'); err != nil {
   241  					return err
   242  				}
   243  			}
   244  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   245  			if quote {
   246  				if err := w.WriteByte('"'); err != nil {
   247  					return err
   248  				}
   249  			}
   250  			// MaxUint64 = 18446744073709551615
   251  			//             0        1         2
   252  			//             12345678901234567890
   253  			var buf [20]byte
   254  			if _, err := w.Write(strconv.AppendUint(buf[:0], val.Uint(), 10)); err != nil {
   255  				return err
   256  			}
   257  			if quote {
   258  				if err := w.WriteByte('"'); err != nil {
   259  					return err
   260  				}
   261  			}
   262  		case reflect.Float32:
   263  			if quote {
   264  				if err := w.WriteByte('"'); err != nil {
   265  					return err
   266  				}
   267  			}
   268  			if err := encodeFloat(w, 32, val); err != nil {
   269  				return err
   270  			}
   271  			if quote {
   272  				if err := w.WriteByte('"'); err != nil {
   273  					return err
   274  				}
   275  			}
   276  		case reflect.Float64:
   277  			if quote {
   278  				if err := w.WriteByte('"'); err != nil {
   279  					return err
   280  				}
   281  			}
   282  			if err := encodeFloat(w, 64, val); err != nil {
   283  				return err
   284  			}
   285  			if quote {
   286  				if err := w.WriteByte('"'); err != nil {
   287  					return err
   288  				}
   289  			}
   290  		case reflect.String:
   291  			if val.Type() == numberType {
   292  				numStr := val.String()
   293  				if numStr == "" {
   294  					numStr = "0"
   295  				}
   296  				if quote {
   297  					if err := w.WriteByte('"'); err != nil {
   298  						return err
   299  					}
   300  				}
   301  				if _, err := w.WriteString(numStr); err != nil {
   302  					return err
   303  				}
   304  				if quote {
   305  					if err := w.WriteByte('"'); err != nil {
   306  						return err
   307  					}
   308  				}
   309  			} else {
   310  				if quote {
   311  					var buf bytes.Buffer
   312  					if err := jsonstring.EncodeStringFromString(&buf, escaper, utf, val, val.String()); err != nil {
   313  						return err
   314  					}
   315  					if err := jsonstring.EncodeStringFromBytes(w, escaper, utf, val, buf.Bytes()); err != nil {
   316  						return err
   317  					}
   318  				} else {
   319  					if err := jsonstring.EncodeStringFromString(w, escaper, utf, val, val.String()); err != nil {
   320  						return err
   321  					}
   322  				}
   323  			}
   324  		case reflect.Interface:
   325  			if val.IsNil() {
   326  				if _, err := w.WriteString("null"); err != nil {
   327  					return err
   328  				}
   329  			} else {
   330  				if err := encode(w, val.Elem(), escaper, utf, quote, cycleDepth, cycleSeen); err != nil {
   331  					return err
   332  				}
   333  			}
   334  		case reflect.Struct:
   335  			if err := w.WriteByte('{'); err != nil {
   336  				return err
   337  			}
   338  			empty := true
   339  			for _, field := range jsonstruct.IndexStruct(val.Type()).ByPos {
   340  				fVal, err := val.FieldByIndexErr(field.Path)
   341  				if err != nil {
   342  					continue
   343  				}
   344  				if field.OmitEmpty && isEmptyValue(fVal) {
   345  					continue
   346  				}
   347  				if !empty {
   348  					if err := w.WriteByte(','); err != nil {
   349  						return err
   350  					}
   351  				}
   352  				empty = false
   353  				if err := jsonstring.EncodeStringFromString(w, escaper, utf, val, field.Name); err != nil {
   354  					return err
   355  				}
   356  				if err := w.WriteByte(':'); err != nil {
   357  					return err
   358  				}
   359  				if err := encode(w, fVal, escaper, utf, field.Quote, cycleDepth, cycleSeen); err != nil {
   360  					return err
   361  				}
   362  			}
   363  			if err := w.WriteByte('}'); err != nil {
   364  				return err
   365  			}
   366  		case reflect.Map:
   367  			if val.IsNil() {
   368  				return discardInt(w.WriteString("null"))
   369  			}
   370  			if val.Len() == 0 {
   371  				return discardInt(w.WriteString("{}"))
   372  			}
   373  			if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
   374  				ptr := val.UnsafePointer()
   375  				if _, seen := cycleSeen[ptr]; seen {
   376  					return &EncodeValueError{
   377  						Value: val,
   378  						Str:   fmt.Sprintf("encountered a cycle via %s", val.Type()),
   379  					}
   380  				}
   381  				cycleSeen[ptr] = struct{}{}
   382  				defer delete(cycleSeen, ptr)
   383  			}
   384  			if err := w.WriteByte('{'); err != nil {
   385  				return err
   386  			}
   387  
   388  			var kBuf strings.Builder
   389  			kEnc := NewReEncoder(&kBuf, ReEncoderConfig{
   390  				AllowMultipleValues: true,
   391  
   392  				Compact: true,
   393  
   394  				BackslashEscape: escaper,
   395  				InvalidUTF8:     utf,
   396  			})
   397  
   398  			type kv struct {
   399  				KStr string
   400  				K    reflect.Value
   401  				V    reflect.Value
   402  			}
   403  			kvs := make([]kv, val.Len())
   404  			iter := val.MapRange()
   405  			for i := 0; iter.Next(); i++ {
   406  				if err := encode(kEnc, iter.Key(), escaper, utf, false, cycleDepth, cycleSeen); err != nil {
   407  					return err
   408  				}
   409  				if err := kEnc.Close(); err != nil {
   410  					return err
   411  				}
   412  				kStr := strings.Trim(kBuf.String(), "\n")
   413  				kBuf.Reset()
   414  				if kStr == "null" {
   415  					kStr = ""
   416  				}
   417  
   418  				// TODO(lukeshu): Have kEnc look at the first byte, and feed directly to a decoder,
   419  				// instead of needing to buffer the whole thing twice.
   420  				if strings.HasPrefix(kStr, `"`) {
   421  					if err := DecodeString(strings.NewReader(kStr), &kBuf); err != nil {
   422  						return err
   423  					}
   424  					kStr = kBuf.String()
   425  					kBuf.Reset()
   426  				}
   427  				kvs[i].KStr = kStr
   428  				kvs[i].K = iter.Key()
   429  				kvs[i].V = iter.Value()
   430  			}
   431  			sort.Slice(kvs, func(i, j int) bool {
   432  				return kvs[i].KStr < kvs[j].KStr
   433  			})
   434  
   435  			for i, kv := range kvs {
   436  				if i > 0 {
   437  					if err := w.WriteByte(','); err != nil {
   438  						return err
   439  					}
   440  				}
   441  				if err := jsonstring.EncodeStringFromString(w, escaper, utf, kv.K, kv.KStr); err != nil {
   442  					return err
   443  				}
   444  				if err := w.WriteByte(':'); err != nil {
   445  					return err
   446  				}
   447  				if err := encode(w, kv.V, escaper, utf, false, cycleDepth, cycleSeen); err != nil {
   448  					return err
   449  				}
   450  			}
   451  			if err := w.WriteByte('}'); err != nil {
   452  				return err
   453  			}
   454  		case reflect.Slice:
   455  			switch {
   456  			case val.IsNil():
   457  				if _, err := w.WriteString("null"); err != nil {
   458  					return err
   459  				}
   460  			case val.Type().Elem().Kind() == reflect.Uint8 && !(false ||
   461  				val.Type().Elem().Implements(encodableType) ||
   462  				reflect.PointerTo(val.Type().Elem()).Implements(encodableType) ||
   463  				val.Type().Elem().Implements(jsonMarshalerType) ||
   464  				reflect.PointerTo(val.Type().Elem()).Implements(jsonMarshalerType) ||
   465  				val.Type().Elem().Implements(textMarshalerType) ||
   466  				reflect.PointerTo(val.Type().Elem()).Implements(textMarshalerType)):
   467  				if err := w.WriteByte('"'); err != nil {
   468  					return err
   469  				}
   470  				enc := base64.NewEncoder(base64.StdEncoding, w)
   471  				if val.CanConvert(byteSliceType) {
   472  					if _, err := enc.Write(val.Convert(byteSliceType).Interface().([]byte)); err != nil {
   473  						return err
   474  					}
   475  				} else {
   476  					// TODO: Surely there's a better way.
   477  					for i, n := 0, val.Len(); i < n; i++ {
   478  						var buf [1]byte
   479  						buf[0] = val.Index(i).Convert(byteType).Interface().(byte)
   480  						if _, err := enc.Write(buf[:]); err != nil {
   481  							return err
   482  						}
   483  					}
   484  				}
   485  				if err := enc.Close(); err != nil {
   486  					return err
   487  				}
   488  				if err := w.WriteByte('"'); err != nil {
   489  					return err
   490  				}
   491  			default:
   492  				if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
   493  					// For slices, val.UnsafePointer() doesn't return a pointer to the slice header
   494  					// or anything like that, it returns a pointer *to the first element in the
   495  					// slice*.  That means that the pointer isn't enough to uniquely identify the
   496  					// slice!  So we pair the pointer with the length of the slice, which is
   497  					// sufficient.
   498  					ptr := struct {
   499  						ptr unsafe.Pointer
   500  						len int
   501  					}{val.UnsafePointer(), val.Len()}
   502  					if _, seen := cycleSeen[ptr]; seen {
   503  						return &EncodeValueError{
   504  							Value: val,
   505  							Str:   fmt.Sprintf("encountered a cycle via %s", val.Type()),
   506  						}
   507  					}
   508  					cycleSeen[ptr] = struct{}{}
   509  					defer delete(cycleSeen, ptr)
   510  				}
   511  				if err := encodeArray(w, val, escaper, utf, cycleDepth, cycleSeen); err != nil {
   512  					return err
   513  				}
   514  			}
   515  		case reflect.Array:
   516  			if err := encodeArray(w, val, escaper, utf, cycleDepth, cycleSeen); err != nil {
   517  				return err
   518  			}
   519  		case reflect.Pointer:
   520  			if val.IsNil() {
   521  				if _, err := w.WriteString("null"); err != nil {
   522  					return err
   523  				}
   524  			} else {
   525  				if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
   526  					ptr := val.UnsafePointer()
   527  					if _, seen := cycleSeen[ptr]; seen {
   528  						return &EncodeValueError{
   529  							Value: val,
   530  							Str:   fmt.Sprintf("encountered a cycle via %s", val.Type()),
   531  						}
   532  					}
   533  					cycleSeen[ptr] = struct{}{}
   534  					defer delete(cycleSeen, ptr)
   535  				}
   536  				if err := encode(w, val.Elem(), escaper, utf, quote, cycleDepth, cycleSeen); err != nil {
   537  					return err
   538  				}
   539  			}
   540  		default:
   541  			return &EncodeTypeError{
   542  				Type: val.Type(),
   543  			}
   544  		}
   545  	}
   546  	return nil
   547  }
   548  
   549  func encodeArray(w *ReEncoder, val reflect.Value, escaper BackslashEscaper, utf InvalidUTF8Mode, cycleDepth uint, cycleSeen map[any]struct{}) error {
   550  	if err := w.WriteByte('['); err != nil {
   551  		return err
   552  	}
   553  	n := val.Len()
   554  	for i := 0; i < n; i++ {
   555  		if i > 0 {
   556  			if err := w.WriteByte(','); err != nil {
   557  				return err
   558  			}
   559  		}
   560  		if err := encode(w, val.Index(i), escaper, utf, false, cycleDepth, cycleSeen); err != nil {
   561  			return err
   562  		}
   563  	}
   564  	if err := w.WriteByte(']'); err != nil {
   565  		return err
   566  	}
   567  	return nil
   568  }