github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/allegrosql/server/milevadb_lock_mut.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package server
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"sync/atomic"
    20  
    21  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    22  	"github.com/whtcorpsinc/BerolinaSQL/ast"
    23  	"github.com/whtcorpsinc/BerolinaSQL/charset"
    24  	"github.com/whtcorpsinc/BerolinaSQL/terror"
    25  	"github.com/whtcorpsinc/milevadb/causet/embedded"
    26  	"github.com/whtcorpsinc/milevadb/ekv"
    27  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    28  	"github.com/whtcorpsinc/milevadb/soliton/sqlexec"
    29  	"github.com/whtcorpsinc/milevadb/stochastik"
    30  	"github.com/whtcorpsinc/milevadb/types"
    31  )
    32  
    33  // MilevaDBDriver implements IDriver.
    34  type MilevaDBDriver struct {
    35  	causetstore ekv.CausetStorage
    36  }
    37  
    38  // NewMilevaDBDriver creates a new MilevaDBDriver.
    39  func NewMilevaDBDriver(causetstore ekv.CausetStorage) *MilevaDBDriver {
    40  	driver := &MilevaDBDriver{
    41  		causetstore: causetstore,
    42  	}
    43  	return driver
    44  }
    45  
    46  // MilevaDBContext implements QueryCtx.
    47  type MilevaDBContext struct {
    48  	stochastik.Stochastik
    49  	currentDB string
    50  	stmts     map[int]*MilevaDBStatement
    51  }
    52  
    53  // MilevaDBStatement implements PreparedStatement.
    54  type MilevaDBStatement struct {
    55  	id          uint32
    56  	numParams   int
    57  	boundParams [][]byte
    58  	paramsType  []byte
    59  	ctx         *MilevaDBContext
    60  	rs          ResultSet
    61  	allegrosql  string
    62  }
    63  
    64  // ID implements PreparedStatement ID method.
    65  func (ts *MilevaDBStatement) ID() int {
    66  	return int(ts.id)
    67  }
    68  
    69  // InterDircute implements PreparedStatement InterDircute method.
    70  func (ts *MilevaDBStatement) InterDircute(ctx context.Context, args []types.Causet) (rs ResultSet, err error) {
    71  	milevadbRecordset, err := ts.ctx.InterDircutePreparedStmt(ctx, ts.id, args)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	if milevadbRecordset == nil {
    76  		return
    77  	}
    78  	rs = &milevadbResultSet{
    79  		recordSet:    milevadbRecordset,
    80  		preparedStmt: ts.ctx.GetStochastikVars().PreparedStmts[ts.id].(*embedded.CachedPrepareStmt),
    81  	}
    82  	return
    83  }
    84  
    85  // AppendParam implements PreparedStatement AppendParam method.
    86  func (ts *MilevaDBStatement) AppendParam(paramID int, data []byte) error {
    87  	if paramID >= len(ts.boundParams) {
    88  		return allegrosql.NewErr(allegrosql.ErrWrongArguments, "stmt_send_longdata")
    89  	}
    90  	// If len(data) is 0, append an empty byte slice to the end to distinguish no data and no parameter.
    91  	if len(data) == 0 {
    92  		ts.boundParams[paramID] = []byte{}
    93  	} else {
    94  		ts.boundParams[paramID] = append(ts.boundParams[paramID], data...)
    95  	}
    96  	return nil
    97  }
    98  
    99  // NumParams implements PreparedStatement NumParams method.
   100  func (ts *MilevaDBStatement) NumParams() int {
   101  	return ts.numParams
   102  }
   103  
   104  // BoundParams implements PreparedStatement BoundParams method.
   105  func (ts *MilevaDBStatement) BoundParams() [][]byte {
   106  	return ts.boundParams
   107  }
   108  
   109  // SetParamsType implements PreparedStatement SetParamsType method.
   110  func (ts *MilevaDBStatement) SetParamsType(paramsType []byte) {
   111  	ts.paramsType = paramsType
   112  }
   113  
   114  // GetParamsType implements PreparedStatement GetParamsType method.
   115  func (ts *MilevaDBStatement) GetParamsType() []byte {
   116  	return ts.paramsType
   117  }
   118  
   119  // StoreResultSet stores ResultSet for stmt fetching
   120  func (ts *MilevaDBStatement) StoreResultSet(rs ResultSet) {
   121  	// refer to https://dev.allegrosql.com/doc/refman/5.7/en/cursor-restrictions.html
   122  	// You can have open only a single cursor per prepared memex.
   123  	// closing previous ResultSet before associating a new ResultSet with this memex
   124  	// if it exists
   125  	if ts.rs != nil {
   126  		terror.Call(ts.rs.Close)
   127  	}
   128  	ts.rs = rs
   129  }
   130  
   131  // GetResultSet gets ResultSet associated this memex
   132  func (ts *MilevaDBStatement) GetResultSet() ResultSet {
   133  	return ts.rs
   134  }
   135  
   136  // Reset implements PreparedStatement Reset method.
   137  func (ts *MilevaDBStatement) Reset() {
   138  	for i := range ts.boundParams {
   139  		ts.boundParams[i] = nil
   140  	}
   141  
   142  	// closing previous ResultSet if it exists
   143  	if ts.rs != nil {
   144  		terror.Call(ts.rs.Close)
   145  		ts.rs = nil
   146  	}
   147  }
   148  
   149  // Close implements PreparedStatement Close method.
   150  func (ts *MilevaDBStatement) Close() error {
   151  	//TODO close at milevadb level
   152  	err := ts.ctx.DropPreparedStmt(ts.id)
   153  	if err != nil {
   154  		return err
   155  	}
   156  	delete(ts.ctx.stmts, int(ts.id))
   157  
   158  	// close ResultSet associated with this memex
   159  	if ts.rs != nil {
   160  		terror.Call(ts.rs.Close)
   161  	}
   162  	return nil
   163  }
   164  
   165  // OpenCtx implements IDriver.
   166  func (qd *MilevaDBDriver) OpenCtx(connID uint64, capability uint32, defCauslation uint8, dbname string, tlsState *tls.ConnectionState) (*MilevaDBContext, error) {
   167  	se, err := stochastik.CreateStochastik(qd.causetstore)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  	se.SetTLSState(tlsState)
   172  	err = se.SetDefCauslation(int(defCauslation))
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	se.SetClientCapability(capability)
   177  	se.SetConnectionID(connID)
   178  	tc := &MilevaDBContext{
   179  		Stochastik: se,
   180  		currentDB:  dbname,
   181  		stmts:      make(map[int]*MilevaDBStatement),
   182  	}
   183  	return tc, nil
   184  }
   185  
   186  // CurrentDB implements QueryCtx CurrentDB method.
   187  func (tc *MilevaDBContext) CurrentDB() string {
   188  	return tc.currentDB
   189  }
   190  
   191  // WarningCount implements QueryCtx WarningCount method.
   192  func (tc *MilevaDBContext) WarningCount() uint16 {
   193  	return tc.GetStochastikVars().StmtCtx.WarningCount()
   194  }
   195  
   196  // InterDircuteStmt implements QueryCtx interface.
   197  func (tc *MilevaDBContext) InterDircuteStmt(ctx context.Context, stmt ast.StmtNode) (ResultSet, error) {
   198  	rs, err := tc.Stochastik.InterDircuteStmt(ctx, stmt)
   199  	if err != nil {
   200  		return nil, err
   201  	}
   202  	if rs == nil {
   203  		return nil, nil
   204  	}
   205  	return &milevadbResultSet{
   206  		recordSet: rs,
   207  	}, nil
   208  }
   209  
   210  // Close implements QueryCtx Close method.
   211  func (tc *MilevaDBContext) Close() error {
   212  	// close PreparedStatement associated with this connection
   213  	for _, v := range tc.stmts {
   214  		terror.Call(v.Close)
   215  	}
   216  
   217  	tc.Stochastik.Close()
   218  	return nil
   219  }
   220  
   221  // FieldList implements QueryCtx FieldList method.
   222  func (tc *MilevaDBContext) FieldList(causet string) (defCausumns []*DeferredCausetInfo, err error) {
   223  	fields, err := tc.Stochastik.FieldList(causet)
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  	defCausumns = make([]*DeferredCausetInfo, 0, len(fields))
   228  	for _, f := range fields {
   229  		defCausumns = append(defCausumns, convertDeferredCausetInfo(f))
   230  	}
   231  	return defCausumns, nil
   232  }
   233  
   234  // GetStatement implements QueryCtx GetStatement method.
   235  func (tc *MilevaDBContext) GetStatement(stmtID int) PreparedStatement {
   236  	tcStmt := tc.stmts[stmtID]
   237  	if tcStmt != nil {
   238  		return tcStmt
   239  	}
   240  	return nil
   241  }
   242  
   243  // Prepare implements QueryCtx Prepare method.
   244  func (tc *MilevaDBContext) Prepare(allegrosql string) (memex PreparedStatement, defCausumns, params []*DeferredCausetInfo, err error) {
   245  	stmtID, paramCount, fields, err := tc.Stochastik.PrepareStmt(allegrosql)
   246  	if err != nil {
   247  		return
   248  	}
   249  	stmt := &MilevaDBStatement{
   250  		allegrosql:  allegrosql,
   251  		id:          stmtID,
   252  		numParams:   paramCount,
   253  		boundParams: make([][]byte, paramCount),
   254  		ctx:         tc,
   255  	}
   256  	memex = stmt
   257  	defCausumns = make([]*DeferredCausetInfo, len(fields))
   258  	for i := range fields {
   259  		defCausumns[i] = convertDeferredCausetInfo(fields[i])
   260  	}
   261  	params = make([]*DeferredCausetInfo, paramCount)
   262  	for i := range params {
   263  		params[i] = &DeferredCausetInfo{
   264  			Type: allegrosql.TypeBlob,
   265  		}
   266  	}
   267  	tc.stmts[int(stmtID)] = stmt
   268  	return
   269  }
   270  
   271  type milevadbResultSet struct {
   272  	recordSet    sqlexec.RecordSet
   273  	defCausumns  []*DeferredCausetInfo
   274  	rows         []chunk.Row
   275  	closed       int32
   276  	preparedStmt *embedded.CachedPrepareStmt
   277  }
   278  
   279  func (trs *milevadbResultSet) NewChunk() *chunk.Chunk {
   280  	return trs.recordSet.NewChunk()
   281  }
   282  
   283  func (trs *milevadbResultSet) Next(ctx context.Context, req *chunk.Chunk) error {
   284  	return trs.recordSet.Next(ctx, req)
   285  }
   286  
   287  func (trs *milevadbResultSet) StoreFetchedRows(rows []chunk.Row) {
   288  	trs.rows = rows
   289  }
   290  
   291  func (trs *milevadbResultSet) GetFetchedRows() []chunk.Row {
   292  	if trs.rows == nil {
   293  		trs.rows = make([]chunk.Row, 0, 1024)
   294  	}
   295  	return trs.rows
   296  }
   297  
   298  func (trs *milevadbResultSet) Close() error {
   299  	if !atomic.CompareAndSwapInt32(&trs.closed, 0, 1) {
   300  		return nil
   301  	}
   302  	err := trs.recordSet.Close()
   303  	trs.recordSet = nil
   304  	return err
   305  }
   306  
   307  // OnFetchReturned implements fetchNotifier#OnFetchReturned
   308  func (trs *milevadbResultSet) OnFetchReturned() {
   309  	if cl, ok := trs.recordSet.(fetchNotifier); ok {
   310  		cl.OnFetchReturned()
   311  	}
   312  }
   313  
   314  func (trs *milevadbResultSet) DeferredCausets() []*DeferredCausetInfo {
   315  	if trs.defCausumns != nil {
   316  		return trs.defCausumns
   317  	}
   318  	// for prepare memex, try to get cached defCausumnInfo array
   319  	if trs.preparedStmt != nil {
   320  		ps := trs.preparedStmt
   321  		if defCausInfos, ok := ps.DeferredCausetInfos.([]*DeferredCausetInfo); ok {
   322  			trs.defCausumns = defCausInfos
   323  		}
   324  	}
   325  	if trs.defCausumns == nil {
   326  		fields := trs.recordSet.Fields()
   327  		for _, v := range fields {
   328  			trs.defCausumns = append(trs.defCausumns, convertDeferredCausetInfo(v))
   329  		}
   330  		if trs.preparedStmt != nil {
   331  			// if DeferredCausetInfo struct has allocated object,
   332  			// here maybe we need deep copy DeferredCausetInfo to do caching
   333  			trs.preparedStmt.DeferredCausetInfos = trs.defCausumns
   334  		}
   335  	}
   336  	return trs.defCausumns
   337  }
   338  
   339  func convertDeferredCausetInfo(fld *ast.ResultField) (ci *DeferredCausetInfo) {
   340  	ci = &DeferredCausetInfo{
   341  		Name:    fld.DeferredCausetAsName.O,
   342  		OrgName: fld.DeferredCauset.Name.O,
   343  		Block:   fld.BlockAsName.O,
   344  		Schema:  fld.DBName.O,
   345  		Flag:    uint16(fld.DeferredCauset.Flag),
   346  		Charset: uint16(allegrosql.CharsetNameToID(fld.DeferredCauset.Charset)),
   347  		Type:    fld.DeferredCauset.Tp,
   348  	}
   349  
   350  	if fld.Block != nil {
   351  		ci.OrgBlock = fld.Block.Name.O
   352  	}
   353  	if fld.DeferredCauset.Flen == types.UnspecifiedLength {
   354  		ci.DeferredCausetLength = 0
   355  	} else {
   356  		ci.DeferredCausetLength = uint32(fld.DeferredCauset.Flen)
   357  	}
   358  	if fld.DeferredCauset.Tp == allegrosql.TypeNewDecimal {
   359  		// Consider the negative sign.
   360  		ci.DeferredCausetLength++
   361  		if fld.DeferredCauset.Decimal > int(types.DefaultFsp) {
   362  			// Consider the decimal point.
   363  			ci.DeferredCausetLength++
   364  		}
   365  	} else if types.IsString(fld.DeferredCauset.Tp) ||
   366  		fld.DeferredCauset.Tp == allegrosql.TypeEnum || fld.DeferredCauset.Tp == allegrosql.TypeSet { // issue #18870
   367  		// Fix issue #4540.
   368  		// The flen is a hint, not a precise value, so most client will not use the value.
   369  		// But we found in rare MyALLEGROSQL client, like Navicat for MyALLEGROSQL(version before 12) will truncate
   370  		// the `show create causet` result. To fix this case, we must use a large enough flen to prevent
   371  		// the truncation, in MyALLEGROSQL, it will multiply bytes length by a multiple based on character set.
   372  		// For examples:
   373  		// * latin, the multiple is 1
   374  		// * gb2312, the multiple is 2
   375  		// * Utf-8, the multiple is 3
   376  		// * utf8mb4, the multiple is 4
   377  		// We used to check non-string types to avoid the truncation problem in some MyALLEGROSQL
   378  		// client such as Navicat. Now we only allow string type enter this branch.
   379  		charsetDesc, err := charset.GetCharsetDesc(fld.DeferredCauset.Charset)
   380  		if err != nil {
   381  			ci.DeferredCausetLength = ci.DeferredCausetLength * 4
   382  		} else {
   383  			ci.DeferredCausetLength = ci.DeferredCausetLength * uint32(charsetDesc.Maxlen)
   384  		}
   385  	}
   386  
   387  	if fld.DeferredCauset.Decimal == types.UnspecifiedLength {
   388  		if fld.DeferredCauset.Tp == allegrosql.TypeDuration {
   389  			ci.Decimal = uint8(types.DefaultFsp)
   390  		} else {
   391  			ci.Decimal = allegrosql.NotFixedDec
   392  		}
   393  	} else {
   394  		ci.Decimal = uint8(fld.DeferredCauset.Decimal)
   395  	}
   396  
   397  	// Keep things compatible for old clients.
   398  	// Refer to allegrosql-server/allegrosql/protodefCaus.cc send_result_set_spacetimedata()
   399  	if ci.Type == allegrosql.TypeVarchar {
   400  		ci.Type = allegrosql.TypeVarString
   401  	}
   402  	return
   403  }