github.com/blend/go-sdk@v1.20220411.3/proxyprotocol/dialer.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package proxyprotocol 9 10 import ( 11 "context" 12 "net" 13 ) 14 15 // NewDialer returns a new proxy protocol dialer. 16 func NewDialer(opts ...DialerOption) *Dialer { 17 d := &Dialer{ 18 Dialer: new(net.Dialer), 19 HeaderProvider: func(_ context.Context, _ net.Conn) *Header { return nil }, 20 } 21 for _, opt := range opts { 22 opt(d) 23 } 24 return d 25 } 26 27 // OptDialerHeaderProvider sets the header provider. 28 func OptDialerHeaderProvider(provider func(context.Context, net.Conn) *Header) DialerOption { 29 return func(d *Dialer) { 30 d.HeaderProvider = provider 31 } 32 } 33 34 // OptDialerConstSourceAdddr sets the header provider to be a constant source. 35 func OptDialerConstSourceAdddr(addr net.Addr) DialerOption { 36 return func(d *Dialer) { 37 d.HeaderProvider = func(_ context.Context, conn net.Conn) *Header { 38 return &Header{ 39 Version: 1, 40 Command: ProtocolVersionAndCommandProxy, 41 TransportProtocol: AddressFamilyAndProtocolTCPv4, 42 SourceAddr: addr, 43 DestinationAddr: conn.RemoteAddr(), 44 } 45 } 46 } 47 } 48 49 // DialerOption mutates a dialer. 50 type DialerOption func(*Dialer) 51 52 // Dialer wraps a dialer with proxy protocol header injection. 53 type Dialer struct { 54 *net.Dialer 55 HeaderProvider func(context.Context, net.Conn) *Header 56 } 57 58 // Dial implements the dialer, calling `HeaderProvider` for a the context passed to it. 59 func (d *Dialer) Dial(network, addr string) (net.Conn, error) { 60 return d.DialContext(context.Background(), network, addr) 61 } 62 63 // DialContext implements the dialer, calling `HeaderProvider` for a the context passed to it. 64 func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { 65 conn, err := d.Dialer.DialContext(ctx, network, addr) 66 if err != nil { 67 return nil, err 68 } 69 70 header := d.HeaderProvider(ctx, conn) 71 if header == nil { 72 return conn, nil 73 } 74 _, err = header.WriteTo(conn) 75 if err != nil { 76 return nil, err 77 } 78 return conn, nil 79 }