github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/allegrosql/server/util.go (about)

     1  // Copyright 2020 The Go-MyALLEGROSQL-Driver Authors. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     5  // You can obtain one at http://mozilla.org/MPL/2.0/.
     6  
     7  // The MIT License (MIT)
     8  //
     9  // Copyright (c) 2020 wandoulabs
    10  // Copyright (c) 2020 siddontang
    11  //
    12  // Permission is hereby granted, free of charge, to any person obtaining a copy of
    13  // this software and associated documentation files (the "Software"), to deal in
    14  // the Software without restriction, including without limitation the rights to
    15  // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
    16  // the Software, and to permit persons to whom the Software is furnished to do so,
    17  // subject to the following conditions:
    18  //
    19  // The above copyright notice and this permission notice shall be included in all
    20  // copies or substantial portions of the Software.
    21  
    22  // Copyright 2020 WHTCORPS INC, Inc.
    23  //
    24  // Licensed under the Apache License, Version 2.0 (the "License");
    25  // you may not use this file except in compliance with the License.
    26  // You may obtain a copy of the License at
    27  //
    28  //     http://www.apache.org/licenses/LICENSE-2.0
    29  //
    30  // Unless required by applicable law or agreed to in writing, software
    31  // distributed under the License is distributed on an "AS IS" BASIS,
    32  // See the License for the specific language governing permissions and
    33  // limitations under the License.
    34  
    35  package server
    36  
    37  import (
    38  	"bytes"
    39  	"encoding/binary"
    40  	"io"
    41  	"math"
    42  	"net/http"
    43  	"strconv"
    44  	"time"
    45  
    46  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    47  	"github.com/whtcorpsinc/milevadb/config"
    48  	"github.com/whtcorpsinc/milevadb/types"
    49  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    50  	"github.com/whtcorpsinc/milevadb/soliton/replog"
    51  )
    52  
    53  func parseNullTermString(b []byte) (str []byte, remain []byte) {
    54  	off := bytes.IndexByte(b, 0)
    55  	if off == -1 {
    56  		return nil, b
    57  	}
    58  	return b[:off], b[off+1:]
    59  }
    60  
    61  func parseLengthEncodedInt(b []byte) (num uint64, isNull bool, n int) {
    62  	switch b[0] {
    63  	// 251: NULL
    64  	case 0xfb:
    65  		n = 1
    66  		isNull = true
    67  		return
    68  
    69  	// 252: value of following 2
    70  	case 0xfc:
    71  		num = uint64(b[1]) | uint64(b[2])<<8
    72  		n = 3
    73  		return
    74  
    75  	// 253: value of following 3
    76  	case 0xfd:
    77  		num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16
    78  		n = 4
    79  		return
    80  
    81  	// 254: value of following 8
    82  	case 0xfe:
    83  		num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
    84  			uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
    85  			uint64(b[7])<<48 | uint64(b[8])<<56
    86  		n = 9
    87  		return
    88  	}
    89  
    90  	// https://dev.allegrosql.com/doc/internals/en/integer.html#length-encoded-integer: If the first byte of a packet is a length-encoded integer and its byte value is 0xfe, you must check the length of the packet to verify that it has enough space for a 8-byte integer.
    91  	// TODO: 0xff is undefined
    92  
    93  	// 0-250: value of first byte
    94  	num = uint64(b[0])
    95  	n = 1
    96  	return
    97  }
    98  
    99  func dumpLengthEncodedInt(buffer []byte, n uint64) []byte {
   100  	switch {
   101  	case n <= 250:
   102  		return append(buffer, byte(n))
   103  
   104  	case n <= 0xffff:
   105  		return append(buffer, 0xfc, byte(n), byte(n>>8))
   106  
   107  	case n <= 0xffffff:
   108  		return append(buffer, 0xfd, byte(n), byte(n>>8), byte(n>>16))
   109  
   110  	case n <= 0xffffffffffffffff:
   111  		return append(buffer, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
   112  			byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
   113  	}
   114  
   115  	return buffer
   116  }
   117  
   118  func parseLengthEncodedBytes(b []byte) ([]byte, bool, int, error) {
   119  	// Get length
   120  	num, isNull, n := parseLengthEncodedInt(b)
   121  	if num < 1 {
   122  		return nil, isNull, n, nil
   123  	}
   124  
   125  	n += int(num)
   126  
   127  	// Check data length
   128  	if len(b) >= n {
   129  		return b[n-int(num) : n], false, n, nil
   130  	}
   131  
   132  	return nil, false, n, io.EOF
   133  }
   134  
   135  func dumpLengthEncodedString(buffer []byte, bytes []byte) []byte {
   136  	buffer = dumpLengthEncodedInt(buffer, uint64(len(bytes)))
   137  	buffer = append(buffer, bytes...)
   138  	return buffer
   139  }
   140  
   141  func dumpUint16(buffer []byte, n uint16) []byte {
   142  	buffer = append(buffer, byte(n))
   143  	buffer = append(buffer, byte(n>>8))
   144  	return buffer
   145  }
   146  
   147  func dumpUint32(buffer []byte, n uint32) []byte {
   148  	buffer = append(buffer, byte(n))
   149  	buffer = append(buffer, byte(n>>8))
   150  	buffer = append(buffer, byte(n>>16))
   151  	buffer = append(buffer, byte(n>>24))
   152  	return buffer
   153  }
   154  
   155  func dumpUint64(buffer []byte, n uint64) []byte {
   156  	buffer = append(buffer, byte(n))
   157  	buffer = append(buffer, byte(n>>8))
   158  	buffer = append(buffer, byte(n>>16))
   159  	buffer = append(buffer, byte(n>>24))
   160  	buffer = append(buffer, byte(n>>32))
   161  	buffer = append(buffer, byte(n>>40))
   162  	buffer = append(buffer, byte(n>>48))
   163  	buffer = append(buffer, byte(n>>56))
   164  	return buffer
   165  }
   166  
   167  func dumpBinaryTime(dur time.Duration) (data []byte) {
   168  	if dur == 0 {
   169  		return []byte{0}
   170  	}
   171  	data = make([]byte, 13)
   172  	data[0] = 12
   173  	if dur < 0 {
   174  		data[1] = 1
   175  		dur = -dur
   176  	}
   177  	days := dur / (24 * time.Hour)
   178  	dur -= days * 24 * time.Hour
   179  	data[2] = byte(days)
   180  	hours := dur / time.Hour
   181  	dur -= hours * time.Hour
   182  	data[6] = byte(hours)
   183  	minutes := dur / time.Minute
   184  	dur -= minutes * time.Minute
   185  	data[7] = byte(minutes)
   186  	seconds := dur / time.Second
   187  	dur -= seconds * time.Second
   188  	data[8] = byte(seconds)
   189  	if dur == 0 {
   190  		data[0] = 8
   191  		return data[:9]
   192  	}
   193  	binary.LittleEndian.PutUint32(data[9:13], uint32(dur/time.Microsecond))
   194  	return
   195  }
   196  
   197  func dumpBinaryDateTime(data []byte, t types.Time) []byte {
   198  	year, mon, day := t.Year(), t.Month(), t.Day()
   199  	switch t.Type() {
   200  	case allegrosql.TypeTimestamp, allegrosql.TypeDatetime:
   201  		if t.IsZero() {
   202  			data = append(data, 0)
   203  		} else {
   204  			data = append(data, 11)
   205  			data = dumpUint16(data, uint16(year))
   206  			data = append(data, byte(mon), byte(day), byte(t.Hour()), byte(t.Minute()), byte(t.Second()))
   207  			data = dumpUint32(data, uint32(t.Microsecond()))
   208  		}
   209  	case allegrosql.TypeDate:
   210  		if t.IsZero() {
   211  			data = append(data, 0)
   212  		} else {
   213  			data = append(data, 4)
   214  			data = dumpUint16(data, uint16(year)) //year
   215  			data = append(data, byte(mon), byte(day))
   216  		}
   217  	}
   218  	return data
   219  }
   220  
   221  func dumpBinaryRow(buffer []byte, defCausumns []*DeferredCausetInfo, event chunk.Row) ([]byte, error) {
   222  	buffer = append(buffer, allegrosql.OKHeader)
   223  	nullBitmapOff := len(buffer)
   224  	numBytes4Null := (len(defCausumns) + 7 + 2) / 8
   225  	for i := 0; i < numBytes4Null; i++ {
   226  		buffer = append(buffer, 0)
   227  	}
   228  	for i := range defCausumns {
   229  		if event.IsNull(i) {
   230  			bytePos := (i + 2) / 8
   231  			bitPos := byte((i + 2) % 8)
   232  			buffer[nullBitmapOff+bytePos] |= 1 << bitPos
   233  			continue
   234  		}
   235  		switch defCausumns[i].Type {
   236  		case allegrosql.TypeTiny:
   237  			buffer = append(buffer, byte(event.GetInt64(i)))
   238  		case allegrosql.TypeShort, allegrosql.TypeYear:
   239  			buffer = dumpUint16(buffer, uint16(event.GetInt64(i)))
   240  		case allegrosql.TypeInt24, allegrosql.TypeLong:
   241  			buffer = dumpUint32(buffer, uint32(event.GetInt64(i)))
   242  		case allegrosql.TypeLonglong:
   243  			buffer = dumpUint64(buffer, event.GetUint64(i))
   244  		case allegrosql.TypeFloat:
   245  			buffer = dumpUint32(buffer, math.Float32bits(event.GetFloat32(i)))
   246  		case allegrosql.TypeDouble:
   247  			buffer = dumpUint64(buffer, math.Float64bits(event.GetFloat64(i)))
   248  		case allegrosql.TypeNewDecimal:
   249  			buffer = dumpLengthEncodedString(buffer, replog.Slice(event.GetMyDecimal(i).String()))
   250  		case allegrosql.TypeString, allegrosql.TypeVarString, allegrosql.TypeVarchar, allegrosql.TypeBit,
   251  			allegrosql.TypeTinyBlob, allegrosql.TypeMediumBlob, allegrosql.TypeLongBlob, allegrosql.TypeBlob:
   252  			buffer = dumpLengthEncodedString(buffer, event.GetBytes(i))
   253  		case allegrosql.TypeDate, allegrosql.TypeDatetime, allegrosql.TypeTimestamp:
   254  			buffer = dumpBinaryDateTime(buffer, event.GetTime(i))
   255  		case allegrosql.TypeDuration:
   256  			buffer = append(buffer, dumpBinaryTime(event.GetDuration(i, 0).Duration)...)
   257  		case allegrosql.TypeEnum:
   258  			buffer = dumpLengthEncodedString(buffer, replog.Slice(event.GetEnum(i).String()))
   259  		case allegrosql.TypeSet:
   260  			buffer = dumpLengthEncodedString(buffer, replog.Slice(event.GetSet(i).String()))
   261  		case allegrosql.TypeJSON:
   262  			buffer = dumpLengthEncodedString(buffer, replog.Slice(event.GetJSON(i).String()))
   263  		default:
   264  			return nil, errInvalidType.GenWithStack("invalid type %v", defCausumns[i].Type)
   265  		}
   266  	}
   267  	return buffer, nil
   268  }
   269  
   270  func dumpTextRow(buffer []byte, defCausumns []*DeferredCausetInfo, event chunk.Row) ([]byte, error) {
   271  	tmp := make([]byte, 0, 20)
   272  	for i, defCaus := range defCausumns {
   273  		if event.IsNull(i) {
   274  			buffer = append(buffer, 0xfb)
   275  			continue
   276  		}
   277  		switch defCaus.Type {
   278  		case allegrosql.TypeTiny, allegrosql.TypeShort, allegrosql.TypeInt24, allegrosql.TypeLong:
   279  			tmp = strconv.AppendInt(tmp[:0], event.GetInt64(i), 10)
   280  			buffer = dumpLengthEncodedString(buffer, tmp)
   281  		case allegrosql.TypeYear:
   282  			year := event.GetInt64(i)
   283  			tmp = tmp[:0]
   284  			if year == 0 {
   285  				tmp = append(tmp, '0', '0', '0', '0')
   286  			} else {
   287  				tmp = strconv.AppendInt(tmp, year, 10)
   288  			}
   289  			buffer = dumpLengthEncodedString(buffer, tmp)
   290  		case allegrosql.TypeLonglong:
   291  			if allegrosql.HasUnsignedFlag(uint(defCausumns[i].Flag)) {
   292  				tmp = strconv.AppendUint(tmp[:0], event.GetUint64(i), 10)
   293  			} else {
   294  				tmp = strconv.AppendInt(tmp[:0], event.GetInt64(i), 10)
   295  			}
   296  			buffer = dumpLengthEncodedString(buffer, tmp)
   297  		case allegrosql.TypeFloat:
   298  			prec := -1
   299  			if defCausumns[i].Decimal > 0 && int(defCaus.Decimal) != allegrosql.NotFixedDec {
   300  				prec = int(defCaus.Decimal)
   301  			}
   302  			tmp = appendFormatFloat(tmp[:0], float64(event.GetFloat32(i)), prec, 32)
   303  			buffer = dumpLengthEncodedString(buffer, tmp)
   304  		case allegrosql.TypeDouble:
   305  			prec := types.UnspecifiedLength
   306  			if defCaus.Decimal > 0 && int(defCaus.Decimal) != allegrosql.NotFixedDec {
   307  				prec = int(defCaus.Decimal)
   308  			}
   309  			tmp = appendFormatFloat(tmp[:0], event.GetFloat64(i), prec, 64)
   310  			buffer = dumpLengthEncodedString(buffer, tmp)
   311  		case allegrosql.TypeNewDecimal:
   312  			buffer = dumpLengthEncodedString(buffer, replog.Slice(event.GetMyDecimal(i).String()))
   313  		case allegrosql.TypeString, allegrosql.TypeVarString, allegrosql.TypeVarchar, allegrosql.TypeBit,
   314  			allegrosql.TypeTinyBlob, allegrosql.TypeMediumBlob, allegrosql.TypeLongBlob, allegrosql.TypeBlob:
   315  			buffer = dumpLengthEncodedString(buffer, event.GetBytes(i))
   316  		case allegrosql.TypeDate, allegrosql.TypeDatetime, allegrosql.TypeTimestamp:
   317  			buffer = dumpLengthEncodedString(buffer, replog.Slice(event.GetTime(i).String()))
   318  		case allegrosql.TypeDuration:
   319  			dur := event.GetDuration(i, int(defCaus.Decimal))
   320  			buffer = dumpLengthEncodedString(buffer, replog.Slice(dur.String()))
   321  		case allegrosql.TypeEnum:
   322  			buffer = dumpLengthEncodedString(buffer, replog.Slice(event.GetEnum(i).String()))
   323  		case allegrosql.TypeSet:
   324  			buffer = dumpLengthEncodedString(buffer, replog.Slice(event.GetSet(i).String()))
   325  		case allegrosql.TypeJSON:
   326  			buffer = dumpLengthEncodedString(buffer, replog.Slice(event.GetJSON(i).String()))
   327  		default:
   328  			return nil, errInvalidType.GenWithStack("invalid type %v", defCausumns[i].Type)
   329  		}
   330  	}
   331  	return buffer, nil
   332  }
   333  
   334  func lengthEncodedIntSize(n uint64) int {
   335  	switch {
   336  	case n <= 250:
   337  		return 1
   338  
   339  	case n <= 0xffff:
   340  		return 3
   341  
   342  	case n <= 0xffffff:
   343  		return 4
   344  	}
   345  
   346  	return 9
   347  }
   348  
   349  const (
   350  	expFormatBig   = 1e15
   351  	expFormatSmall = 1e-15
   352  )
   353  
   354  func appendFormatFloat(in []byte, fVal float64, prec, bitSize int) []byte {
   355  	absVal := math.Abs(fVal)
   356  	var out []byte
   357  	if prec == types.UnspecifiedLength && (absVal >= expFormatBig || (absVal != 0 && absVal < expFormatSmall)) {
   358  		out = strconv.AppendFloat(in, fVal, 'e', prec, bitSize)
   359  		valStr := out[len(in):]
   360  		// remove the '+' from the string for compatibility.
   361  		plusPos := bytes.IndexByte(valStr, '+')
   362  		if plusPos > 0 {
   363  			plusPosInOut := len(in) + plusPos
   364  			out = append(out[:plusPosInOut], out[plusPosInOut+1:]...)
   365  		}
   366  	} else {
   367  		out = strconv.AppendFloat(in, fVal, 'f', prec, bitSize)
   368  	}
   369  	return out
   370  }
   371  
   372  // CorsHandler adds Cors Header if `cors` config is set.
   373  type CorsHandler struct {
   374  	handler http.Handler
   375  	cfg     *config.Config
   376  }
   377  
   378  func (h CorsHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   379  	if h.cfg.Cors != "" {
   380  		w.Header().Set("Access-Control-Allow-Origin", h.cfg.Cors)
   381  		w.Header().Set("Access-Control-Allow-Methods", "GET")
   382  	}
   383  	h.handler.ServeHTTP(w, req)
   384  }