github.com/dolthub/go-mysql-server@v0.18.0/server/extension.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package server
    16  
    17  import (
    18  	"sort"
    19  
    20  	"github.com/dolthub/vitess/go/mysql"
    21  	"github.com/dolthub/vitess/go/sqltypes"
    22  	querypb "github.com/dolthub/vitess/go/vt/proto/query"
    23  	"github.com/dolthub/vitess/go/vt/sqlparser"
    24  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    25  
    26  	sqle "github.com/dolthub/go-mysql-server"
    27  )
    28  
    29  func Intercept(h Interceptor) {
    30  	inters = append(inters, h)
    31  	sort.Slice(inters, func(i, j int) bool { return inters[i].Priority() < inters[j].Priority() })
    32  }
    33  
    34  func WithChain() Option {
    35  	return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) {
    36  		f := DefaultProtocolListenerFunc
    37  		DefaultProtocolListenerFunc = func(cfg mysql.ListenerConfig) (ProtocolListener, error) {
    38  			cfg.Handler = buildChain(cfg.Handler)
    39  			return f(cfg)
    40  		}
    41  	}
    42  }
    43  
    44  var inters []Interceptor
    45  
    46  func buildChain(h mysql.Handler) mysql.Handler {
    47  	var last Chain = h
    48  	for i := len(inters) - 1; i >= 0; i-- {
    49  		filter := inters[i]
    50  		next := last
    51  		last = &chainInterceptor{i: filter, c: next}
    52  	}
    53  	return &interceptorHandler{h: h, c: last}
    54  }
    55  
    56  type Interceptor interface {
    57  
    58  	// Priority returns the priority of the interceptor.
    59  	Priority() int
    60  
    61  	// Query is called when a connection receives a query.
    62  	// Note the contents of the query slice may change after
    63  	// the first call to callback. So the Handler should not
    64  	// hang on to the byte slice.
    65  	Query(chain Chain, c *mysql.Conn, query string, callback func(res *sqltypes.Result, more bool) error) error
    66  
    67  	// ParsedQuery is called when a connection receives a
    68  	// query that has already been parsed. Note the contents
    69  	// of the query slice may change after the first call to
    70  	// callback. So the Handler should not hang on to the byte
    71  	// slice.
    72  	ParsedQuery(chain Chain, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(res *sqltypes.Result, more bool) error) error
    73  
    74  	// MultiQuery is called when a connection receives a query and the
    75  	// client supports MULTI_STATEMENT. It should process the first
    76  	// statement in |query| and return the remainder. It will be called
    77  	// multiple times until the remainder is |""|.
    78  	MultiQuery(chain Chain, c *mysql.Conn, query string, callback func(res *sqltypes.Result, more bool) error) (string, error)
    79  
    80  	// Prepare is called when a connection receives a prepared
    81  	// statement query.
    82  	Prepare(chain Chain, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error)
    83  
    84  	// StmtExecute is called when a connection receives a statement
    85  	// execute query.
    86  	StmtExecute(chain Chain, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error
    87  }
    88  
    89  type Chain interface {
    90  
    91  	// ComQuery is called when a connection receives a query.
    92  	// Note the contents of the query slice may change after
    93  	// the first call to callback. So the Handler should not
    94  	// hang on to the byte slice.
    95  	ComQuery(c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error
    96  
    97  	// ComMultiQuery is called when a connection receives a query and the
    98  	// client supports MULTI_STATEMENT. It should process the first
    99  	// statement in |query| and return the remainder. It will be called
   100  	// multiple times until the remainder is |""|.
   101  	ComMultiQuery(c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error)
   102  
   103  	// ComPrepare is called when a connection receives a prepared
   104  	// statement query.
   105  	ComPrepare(c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error)
   106  
   107  	// ComStmtExecute is called when a connection receives a statement
   108  	// execute query.
   109  	ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error
   110  }
   111  
   112  type chainInterceptor struct {
   113  	i Interceptor
   114  	c Chain
   115  }
   116  
   117  func (ci *chainInterceptor) ComQuery(c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error {
   118  	return ci.i.Query(ci.c, c, query, callback)
   119  }
   120  
   121  func (ci *chainInterceptor) ComMultiQuery(c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) {
   122  	return ci.i.MultiQuery(ci.c, c, query, callback)
   123  }
   124  
   125  func (ci *chainInterceptor) ComPrepare(c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) {
   126  	return ci.i.Prepare(ci.c, c, query, prepare)
   127  }
   128  
   129  func (ci *chainInterceptor) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
   130  	return ci.i.StmtExecute(ci.c, c, prepare, callback)
   131  }
   132  
   133  type interceptorHandler struct {
   134  	c Chain
   135  	h mysql.Handler
   136  }
   137  
   138  func (ih *interceptorHandler) NewConnection(c *mysql.Conn) {
   139  	ih.h.NewConnection(c)
   140  }
   141  
   142  func (ih *interceptorHandler) ConnectionClosed(c *mysql.Conn) {
   143  	ih.h.ConnectionClosed(c)
   144  }
   145  
   146  func (ih *interceptorHandler) ComInitDB(c *mysql.Conn, schemaName string) error {
   147  	return ih.h.ComInitDB(c, schemaName)
   148  }
   149  
   150  func (ih *interceptorHandler) ComQuery(c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error {
   151  	return ih.c.ComQuery(c, query, callback)
   152  }
   153  
   154  func (ih *interceptorHandler) ComMultiQuery(c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) {
   155  	return ih.c.ComMultiQuery(c, query, callback)
   156  }
   157  
   158  func (ih *interceptorHandler) ComPrepare(c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) {
   159  	return ih.c.ComPrepare(c, query, prepare)
   160  }
   161  
   162  func (ih *interceptorHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
   163  	return ih.c.ComStmtExecute(c, prepare, callback)
   164  }
   165  
   166  func (ih *interceptorHandler) WarningCount(c *mysql.Conn) uint16 {
   167  	return ih.h.WarningCount(c)
   168  }
   169  
   170  func (ih *interceptorHandler) ComResetConnection(c *mysql.Conn) error {
   171  	return ih.h.ComResetConnection(c)
   172  }
   173  
   174  func (ih *interceptorHandler) ParserOptionsForConnection(c *mysql.Conn) (ast.ParserOptions, error) {
   175  	return ih.h.ParserOptionsForConnection(c)
   176  }