github.com/blend/go-sdk@v1.20220411.3/proxyprotocol/proxy_protocol.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package proxyprotocol 9 10 // Taken from https://github.com/armon/go-proxyproto 11 // The MIT License (MIT) 12 13 // Copyright (c) 2014 Armon Dadgar 14 15 // Permission is hereby granted, free of charge, to any person obtaining a copy 16 // of this software and associated documentation files (the "Software"), to deal 17 // in the Software without restriction, including without limitation the rights 18 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 19 // copies of the Software, and to permit persons to whom the Software is 20 // furnished to do so, subject to the following conditions: 21 22 // The above copyright notice and this permission notice shall be included in all 23 // copies or substantial portions of the Software. 24 25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 28 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 // SOFTWARE. 32 33 import ( 34 "bufio" 35 "bytes" 36 "fmt" 37 "io" 38 "log" 39 "net" 40 "strconv" 41 "strings" 42 "sync" 43 "time" 44 45 "github.com/blend/go-sdk/ex" 46 ) 47 48 var ( 49 // prefix is the string we look for at the start of a connection 50 // to check if this connection is using the proxy protocol 51 prefix = []byte("PROXY ") 52 prefixLen = len(prefix) 53 54 // ErrInvalidUpstream is a common error. 55 ErrInvalidUpstream ex.Class = "upstream connection address not trusted for PROXY information" 56 ) 57 58 // SourceChecker can be used to decide whether to trust the PROXY info or pass 59 // the original connection address through. If set, the connecting address is 60 // passed in as an argument. If the function returns an error due to the source 61 // being disallowed, it should return ErrInvalidUpstream. 62 // 63 // Behavior is as follows: 64 // * If error is not nil, the call to Accept() will fail. If the reason for 65 // triggering this failure is due to a disallowed source, it should return 66 // ErrInvalidUpstream. 67 // * If bool is true, the PROXY-set address is used. 68 // * If bool is false, the connection's remote address is used, rather than the 69 // address claimed in the PROXY info. 70 type SourceChecker func(net.Addr) (bool, error) 71 72 // Listener is used to wrap an underlying listener, 73 // whose connections may be using the HAProxy Proxy Protocol (version 1). 74 // If the connection is using the protocol, the RemoteAddr() will return 75 // the correct client address. 76 // 77 // Optionally define ProxyHeaderTimeout to set a maximum time to 78 // receive the Proxy Protocol Header. Zero means no timeout. 79 type Listener struct { 80 Listener net.Listener 81 ProxyHeaderTimeout time.Duration 82 SourceCheck SourceChecker 83 } 84 85 // Conn is used to wrap and underlying connection which 86 // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will 87 // return the address of the client instead of the proxy address. 88 type Conn struct { 89 bufReader *bufio.Reader 90 conn net.Conn 91 dstAddr *net.TCPAddr 92 srcAddr *net.TCPAddr 93 useConnRemoteAddr bool 94 once sync.Once 95 proxyHeaderTimeout time.Duration 96 } 97 98 // Accept waits for and returns the next connection to the listener. 99 func (p *Listener) Accept() (net.Conn, error) { 100 // Get the underlying connection 101 conn, err := p.Listener.Accept() 102 if err != nil { 103 return nil, err 104 } 105 var useConnRemoteAddr bool 106 if p.SourceCheck != nil { 107 allowed, err := p.SourceCheck(conn.RemoteAddr()) 108 if err != nil { 109 return nil, err 110 } 111 if !allowed { 112 useConnRemoteAddr = true 113 } 114 } 115 newConn := NewConn(conn, p.ProxyHeaderTimeout) 116 newConn.useConnRemoteAddr = useConnRemoteAddr 117 return newConn, nil 118 } 119 120 // Close closes the underlying listener. 121 func (p *Listener) Close() error { 122 return p.Listener.Close() 123 } 124 125 // Addr returns the underlying listener's network address. 126 func (p *Listener) Addr() net.Addr { 127 return p.Listener.Addr() 128 } 129 130 // NewConn is used to wrap a net.Conn that may be speaking 131 // the proxy protocol into a proxyproto.Conn 132 func NewConn(conn net.Conn, timeout time.Duration) *Conn { 133 pConn := &Conn{ 134 bufReader: bufio.NewReader(conn), 135 conn: conn, 136 proxyHeaderTimeout: timeout, 137 } 138 return pConn 139 } 140 141 // Read is check for the proxy protocol header when doing 142 // the initial scan. If there is an error parsing the header, 143 // it is returned and the socket is closed. 144 func (p *Conn) Read(b []byte) (int, error) { 145 var err error 146 p.once.Do(func() { err = p.checkPrefix() }) 147 if err != nil { 148 return 0, err 149 } 150 return p.bufReader.Read(b) 151 } 152 153 func (p *Conn) Write(b []byte) (int, error) { 154 return p.conn.Write(b) 155 } 156 157 // Close closes the underlying connection. 158 func (p *Conn) Close() error { 159 return p.conn.Close() 160 } 161 162 // LocalAddr returns the local address of the underlying connection. 163 func (p *Conn) LocalAddr() net.Addr { 164 return p.conn.LocalAddr() 165 } 166 167 // RemoteAddr returns the address of the client if the proxy 168 // protocol is being used, otherwise just returns the address of 169 // the socket peer. If there is an error parsing the header, the 170 // address of the client is not returned, and the socket is closed. 171 // Once implication of this is that the call could block if the 172 // client is slow. Using a Deadline is recommended if this is called 173 // before Read() 174 func (p *Conn) RemoteAddr() net.Addr { 175 p.once.Do(func() { 176 if err := p.checkPrefix(); err != nil && err != io.EOF { 177 log.Printf("[ERR] Failed to read proxy prefix: %v", err) 178 p.Close() 179 p.bufReader = bufio.NewReader(p.conn) 180 } 181 }) 182 if p.srcAddr != nil && !p.useConnRemoteAddr { 183 return p.srcAddr 184 } 185 return p.conn.RemoteAddr() 186 } 187 188 // SetDeadline sets a field. 189 func (p *Conn) SetDeadline(t time.Time) error { 190 return p.conn.SetDeadline(t) 191 } 192 193 // SetReadDeadline reads a field. 194 func (p *Conn) SetReadDeadline(t time.Time) error { 195 return p.conn.SetReadDeadline(t) 196 } 197 198 // SetWriteDeadline sets a field. 199 func (p *Conn) SetWriteDeadline(t time.Time) error { 200 return p.conn.SetWriteDeadline(t) 201 } 202 203 func (p *Conn) checkPrefix() error { 204 if p.proxyHeaderTimeout != 0 { 205 readDeadLine := time.Now().Add(p.proxyHeaderTimeout) 206 _ = p.conn.SetReadDeadline(readDeadLine) 207 defer func() { _ = p.conn.SetReadDeadline(time.Time{}) }() 208 } 209 210 // Incrementally check each byte of the prefix 211 for i := 1; i <= prefixLen; i++ { 212 inp, err := p.bufReader.Peek(i) 213 214 if err != nil { 215 if neterr, ok := err.(net.Error); ok && neterr.Timeout() { 216 return nil 217 } 218 return err 219 } 220 221 // Check for a prefix mis-match, quit early 222 if !bytes.Equal(inp, prefix[:i]) { 223 return nil 224 } 225 } 226 227 // Read the header line 228 header, err := p.bufReader.ReadString('\n') 229 if err != nil { 230 p.conn.Close() 231 return err 232 } 233 234 // Strip the carriage return and new line 235 header = header[:len(header)-2] 236 237 // Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>) 238 parts := strings.Split(header, " ") 239 if len(parts) != 6 { 240 p.conn.Close() 241 return fmt.Errorf("invalid header line: %s", header) 242 } 243 244 // Verify the type is known 245 switch parts[1] { 246 case "TCP4": 247 case "TCP6": 248 default: 249 p.conn.Close() 250 return fmt.Errorf("unhandled address type: %s", parts[1]) 251 } 252 253 // Parse out the source address 254 ip := net.ParseIP(parts[2]) 255 if ip == nil { 256 p.conn.Close() 257 return fmt.Errorf("invalid source ip: %s", parts[2]) 258 } 259 port, err := strconv.Atoi(parts[4]) 260 if err != nil { 261 p.conn.Close() 262 return fmt.Errorf("invalid source port: %s", parts[4]) 263 } 264 p.srcAddr = &net.TCPAddr{IP: ip, Port: port} 265 266 // Parse out the destination address 267 ip = net.ParseIP(parts[3]) 268 if ip == nil { 269 p.conn.Close() 270 return fmt.Errorf("invalid destination ip: %s", parts[3]) 271 } 272 port, err = strconv.Atoi(parts[5]) 273 if err != nil { 274 p.conn.Close() 275 return fmt.Errorf("invalid destination port: %s", parts[5]) 276 } 277 p.dstAddr = &net.TCPAddr{IP: ip, Port: port} 278 279 return nil 280 }