vitess.io/vitess@v0.16.2/go/mysql/mysql_fuzzer.go (about) 1 //go:build gofuzz 2 // +build gofuzz 3 4 /* 5 Copyright 2021 The Vitess Authors. 6 7 Licensed under the Apache License, Version 2.0 (the "License"); 8 you may not use this file except in compliance with the License. 9 You may obtain a copy of the License at 10 11 http://www.apache.org/licenses/LICENSE-2.0 12 13 Unless required by applicable law or agreed to in writing, software 14 distributed under the License is distributed on an "AS IS" BASIS, 15 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 See the License for the specific language governing permissions and 17 limitations under the License. 18 */ 19 20 package mysql 21 22 import ( 23 "context" 24 "crypto/tls" 25 "fmt" 26 "net" 27 "os" 28 "path" 29 "sync" 30 "time" 31 32 gofuzzheaders "github.com/AdaLogics/go-fuzz-headers" 33 34 "vitess.io/vitess/go/sqltypes" 35 querypb "vitess.io/vitess/go/vt/proto/query" 36 "vitess.io/vitess/go/vt/tlstest" 37 "vitess.io/vitess/go/vt/vttls" 38 ) 39 40 func createFuzzingSocketPair() (net.Listener, *Conn, *Conn) { 41 // Create a listener. 42 listener, err := net.Listen("tcp", "127.0.0.1:") 43 if err != nil { 44 fmt.Println("We got an error early on") 45 return nil, nil, nil 46 } 47 addr := listener.Addr().String() 48 listener.(*net.TCPListener).SetDeadline(time.Now().Add(10 * time.Second)) 49 50 // Dial a client, Accept a server. 51 wg := sync.WaitGroup{} 52 53 var clientConn net.Conn 54 var clientErr error 55 wg.Add(1) 56 go func() { 57 defer wg.Done() 58 clientConn, clientErr = net.DialTimeout("tcp", addr, 10*time.Second) 59 }() 60 61 var serverConn net.Conn 62 var serverErr error 63 wg.Add(1) 64 go func() { 65 defer wg.Done() 66 serverConn, serverErr = listener.Accept() 67 }() 68 69 wg.Wait() 70 71 if clientErr != nil { 72 return nil, nil, nil 73 } 74 if serverErr != nil { 75 return nil, nil, nil 76 } 77 78 // Create a Conn on both sides. 79 cConn := newConn(clientConn) 80 sConn := newConn(serverConn) 81 82 return listener, sConn, cConn 83 } 84 85 type fuzztestRun struct { 86 UnimplementedHandler 87 } 88 89 func (t fuzztestRun) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error { 90 return nil 91 } 92 93 func (t fuzztestRun) ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) { 94 return nil, nil 95 } 96 97 func (t fuzztestRun) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { 98 return nil 99 } 100 101 func (t fuzztestRun) WarningCount(c *Conn) uint16 { 102 return 0 103 } 104 105 var _ Handler = (*fuzztestRun)(nil) 106 107 type fuzztestConn struct { 108 writeToPass []bool 109 pos int 110 queryPacket []byte 111 } 112 113 func (t fuzztestConn) Read(b []byte) (n int, err error) { 114 for i := 0; i < len(b) && i < len(t.queryPacket); i++ { 115 b[i] = t.queryPacket[i] 116 } 117 return len(b), nil 118 } 119 120 func (t fuzztestConn) Write(b []byte) (n int, err error) { 121 t.pos = t.pos + 1 122 if t.writeToPass[t.pos] { 123 return 0, nil 124 } 125 return 0, fmt.Errorf("error in writing to connection") 126 } 127 128 func (t fuzztestConn) Close() error { 129 panic("implement me") 130 } 131 132 func (t fuzztestConn) LocalAddr() net.Addr { 133 panic("implement me") 134 } 135 136 func (t fuzztestConn) RemoteAddr() net.Addr { 137 return fuzzmockAddress{s: "a"} 138 } 139 140 func (t fuzztestConn) SetDeadline(t1 time.Time) error { 141 panic("implement me") 142 } 143 144 func (t fuzztestConn) SetReadDeadline(t1 time.Time) error { 145 panic("implement me") 146 } 147 148 func (t fuzztestConn) SetWriteDeadline(t1 time.Time) error { 149 panic("implement me") 150 } 151 152 var _ net.Conn = (*fuzztestConn)(nil) 153 154 type fuzzmockAddress struct { 155 s string 156 } 157 158 func (m fuzzmockAddress) Network() string { 159 return m.s 160 } 161 162 func (m fuzzmockAddress) String() string { 163 return m.s 164 } 165 166 var _ net.Addr = (*fuzzmockAddress)(nil) 167 168 // Fuzzers begin here: 169 func FuzzWritePacket(data []byte) int { 170 if len(data) < 10 { 171 return -1 172 } 173 listener, sConn, cConn := createFuzzingSocketPair() 174 defer func() { 175 listener.Close() 176 sConn.Close() 177 cConn.Close() 178 }() 179 180 err := cConn.writePacket(data) 181 if err != nil { 182 return 0 183 } 184 _, err = sConn.ReadPacket() 185 if err != nil { 186 return 0 187 } 188 return 1 189 } 190 191 func FuzzHandleNextCommand(data []byte) int { 192 if len(data) < 10 { 193 return -1 194 } 195 sConn := newConn(fuzztestConn{ 196 writeToPass: []bool{false}, 197 pos: -1, 198 queryPacket: data, 199 }) 200 sConn.PrepareData = map[uint32]*PrepareData{} 201 202 handler := &fuzztestRun{} 203 _ = sConn.handleNextCommand(handler) 204 return 1 205 } 206 207 func FuzzReadQueryResults(data []byte) int { 208 listener, sConn, cConn := createFuzzingSocketPair() 209 defer func() { 210 listener.Close() 211 sConn.Close() 212 cConn.Close() 213 }() 214 err := cConn.WriteComQuery(string(data)) 215 if err != nil { 216 return 0 217 } 218 handler := &fuzztestRun{} 219 _ = sConn.handleNextCommand(handler) 220 _, _, _, err = cConn.ReadQueryResult(100, true) 221 if err != nil { 222 return 0 223 } 224 return 1 225 } 226 227 type fuzzTestHandler struct { 228 UnimplementedHandler 229 230 mu sync.Mutex 231 lastConn *Conn 232 result *sqltypes.Result 233 err error 234 warnings uint16 235 } 236 237 func (th *fuzzTestHandler) LastConn() *Conn { 238 th.mu.Lock() 239 defer th.mu.Unlock() 240 return th.lastConn 241 } 242 243 func (th *fuzzTestHandler) Result() *sqltypes.Result { 244 th.mu.Lock() 245 defer th.mu.Unlock() 246 return th.result 247 } 248 249 func (th *fuzzTestHandler) SetErr(err error) { 250 th.mu.Lock() 251 defer th.mu.Unlock() 252 th.err = err 253 } 254 255 func (th *fuzzTestHandler) Err() error { 256 th.mu.Lock() 257 defer th.mu.Unlock() 258 return th.err 259 } 260 261 func (th *fuzzTestHandler) SetWarnings(count uint16) { 262 th.mu.Lock() 263 defer th.mu.Unlock() 264 th.warnings = count 265 } 266 267 func (th *fuzzTestHandler) NewConnection(c *Conn) { 268 th.mu.Lock() 269 defer th.mu.Unlock() 270 th.lastConn = c 271 } 272 273 func (th *fuzzTestHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error { 274 275 return nil 276 } 277 278 func (th *fuzzTestHandler) ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) { 279 return nil, nil 280 } 281 282 func (th *fuzzTestHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { 283 return nil 284 } 285 286 func (th *fuzzTestHandler) ComResetConnection(c *Conn) { 287 288 } 289 290 func (th *fuzzTestHandler) WarningCount(c *Conn) uint16 { 291 th.mu.Lock() 292 defer th.mu.Unlock() 293 return th.warnings 294 } 295 296 func (c *Conn) writeFuzzedPacket(packet []byte) { 297 c.sequence = 0 298 data, pos := c.startEphemeralPacketWithHeader(len(packet) + 1) 299 copy(data[pos:], packet) 300 _ = c.writeEphemeralPacket() 301 } 302 303 func FuzzTLSServer(data []byte) int { 304 if len(data) < 40 { 305 return -1 306 } 307 // totalQueries is the number of queries the fuzzer 308 // makes in each fuzz iteration 309 totalQueries := 20 310 var queries [][]byte 311 c := gofuzzheaders.NewConsumer(data) 312 for i := 0; i < totalQueries; i++ { 313 query, err := c.GetBytes() 314 if err != nil { 315 return -1 316 } 317 if len(query) < 40 { 318 continue 319 } 320 queries = append(queries, query) 321 } 322 323 th := &fuzzTestHandler{} 324 325 authServer := NewAuthServerStatic("", "", 0) 326 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 327 Password: "password1", 328 }} 329 defer authServer.close() 330 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false) 331 if err != nil { 332 return -1 333 } 334 defer l.Close() 335 336 host := l.Addr().(*net.TCPAddr).IP.String() 337 port := l.Addr().(*net.TCPAddr).Port 338 root, err := os.MkdirTemp("", "TestTLSServer") 339 if err != nil { 340 return -1 341 } 342 defer os.RemoveAll(root) 343 tlstest.CreateCA(root) 344 tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") 345 tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") 346 347 serverConfig, err := vttls.ServerConfig( 348 path.Join(root, "server-cert.pem"), 349 path.Join(root, "server-key.pem"), 350 path.Join(root, "ca-cert.pem"), 351 "", 352 "", 353 tls.VersionTLS12) 354 if err != nil { 355 return -1 356 } 357 l.TLSConfig.Store(serverConfig) 358 go l.Accept() 359 360 connCountByTLSVer.ResetAll() 361 // Setup the right parameters. 362 params := &ConnParams{ 363 Host: host, 364 Port: port, 365 Uname: "user1", 366 Pass: "password1", 367 // SSL flags. 368 SslMode: vttls.VerifyIdentity, 369 SslCa: path.Join(root, "ca-cert.pem"), 370 SslCert: path.Join(root, "client-cert.pem"), 371 SslKey: path.Join(root, "client-key.pem"), 372 ServerName: "server.example.com", 373 } 374 conn, err := Connect(context.Background(), params) 375 if err != nil { 376 return -1 377 } 378 379 for i := 0; i < len(queries); i++ { 380 conn.writeFuzzedPacket(queries[i]) 381 } 382 return 1 383 }