go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/proxyproto/listener.go (about) 1 /* 2 3 Copyright (c) 2023 - Present. Will Charczuk. All rights reserved. 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository. 5 6 */ 7 8 package proxyproto 9 10 // Taken from https://github.com/armon/go-proxyproto 11 // The MIT License (MIT) 12 13 // Copyright (c) 2014 Armon Dadgar 14 15 // Permission is hereby granted, free of charge, to any person obtaining a copy 16 // of this software and associated documentation files (the "Software"), to deal 17 // in the Software without restriction, including without limitation the rights 18 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 19 // copies of the Software, and to permit persons to whom the Software is 20 // furnished to do so, subject to the following conditions: 21 22 // The above copyright notice and this permission notice shall be included in all 23 // copies or substantial portions of the Software. 24 25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 28 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 // SOFTWARE. 32 33 import ( 34 "bufio" 35 "bytes" 36 "errors" 37 "fmt" 38 "io" 39 "net" 40 "strconv" 41 "strings" 42 "sync" 43 "time" 44 ) 45 46 var ( 47 // prefix is the string we look for at the start of a connection 48 // to check if this connection is using the proxy protocol 49 prefix = []byte("PROXY ") 50 prefixLen = len(prefix) 51 52 // ErrInvalidUpstream is a common error. 53 ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information") 54 ) 55 56 // SourceChecker can be used to decide whether to trust the PROXY info or pass 57 // the original connection address through. If set, the connecting address is 58 // passed in as an argument. If the function returns an error due to the source 59 // being disallowed, it should return ErrInvalidUpstream. 60 // 61 // Behavior is as follows: 62 // * If error is not nil, the call to Accept() will fail. If the reason for 63 // triggering this failure is due to a disallowed source, it should return 64 // ErrInvalidUpstream. 65 // * If bool is true, the PROXY-set address is used. 66 // * If bool is false, the connection's remote address is used, rather than the 67 // address claimed in the PROXY info. 68 type SourceChecker func(net.Addr) (bool, error) 69 70 // Listener is used to wrap an underlying listener, 71 // whose connections may be using the HAProxy Proxy Protocol (version 1). 72 // If the connection is using the protocol, the RemoteAddr() will return 73 // the correct client address. 74 // 75 // Optionally define ProxyHeaderTimeout to set a maximum time to 76 // receive the Proxy Protocol Header. Zero means no timeout. 77 type Listener struct { 78 Listener net.Listener 79 ProxyHeaderTimeout time.Duration 80 SourceCheck SourceChecker 81 } 82 83 // Conn is used to wrap and underlying connection which 84 // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will 85 // return the address of the client instead of the proxy address. 86 type Conn struct { 87 bufReader *bufio.Reader 88 conn net.Conn 89 dstAddr *net.TCPAddr 90 srcAddr *net.TCPAddr 91 useConnRemoteAddr bool 92 once sync.Once 93 proxyHeaderTimeout time.Duration 94 } 95 96 // Accept waits for and returns the next connection to the listener. 97 func (p *Listener) Accept() (net.Conn, error) { 98 // Get the underlying connection 99 conn, err := p.Listener.Accept() 100 if err != nil { 101 return nil, err 102 } 103 var useConnRemoteAddr bool 104 if p.SourceCheck != nil { 105 allowed, err := p.SourceCheck(conn.RemoteAddr()) 106 if err != nil { 107 return nil, err 108 } 109 if !allowed { 110 useConnRemoteAddr = true 111 } 112 } 113 newConn := NewConn(conn, p.ProxyHeaderTimeout) 114 newConn.useConnRemoteAddr = useConnRemoteAddr 115 return newConn, nil 116 } 117 118 // Close closes the underlying listener. 119 func (p *Listener) Close() error { 120 return p.Listener.Close() 121 } 122 123 // Addr returns the underlying listener's network address. 124 func (p *Listener) Addr() net.Addr { 125 return p.Listener.Addr() 126 } 127 128 // NewConn is used to wrap a net.Conn that may be speaking 129 // the proxy protocol into a proxyproto.Conn 130 func NewConn(conn net.Conn, timeout time.Duration) *Conn { 131 pConn := &Conn{ 132 bufReader: bufio.NewReader(conn), 133 conn: conn, 134 proxyHeaderTimeout: timeout, 135 } 136 return pConn 137 } 138 139 // Read is check for the proxy protocol header when doing 140 // the initial scan. If there is an error parsing the header, 141 // it is returned and the socket is closed. 142 func (p *Conn) Read(b []byte) (int, error) { 143 var err error 144 p.once.Do(func() { err = p.checkPrefix() }) 145 if err != nil { 146 return 0, err 147 } 148 return p.bufReader.Read(b) 149 } 150 151 func (p *Conn) Write(b []byte) (int, error) { 152 return p.conn.Write(b) 153 } 154 155 // Close closes the underlying connection. 156 func (p *Conn) Close() error { 157 return p.conn.Close() 158 } 159 160 // LocalAddr returns the local address of the underlying connection. 161 func (p *Conn) LocalAddr() net.Addr { 162 return p.conn.LocalAddr() 163 } 164 165 // RemoteAddr returns the address of the client if the proxy 166 // protocol is being used, otherwise just returns the address of 167 // the socket peer. If there is an error parsing the header, the 168 // address of the client is not returned, and the socket is closed. 169 // Once implication of this is that the call could block if the 170 // client is slow. Using a Deadline is recommended if this is called 171 // before Read() 172 func (p *Conn) RemoteAddr() net.Addr { 173 p.once.Do(func() { 174 if err := p.checkPrefix(); err != nil && err != io.EOF { 175 p.Close() 176 p.bufReader = bufio.NewReader(p.conn) 177 } 178 }) 179 if p.srcAddr != nil && !p.useConnRemoteAddr { 180 return p.srcAddr 181 } 182 return p.conn.RemoteAddr() 183 } 184 185 // SetDeadline sets a field. 186 func (p *Conn) SetDeadline(t time.Time) error { 187 return p.conn.SetDeadline(t) 188 } 189 190 // SetReadDeadline reads a field. 191 func (p *Conn) SetReadDeadline(t time.Time) error { 192 return p.conn.SetReadDeadline(t) 193 } 194 195 // SetWriteDeadline sets a field. 196 func (p *Conn) SetWriteDeadline(t time.Time) error { 197 return p.conn.SetWriteDeadline(t) 198 } 199 200 func (p *Conn) checkPrefix() error { 201 if p.proxyHeaderTimeout != 0 { 202 readDeadLine := time.Now().Add(p.proxyHeaderTimeout) 203 _ = p.conn.SetReadDeadline(readDeadLine) 204 defer func() { _ = p.conn.SetReadDeadline(time.Time{}) }() 205 } 206 207 // Incrementally check each byte of the prefix 208 for i := 1; i <= prefixLen; i++ { 209 inp, err := p.bufReader.Peek(i) 210 211 if err != nil { 212 if neterr, ok := err.(net.Error); ok && neterr.Timeout() { 213 return nil 214 } 215 return err 216 } 217 218 // Check for a prefix mis-match, quit early 219 if !bytes.Equal(inp, prefix[:i]) { 220 return nil 221 } 222 } 223 224 // Read the header line 225 header, err := p.bufReader.ReadString('\n') 226 if err != nil { 227 p.conn.Close() 228 return err 229 } 230 231 // Strip the carriage return and new line 232 header = header[:len(header)-2] 233 234 // Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>) 235 parts := strings.Split(header, " ") 236 if len(parts) != 6 { 237 p.conn.Close() 238 return fmt.Errorf("invalid header line: %s", header) 239 } 240 241 // Verify the type is known 242 switch parts[1] { 243 case "TCP4": 244 case "TCP6": 245 default: 246 p.conn.Close() 247 return fmt.Errorf("unhandled address type: %s", parts[1]) 248 } 249 250 // Parse out the source address 251 ip := net.ParseIP(parts[2]) 252 if ip == nil { 253 p.conn.Close() 254 return fmt.Errorf("invalid source ip: %s", parts[2]) 255 } 256 port, err := strconv.Atoi(parts[4]) 257 if err != nil { 258 p.conn.Close() 259 return fmt.Errorf("invalid source port: %s", parts[4]) 260 } 261 p.srcAddr = &net.TCPAddr{IP: ip, Port: port} 262 263 // Parse out the destination address 264 ip = net.ParseIP(parts[3]) 265 if ip == nil { 266 p.conn.Close() 267 return fmt.Errorf("invalid destination ip: %s", parts[3]) 268 } 269 port, err = strconv.Atoi(parts[5]) 270 if err != nil { 271 p.conn.Close() 272 return fmt.Errorf("invalid destination port: %s", parts[5]) 273 } 274 p.dstAddr = &net.TCPAddr{IP: ip, Port: port} 275 276 return nil 277 }