vitess.io/vitess@v0.16.2/go/mysql/mysql_fuzzer.go (about)

     1  //go:build gofuzz
     2  // +build gofuzz
     3  
     4  /*
     5  Copyright 2021 The Vitess Authors.
     6  
     7  Licensed under the Apache License, Version 2.0 (the "License");
     8  you may not use this file except in compliance with the License.
     9  You may obtain a copy of the License at
    10  
    11      http://www.apache.org/licenses/LICENSE-2.0
    12  
    13  Unless required by applicable law or agreed to in writing, software
    14  distributed under the License is distributed on an "AS IS" BASIS,
    15  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    16  See the License for the specific language governing permissions and
    17  limitations under the License.
    18  */
    19  
    20  package mysql
    21  
    22  import (
    23  	"context"
    24  	"crypto/tls"
    25  	"fmt"
    26  	"net"
    27  	"os"
    28  	"path"
    29  	"sync"
    30  	"time"
    31  
    32  	gofuzzheaders "github.com/AdaLogics/go-fuzz-headers"
    33  
    34  	"vitess.io/vitess/go/sqltypes"
    35  	querypb "vitess.io/vitess/go/vt/proto/query"
    36  	"vitess.io/vitess/go/vt/tlstest"
    37  	"vitess.io/vitess/go/vt/vttls"
    38  )
    39  
    40  func createFuzzingSocketPair() (net.Listener, *Conn, *Conn) {
    41  	// Create a listener.
    42  	listener, err := net.Listen("tcp", "127.0.0.1:")
    43  	if err != nil {
    44  		fmt.Println("We got an error early on")
    45  		return nil, nil, nil
    46  	}
    47  	addr := listener.Addr().String()
    48  	listener.(*net.TCPListener).SetDeadline(time.Now().Add(10 * time.Second))
    49  
    50  	// Dial a client, Accept a server.
    51  	wg := sync.WaitGroup{}
    52  
    53  	var clientConn net.Conn
    54  	var clientErr error
    55  	wg.Add(1)
    56  	go func() {
    57  		defer wg.Done()
    58  		clientConn, clientErr = net.DialTimeout("tcp", addr, 10*time.Second)
    59  	}()
    60  
    61  	var serverConn net.Conn
    62  	var serverErr error
    63  	wg.Add(1)
    64  	go func() {
    65  		defer wg.Done()
    66  		serverConn, serverErr = listener.Accept()
    67  	}()
    68  
    69  	wg.Wait()
    70  
    71  	if clientErr != nil {
    72  		return nil, nil, nil
    73  	}
    74  	if serverErr != nil {
    75  		return nil, nil, nil
    76  	}
    77  
    78  	// Create a Conn on both sides.
    79  	cConn := newConn(clientConn)
    80  	sConn := newConn(serverConn)
    81  
    82  	return listener, sConn, cConn
    83  }
    84  
    85  type fuzztestRun struct {
    86  	UnimplementedHandler
    87  }
    88  
    89  func (t fuzztestRun) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error {
    90  	return nil
    91  }
    92  
    93  func (t fuzztestRun) ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
    94  	return nil, nil
    95  }
    96  
    97  func (t fuzztestRun) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error {
    98  	return nil
    99  }
   100  
   101  func (t fuzztestRun) WarningCount(c *Conn) uint16 {
   102  	return 0
   103  }
   104  
   105  var _ Handler = (*fuzztestRun)(nil)
   106  
   107  type fuzztestConn struct {
   108  	writeToPass []bool
   109  	pos         int
   110  	queryPacket []byte
   111  }
   112  
   113  func (t fuzztestConn) Read(b []byte) (n int, err error) {
   114  	for i := 0; i < len(b) && i < len(t.queryPacket); i++ {
   115  		b[i] = t.queryPacket[i]
   116  	}
   117  	return len(b), nil
   118  }
   119  
   120  func (t fuzztestConn) Write(b []byte) (n int, err error) {
   121  	t.pos = t.pos + 1
   122  	if t.writeToPass[t.pos] {
   123  		return 0, nil
   124  	}
   125  	return 0, fmt.Errorf("error in writing to connection")
   126  }
   127  
   128  func (t fuzztestConn) Close() error {
   129  	panic("implement me")
   130  }
   131  
   132  func (t fuzztestConn) LocalAddr() net.Addr {
   133  	panic("implement me")
   134  }
   135  
   136  func (t fuzztestConn) RemoteAddr() net.Addr {
   137  	return fuzzmockAddress{s: "a"}
   138  }
   139  
   140  func (t fuzztestConn) SetDeadline(t1 time.Time) error {
   141  	panic("implement me")
   142  }
   143  
   144  func (t fuzztestConn) SetReadDeadline(t1 time.Time) error {
   145  	panic("implement me")
   146  }
   147  
   148  func (t fuzztestConn) SetWriteDeadline(t1 time.Time) error {
   149  	panic("implement me")
   150  }
   151  
   152  var _ net.Conn = (*fuzztestConn)(nil)
   153  
   154  type fuzzmockAddress struct {
   155  	s string
   156  }
   157  
   158  func (m fuzzmockAddress) Network() string {
   159  	return m.s
   160  }
   161  
   162  func (m fuzzmockAddress) String() string {
   163  	return m.s
   164  }
   165  
   166  var _ net.Addr = (*fuzzmockAddress)(nil)
   167  
   168  // Fuzzers begin here:
   169  func FuzzWritePacket(data []byte) int {
   170  	if len(data) < 10 {
   171  		return -1
   172  	}
   173  	listener, sConn, cConn := createFuzzingSocketPair()
   174  	defer func() {
   175  		listener.Close()
   176  		sConn.Close()
   177  		cConn.Close()
   178  	}()
   179  
   180  	err := cConn.writePacket(data)
   181  	if err != nil {
   182  		return 0
   183  	}
   184  	_, err = sConn.ReadPacket()
   185  	if err != nil {
   186  		return 0
   187  	}
   188  	return 1
   189  }
   190  
   191  func FuzzHandleNextCommand(data []byte) int {
   192  	if len(data) < 10 {
   193  		return -1
   194  	}
   195  	sConn := newConn(fuzztestConn{
   196  		writeToPass: []bool{false},
   197  		pos:         -1,
   198  		queryPacket: data,
   199  	})
   200  	sConn.PrepareData = map[uint32]*PrepareData{}
   201  
   202  	handler := &fuzztestRun{}
   203  	_ = sConn.handleNextCommand(handler)
   204  	return 1
   205  }
   206  
   207  func FuzzReadQueryResults(data []byte) int {
   208  	listener, sConn, cConn := createFuzzingSocketPair()
   209  	defer func() {
   210  		listener.Close()
   211  		sConn.Close()
   212  		cConn.Close()
   213  	}()
   214  	err := cConn.WriteComQuery(string(data))
   215  	if err != nil {
   216  		return 0
   217  	}
   218  	handler := &fuzztestRun{}
   219  	_ = sConn.handleNextCommand(handler)
   220  	_, _, _, err = cConn.ReadQueryResult(100, true)
   221  	if err != nil {
   222  		return 0
   223  	}
   224  	return 1
   225  }
   226  
   227  type fuzzTestHandler struct {
   228  	UnimplementedHandler
   229  
   230  	mu       sync.Mutex
   231  	lastConn *Conn
   232  	result   *sqltypes.Result
   233  	err      error
   234  	warnings uint16
   235  }
   236  
   237  func (th *fuzzTestHandler) LastConn() *Conn {
   238  	th.mu.Lock()
   239  	defer th.mu.Unlock()
   240  	return th.lastConn
   241  }
   242  
   243  func (th *fuzzTestHandler) Result() *sqltypes.Result {
   244  	th.mu.Lock()
   245  	defer th.mu.Unlock()
   246  	return th.result
   247  }
   248  
   249  func (th *fuzzTestHandler) SetErr(err error) {
   250  	th.mu.Lock()
   251  	defer th.mu.Unlock()
   252  	th.err = err
   253  }
   254  
   255  func (th *fuzzTestHandler) Err() error {
   256  	th.mu.Lock()
   257  	defer th.mu.Unlock()
   258  	return th.err
   259  }
   260  
   261  func (th *fuzzTestHandler) SetWarnings(count uint16) {
   262  	th.mu.Lock()
   263  	defer th.mu.Unlock()
   264  	th.warnings = count
   265  }
   266  
   267  func (th *fuzzTestHandler) NewConnection(c *Conn) {
   268  	th.mu.Lock()
   269  	defer th.mu.Unlock()
   270  	th.lastConn = c
   271  }
   272  
   273  func (th *fuzzTestHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error {
   274  
   275  	return nil
   276  }
   277  
   278  func (th *fuzzTestHandler) ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
   279  	return nil, nil
   280  }
   281  
   282  func (th *fuzzTestHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error {
   283  	return nil
   284  }
   285  
   286  func (th *fuzzTestHandler) ComResetConnection(c *Conn) {
   287  
   288  }
   289  
   290  func (th *fuzzTestHandler) WarningCount(c *Conn) uint16 {
   291  	th.mu.Lock()
   292  	defer th.mu.Unlock()
   293  	return th.warnings
   294  }
   295  
   296  func (c *Conn) writeFuzzedPacket(packet []byte) {
   297  	c.sequence = 0
   298  	data, pos := c.startEphemeralPacketWithHeader(len(packet) + 1)
   299  	copy(data[pos:], packet)
   300  	_ = c.writeEphemeralPacket()
   301  }
   302  
   303  func FuzzTLSServer(data []byte) int {
   304  	if len(data) < 40 {
   305  		return -1
   306  	}
   307  	// totalQueries is the number of queries the fuzzer
   308  	// makes in each fuzz iteration
   309  	totalQueries := 20
   310  	var queries [][]byte
   311  	c := gofuzzheaders.NewConsumer(data)
   312  	for i := 0; i < totalQueries; i++ {
   313  		query, err := c.GetBytes()
   314  		if err != nil {
   315  			return -1
   316  		}
   317  		if len(query) < 40 {
   318  			continue
   319  		}
   320  		queries = append(queries, query)
   321  	}
   322  
   323  	th := &fuzzTestHandler{}
   324  
   325  	authServer := NewAuthServerStatic("", "", 0)
   326  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   327  		Password: "password1",
   328  	}}
   329  	defer authServer.close()
   330  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false)
   331  	if err != nil {
   332  		return -1
   333  	}
   334  	defer l.Close()
   335  
   336  	host := l.Addr().(*net.TCPAddr).IP.String()
   337  	port := l.Addr().(*net.TCPAddr).Port
   338  	root, err := os.MkdirTemp("", "TestTLSServer")
   339  	if err != nil {
   340  		return -1
   341  	}
   342  	defer os.RemoveAll(root)
   343  	tlstest.CreateCA(root)
   344  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
   345  	tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
   346  
   347  	serverConfig, err := vttls.ServerConfig(
   348  		path.Join(root, "server-cert.pem"),
   349  		path.Join(root, "server-key.pem"),
   350  		path.Join(root, "ca-cert.pem"),
   351  		"",
   352  		"",
   353  		tls.VersionTLS12)
   354  	if err != nil {
   355  		return -1
   356  	}
   357  	l.TLSConfig.Store(serverConfig)
   358  	go l.Accept()
   359  
   360  	connCountByTLSVer.ResetAll()
   361  	// Setup the right parameters.
   362  	params := &ConnParams{
   363  		Host:  host,
   364  		Port:  port,
   365  		Uname: "user1",
   366  		Pass:  "password1",
   367  		// SSL flags.
   368  		SslMode:    vttls.VerifyIdentity,
   369  		SslCa:      path.Join(root, "ca-cert.pem"),
   370  		SslCert:    path.Join(root, "client-cert.pem"),
   371  		SslKey:     path.Join(root, "client-key.pem"),
   372  		ServerName: "server.example.com",
   373  	}
   374  	conn, err := Connect(context.Background(), params)
   375  	if err != nil {
   376  		return -1
   377  	}
   378  
   379  	for i := 0; i < len(queries); i++ {
   380  		conn.writeFuzzedPacket(queries[i])
   381  	}
   382  	return 1
   383  }