github.com/dolthub/go-mysql-server@v0.18.0/sql/types/json_encode.go (about) 1 package types 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "io" 7 "reflect" 8 "sort" 9 "strconv" 10 "time" 11 ) 12 13 var isEscaped = [256]bool{} 14 var escapeSeq = [256][]byte{} 15 16 func init() { 17 isEscaped[uint8('\b')] = true 18 isEscaped[uint8('\f')] = true 19 isEscaped[uint8('\n')] = true 20 isEscaped[uint8('\r')] = true 21 isEscaped[uint8('\t')] = true 22 isEscaped[uint8('"')] = true 23 isEscaped[uint8('\\')] = true 24 25 escapeSeq[uint8('\b')] = []byte("\\b") 26 escapeSeq[uint8('\f')] = []byte("\\f") 27 escapeSeq[uint8('\n')] = []byte("\\n") 28 escapeSeq[uint8('\r')] = []byte("\\r") 29 escapeSeq[uint8('\t')] = []byte("\\t") 30 escapeSeq[uint8('"')] = []byte("\\\"") 31 escapeSeq[uint8('\\')] = []byte("\\\\") 32 } 33 34 type NoCopyBuilder struct { 35 buffers [][]byte 36 curr int 37 lastAlloc int64 38 totalSize int64 39 } 40 41 func NewNoCopyBuilder(initialAlloc int64) *NoCopyBuilder { 42 return &NoCopyBuilder{ 43 buffers: [][]byte{make([]byte, 0, initialAlloc)}, 44 lastAlloc: initialAlloc, 45 } 46 } 47 48 func (b *NoCopyBuilder) Write(p []byte) (int, error) { 49 currBuff := b.buffers[b.curr] 50 51 toWrite := len(p) 52 sourcePos := 0 53 destPos := len(currBuff) 54 space := cap(currBuff) - destPos 55 56 if space > 0 { 57 firstWrite := toWrite 58 if firstWrite > space { 59 firstWrite = space 60 } 61 62 currBuff = currBuff[:destPos+firstWrite] 63 n := copy(currBuff[destPos:], p[sourcePos:firstWrite]) 64 b.buffers[b.curr] = currBuff 65 66 if n != firstWrite { 67 return -1, fmt.Errorf("failed to copy %d bytes to buffer", firstWrite) 68 } 69 70 toWrite -= firstWrite 71 sourcePos += firstWrite 72 } 73 74 if toWrite > 0 { 75 toAlloc := b.lastAlloc * 2 76 if toAlloc < int64(toWrite) { 77 toAlloc = int64(toWrite * 2) 78 } 79 80 newBuff := make([]byte, toWrite, toAlloc) 81 b.buffers = append(b.buffers, newBuff) 82 b.curr++ 83 b.lastAlloc = toAlloc 84 85 n := copy(newBuff, p[sourcePos:]) 86 if n != toWrite { 87 return -1, fmt.Errorf("failed to copy %d bytes to buffer", toWrite) 88 } 89 } 90 91 b.totalSize += int64(len(p)) 92 return len(p), nil 93 } 94 95 func (b *NoCopyBuilder) Bytes() []byte { 96 if len(b.buffers) == 1 { 97 return b.buffers[0] 98 } 99 100 res := make([]byte, b.totalSize) 101 pos := 0 102 for _, buff := range b.buffers { 103 n := copy(res[pos:], buff) 104 if n != len(buff) { 105 panic(fmt.Errorf("failed to copy %d bytes to buffer", len(buff))) 106 } 107 pos += len(buff) 108 } 109 110 return res 111 } 112 113 func (b *NoCopyBuilder) String() string { 114 return string(b.Bytes()) 115 } 116 117 func WriteStrings(wr io.Writer, strs ...string) (int, error) { 118 var totalN int 119 for _, str := range strs { 120 n, err := wr.Write([]byte(str)) 121 if err != nil { 122 return totalN, err 123 } 124 125 totalN += n 126 } 127 128 return totalN, nil 129 } 130 131 // marshalToMySqlString is a helper function to marshal a JSONDocument to a string that is 132 // compatible with MySQL's JSON output, including spaces. 133 func marshalToMySqlString(val interface{}) (string, error) { 134 b := NewNoCopyBuilder(1024) 135 err := writeMarshalledValue(b, val) 136 if err != nil { 137 return "", err 138 } 139 140 return b.String(), nil 141 } 142 143 func sortKeys[T any](m map[string]T) []string { 144 var keys []string 145 for k := range m { 146 keys = append(keys, k) 147 } 148 149 sort.Slice(keys, func(i, j int) bool { 150 if len(keys[i]) != len(keys[j]) { 151 return len(keys[i]) < len(keys[j]) 152 } 153 return keys[i] < keys[j] 154 }) 155 return keys 156 } 157 158 func writeMarshalledValue(writer io.Writer, val interface{}) error { 159 switch val := val.(type) { 160 case []interface{}: 161 writer.Write([]byte{'['}) 162 for i, v := range val { 163 err := writeMarshalledValue(writer, v) 164 if err != nil { 165 return err 166 } 167 168 if i != len(val)-1 { 169 writer.Write([]byte{',', ' '}) 170 } 171 } 172 writer.Write([]byte{']'}) 173 return nil 174 175 case map[string]string: 176 keys := sortKeys(val) 177 178 writer.Write([]byte{'{'}) 179 for i, k := range keys { 180 writer.Write([]byte{'"'}) 181 writer.Write([]byte(k)) 182 writer.Write([]byte(`": "`)) 183 writer.Write([]byte(val[k])) 184 writer.Write([]byte{'"'}) 185 186 if i != len(keys)-1 { 187 writer.Write([]byte{',', ' '}) 188 } 189 } 190 191 writer.Write([]byte{'}'}) 192 return nil 193 194 case map[string]interface{}: 195 keys := sortKeys(val) 196 197 writer.Write([]byte{'{'}) 198 for i, k := range keys { 199 writer.Write([]byte{'"'}) 200 writer.Write([]byte(k)) 201 writer.Write([]byte(`": `)) 202 err := writeMarshalledValue(writer, val[k]) 203 if err != nil { 204 return err 205 } 206 207 if i != len(keys)-1 { 208 writer.Write([]byte{',', ' '}) 209 } 210 } 211 212 writer.Write([]byte{'}'}) 213 return nil 214 215 case string: 216 writer.Write([]byte{'"'}) 217 // iterate over each rune in the string to escape any special characters 218 start := 0 219 for i, r := range val { 220 if r > '\\' { 221 continue 222 } 223 224 b := uint8(r) 225 if isEscaped[b] { 226 if start != i { 227 writer.Write([]byte(val[start:i])) 228 } 229 230 writer.Write(escapeSeq[b]) 231 start = i + 1 232 } 233 } 234 235 if start != len(val) { 236 writer.Write([]byte(val[start:])) 237 } 238 239 writer.Write([]byte{'"'}) 240 return nil 241 242 case float64: 243 // JSON doesn't distinguish between integers and floats, so we need to check if the float is an integer 244 if val == float64(int64(val)) { 245 _, err := writer.Write([]byte(strconv.FormatInt(int64(val), 10))) 246 return err 247 } 248 249 _, err := writer.Write([]byte(strconv.FormatFloat(val, 'f', -1, 64))) 250 return err 251 252 case float32: 253 // JSON doesn't distinguish between integers and floats, so we need to check if the float is an integer 254 if val == float32(int32(val)) { 255 _, err := writer.Write([]byte(strconv.FormatInt(int64(val), 10))) 256 return err 257 } 258 259 _, err := writer.Write([]byte(strconv.FormatFloat(float64(val), 'f', -1, 32))) 260 return err 261 262 case int64: 263 _, err := writer.Write([]byte(strconv.FormatInt(val, 10))) 264 return err 265 266 case int32: 267 _, err := writer.Write([]byte(strconv.FormatInt(int64(val), 10))) 268 return err 269 270 case int16: 271 _, err := writer.Write([]byte(strconv.FormatInt(int64(val), 10))) 272 return err 273 274 case int8: 275 _, err := writer.Write([]byte(strconv.FormatInt(int64(val), 10))) 276 return err 277 278 case int: 279 _, err := writer.Write([]byte(strconv.FormatInt(int64(val), 10))) 280 return err 281 282 case uint64: 283 _, err := writer.Write([]byte(strconv.FormatUint(val, 10))) 284 return err 285 286 case uint32: 287 _, err := writer.Write([]byte(strconv.FormatUint(uint64(val), 10))) 288 return err 289 290 case uint16: 291 _, err := writer.Write([]byte(strconv.FormatUint(uint64(val), 10))) 292 return err 293 294 case uint8: 295 _, err := writer.Write([]byte(strconv.FormatUint(uint64(val), 10))) 296 return err 297 298 case bool: 299 if val { 300 writer.Write([]byte("true")) 301 } else { 302 writer.Write([]byte("false")) 303 } 304 305 return nil 306 307 case nil: 308 writer.Write([]byte("null")) 309 return nil 310 311 case time.Time: 312 writer.Write([]byte{'"'}) 313 writer.Write([]byte(val.Format(time.RFC3339))) 314 writer.Write([]byte{'"'}) 315 return nil 316 case json.Marshaler: 317 bytes, err := val.MarshalJSON() 318 if err != nil { 319 return err 320 } 321 writer.Write(bytes) 322 return nil 323 default: 324 r := reflect.ValueOf(val) 325 switch r.Kind() { 326 case reflect.Slice, reflect.Array: 327 writer.Write([]byte{'['}) 328 for i := 0; i < r.Len(); i++ { 329 err := writeMarshalledValue(writer, r.Index(i).Interface()) 330 if err != nil { 331 return err 332 } 333 334 if i != r.Len()-1 { 335 writer.Write([]byte{',', ' '}) 336 } 337 } 338 writer.Write([]byte{']'}) 339 return nil 340 341 case reflect.Map: 342 interfMap := make(map[string]interface{}) 343 for _, k := range r.MapKeys() { 344 interfMap[k.String()] = r.MapIndex(k).Interface() 345 } 346 347 return writeMarshalledValue(writer, interfMap) 348 349 default: 350 return fmt.Errorf("unsupported type: %T", val) 351 } 352 } 353 }