github.com/Cloud-Foundations/Dominator@v0.3.4/lib/srpc/client.go (about)

     1  package srpc
     2  
     3  import (
     4  	"bufio"
     5  	"crypto/tls"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"os"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	libnet "github.com/Cloud-Foundations/Dominator/lib/net"
    16  	"github.com/Cloud-Foundations/tricorder/go/tricorder"
    17  	"github.com/Cloud-Foundations/tricorder/go/tricorder/units"
    18  )
    19  
    20  type endpointType struct {
    21  	coderMaker coderMaker
    22  	path       string
    23  	tls        bool
    24  }
    25  
    26  var (
    27  	attemptTransportUpgrade   = true // Changed by tests.
    28  	clientMetricsDir          *tricorder.DirectorySpec
    29  	clientMetricsMutex        sync.Mutex
    30  	numInUseClientConnections uint64
    31  	numOpenClientConnections  uint64
    32  )
    33  
    34  func init() {
    35  	registerClientMetrics()
    36  }
    37  
    38  func registerClientMetrics() {
    39  	var err error
    40  	clientMetricsDir, err = tricorder.RegisterDirectory("srpc/client")
    41  	if err != nil {
    42  		panic(err)
    43  	}
    44  	err = clientMetricsDir.RegisterMetric("num-in-use-connections",
    45  		&numInUseClientConnections, units.None,
    46  		"number of connections in use")
    47  	if err != nil {
    48  		panic(err)
    49  	}
    50  	err = clientMetricsDir.RegisterMetric("num-open-connections",
    51  		&numOpenClientConnections, units.None, "number of open connections")
    52  	if err != nil {
    53  		panic(err)
    54  	}
    55  }
    56  
    57  func dial(network, address string, dialer Dialer) (net.Conn, error) {
    58  	hostPort := strings.SplitN(address, ":", 2)
    59  	address = strings.SplitN(hostPort[0], "*", 2)[0] + ":" + hostPort[1]
    60  	conn, err := dialer.Dial(network, address)
    61  	if err != nil {
    62  		if strings.Contains(err.Error(), ErrorConnectionRefused.Error()) {
    63  			return nil, ErrorConnectionRefused
    64  		}
    65  		if strings.Contains(err.Error(), ErrorNoRouteToHost.Error()) {
    66  			return nil, ErrorNoRouteToHost
    67  		}
    68  		return nil, err
    69  	}
    70  	if tcpConn, ok := conn.(libnet.TCPConn); ok {
    71  		if err := tcpConn.SetKeepAlive(true); err != nil {
    72  			conn.Close()
    73  			return nil, err
    74  		}
    75  		if err := tcpConn.SetKeepAlivePeriod(time.Minute * 5); err != nil {
    76  			conn.Close()
    77  			return nil, err
    78  		}
    79  	}
    80  	return conn, nil
    81  }
    82  
    83  func dialHTTP(network, address string, tlsConfig *tls.Config,
    84  	dialer Dialer) (*Client, error) {
    85  	if *srpcProxy == "" {
    86  		return dialHTTPDirect(network, address, tlsConfig, dialer)
    87  	}
    88  	var err error
    89  	if d, ok := dialer.(*net.Dialer); ok {
    90  		dialer, err = newProxyDialer(*srpcProxy, d)
    91  	} else {
    92  		dialer, err = newProxyDialer(*srpcProxy, &net.Dialer{})
    93  	}
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  	return dialHTTPDirect(network, address, tlsConfig, dialer)
    98  }
    99  
   100  func dialHTTPDirect(network, address string, tlsConfig *tls.Config,
   101  	dialer Dialer) (*Client, error) {
   102  	insecureEndpoints := []endpointType{
   103  		{&gobCoder{}, rpcPath, false},
   104  		{&jsonCoder{}, jsonRpcPath, false},
   105  	}
   106  	secureEndpoints := []endpointType{
   107  		{&gobCoder{}, tlsRpcPath, true},
   108  		{&jsonCoder{}, jsonTlsRpcPath, true},
   109  	}
   110  	if tlsConfig == nil {
   111  		return dialHTTPEndpoints(network, address, nil, false, dialer,
   112  			insecureEndpoints)
   113  	} else {
   114  		var endpoints []endpointType
   115  		endpoints = append(endpoints, secureEndpoints...)
   116  		if tlsConfig.InsecureSkipVerify { // Don't have to trust server.
   117  			endpoints = append(endpoints, insecureEndpoints...)
   118  		}
   119  		client, err := dialHTTPEndpoints(network, address, tlsConfig, false,
   120  			dialer, endpoints)
   121  		if err != nil &&
   122  			strings.Contains(err.Error(), "malformed HTTP response") {
   123  			// The server may do TLS on all connections: try that.
   124  			return dialHTTPEndpoints(network, address, tlsConfig, true, dialer,
   125  				secureEndpoints)
   126  		}
   127  		return client, err
   128  	}
   129  }
   130  
   131  func dialHTTPEndpoint(network, address string, tlsConfig *tls.Config,
   132  	fullTLS bool, dialer Dialer, endpoint endpointType) (*Client, error) {
   133  	unsecuredConn, err := dial(network, address, dialer)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	dataConn := unsecuredConn
   138  	doClose := true
   139  	defer func() {
   140  		if doClose {
   141  			dataConn.Close()
   142  		}
   143  	}()
   144  	if fullTLS {
   145  		tlsConn := tls.Client(unsecuredConn, tlsConfig)
   146  		if err := tlsConn.Handshake(); err != nil {
   147  			if strings.Contains(err.Error(), ErrorBadCertificate.Error()) {
   148  				return nil, ErrorBadCertificate
   149  			}
   150  			return nil, err
   151  		}
   152  		dataConn = tlsConn
   153  	}
   154  	if err := doHTTPConnect(dataConn, endpoint.path); err != nil {
   155  		return nil, err
   156  	}
   157  	if endpoint.tls && !fullTLS {
   158  		tlsConn := tls.Client(unsecuredConn, tlsConfig)
   159  		if err := tlsConn.Handshake(); err != nil {
   160  			if strings.Contains(err.Error(), ErrorBadCertificate.Error()) {
   161  				return nil, ErrorBadCertificate
   162  			}
   163  			return nil, err
   164  		}
   165  		dataConn = tlsConn
   166  	}
   167  	doClose = false
   168  	return newClient(unsecuredConn, dataConn, endpoint.tls, endpoint.coderMaker)
   169  }
   170  
   171  func dialHTTPEndpoints(network, address string, tlsConfig *tls.Config,
   172  	fullTLS bool, dialer Dialer, endpoints []endpointType) (*Client, error) {
   173  	for _, endpoint := range endpoints {
   174  		client, err := dialHTTPEndpoint(network, address, tlsConfig, fullTLS,
   175  			dialer, endpoint)
   176  		if err == nil {
   177  			return client, nil
   178  		}
   179  		if err != ErrorNoSrpcEndpoint {
   180  			return nil, err
   181  		}
   182  	}
   183  	return nil, ErrorNoSrpcEndpoint
   184  }
   185  
   186  func doHTTPConnect(conn net.Conn, path string) error {
   187  	var query string
   188  	if *srpcClientDoNotUseMethodPowers {
   189  		query = "?" + doNotUseMethodPowers + "=true"
   190  	}
   191  	io.WriteString(conn, "CONNECT "+path+query+" HTTP/1.0\n\n")
   192  	// Require successful HTTP response before switching to SRPC protocol.
   193  	resp, err := http.ReadResponse(bufio.NewReader(conn),
   194  		&http.Request{Method: "CONNECT"})
   195  	if err != nil {
   196  		return err
   197  	}
   198  	if resp.StatusCode == http.StatusNotFound {
   199  		return ErrorNoSrpcEndpoint
   200  	}
   201  	if resp.StatusCode == http.StatusUnauthorized {
   202  		return ErrorBadCertificate
   203  	}
   204  	if resp.StatusCode == http.StatusMethodNotAllowed {
   205  		return ErrorMissingCertificate
   206  	}
   207  	if resp.StatusCode != http.StatusOK || resp.Status != connectString {
   208  		return errors.New("unexpected HTTP response: " + resp.Status)
   209  	}
   210  	return nil
   211  }
   212  
   213  func getEarliestClientCertExpiration() time.Time {
   214  	var earliest time.Time
   215  	if clientTlsConfig == nil {
   216  		return earliest
   217  	}
   218  	for _, cert := range clientTlsConfig.Certificates {
   219  		if cert.Leaf != nil && !cert.Leaf.NotAfter.IsZero() {
   220  			if earliest.IsZero() {
   221  				earliest = cert.Leaf.NotAfter
   222  			} else if cert.Leaf.NotAfter.Before(earliest) {
   223  				earliest = cert.Leaf.NotAfter
   224  			}
   225  		}
   226  	}
   227  	return earliest
   228  }
   229  
   230  func newClient(rawConn, dataConn net.Conn, isEncrypted bool,
   231  	makeCoder coderMaker) (*Client, error) {
   232  	clientMetricsMutex.Lock()
   233  	numOpenClientConnections++
   234  	clientMetricsMutex.Unlock()
   235  	client := &Client{
   236  		bufrw: bufio.NewReadWriter(bufio.NewReader(dataConn),
   237  			bufio.NewWriter(dataConn)),
   238  		conn:        dataConn,
   239  		connType:    "unknown",
   240  		localAddr:   rawConn.LocalAddr().String(),
   241  		isEncrypted: isEncrypted,
   242  		makeCoder:   makeCoder,
   243  		remoteAddr:  rawConn.RemoteAddr().String(),
   244  	}
   245  	if tcpConn, ok := rawConn.(libnet.TCPConn); ok {
   246  		client.tcpConn = tcpConn
   247  		client.connType = "TCP"
   248  	}
   249  	if isEncrypted {
   250  		client.connType += "/TLS"
   251  	}
   252  	if attemptTransportUpgrade && *srpcProxy == "" {
   253  		oldBufrw := client.bufrw
   254  		if _, err := client.localAttemptUpgradeToUnix(); err != nil {
   255  			client.Close()
   256  			return nil, err
   257  		}
   258  		if client.conn != dataConn && client.bufrw == oldBufrw {
   259  			logger.Debugf(0,
   260  				"transport type: %s did not replace buffer, fixing\n",
   261  				client.connType)
   262  			client.bufrw = bufio.NewReadWriter(bufio.NewReader(client.conn),
   263  				bufio.NewWriter(client.conn))
   264  		}
   265  	}
   266  	logger.Debugf(0, "made %s connection to: %s\n",
   267  		client.connType, client.remoteAddr)
   268  	return client, nil
   269  }
   270  
   271  func newFakeClient(options FakeClientOptions) *Client {
   272  	return &Client{fakeClientOptions: &options}
   273  }
   274  
   275  func (client *Client) call(serviceMethod string) (*Conn, error) {
   276  	if client.conn == nil {
   277  		panic("cannot call Client after Close()")
   278  	}
   279  	if client.resource != nil && !client.resource.inUse {
   280  		panic("cannot call Client after Close() or Put()")
   281  	}
   282  	client.callLock.Lock()
   283  	conn, err := client.callWithLock(serviceMethod)
   284  	if err != nil {
   285  		client.callLock.Unlock()
   286  	}
   287  	return conn, err
   288  }
   289  
   290  func (client *Client) callWithLock(serviceMethod string) (*Conn, error) {
   291  	_, err := client.bufrw.WriteString(serviceMethod + "\n")
   292  	if err != nil {
   293  		return nil, err
   294  	}
   295  	if err = client.bufrw.Flush(); err != nil {
   296  		return nil, err
   297  	}
   298  	resp, err := client.bufrw.ReadString('\n')
   299  	if err != nil {
   300  		return nil, err
   301  	}
   302  	if resp != "\n" {
   303  		resp := resp[:len(resp)-1]
   304  		if resp == ErrorAccessToMethodDenied.Error() {
   305  			return nil, ErrorAccessToMethodDenied
   306  		}
   307  		return nil, errors.New(resp)
   308  	}
   309  	conn := &Conn{
   310  		Decoder:     client.makeCoder.MakeDecoder(client.bufrw),
   311  		Encoder:     client.makeCoder.MakeEncoder(client.bufrw),
   312  		parent:      client,
   313  		isEncrypted: client.isEncrypted,
   314  		ReadWriter:  client.bufrw,
   315  	}
   316  	return conn, nil
   317  }
   318  
   319  func (client *Client) close() error {
   320  	if client.fakeClientOptions != nil {
   321  		return nil
   322  	}
   323  	if client.conn == nil {
   324  		return os.ErrClosed
   325  	}
   326  	client.bufrw.Flush()
   327  	if client.resource == nil {
   328  		clientMetricsMutex.Lock()
   329  		numOpenClientConnections--
   330  		clientMetricsMutex.Unlock()
   331  		conn := client.conn
   332  		client.conn = nil
   333  		return conn.Close()
   334  	}
   335  	client.resource.resource.Release()
   336  	client.conn = nil
   337  	clientMetricsMutex.Lock()
   338  	if client.resource.inUse {
   339  		numInUseClientConnections--
   340  		client.resource.inUse = false
   341  	}
   342  	numOpenClientConnections--
   343  	clientMetricsMutex.Unlock()
   344  	return client.resource.closeError
   345  }
   346  
   347  func (client *Client) ping() error {
   348  	conn, err := client.call("")
   349  	if err != nil {
   350  		return err
   351  	}
   352  	conn.Close()
   353  	return nil
   354  }
   355  
   356  func (client *Client) requestReply(serviceMethod string, request interface{},
   357  	reply interface{}) error {
   358  	conn, err := client.Call(serviceMethod)
   359  	if err != nil {
   360  		return err
   361  	}
   362  	defer conn.Close()
   363  	return conn.requestReply(request, reply)
   364  }
   365  
   366  func (conn *Conn) requestReply(request interface{}, reply interface{}) error {
   367  	if err := conn.Encode(request); err != nil {
   368  		return err
   369  	}
   370  	if err := conn.Flush(); err != nil {
   371  		return err
   372  	}
   373  	str, err := conn.ReadString('\n')
   374  	if err != nil {
   375  		return err
   376  	}
   377  	if str != "\n" {
   378  		return errors.New(str[:len(str)-1])
   379  	}
   380  	return conn.Decode(reply)
   381  }
   382  
   383  func (client *Client) setKeepAlive(keepalive bool) error {
   384  	if client.tcpConn == nil {
   385  		return nil
   386  	}
   387  	return client.tcpConn.SetKeepAlive(keepalive)
   388  }
   389  
   390  func (client *Client) setKeepAlivePeriod(d time.Duration) error {
   391  	if client.tcpConn == nil {
   392  		return nil
   393  	}
   394  	return client.tcpConn.SetKeepAlivePeriod(d)
   395  }