github.com/dolthub/go-mysql-server@v0.18.0/server/listener.go (about) 1 // Copyright 2020-2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package server 16 17 import ( 18 "errors" 19 "fmt" 20 "net" 21 "runtime" 22 "sync" 23 "syscall" 24 25 "golang.org/x/sync/errgroup" 26 ) 27 28 var UnixSocketInUseError = errors.New("bind address at given unix socket path is already in use") 29 30 // connRes represents a connection made to a listener and an error result 31 type connRes struct { 32 conn net.Conn 33 err error 34 } 35 36 // Listener implements a single listener with two net.Listener, 37 // one for TCP socket and another for unix socket connections. 38 type Listener struct { 39 // netListener is a tcp socket listener 40 netListener net.Listener 41 // unixListener is a unix socket listener 42 unixListener net.Listener 43 eg *errgroup.Group 44 // channel to receive connections on either listener 45 conns chan connRes 46 // channel to close both listener 47 shutdown chan struct{} 48 once *sync.Once 49 } 50 51 // NewListener creates a new Listener. 52 // 'protocol' takes "tcp" and 'address' takes "host:port" information for TCP socket connection. 53 // For unix socket connection, 'unixSocketPath' takes a path for the unix socket file. 54 // If 'unixSocketPath' is empty, no need to create the second listener. 55 func NewListener(protocol, address string, unixSocketPath string) (*Listener, error) { 56 netl, err := newNetListener(protocol, address) 57 if err != nil { 58 return nil, err 59 } 60 61 var unixl net.Listener 62 var unixSocketInUse error 63 if unixSocketPath != "" { 64 if runtime.GOOS == "windows" { 65 return nil, fmt.Errorf("unable to create unix socket listener on Windows") 66 } 67 unixListener, err := net.ListenUnix("unix", &net.UnixAddr{Name: unixSocketPath, Net: "unix"}) 68 if err == nil { 69 unixl = unixListener 70 } else if errors.Is(err, syscall.EADDRINUSE) { 71 // we continue if error is unix socket bind address is already in use 72 // we return UnixSocketInUseError error to track the error back to where server gets started and add warning 73 unixSocketInUse = UnixSocketInUseError 74 } else { 75 return nil, err 76 } 77 } 78 79 l := &Listener{ 80 netListener: netl, 81 unixListener: unixl, 82 conns: make(chan connRes), 83 eg: new(errgroup.Group), 84 shutdown: make(chan struct{}), 85 once: &sync.Once{}, 86 } 87 l.eg.Go(func() error { 88 for { 89 conn, err := l.netListener.Accept() 90 // connection can be closed already from the other goroutine 91 if errors.Is(err, net.ErrClosed) { 92 return nil 93 } 94 95 select { 96 case <-l.shutdown: 97 conn.Close() 98 return nil 99 case l.conns <- connRes{conn, err}: 100 } 101 } 102 }) 103 104 if l.unixListener != nil { 105 l.eg.Go(func() error { 106 for { 107 conn, err := l.unixListener.Accept() 108 // connection can be closed already from the other goroutine 109 if errors.Is(err, net.ErrClosed) { 110 return nil 111 } 112 113 select { 114 case <-l.shutdown: 115 conn.Close() 116 return nil 117 case l.conns <- connRes{conn, err}: 118 } 119 } 120 }) 121 } 122 123 return l, unixSocketInUse 124 } 125 126 func (l *Listener) Accept() (net.Conn, error) { 127 cr, ok := <-l.conns 128 if !ok { 129 return nil, net.ErrClosed 130 } 131 return cr.conn, cr.err 132 } 133 134 func (l *Listener) Close() error { 135 err := l.netListener.Close() 136 if err != nil && !errors.Is(err, net.ErrClosed) { 137 return err 138 } 139 if l.unixListener != nil { 140 err = l.unixListener.Close() 141 if err != nil && !errors.Is(err, net.ErrClosed) { 142 return err 143 } 144 } 145 l.once.Do(func() { 146 close(l.shutdown) 147 close(l.conns) 148 }) 149 return l.eg.Wait() 150 } 151 152 func (l *Listener) Addr() net.Addr { 153 return l.netListener.Addr() 154 }