git.lukeshu.com/go/lowmemjson@v0.3.9-0.20230723050957-72f6d13f6fb2/compat/json/compat.go (about)

     1  // Copyright (C) 2022-2023  Luke Shumaker <lukeshu@lukeshu.com>
     2  //
     3  // SPDX-License-Identifier: GPL-2.0-or-later
     4  
     5  // Package json is a wrapper around lowmemjson that is a (mostly)
     6  // drop-in replacement for the standard library's encoding/json.
     7  package json
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"encoding/json"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"strconv"
    17  	"unicode/utf8"
    18  
    19  	"git.lukeshu.com/go/lowmemjson"
    20  	"git.lukeshu.com/go/lowmemjson/internal/jsonstring"
    21  )
    22  
    23  //nolint:stylecheck // ST1021 False positive; these aren't comments on individual types.
    24  type (
    25  	Number      = json.Number
    26  	RawMessage  = json.RawMessage
    27  	Marshaler   = json.Marshaler
    28  	Unmarshaler = json.Unmarshaler
    29  
    30  	// low-level decode errors.
    31  	UnmarshalTypeError = json.UnmarshalTypeError
    32  	// SyntaxError        = json.SyntaxError // Duplicated to access a private field.
    33  
    34  	// high-level decode errors.
    35  	InvalidUnmarshalError = json.InvalidUnmarshalError
    36  
    37  	// marshal errors.
    38  	UnsupportedTypeError  = json.UnsupportedTypeError
    39  	UnsupportedValueError = json.UnsupportedValueError
    40  	// MarshalerError        = json.MarshalerError // Duplicated to access a private field.
    41  )
    42  
    43  // Error conversion //////////////////////////////////////////////////
    44  
    45  func convertError(err error, isUnmarshal bool) error {
    46  	switch err := err.(type) {
    47  	case nil:
    48  		return nil
    49  	case *lowmemjson.DecodeArgumentError:
    50  		return err
    51  	case *lowmemjson.DecodeError:
    52  		switch suberr := err.Err.(type) {
    53  		case *lowmemjson.DecodeReadError:
    54  			return err
    55  		case *lowmemjson.DecodeSyntaxError:
    56  			if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
    57  				if isUnmarshal {
    58  					return &SyntaxError{
    59  						msg:    "unexpected end of JSON input",
    60  						Offset: suberr.Offset,
    61  					}
    62  				}
    63  				return suberr.Err
    64  			}
    65  			return &SyntaxError{
    66  				msg:    suberr.Err.Error(),
    67  				Offset: suberr.Offset + 1,
    68  			}
    69  		case *lowmemjson.DecodeTypeError:
    70  			switch subsuberr := suberr.Err.(type) {
    71  			case *UnmarshalTypeError:
    72  				// Populate the .Struct and .Field members.
    73  				subsuberr.Struct = err.FieldParent
    74  				subsuberr.Field = err.FieldName
    75  				return subsuberr
    76  			default:
    77  				switch {
    78  				case errors.Is(err, lowmemjson.ErrDecodeNonEmptyInterface),
    79  					errors.Is(err, strconv.ErrSyntax),
    80  					errors.Is(err, strconv.ErrRange):
    81  					return &UnmarshalTypeError{
    82  						Value:  suberr.JSONType,
    83  						Type:   suberr.GoType,
    84  						Offset: suberr.Offset,
    85  						Struct: err.FieldParent,
    86  						Field:  err.FieldName,
    87  					}
    88  				default:
    89  					return subsuberr
    90  				}
    91  			case nil, *lowmemjson.DecodeArgumentError:
    92  				return &UnmarshalTypeError{
    93  					Value:  suberr.JSONType,
    94  					Type:   suberr.GoType,
    95  					Offset: suberr.Offset,
    96  					Struct: err.FieldParent,
    97  					Field:  err.FieldName,
    98  				}
    99  			}
   100  		default:
   101  			panic(fmt.Errorf("should not happen: unexpected lowmemjson.DecodeError sub-type: %T: %w", suberr, err))
   102  		}
   103  	case *lowmemjson.EncodeWriteError:
   104  		return err
   105  	case *lowmemjson.EncodeTypeError:
   106  		return err
   107  	case *lowmemjson.EncodeValueError:
   108  		return err
   109  	case *lowmemjson.EncodeMethodError:
   110  		return &MarshalerError{
   111  			Type:       err.Type,
   112  			Err:        err.Err,
   113  			sourceFunc: err.SourceFunc,
   114  		}
   115  	case *lowmemjson.ReEncodeWriteError:
   116  		return err
   117  	case *lowmemjson.ReEncodeSyntaxError:
   118  		ret := &SyntaxError{
   119  			msg:    err.Err.Error(),
   120  			Offset: err.Offset + 1,
   121  		}
   122  		if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
   123  			ret.msg = "unexpected end of JSON input"
   124  		}
   125  		return ret
   126  	default:
   127  		panic(fmt.Errorf("should not happen: unexpected lowmemjson error type: %T: %w", err, err))
   128  	}
   129  }
   130  
   131  // Encode wrappers ///////////////////////////////////////////////////
   132  
   133  func marshal(v any, cfg lowmemjson.ReEncoderConfig) ([]byte, error) {
   134  	var buf bytes.Buffer
   135  	if err := convertError(lowmemjson.NewEncoder(lowmemjson.NewReEncoder(&buf, cfg)).Encode(v), false); err != nil {
   136  		return nil, err
   137  	}
   138  	return buf.Bytes(), nil
   139  }
   140  
   141  func MarshalIndent(v any, prefix, indent string) ([]byte, error) {
   142  	return marshal(v, lowmemjson.ReEncoderConfig{
   143  		Indent: indent,
   144  		Prefix: prefix,
   145  	})
   146  }
   147  
   148  func Marshal(v any) ([]byte, error) {
   149  	return marshal(v, lowmemjson.ReEncoderConfig{
   150  		Compact: true,
   151  	})
   152  }
   153  
   154  type Encoder struct {
   155  	out io.Writer
   156  	buf bytes.Buffer
   157  
   158  	cfg lowmemjson.ReEncoderConfig
   159  
   160  	encoder   *lowmemjson.Encoder
   161  	formatter *lowmemjson.ReEncoder
   162  }
   163  
   164  func NewEncoder(w io.Writer) *Encoder {
   165  	ret := &Encoder{
   166  		out: w,
   167  
   168  		cfg: lowmemjson.ReEncoderConfig{
   169  			AllowMultipleValues: true,
   170  
   171  			Compact:               true,
   172  			ForceTrailingNewlines: true,
   173  		},
   174  	}
   175  	ret.refreshConfig()
   176  	return ret
   177  }
   178  
   179  func (enc *Encoder) refreshConfig() {
   180  	enc.formatter = lowmemjson.NewReEncoder(&enc.buf, enc.cfg)
   181  	enc.encoder = lowmemjson.NewEncoder(enc.formatter)
   182  }
   183  
   184  func (enc *Encoder) Encode(v any) error {
   185  	if err := convertError(enc.encoder.Encode(v), false); err != nil {
   186  		enc.buf.Reset()
   187  		return err
   188  	}
   189  	if _, err := enc.buf.WriteTo(enc.out); err != nil {
   190  		return err
   191  	}
   192  	return nil
   193  }
   194  
   195  func (enc *Encoder) SetEscapeHTML(on bool) {
   196  	if on {
   197  		enc.cfg.BackslashEscape = lowmemjson.EscapeDefault
   198  	} else {
   199  		enc.cfg.BackslashEscape = lowmemjson.EscapeDefaultNonHTMLSafe
   200  	}
   201  	enc.refreshConfig()
   202  }
   203  
   204  func (enc *Encoder) SetIndent(prefix, indent string) {
   205  	enc.cfg.Compact = prefix == "" && indent == ""
   206  	enc.cfg.Prefix = prefix
   207  	enc.cfg.Indent = indent
   208  	enc.refreshConfig()
   209  }
   210  
   211  // ReEncode wrappers /////////////////////////////////////////////////
   212  
   213  func HTMLEscape(dst *bytes.Buffer, src []byte) {
   214  	for n := 0; n < len(src); {
   215  		c, size := utf8.DecodeRune(src[n:])
   216  		if c == utf8.RuneError && size == 1 {
   217  			dst.WriteByte(src[n])
   218  		} else {
   219  			mode := lowmemjson.EscapeHTMLSafe(c, lowmemjson.BackslashEscapeNone)
   220  			switch mode {
   221  			case lowmemjson.BackslashEscapeNone:
   222  				dst.WriteRune(c)
   223  			case lowmemjson.BackslashEscapeUnicode:
   224  				_ = jsonstring.WriteStringUnicodeEscape(dst, c, mode)
   225  			default:
   226  				panic(fmt.Errorf("lowmemjson.EscapeHTMLSafe returned an unexpected escape mode=%d", mode))
   227  			}
   228  		}
   229  		n += size
   230  	}
   231  }
   232  
   233  func reencode(dst io.Writer, src []byte, cfg lowmemjson.ReEncoderConfig) error {
   234  	formatter := lowmemjson.NewReEncoder(dst, cfg)
   235  	_, err := formatter.Write(src)
   236  	if err == nil {
   237  		err = formatter.Close()
   238  	}
   239  	return convertError(err, false)
   240  }
   241  
   242  func Compact(dst *bytes.Buffer, src []byte) error {
   243  	start := dst.Len()
   244  	err := reencode(dst, src, lowmemjson.ReEncoderConfig{
   245  		Compact:         true,
   246  		InvalidUTF8:     lowmemjson.InvalidUTF8Preserve,
   247  		BackslashEscape: lowmemjson.EscapePreserve,
   248  	})
   249  	if err != nil {
   250  		dst.Truncate(start)
   251  	}
   252  	return err
   253  }
   254  
   255  func isSpace(c byte) bool {
   256  	switch c {
   257  	case 0x0020, 0x000A, 0x000D, 0x0009:
   258  		return true
   259  	default:
   260  		return false
   261  	}
   262  }
   263  
   264  func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
   265  	start := dst.Len()
   266  	err := reencode(dst, src, lowmemjson.ReEncoderConfig{
   267  		Indent:          indent,
   268  		Prefix:          prefix,
   269  		InvalidUTF8:     lowmemjson.InvalidUTF8Preserve,
   270  		BackslashEscape: lowmemjson.EscapePreserve,
   271  	})
   272  	if err != nil {
   273  		dst.Truncate(start)
   274  		return err
   275  	}
   276  
   277  	// Preserve trailing whitespace.
   278  	lastNonWS := len(src) - 1
   279  	for ; lastNonWS >= 0 && isSpace(src[lastNonWS]); lastNonWS-- {
   280  	}
   281  	if _, err := dst.Write(src[lastNonWS+1:]); err != nil {
   282  		return err
   283  	}
   284  
   285  	return nil
   286  }
   287  
   288  func Valid(data []byte) bool {
   289  	formatter := lowmemjson.NewReEncoder(io.Discard, lowmemjson.ReEncoderConfig{
   290  		Compact:     true,
   291  		InvalidUTF8: lowmemjson.InvalidUTF8Error,
   292  	})
   293  	if _, err := formatter.Write(data); err != nil {
   294  		return false
   295  	}
   296  	if err := formatter.Close(); err != nil {
   297  		return false
   298  	}
   299  	return true
   300  }
   301  
   302  // Decode wrappers ///////////////////////////////////////////////////
   303  
   304  type decodeValidator struct{}
   305  
   306  func (*decodeValidator) DecodeJSON(r io.RuneScanner) error {
   307  	for {
   308  		if _, _, err := r.ReadRune(); err != nil {
   309  
   310  			if err == io.EOF {
   311  				return nil
   312  			}
   313  			return err
   314  		}
   315  	}
   316  }
   317  
   318  var _ lowmemjson.Decodable = (*decodeValidator)(nil)
   319  
   320  func Unmarshal(data []byte, ptr any) error {
   321  	if err := convertError(lowmemjson.NewDecoder(bytes.NewReader(data)).DecodeThenEOF(&decodeValidator{}), true); err != nil {
   322  		return err
   323  	}
   324  	if err := convertError(lowmemjson.NewDecoder(bytes.NewReader(data)).DecodeThenEOF(ptr), true); err != nil {
   325  		return err
   326  	}
   327  	return nil
   328  }
   329  
   330  type teeRuneScanner struct {
   331  	src interface {
   332  		io.RuneScanner
   333  		io.ByteScanner
   334  	}
   335  	dst      *bytes.Buffer
   336  	lastSize int
   337  }
   338  
   339  func (tee *teeRuneScanner) ReadRune() (r rune, size int, err error) {
   340  	r, size, err = tee.src.ReadRune()
   341  	if err == nil {
   342  		if r == utf8.RuneError && size == 1 {
   343  			_ = tee.src.UnreadRune()
   344  			b, _ := tee.src.ReadByte()
   345  			_ = tee.dst.WriteByte(b)
   346  		} else {
   347  			_, _ = tee.dst.WriteRune(r)
   348  		}
   349  	}
   350  	tee.lastSize = size
   351  	return
   352  }
   353  
   354  func (tee *teeRuneScanner) UnreadRune() error {
   355  	if tee.lastSize == 0 {
   356  		return lowmemjson.ErrInvalidUnreadRune
   357  	}
   358  	_ = tee.src.UnreadRune()
   359  	tee.dst.Truncate(tee.dst.Len() - tee.lastSize)
   360  	tee.lastSize = 0
   361  	return nil
   362  }
   363  
   364  func (tee *teeRuneScanner) ReadByte() (b byte, err error) {
   365  	b, err = tee.src.ReadByte()
   366  	if err == nil {
   367  		_ = tee.dst.WriteByte(b)
   368  		tee.lastSize = 1
   369  	}
   370  	return
   371  }
   372  
   373  func (tee *teeRuneScanner) UnreadByte() error {
   374  	if tee.lastSize != 1 {
   375  		return lowmemjson.ErrInvalidUnreadRune
   376  	}
   377  	_ = tee.src.UnreadByte()
   378  	tee.dst.Truncate(tee.dst.Len() - tee.lastSize)
   379  	tee.lastSize = 0
   380  	return nil
   381  }
   382  
   383  type Decoder struct {
   384  	validatorBuf *bufio.Reader
   385  	validator    *lowmemjson.Decoder
   386  
   387  	decoderBuf bytes.Buffer
   388  	*lowmemjson.Decoder
   389  }
   390  
   391  func NewDecoder(r io.Reader) *Decoder {
   392  	br, ok := r.(*bufio.Reader)
   393  	if !ok {
   394  		br = bufio.NewReader(r)
   395  	}
   396  	ret := &Decoder{
   397  		validatorBuf: br,
   398  	}
   399  	ret.validator = lowmemjson.NewDecoder(&teeRuneScanner{
   400  		src: ret.validatorBuf,
   401  		dst: &ret.decoderBuf,
   402  	})
   403  	ret.Decoder = lowmemjson.NewDecoder(&ret.decoderBuf)
   404  	return ret
   405  }
   406  
   407  func (dec *Decoder) Decode(ptr any) error {
   408  	if err := convertError(dec.validator.Decode(&decodeValidator{}), false); err != nil {
   409  		return err
   410  	}
   411  	if err := convertError(dec.Decoder.Decode(ptr), false); err != nil {
   412  		return err
   413  	}
   414  	return nil
   415  }
   416  
   417  func (dec *Decoder) Buffered() io.Reader {
   418  	dat, _ := dec.validatorBuf.Peek(dec.validatorBuf.Buffered())
   419  	return bytes.NewReader(dat)
   420  }
   421  
   422  // func (dec *Decoder) Token() (Token, error)