go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/proxyproto/listener.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package proxyproto
     9  
    10  // Taken from https://github.com/armon/go-proxyproto
    11  // The MIT License (MIT)
    12  
    13  // Copyright (c) 2014 Armon Dadgar
    14  
    15  // Permission is hereby granted, free of charge, to any person obtaining a copy
    16  // of this software and associated documentation files (the "Software"), to deal
    17  // in the Software without restriction, including without limitation the rights
    18  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    19  // copies of the Software, and to permit persons to whom the Software is
    20  // furnished to do so, subject to the following conditions:
    21  
    22  // The above copyright notice and this permission notice shall be included in all
    23  // copies or substantial portions of the Software.
    24  
    25  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    26  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    27  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    28  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    29  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    30  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    31  // SOFTWARE.
    32  
    33  import (
    34  	"bufio"
    35  	"bytes"
    36  	"errors"
    37  	"fmt"
    38  	"io"
    39  	"net"
    40  	"strconv"
    41  	"strings"
    42  	"sync"
    43  	"time"
    44  )
    45  
    46  var (
    47  	// prefix is the string we look for at the start of a connection
    48  	// to check if this connection is using the proxy protocol
    49  	prefix    = []byte("PROXY ")
    50  	prefixLen = len(prefix)
    51  
    52  	// ErrInvalidUpstream is a common error.
    53  	ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information")
    54  )
    55  
    56  // SourceChecker can be used to decide whether to trust the PROXY info or pass
    57  // the original connection address through. If set, the connecting address is
    58  // passed in as an argument. If the function returns an error due to the source
    59  // being disallowed, it should return ErrInvalidUpstream.
    60  //
    61  // Behavior is as follows:
    62  // * If error is not nil, the call to Accept() will fail. If the reason for
    63  // triggering this failure is due to a disallowed source, it should return
    64  // ErrInvalidUpstream.
    65  // * If bool is true, the PROXY-set address is used.
    66  // * If bool is false, the connection's remote address is used, rather than the
    67  // address claimed in the PROXY info.
    68  type SourceChecker func(net.Addr) (bool, error)
    69  
    70  // Listener is used to wrap an underlying listener,
    71  // whose connections may be using the HAProxy Proxy Protocol (version 1).
    72  // If the connection is using the protocol, the RemoteAddr() will return
    73  // the correct client address.
    74  //
    75  // Optionally define ProxyHeaderTimeout to set a maximum time to
    76  // receive the Proxy Protocol Header. Zero means no timeout.
    77  type Listener struct {
    78  	Listener           net.Listener
    79  	ProxyHeaderTimeout time.Duration
    80  	SourceCheck        SourceChecker
    81  }
    82  
    83  // Conn is used to wrap and underlying connection which
    84  // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
    85  // return the address of the client instead of the proxy address.
    86  type Conn struct {
    87  	bufReader          *bufio.Reader
    88  	conn               net.Conn
    89  	dstAddr            *net.TCPAddr
    90  	srcAddr            *net.TCPAddr
    91  	useConnRemoteAddr  bool
    92  	once               sync.Once
    93  	proxyHeaderTimeout time.Duration
    94  }
    95  
    96  // Accept waits for and returns the next connection to the listener.
    97  func (p *Listener) Accept() (net.Conn, error) {
    98  	// Get the underlying connection
    99  	conn, err := p.Listener.Accept()
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	var useConnRemoteAddr bool
   104  	if p.SourceCheck != nil {
   105  		allowed, err := p.SourceCheck(conn.RemoteAddr())
   106  		if err != nil {
   107  			return nil, err
   108  		}
   109  		if !allowed {
   110  			useConnRemoteAddr = true
   111  		}
   112  	}
   113  	newConn := NewConn(conn, p.ProxyHeaderTimeout)
   114  	newConn.useConnRemoteAddr = useConnRemoteAddr
   115  	return newConn, nil
   116  }
   117  
   118  // Close closes the underlying listener.
   119  func (p *Listener) Close() error {
   120  	return p.Listener.Close()
   121  }
   122  
   123  // Addr returns the underlying listener's network address.
   124  func (p *Listener) Addr() net.Addr {
   125  	return p.Listener.Addr()
   126  }
   127  
   128  // NewConn is used to wrap a net.Conn that may be speaking
   129  // the proxy protocol into a proxyproto.Conn
   130  func NewConn(conn net.Conn, timeout time.Duration) *Conn {
   131  	pConn := &Conn{
   132  		bufReader:          bufio.NewReader(conn),
   133  		conn:               conn,
   134  		proxyHeaderTimeout: timeout,
   135  	}
   136  	return pConn
   137  }
   138  
   139  // Read is check for the proxy protocol header when doing
   140  // the initial scan. If there is an error parsing the header,
   141  // it is returned and the socket is closed.
   142  func (p *Conn) Read(b []byte) (int, error) {
   143  	var err error
   144  	p.once.Do(func() { err = p.checkPrefix() })
   145  	if err != nil {
   146  		return 0, err
   147  	}
   148  	return p.bufReader.Read(b)
   149  }
   150  
   151  func (p *Conn) Write(b []byte) (int, error) {
   152  	return p.conn.Write(b)
   153  }
   154  
   155  // Close closes the underlying connection.
   156  func (p *Conn) Close() error {
   157  	return p.conn.Close()
   158  }
   159  
   160  // LocalAddr returns the local address of the underlying connection.
   161  func (p *Conn) LocalAddr() net.Addr {
   162  	return p.conn.LocalAddr()
   163  }
   164  
   165  // RemoteAddr returns the address of the client if the proxy
   166  // protocol is being used, otherwise just returns the address of
   167  // the socket peer. If there is an error parsing the header, the
   168  // address of the client is not returned, and the socket is closed.
   169  // Once implication of this is that the call could block if the
   170  // client is slow. Using a Deadline is recommended if this is called
   171  // before Read()
   172  func (p *Conn) RemoteAddr() net.Addr {
   173  	p.once.Do(func() {
   174  		if err := p.checkPrefix(); err != nil && err != io.EOF {
   175  			p.Close()
   176  			p.bufReader = bufio.NewReader(p.conn)
   177  		}
   178  	})
   179  	if p.srcAddr != nil && !p.useConnRemoteAddr {
   180  		return p.srcAddr
   181  	}
   182  	return p.conn.RemoteAddr()
   183  }
   184  
   185  // SetDeadline sets a field.
   186  func (p *Conn) SetDeadline(t time.Time) error {
   187  	return p.conn.SetDeadline(t)
   188  }
   189  
   190  // SetReadDeadline reads a field.
   191  func (p *Conn) SetReadDeadline(t time.Time) error {
   192  	return p.conn.SetReadDeadline(t)
   193  }
   194  
   195  // SetWriteDeadline sets a field.
   196  func (p *Conn) SetWriteDeadline(t time.Time) error {
   197  	return p.conn.SetWriteDeadline(t)
   198  }
   199  
   200  func (p *Conn) checkPrefix() error {
   201  	if p.proxyHeaderTimeout != 0 {
   202  		readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
   203  		_ = p.conn.SetReadDeadline(readDeadLine)
   204  		defer func() { _ = p.conn.SetReadDeadline(time.Time{}) }()
   205  	}
   206  
   207  	// Incrementally check each byte of the prefix
   208  	for i := 1; i <= prefixLen; i++ {
   209  		inp, err := p.bufReader.Peek(i)
   210  
   211  		if err != nil {
   212  			if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
   213  				return nil
   214  			}
   215  			return err
   216  		}
   217  
   218  		// Check for a prefix mis-match, quit early
   219  		if !bytes.Equal(inp, prefix[:i]) {
   220  			return nil
   221  		}
   222  	}
   223  
   224  	// Read the header line
   225  	header, err := p.bufReader.ReadString('\n')
   226  	if err != nil {
   227  		p.conn.Close()
   228  		return err
   229  	}
   230  
   231  	// Strip the carriage return and new line
   232  	header = header[:len(header)-2]
   233  
   234  	// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
   235  	parts := strings.Split(header, " ")
   236  	if len(parts) != 6 {
   237  		p.conn.Close()
   238  		return fmt.Errorf("invalid header line: %s", header)
   239  	}
   240  
   241  	// Verify the type is known
   242  	switch parts[1] {
   243  	case "TCP4":
   244  	case "TCP6":
   245  	default:
   246  		p.conn.Close()
   247  		return fmt.Errorf("unhandled address type: %s", parts[1])
   248  	}
   249  
   250  	// Parse out the source address
   251  	ip := net.ParseIP(parts[2])
   252  	if ip == nil {
   253  		p.conn.Close()
   254  		return fmt.Errorf("invalid source ip: %s", parts[2])
   255  	}
   256  	port, err := strconv.Atoi(parts[4])
   257  	if err != nil {
   258  		p.conn.Close()
   259  		return fmt.Errorf("invalid source port: %s", parts[4])
   260  	}
   261  	p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
   262  
   263  	// Parse out the destination address
   264  	ip = net.ParseIP(parts[3])
   265  	if ip == nil {
   266  		p.conn.Close()
   267  		return fmt.Errorf("invalid destination ip: %s", parts[3])
   268  	}
   269  	port, err = strconv.Atoi(parts[5])
   270  	if err != nil {
   271  		p.conn.Close()
   272  		return fmt.Errorf("invalid destination port: %s", parts[5])
   273  	}
   274  	p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
   275  
   276  	return nil
   277  }