github.com/Axway/agent-sdk@v1.1.101/pkg/util/dialer.go (about)

     1  package util
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"net"
     9  	"net/http"
    10  	"net/url"
    11  	"time"
    12  
    13  	"github.com/Axway/agent-sdk/pkg/util/log"
    14  	"golang.org/x/net/proxy"
    15  )
    16  
    17  const (
    18  	// DefaultKeepAliveInterval - default duration to send keep alive pings
    19  	DefaultKeepAliveInterval = 50 * time.Second
    20  	// DefaultKeepAliveTimeout - default keepalive timeout
    21  	DefaultKeepAliveTimeout = 10 * time.Second
    22  )
    23  
    24  // Dialer - interface for http dialer for proxy and single entry point
    25  type Dialer interface {
    26  	// Dial - interface used by libbeat for tcp network dial
    27  	Dial(network string, addr string) (net.Conn, error)
    28  	// DialContext - interface used by http transport
    29  	DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
    30  	// GetProxyScheme() string
    31  	GetProxyScheme() string
    32  }
    33  
    34  type dialer struct {
    35  	singleEntryHostMap map[string]string
    36  	proxyScheme        string
    37  	proxyAddress       string
    38  	userName           string
    39  	password           string
    40  }
    41  
    42  // NewDialer - creates a new dialer
    43  func NewDialer(proxyURL *url.URL, singleEntryHostMap map[string]string) Dialer {
    44  	dialer := &dialer{
    45  		singleEntryHostMap: singleEntryHostMap,
    46  	}
    47  	if proxyURL != nil {
    48  		dialer.proxyScheme = proxyURL.Scheme
    49  		dialer.proxyAddress = proxyURL.Host
    50  		if user := proxyURL.User; user != nil {
    51  			dialer.userName = user.Username()
    52  			dialer.password, _ = user.Password()
    53  		}
    54  	}
    55  	if dialer.singleEntryHostMap == nil {
    56  		dialer.singleEntryHostMap = map[string]string{}
    57  	}
    58  	return dialer
    59  }
    60  
    61  // Dial- manages the connections to proxy and single entry point for tcp transports
    62  func (d *dialer) Dial(network string, addr string) (net.Conn, error) {
    63  	conn, err := d.DialContext(context.Background(), network, addr)
    64  	if err == nil && len(d.singleEntryHostMap) > 0 && addr != conn.RemoteAddr().String() {
    65  		log.Tracef("routing the traffic for %s via %s", addr, conn.RemoteAddr().String())
    66  	}
    67  	return conn, err
    68  }
    69  
    70  // DialContext - manages the connections to proxy and single entry point
    71  func (d *dialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
    72  	originalAddr := addr
    73  	singleEntryHost, ok := d.singleEntryHostMap[addr]
    74  	if ok {
    75  		addr = singleEntryHost
    76  	}
    77  	if d.proxyAddress != "" {
    78  		switch d.proxyScheme {
    79  		case "socks5", "socks5h":
    80  			return d.socksConnect(network, originalAddr, singleEntryHost)
    81  		case "http", "https":
    82  		default:
    83  			return nil, fmt.Errorf("could not setup proxy, unsupported proxy scheme %s", d.proxyScheme)
    84  		}
    85  		addr = d.proxyAddress
    86  	}
    87  	conn, err := (&net.Dialer{
    88  		Timeout:   DefaultKeepAliveTimeout,
    89  		KeepAlive: DefaultKeepAliveInterval,
    90  		DualStack: true}).DialContext(ctx, network, addr)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  	if d.proxyAddress != "" {
    95  		switch d.proxyScheme {
    96  		case "http", "https":
    97  			err = d.httpConnect(ctx, conn, originalAddr, singleEntryHost)
    98  			if err != nil {
    99  				conn.Close()
   100  				return nil, err
   101  			}
   102  		}
   103  	}
   104  	return conn, nil
   105  }
   106  
   107  func (d *dialer) GetProxyScheme() string {
   108  	if d.proxyAddress != "" {
   109  		return d.proxyScheme
   110  	}
   111  	return ""
   112  }
   113  
   114  func (d *dialer) socksConnect(network, addr, singleEntryHost string) (net.Conn, error) {
   115  	var auth *proxy.Auth
   116  	if d.userName != "" {
   117  		auth = new(proxy.Auth)
   118  		auth.User = d.userName
   119  		if d.password != "" {
   120  			auth.Password = d.password
   121  		}
   122  	}
   123  	socksDialer, err := proxy.SOCKS5(network, d.proxyAddress, auth, nil)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  	targetAddr := addr
   128  	if singleEntryHost != "" {
   129  		targetAddr = singleEntryHost
   130  	}
   131  	return socksDialer.Dial(network, targetAddr)
   132  }
   133  
   134  func (d *dialer) httpConnect(ctx context.Context, conn net.Conn, targetAddr, sniHost string) error {
   135  	req := d.createConnectRequest(ctx, targetAddr, sniHost)
   136  	if err := req.Write(conn); err != nil {
   137  		return err
   138  	}
   139  
   140  	r := bufio.NewReader(conn)
   141  	resp, err := http.ReadResponse(r, req)
   142  	if err != nil {
   143  		return err
   144  	}
   145  	defer resp.Body.Close()
   146  	if resp.StatusCode != http.StatusOK {
   147  		return fmt.Errorf("failed to connect proxy, status : %s", resp.Status)
   148  	}
   149  	return nil
   150  }
   151  
   152  func (d *dialer) createConnectRequest(ctx context.Context, targetAddress, sniHost string) *http.Request {
   153  	req := &http.Request{
   154  		Method: http.MethodConnect,
   155  		URL:    &url.URL{Opaque: targetAddress},
   156  		Host:   targetAddress,
   157  	}
   158  	if sniHost != "" {
   159  		req.URL = &url.URL{Opaque: sniHost}
   160  	}
   161  
   162  	if d.userName != "" {
   163  		token := base64.StdEncoding.EncodeToString([]byte(d.userName + ":" + d.password))
   164  		req.Header = map[string][]string{
   165  			"Proxy-Authorization": {"Basic " + token},
   166  		}
   167  	}
   168  	return req.WithContext(ctx)
   169  }