github.com/blend/go-sdk@v1.20220411.3/proxyprotocol/proxy_protocol.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package proxyprotocol
     9  
    10  // Taken from https://github.com/armon/go-proxyproto
    11  // The MIT License (MIT)
    12  
    13  // Copyright (c) 2014 Armon Dadgar
    14  
    15  // Permission is hereby granted, free of charge, to any person obtaining a copy
    16  // of this software and associated documentation files (the "Software"), to deal
    17  // in the Software without restriction, including without limitation the rights
    18  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    19  // copies of the Software, and to permit persons to whom the Software is
    20  // furnished to do so, subject to the following conditions:
    21  
    22  // The above copyright notice and this permission notice shall be included in all
    23  // copies or substantial portions of the Software.
    24  
    25  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    26  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    27  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    28  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    29  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    30  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    31  // SOFTWARE.
    32  
    33  import (
    34  	"bufio"
    35  	"bytes"
    36  	"fmt"
    37  	"io"
    38  	"log"
    39  	"net"
    40  	"strconv"
    41  	"strings"
    42  	"sync"
    43  	"time"
    44  
    45  	"github.com/blend/go-sdk/ex"
    46  )
    47  
    48  var (
    49  	// prefix is the string we look for at the start of a connection
    50  	// to check if this connection is using the proxy protocol
    51  	prefix    = []byte("PROXY ")
    52  	prefixLen = len(prefix)
    53  
    54  	// ErrInvalidUpstream is a common error.
    55  	ErrInvalidUpstream ex.Class = "upstream connection address not trusted for PROXY information"
    56  )
    57  
    58  // SourceChecker can be used to decide whether to trust the PROXY info or pass
    59  // the original connection address through. If set, the connecting address is
    60  // passed in as an argument. If the function returns an error due to the source
    61  // being disallowed, it should return ErrInvalidUpstream.
    62  //
    63  // Behavior is as follows:
    64  // * If error is not nil, the call to Accept() will fail. If the reason for
    65  // triggering this failure is due to a disallowed source, it should return
    66  // ErrInvalidUpstream.
    67  // * If bool is true, the PROXY-set address is used.
    68  // * If bool is false, the connection's remote address is used, rather than the
    69  // address claimed in the PROXY info.
    70  type SourceChecker func(net.Addr) (bool, error)
    71  
    72  // Listener is used to wrap an underlying listener,
    73  // whose connections may be using the HAProxy Proxy Protocol (version 1).
    74  // If the connection is using the protocol, the RemoteAddr() will return
    75  // the correct client address.
    76  //
    77  // Optionally define ProxyHeaderTimeout to set a maximum time to
    78  // receive the Proxy Protocol Header. Zero means no timeout.
    79  type Listener struct {
    80  	Listener           net.Listener
    81  	ProxyHeaderTimeout time.Duration
    82  	SourceCheck        SourceChecker
    83  }
    84  
    85  // Conn is used to wrap and underlying connection which
    86  // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
    87  // return the address of the client instead of the proxy address.
    88  type Conn struct {
    89  	bufReader          *bufio.Reader
    90  	conn               net.Conn
    91  	dstAddr            *net.TCPAddr
    92  	srcAddr            *net.TCPAddr
    93  	useConnRemoteAddr  bool
    94  	once               sync.Once
    95  	proxyHeaderTimeout time.Duration
    96  }
    97  
    98  // Accept waits for and returns the next connection to the listener.
    99  func (p *Listener) Accept() (net.Conn, error) {
   100  	// Get the underlying connection
   101  	conn, err := p.Listener.Accept()
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	var useConnRemoteAddr bool
   106  	if p.SourceCheck != nil {
   107  		allowed, err := p.SourceCheck(conn.RemoteAddr())
   108  		if err != nil {
   109  			return nil, err
   110  		}
   111  		if !allowed {
   112  			useConnRemoteAddr = true
   113  		}
   114  	}
   115  	newConn := NewConn(conn, p.ProxyHeaderTimeout)
   116  	newConn.useConnRemoteAddr = useConnRemoteAddr
   117  	return newConn, nil
   118  }
   119  
   120  // Close closes the underlying listener.
   121  func (p *Listener) Close() error {
   122  	return p.Listener.Close()
   123  }
   124  
   125  // Addr returns the underlying listener's network address.
   126  func (p *Listener) Addr() net.Addr {
   127  	return p.Listener.Addr()
   128  }
   129  
   130  // NewConn is used to wrap a net.Conn that may be speaking
   131  // the proxy protocol into a proxyproto.Conn
   132  func NewConn(conn net.Conn, timeout time.Duration) *Conn {
   133  	pConn := &Conn{
   134  		bufReader:          bufio.NewReader(conn),
   135  		conn:               conn,
   136  		proxyHeaderTimeout: timeout,
   137  	}
   138  	return pConn
   139  }
   140  
   141  // Read is check for the proxy protocol header when doing
   142  // the initial scan. If there is an error parsing the header,
   143  // it is returned and the socket is closed.
   144  func (p *Conn) Read(b []byte) (int, error) {
   145  	var err error
   146  	p.once.Do(func() { err = p.checkPrefix() })
   147  	if err != nil {
   148  		return 0, err
   149  	}
   150  	return p.bufReader.Read(b)
   151  }
   152  
   153  func (p *Conn) Write(b []byte) (int, error) {
   154  	return p.conn.Write(b)
   155  }
   156  
   157  // Close closes the underlying connection.
   158  func (p *Conn) Close() error {
   159  	return p.conn.Close()
   160  }
   161  
   162  // LocalAddr returns the local address of the underlying connection.
   163  func (p *Conn) LocalAddr() net.Addr {
   164  	return p.conn.LocalAddr()
   165  }
   166  
   167  // RemoteAddr returns the address of the client if the proxy
   168  // protocol is being used, otherwise just returns the address of
   169  // the socket peer. If there is an error parsing the header, the
   170  // address of the client is not returned, and the socket is closed.
   171  // Once implication of this is that the call could block if the
   172  // client is slow. Using a Deadline is recommended if this is called
   173  // before Read()
   174  func (p *Conn) RemoteAddr() net.Addr {
   175  	p.once.Do(func() {
   176  		if err := p.checkPrefix(); err != nil && err != io.EOF {
   177  			log.Printf("[ERR] Failed to read proxy prefix: %v", err)
   178  			p.Close()
   179  			p.bufReader = bufio.NewReader(p.conn)
   180  		}
   181  	})
   182  	if p.srcAddr != nil && !p.useConnRemoteAddr {
   183  		return p.srcAddr
   184  	}
   185  	return p.conn.RemoteAddr()
   186  }
   187  
   188  // SetDeadline sets a field.
   189  func (p *Conn) SetDeadline(t time.Time) error {
   190  	return p.conn.SetDeadline(t)
   191  }
   192  
   193  // SetReadDeadline reads a field.
   194  func (p *Conn) SetReadDeadline(t time.Time) error {
   195  	return p.conn.SetReadDeadline(t)
   196  }
   197  
   198  // SetWriteDeadline sets a field.
   199  func (p *Conn) SetWriteDeadline(t time.Time) error {
   200  	return p.conn.SetWriteDeadline(t)
   201  }
   202  
   203  func (p *Conn) checkPrefix() error {
   204  	if p.proxyHeaderTimeout != 0 {
   205  		readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
   206  		_ = p.conn.SetReadDeadline(readDeadLine)
   207  		defer func() { _ = p.conn.SetReadDeadline(time.Time{}) }()
   208  	}
   209  
   210  	// Incrementally check each byte of the prefix
   211  	for i := 1; i <= prefixLen; i++ {
   212  		inp, err := p.bufReader.Peek(i)
   213  
   214  		if err != nil {
   215  			if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
   216  				return nil
   217  			}
   218  			return err
   219  		}
   220  
   221  		// Check for a prefix mis-match, quit early
   222  		if !bytes.Equal(inp, prefix[:i]) {
   223  			return nil
   224  		}
   225  	}
   226  
   227  	// Read the header line
   228  	header, err := p.bufReader.ReadString('\n')
   229  	if err != nil {
   230  		p.conn.Close()
   231  		return err
   232  	}
   233  
   234  	// Strip the carriage return and new line
   235  	header = header[:len(header)-2]
   236  
   237  	// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
   238  	parts := strings.Split(header, " ")
   239  	if len(parts) != 6 {
   240  		p.conn.Close()
   241  		return fmt.Errorf("invalid header line: %s", header)
   242  	}
   243  
   244  	// Verify the type is known
   245  	switch parts[1] {
   246  	case "TCP4":
   247  	case "TCP6":
   248  	default:
   249  		p.conn.Close()
   250  		return fmt.Errorf("unhandled address type: %s", parts[1])
   251  	}
   252  
   253  	// Parse out the source address
   254  	ip := net.ParseIP(parts[2])
   255  	if ip == nil {
   256  		p.conn.Close()
   257  		return fmt.Errorf("invalid source ip: %s", parts[2])
   258  	}
   259  	port, err := strconv.Atoi(parts[4])
   260  	if err != nil {
   261  		p.conn.Close()
   262  		return fmt.Errorf("invalid source port: %s", parts[4])
   263  	}
   264  	p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
   265  
   266  	// Parse out the destination address
   267  	ip = net.ParseIP(parts[3])
   268  	if ip == nil {
   269  		p.conn.Close()
   270  		return fmt.Errorf("invalid destination ip: %s", parts[3])
   271  	}
   272  	port, err = strconv.Atoi(parts[5])
   273  	if err != nil {
   274  		p.conn.Close()
   275  		return fmt.Errorf("invalid destination port: %s", parts[5])
   276  	}
   277  	p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
   278  
   279  	return nil
   280  }