github.com/cloud-foundations/dominator@v0.0.0-20221004181915-6e4fee580046/lib/srpc/clientProxy.go (about)

     1  package srpc
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"net/url"
     7  	"time"
     8  
     9  	"github.com/Cloud-Foundations/Dominator/lib/errors"
    10  	"github.com/Cloud-Foundations/Dominator/lib/net/proxy"
    11  	proto "github.com/Cloud-Foundations/Dominator/proto/proxy"
    12  )
    13  
    14  var (
    15  	errorUnsupportedTransport = errors.New("unsupported transport")
    16  	errorNotImplemented       = errors.New("not implemented")
    17  )
    18  
    19  type fakeAddress struct{}
    20  
    21  type proxyConn struct {
    22  	client *Client
    23  	conn   *Conn
    24  }
    25  
    26  type proxyDialer struct {
    27  	dialer       *net.Dialer
    28  	proxyAddress string
    29  }
    30  
    31  func newProxyDialer(proxyURL string, dialer *net.Dialer) (Dialer, error) {
    32  	if proxyURL == "" {
    33  		return dialer, nil
    34  	}
    35  	if parsedProxy, err := url.Parse(proxyURL); err != nil {
    36  		return nil, err
    37  	} else {
    38  		switch parsedProxy.Scheme {
    39  		case "srpc":
    40  			return &proxyDialer{
    41  				dialer:       dialer,
    42  				proxyAddress: parsedProxy.Host,
    43  			}, nil
    44  		default:
    45  			return proxy.NewDialer(proxyURL, dialer)
    46  		}
    47  	}
    48  }
    49  
    50  func (fakeAddress) Network() string {
    51  	return "tcp"
    52  }
    53  
    54  func (fakeAddress) String() string {
    55  	return "not-implemented"
    56  }
    57  
    58  func (d *proxyDialer) Dial(network, address string) (net.Conn, error) {
    59  	switch network {
    60  	case "tcp":
    61  		return d.dialTCP(address)
    62  	case "udp":
    63  	}
    64  	return nil, errorUnsupportedTransport
    65  }
    66  
    67  func (d *proxyDialer) dialTCP(address string) (net.Conn, error) {
    68  	client, err := dialHTTP("tcp", d.proxyAddress, clientTlsConfig, d.dialer)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	defer func() {
    73  		if client != nil {
    74  			client.Close()
    75  		}
    76  	}()
    77  	conn, err := client.Call("Proxy.Connect")
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	defer func() {
    82  		if conn != nil {
    83  			conn.Close()
    84  		}
    85  	}()
    86  	err = conn.Encode(proto.ConnectRequest{
    87  		Address: address,
    88  		Network: "tcp",
    89  		Timeout: d.dialer.Timeout,
    90  	})
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  	if err := conn.Flush(); err != nil {
    95  		return nil, err
    96  	}
    97  	var response proto.ConnectResponse
    98  	if err := conn.Decode(&response); err != nil {
    99  		return nil, fmt.Errorf("error decoding: %s", err)
   100  	}
   101  	if err := errors.New(response.Error); err != nil {
   102  		return nil, err
   103  	}
   104  	proxiedConn := proxyConn{
   105  		client: client,
   106  		conn:   conn,
   107  	}
   108  	client = nil
   109  	conn = nil
   110  	return &proxiedConn, nil
   111  }
   112  
   113  func (pc *proxyConn) Close() error {
   114  	err1 := pc.conn.Close()
   115  	err2 := pc.client.Close()
   116  	if err1 != nil {
   117  		return err1
   118  	}
   119  	return err2
   120  }
   121  
   122  func (pc *proxyConn) LocalAddr() net.Addr {
   123  	return fakeAddress{}
   124  }
   125  
   126  func (pc *proxyConn) Read(b []byte) (int, error) {
   127  	return pc.conn.Read(b)
   128  }
   129  
   130  func (pc *proxyConn) RemoteAddr() net.Addr {
   131  	return fakeAddress{}
   132  }
   133  
   134  func (pc *proxyConn) SetDeadline(t time.Time) error {
   135  	return errorNotImplemented
   136  }
   137  
   138  func (pc *proxyConn) SetReadDeadline(t time.Time) error {
   139  	return errorNotImplemented
   140  }
   141  
   142  func (pc *proxyConn) SetWriteDeadline(t time.Time) error {
   143  	return errorNotImplemented
   144  }
   145  
   146  func (pc *proxyConn) Write(b []byte) (int, error) {
   147  	if nWritten, err := pc.conn.Write(b); err != nil {
   148  		pc.conn.Flush()
   149  		return nWritten, err
   150  	} else {
   151  		return nWritten, pc.conn.Flush()
   152  	}
   153  }