github.com/dolthub/go-mysql-server@v0.18.0/sql/types/json_encode.go (about)

     1  package types
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io"
     7  	"reflect"
     8  	"sort"
     9  	"strconv"
    10  	"time"
    11  )
    12  
    13  var isEscaped = [256]bool{}
    14  var escapeSeq = [256][]byte{}
    15  
    16  func init() {
    17  	isEscaped[uint8('\b')] = true
    18  	isEscaped[uint8('\f')] = true
    19  	isEscaped[uint8('\n')] = true
    20  	isEscaped[uint8('\r')] = true
    21  	isEscaped[uint8('\t')] = true
    22  	isEscaped[uint8('"')] = true
    23  	isEscaped[uint8('\\')] = true
    24  
    25  	escapeSeq[uint8('\b')] = []byte("\\b")
    26  	escapeSeq[uint8('\f')] = []byte("\\f")
    27  	escapeSeq[uint8('\n')] = []byte("\\n")
    28  	escapeSeq[uint8('\r')] = []byte("\\r")
    29  	escapeSeq[uint8('\t')] = []byte("\\t")
    30  	escapeSeq[uint8('"')] = []byte("\\\"")
    31  	escapeSeq[uint8('\\')] = []byte("\\\\")
    32  }
    33  
    34  type NoCopyBuilder struct {
    35  	buffers   [][]byte
    36  	curr      int
    37  	lastAlloc int64
    38  	totalSize int64
    39  }
    40  
    41  func NewNoCopyBuilder(initialAlloc int64) *NoCopyBuilder {
    42  	return &NoCopyBuilder{
    43  		buffers:   [][]byte{make([]byte, 0, initialAlloc)},
    44  		lastAlloc: initialAlloc,
    45  	}
    46  }
    47  
    48  func (b *NoCopyBuilder) Write(p []byte) (int, error) {
    49  	currBuff := b.buffers[b.curr]
    50  
    51  	toWrite := len(p)
    52  	sourcePos := 0
    53  	destPos := len(currBuff)
    54  	space := cap(currBuff) - destPos
    55  
    56  	if space > 0 {
    57  		firstWrite := toWrite
    58  		if firstWrite > space {
    59  			firstWrite = space
    60  		}
    61  
    62  		currBuff = currBuff[:destPos+firstWrite]
    63  		n := copy(currBuff[destPos:], p[sourcePos:firstWrite])
    64  		b.buffers[b.curr] = currBuff
    65  
    66  		if n != firstWrite {
    67  			return -1, fmt.Errorf("failed to copy %d bytes to buffer", firstWrite)
    68  		}
    69  
    70  		toWrite -= firstWrite
    71  		sourcePos += firstWrite
    72  	}
    73  
    74  	if toWrite > 0 {
    75  		toAlloc := b.lastAlloc * 2
    76  		if toAlloc < int64(toWrite) {
    77  			toAlloc = int64(toWrite * 2)
    78  		}
    79  
    80  		newBuff := make([]byte, toWrite, toAlloc)
    81  		b.buffers = append(b.buffers, newBuff)
    82  		b.curr++
    83  		b.lastAlloc = toAlloc
    84  
    85  		n := copy(newBuff, p[sourcePos:])
    86  		if n != toWrite {
    87  			return -1, fmt.Errorf("failed to copy %d bytes to buffer", toWrite)
    88  		}
    89  	}
    90  
    91  	b.totalSize += int64(len(p))
    92  	return len(p), nil
    93  }
    94  
    95  func (b *NoCopyBuilder) Bytes() []byte {
    96  	if len(b.buffers) == 1 {
    97  		return b.buffers[0]
    98  	}
    99  
   100  	res := make([]byte, b.totalSize)
   101  	pos := 0
   102  	for _, buff := range b.buffers {
   103  		n := copy(res[pos:], buff)
   104  		if n != len(buff) {
   105  			panic(fmt.Errorf("failed to copy %d bytes to buffer", len(buff)))
   106  		}
   107  		pos += len(buff)
   108  	}
   109  
   110  	return res
   111  }
   112  
   113  func (b *NoCopyBuilder) String() string {
   114  	return string(b.Bytes())
   115  }
   116  
   117  func WriteStrings(wr io.Writer, strs ...string) (int, error) {
   118  	var totalN int
   119  	for _, str := range strs {
   120  		n, err := wr.Write([]byte(str))
   121  		if err != nil {
   122  			return totalN, err
   123  		}
   124  
   125  		totalN += n
   126  	}
   127  
   128  	return totalN, nil
   129  }
   130  
   131  // marshalToMySqlString is a helper function to marshal a JSONDocument to a string that is
   132  // compatible with MySQL's JSON output, including spaces.
   133  func marshalToMySqlString(val interface{}) (string, error) {
   134  	b := NewNoCopyBuilder(1024)
   135  	err := writeMarshalledValue(b, val)
   136  	if err != nil {
   137  		return "", err
   138  	}
   139  
   140  	return b.String(), nil
   141  }
   142  
   143  func sortKeys[T any](m map[string]T) []string {
   144  	var keys []string
   145  	for k := range m {
   146  		keys = append(keys, k)
   147  	}
   148  
   149  	sort.Slice(keys, func(i, j int) bool {
   150  		if len(keys[i]) != len(keys[j]) {
   151  			return len(keys[i]) < len(keys[j])
   152  		}
   153  		return keys[i] < keys[j]
   154  	})
   155  	return keys
   156  }
   157  
   158  func writeMarshalledValue(writer io.Writer, val interface{}) error {
   159  	switch val := val.(type) {
   160  	case []interface{}:
   161  		writer.Write([]byte{'['})
   162  		for i, v := range val {
   163  			err := writeMarshalledValue(writer, v)
   164  			if err != nil {
   165  				return err
   166  			}
   167  
   168  			if i != len(val)-1 {
   169  				writer.Write([]byte{',', ' '})
   170  			}
   171  		}
   172  		writer.Write([]byte{']'})
   173  		return nil
   174  
   175  	case map[string]string:
   176  		keys := sortKeys(val)
   177  
   178  		writer.Write([]byte{'{'})
   179  		for i, k := range keys {
   180  			writer.Write([]byte{'"'})
   181  			writer.Write([]byte(k))
   182  			writer.Write([]byte(`": "`))
   183  			writer.Write([]byte(val[k]))
   184  			writer.Write([]byte{'"'})
   185  
   186  			if i != len(keys)-1 {
   187  				writer.Write([]byte{',', ' '})
   188  			}
   189  		}
   190  
   191  		writer.Write([]byte{'}'})
   192  		return nil
   193  
   194  	case map[string]interface{}:
   195  		keys := sortKeys(val)
   196  
   197  		writer.Write([]byte{'{'})
   198  		for i, k := range keys {
   199  			writer.Write([]byte{'"'})
   200  			writer.Write([]byte(k))
   201  			writer.Write([]byte(`": `))
   202  			err := writeMarshalledValue(writer, val[k])
   203  			if err != nil {
   204  				return err
   205  			}
   206  
   207  			if i != len(keys)-1 {
   208  				writer.Write([]byte{',', ' '})
   209  			}
   210  		}
   211  
   212  		writer.Write([]byte{'}'})
   213  		return nil
   214  
   215  	case string:
   216  		writer.Write([]byte{'"'})
   217  		// iterate over each rune in the string to escape any special characters
   218  		start := 0
   219  		for i, r := range val {
   220  			if r > '\\' {
   221  				continue
   222  			}
   223  
   224  			b := uint8(r)
   225  			if isEscaped[b] {
   226  				if start != i {
   227  					writer.Write([]byte(val[start:i]))
   228  				}
   229  
   230  				writer.Write(escapeSeq[b])
   231  				start = i + 1
   232  			}
   233  		}
   234  
   235  		if start != len(val) {
   236  			writer.Write([]byte(val[start:]))
   237  		}
   238  
   239  		writer.Write([]byte{'"'})
   240  		return nil
   241  
   242  	case float64:
   243  		// JSON doesn't distinguish between integers and floats, so we need to check if the float is an integer
   244  		if val == float64(int64(val)) {
   245  			_, err := writer.Write([]byte(strconv.FormatInt(int64(val), 10)))
   246  			return err
   247  		}
   248  
   249  		_, err := writer.Write([]byte(strconv.FormatFloat(val, 'f', -1, 64)))
   250  		return err
   251  
   252  	case float32:
   253  		// JSON doesn't distinguish between integers and floats, so we need to check if the float is an integer
   254  		if val == float32(int32(val)) {
   255  			_, err := writer.Write([]byte(strconv.FormatInt(int64(val), 10)))
   256  			return err
   257  		}
   258  
   259  		_, err := writer.Write([]byte(strconv.FormatFloat(float64(val), 'f', -1, 32)))
   260  		return err
   261  
   262  	case int64:
   263  		_, err := writer.Write([]byte(strconv.FormatInt(val, 10)))
   264  		return err
   265  
   266  	case int32:
   267  		_, err := writer.Write([]byte(strconv.FormatInt(int64(val), 10)))
   268  		return err
   269  
   270  	case int16:
   271  		_, err := writer.Write([]byte(strconv.FormatInt(int64(val), 10)))
   272  		return err
   273  
   274  	case int8:
   275  		_, err := writer.Write([]byte(strconv.FormatInt(int64(val), 10)))
   276  		return err
   277  
   278  	case int:
   279  		_, err := writer.Write([]byte(strconv.FormatInt(int64(val), 10)))
   280  		return err
   281  
   282  	case uint64:
   283  		_, err := writer.Write([]byte(strconv.FormatUint(val, 10)))
   284  		return err
   285  
   286  	case uint32:
   287  		_, err := writer.Write([]byte(strconv.FormatUint(uint64(val), 10)))
   288  		return err
   289  
   290  	case uint16:
   291  		_, err := writer.Write([]byte(strconv.FormatUint(uint64(val), 10)))
   292  		return err
   293  
   294  	case uint8:
   295  		_, err := writer.Write([]byte(strconv.FormatUint(uint64(val), 10)))
   296  		return err
   297  
   298  	case bool:
   299  		if val {
   300  			writer.Write([]byte("true"))
   301  		} else {
   302  			writer.Write([]byte("false"))
   303  		}
   304  
   305  		return nil
   306  
   307  	case nil:
   308  		writer.Write([]byte("null"))
   309  		return nil
   310  
   311  	case time.Time:
   312  		writer.Write([]byte{'"'})
   313  		writer.Write([]byte(val.Format(time.RFC3339)))
   314  		writer.Write([]byte{'"'})
   315  		return nil
   316  	case json.Marshaler:
   317  		bytes, err := val.MarshalJSON()
   318  		if err != nil {
   319  			return err
   320  		}
   321  		writer.Write(bytes)
   322  		return nil
   323  	default:
   324  		r := reflect.ValueOf(val)
   325  		switch r.Kind() {
   326  		case reflect.Slice, reflect.Array:
   327  			writer.Write([]byte{'['})
   328  			for i := 0; i < r.Len(); i++ {
   329  				err := writeMarshalledValue(writer, r.Index(i).Interface())
   330  				if err != nil {
   331  					return err
   332  				}
   333  
   334  				if i != r.Len()-1 {
   335  					writer.Write([]byte{',', ' '})
   336  				}
   337  			}
   338  			writer.Write([]byte{']'})
   339  			return nil
   340  
   341  		case reflect.Map:
   342  			interfMap := make(map[string]interface{})
   343  			for _, k := range r.MapKeys() {
   344  				interfMap[k.String()] = r.MapIndex(k).Interface()
   345  			}
   346  
   347  			return writeMarshalledValue(writer, interfMap)
   348  
   349  		default:
   350  			return fmt.Errorf("unsupported type: %T", val)
   351  		}
   352  	}
   353  }