github.com/dolthub/go-mysql-server@v0.18.0/server/extension.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package server 16 17 import ( 18 "sort" 19 20 "github.com/dolthub/vitess/go/mysql" 21 "github.com/dolthub/vitess/go/sqltypes" 22 querypb "github.com/dolthub/vitess/go/vt/proto/query" 23 "github.com/dolthub/vitess/go/vt/sqlparser" 24 ast "github.com/dolthub/vitess/go/vt/sqlparser" 25 26 sqle "github.com/dolthub/go-mysql-server" 27 ) 28 29 func Intercept(h Interceptor) { 30 inters = append(inters, h) 31 sort.Slice(inters, func(i, j int) bool { return inters[i].Priority() < inters[j].Priority() }) 32 } 33 34 func WithChain() Option { 35 return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) { 36 f := DefaultProtocolListenerFunc 37 DefaultProtocolListenerFunc = func(cfg mysql.ListenerConfig) (ProtocolListener, error) { 38 cfg.Handler = buildChain(cfg.Handler) 39 return f(cfg) 40 } 41 } 42 } 43 44 var inters []Interceptor 45 46 func buildChain(h mysql.Handler) mysql.Handler { 47 var last Chain = h 48 for i := len(inters) - 1; i >= 0; i-- { 49 filter := inters[i] 50 next := last 51 last = &chainInterceptor{i: filter, c: next} 52 } 53 return &interceptorHandler{h: h, c: last} 54 } 55 56 type Interceptor interface { 57 58 // Priority returns the priority of the interceptor. 59 Priority() int 60 61 // Query is called when a connection receives a query. 62 // Note the contents of the query slice may change after 63 // the first call to callback. So the Handler should not 64 // hang on to the byte slice. 65 Query(chain Chain, c *mysql.Conn, query string, callback func(res *sqltypes.Result, more bool) error) error 66 67 // ParsedQuery is called when a connection receives a 68 // query that has already been parsed. Note the contents 69 // of the query slice may change after the first call to 70 // callback. So the Handler should not hang on to the byte 71 // slice. 72 ParsedQuery(chain Chain, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(res *sqltypes.Result, more bool) error) error 73 74 // MultiQuery is called when a connection receives a query and the 75 // client supports MULTI_STATEMENT. It should process the first 76 // statement in |query| and return the remainder. It will be called 77 // multiple times until the remainder is |""|. 78 MultiQuery(chain Chain, c *mysql.Conn, query string, callback func(res *sqltypes.Result, more bool) error) (string, error) 79 80 // Prepare is called when a connection receives a prepared 81 // statement query. 82 Prepare(chain Chain, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) 83 84 // StmtExecute is called when a connection receives a statement 85 // execute query. 86 StmtExecute(chain Chain, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error 87 } 88 89 type Chain interface { 90 91 // ComQuery is called when a connection receives a query. 92 // Note the contents of the query slice may change after 93 // the first call to callback. So the Handler should not 94 // hang on to the byte slice. 95 ComQuery(c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error 96 97 // ComMultiQuery is called when a connection receives a query and the 98 // client supports MULTI_STATEMENT. It should process the first 99 // statement in |query| and return the remainder. It will be called 100 // multiple times until the remainder is |""|. 101 ComMultiQuery(c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) 102 103 // ComPrepare is called when a connection receives a prepared 104 // statement query. 105 ComPrepare(c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) 106 107 // ComStmtExecute is called when a connection receives a statement 108 // execute query. 109 ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error 110 } 111 112 type chainInterceptor struct { 113 i Interceptor 114 c Chain 115 } 116 117 func (ci *chainInterceptor) ComQuery(c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error { 118 return ci.i.Query(ci.c, c, query, callback) 119 } 120 121 func (ci *chainInterceptor) ComMultiQuery(c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) { 122 return ci.i.MultiQuery(ci.c, c, query, callback) 123 } 124 125 func (ci *chainInterceptor) ComPrepare(c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) { 126 return ci.i.Prepare(ci.c, c, query, prepare) 127 } 128 129 func (ci *chainInterceptor) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { 130 return ci.i.StmtExecute(ci.c, c, prepare, callback) 131 } 132 133 type interceptorHandler struct { 134 c Chain 135 h mysql.Handler 136 } 137 138 func (ih *interceptorHandler) NewConnection(c *mysql.Conn) { 139 ih.h.NewConnection(c) 140 } 141 142 func (ih *interceptorHandler) ConnectionClosed(c *mysql.Conn) { 143 ih.h.ConnectionClosed(c) 144 } 145 146 func (ih *interceptorHandler) ComInitDB(c *mysql.Conn, schemaName string) error { 147 return ih.h.ComInitDB(c, schemaName) 148 } 149 150 func (ih *interceptorHandler) ComQuery(c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error { 151 return ih.c.ComQuery(c, query, callback) 152 } 153 154 func (ih *interceptorHandler) ComMultiQuery(c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) { 155 return ih.c.ComMultiQuery(c, query, callback) 156 } 157 158 func (ih *interceptorHandler) ComPrepare(c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) { 159 return ih.c.ComPrepare(c, query, prepare) 160 } 161 162 func (ih *interceptorHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { 163 return ih.c.ComStmtExecute(c, prepare, callback) 164 } 165 166 func (ih *interceptorHandler) WarningCount(c *mysql.Conn) uint16 { 167 return ih.h.WarningCount(c) 168 } 169 170 func (ih *interceptorHandler) ComResetConnection(c *mysql.Conn) error { 171 return ih.h.ComResetConnection(c) 172 } 173 174 func (ih *interceptorHandler) ParserOptionsForConnection(c *mysql.Conn) (ast.ParserOptions, error) { 175 return ih.h.ParserOptionsForConnection(c) 176 }