github.com/blend/go-sdk@v1.20220411.3/proxyprotocol/dialer.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package proxyprotocol
     9  
    10  import (
    11  	"context"
    12  	"net"
    13  )
    14  
    15  // NewDialer returns a new proxy protocol dialer.
    16  func NewDialer(opts ...DialerOption) *Dialer {
    17  	d := &Dialer{
    18  		Dialer:         new(net.Dialer),
    19  		HeaderProvider: func(_ context.Context, _ net.Conn) *Header { return nil },
    20  	}
    21  	for _, opt := range opts {
    22  		opt(d)
    23  	}
    24  	return d
    25  }
    26  
    27  // OptDialerHeaderProvider sets the header provider.
    28  func OptDialerHeaderProvider(provider func(context.Context, net.Conn) *Header) DialerOption {
    29  	return func(d *Dialer) {
    30  		d.HeaderProvider = provider
    31  	}
    32  }
    33  
    34  // OptDialerConstSourceAdddr sets the header provider to be a constant source.
    35  func OptDialerConstSourceAdddr(addr net.Addr) DialerOption {
    36  	return func(d *Dialer) {
    37  		d.HeaderProvider = func(_ context.Context, conn net.Conn) *Header {
    38  			return &Header{
    39  				Version:           1,
    40  				Command:           ProtocolVersionAndCommandProxy,
    41  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
    42  				SourceAddr:        addr,
    43  				DestinationAddr:   conn.RemoteAddr(),
    44  			}
    45  		}
    46  	}
    47  }
    48  
    49  // DialerOption mutates a dialer.
    50  type DialerOption func(*Dialer)
    51  
    52  // Dialer wraps a dialer with proxy protocol header injection.
    53  type Dialer struct {
    54  	*net.Dialer
    55  	HeaderProvider func(context.Context, net.Conn) *Header
    56  }
    57  
    58  // Dial implements the dialer, calling `HeaderProvider` for a the context passed to it.
    59  func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
    60  	return d.DialContext(context.Background(), network, addr)
    61  }
    62  
    63  // DialContext implements the dialer, calling `HeaderProvider` for a the context passed to it.
    64  func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
    65  	conn, err := d.Dialer.DialContext(ctx, network, addr)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  
    70  	header := d.HeaderProvider(ctx, conn)
    71  	if header == nil {
    72  		return conn, nil
    73  	}
    74  	_, err = header.WriteTo(conn)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	return conn, nil
    79  }