github.com/cloud-foundations/dominator@v0.0.0-20221004181915-6e4fee580046/lib/srpc/server.go (about)

     1  package srpc
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"errors"
     9  	"io"
    10  	"log"
    11  	"net/http"
    12  	"path/filepath"
    13  	"reflect"
    14  	"sort"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  
    19  	"github.com/Cloud-Foundations/Dominator/lib/net"
    20  	"github.com/Cloud-Foundations/Dominator/lib/x509util"
    21  	"github.com/Cloud-Foundations/tricorder/go/tricorder"
    22  	"github.com/Cloud-Foundations/tricorder/go/tricorder/units"
    23  )
    24  
    25  const (
    26  	connectString   = "200 Connected to Go SRPC"
    27  	rpcPath         = "/_goSRPC_/"      // Legacy endpoint. GOB coder.
    28  	tlsRpcPath      = "/_go_TLS_SRPC_/" // Legacy endpoint. GOB coder.
    29  	jsonRpcPath     = "/_SRPC_/unsecured/JSON"
    30  	jsonTlsRpcPath  = "/_SRPC_/TLS/JSON"
    31  	listMethodsPath = rpcPath + "listMethods"
    32  
    33  	methodTypeRaw = iota
    34  	methodTypeCoder
    35  	methodTypeRequestReply
    36  )
    37  
    38  type methodWrapper struct {
    39  	methodType                    int
    40  	public                        bool
    41  	fn                            reflect.Value
    42  	requestType                   reflect.Type
    43  	responseType                  reflect.Type
    44  	failedCallsDistribution       *tricorder.CumulativeDistribution
    45  	failedRRCallsDistribution     *tricorder.CumulativeDistribution
    46  	numDeniedCalls                uint64
    47  	numPermittedCalls             uint64
    48  	successfulCallsDistribution   *tricorder.CumulativeDistribution
    49  	successfulRRCallsDistribution *tricorder.CumulativeDistribution
    50  }
    51  
    52  type receiverType struct {
    53  	methods     map[string]*methodWrapper
    54  	blockMethod func(methodName string,
    55  		authInfo *AuthInformation) (func(), error)
    56  	grantMethod func(serviceMethod string, authInfo *AuthInformation) bool
    57  }
    58  
    59  var (
    60  	defaultGrantMethod = func(serviceMethod string,
    61  		authInfo *AuthInformation) bool {
    62  		return false
    63  	}
    64  	receivers                    map[string]receiverType = make(map[string]receiverType)
    65  	serverMetricsDir             *tricorder.DirectorySpec
    66  	bucketer                     *tricorder.Bucketer
    67  	serverMetricsMutex           sync.Mutex
    68  	numServerConnections         uint64
    69  	numOpenServerConnections     uint64
    70  	numRejectedServerConnections uint64
    71  )
    72  
    73  // Precompute some reflect types. Can't use the types directly because Typeof
    74  // takes an empty interface value. This is annoying.
    75  var typeOfConn = reflect.TypeOf((**Conn)(nil)).Elem()
    76  var typeOfDecoder = reflect.TypeOf((*Decoder)(nil)).Elem()
    77  var typeOfEncoder = reflect.TypeOf((*Encoder)(nil)).Elem()
    78  var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
    79  
    80  func init() {
    81  	http.HandleFunc(rpcPath, gobUnsecuredHttpHandler)
    82  	http.HandleFunc(tlsRpcPath, gobTlsHttpHandler)
    83  	http.HandleFunc(jsonRpcPath, jsonUnsecuredHttpHandler)
    84  	http.HandleFunc(jsonTlsRpcPath, jsonTlsHttpHandler)
    85  	http.HandleFunc(listMethodsPath, listMethodsHttpHandler)
    86  	registerServerMetrics()
    87  }
    88  
    89  func registerServerMetrics() {
    90  	var err error
    91  	serverMetricsDir, err = tricorder.RegisterDirectory("srpc/server")
    92  	if err != nil {
    93  		panic(err)
    94  	}
    95  	err = serverMetricsDir.RegisterMetric("num-connections",
    96  		&numServerConnections, units.None, "number of connection attempts")
    97  	if err != nil {
    98  		panic(err)
    99  	}
   100  	err = serverMetricsDir.RegisterMetric("num-open-connections",
   101  		&numOpenServerConnections, units.None, "number of open connections")
   102  	if err != nil {
   103  		panic(err)
   104  	}
   105  	err = serverMetricsDir.RegisterMetric("num-rejected-connections",
   106  		&numRejectedServerConnections, units.None,
   107  		"number of rejected connections")
   108  	if err != nil {
   109  		panic(err)
   110  	}
   111  	bucketer = tricorder.NewGeometricBucketer(0.1, 1e5)
   112  }
   113  
   114  func defaultMethodBlocker(methodName string,
   115  	authInfo *AuthInformation) (func(), error) {
   116  	return nil, nil
   117  }
   118  
   119  func defaultMethodGranter(serviceMethod string,
   120  	authInfo *AuthInformation) bool {
   121  	return defaultGrantMethod(serviceMethod, authInfo)
   122  }
   123  
   124  func registerName(name string, rcvr interface{},
   125  	options ReceiverOptions) error {
   126  	receiver := receiverType{methods: make(map[string]*methodWrapper)}
   127  	typeOfReceiver := reflect.TypeOf(rcvr)
   128  	valueOfReceiver := reflect.ValueOf(rcvr)
   129  	receiverMetricsDir, err := serverMetricsDir.RegisterDirectory(name)
   130  	if err != nil {
   131  		return err
   132  	}
   133  	publicMethods := make(map[string]struct{}, len(options.PublicMethods))
   134  	for _, methodName := range options.PublicMethods {
   135  		publicMethods[methodName] = struct{}{}
   136  	}
   137  	for index := 0; index < typeOfReceiver.NumMethod(); index++ {
   138  		method := typeOfReceiver.Method(index)
   139  		if method.PkgPath != "" { // Method must be exported.
   140  			continue
   141  		}
   142  		methodType := method.Type
   143  		mVal := getMethod(methodType, valueOfReceiver.Method(index))
   144  		if mVal == nil {
   145  			continue
   146  		}
   147  		receiver.methods[method.Name] = mVal
   148  		if _, ok := publicMethods[method.Name]; ok {
   149  			mVal.public = true
   150  		}
   151  		dir, err := receiverMetricsDir.RegisterDirectory(method.Name)
   152  		if err != nil {
   153  			return err
   154  		}
   155  		if err := mVal.registerMetrics(dir); err != nil {
   156  			return err
   157  		}
   158  	}
   159  	if blocker, ok := rcvr.(MethodBlocker); ok {
   160  		receiver.blockMethod = blocker.BlockMethod
   161  	} else {
   162  		receiver.blockMethod = defaultMethodBlocker
   163  	}
   164  	if granter, ok := rcvr.(MethodGranter); ok {
   165  		receiver.grantMethod = granter.GrantMethod
   166  	} else {
   167  		receiver.grantMethod = defaultMethodGranter
   168  	}
   169  	receivers[name] = receiver
   170  	return nil
   171  }
   172  
   173  func getMethod(methodType reflect.Type, fn reflect.Value) *methodWrapper {
   174  	if methodType.NumOut() != 1 {
   175  		return nil
   176  	}
   177  	if methodType.Out(0) != typeOfError {
   178  		return nil
   179  	}
   180  	if methodType.NumIn() == 2 {
   181  		// Method needs two ins: receiver, *Conn.
   182  		if methodType.In(1) != typeOfConn {
   183  			return nil
   184  		}
   185  		return &methodWrapper{methodType: methodTypeRaw, fn: fn}
   186  	}
   187  	if methodType.NumIn() == 4 {
   188  		if methodType.In(1) != typeOfConn {
   189  			return nil
   190  		}
   191  		// Coder Method needs four ins: receiver, *Conn, Decoder, Encoder.
   192  		if methodType.In(2) == typeOfDecoder &&
   193  			methodType.In(3) == typeOfEncoder {
   194  			return &methodWrapper{
   195  				methodType: methodTypeCoder,
   196  				fn:         fn,
   197  			}
   198  		}
   199  		// RequestReply Method needs four ins: receiver, *Conn, request, *reply.
   200  		if methodType.In(3).Kind() == reflect.Ptr {
   201  			return &methodWrapper{
   202  				methodType:   methodTypeRequestReply,
   203  				fn:           fn,
   204  				requestType:  methodType.In(2),
   205  				responseType: methodType.In(3).Elem(),
   206  			}
   207  		}
   208  	}
   209  	return nil
   210  }
   211  
   212  func (m *methodWrapper) registerMetrics(dir *tricorder.DirectorySpec) error {
   213  	m.failedCallsDistribution = bucketer.NewCumulativeDistribution()
   214  	err := dir.RegisterMetric("failed-call-durations",
   215  		m.failedCallsDistribution, units.Millisecond,
   216  		"duration of failed calls")
   217  	if err != nil {
   218  		return err
   219  	}
   220  	err = dir.RegisterMetric("num-denied-calls", &m.numDeniedCalls,
   221  		units.None, "number of denied calls to method")
   222  	if err != nil {
   223  		return err
   224  	}
   225  	err = dir.RegisterMetric("num-permitted-calls", &m.numPermittedCalls,
   226  		units.None, "number of permitted calls to method")
   227  	if err != nil {
   228  		return err
   229  	}
   230  	m.successfulCallsDistribution = bucketer.NewCumulativeDistribution()
   231  	err = dir.RegisterMetric("successful-call-durations",
   232  		m.successfulCallsDistribution, units.Millisecond,
   233  		"duration of successful calls")
   234  	if err != nil {
   235  		return err
   236  	}
   237  	if m.methodType != methodTypeRequestReply {
   238  		return nil
   239  	}
   240  	m.failedRRCallsDistribution = bucketer.NewCumulativeDistribution()
   241  	err = dir.RegisterMetric("failed-request-reply-call-durations",
   242  		m.failedRRCallsDistribution, units.Millisecond,
   243  		"duration of failed request-reply calls")
   244  	if err != nil {
   245  		return err
   246  	}
   247  	m.successfulRRCallsDistribution = bucketer.NewCumulativeDistribution()
   248  	err = dir.RegisterMetric("successful-request-reply-call-durations",
   249  		m.successfulRRCallsDistribution, units.Millisecond,
   250  		"duration of successful request-reply calls")
   251  	if err != nil {
   252  		return err
   253  	}
   254  	return nil
   255  }
   256  
   257  func gobTlsHttpHandler(w http.ResponseWriter, req *http.Request) {
   258  	httpHandler(w, req, true, &gobCoder{})
   259  }
   260  
   261  func gobUnsecuredHttpHandler(w http.ResponseWriter, req *http.Request) {
   262  	httpHandler(w, req, false, &gobCoder{})
   263  }
   264  
   265  func jsonTlsHttpHandler(w http.ResponseWriter, req *http.Request) {
   266  	httpHandler(w, req, true, &jsonCoder{})
   267  }
   268  
   269  func jsonUnsecuredHttpHandler(w http.ResponseWriter, req *http.Request) {
   270  	httpHandler(w, req, false, &jsonCoder{})
   271  }
   272  
   273  func httpHandler(w http.ResponseWriter, req *http.Request, doTls bool,
   274  	makeCoder coderMaker) {
   275  	serverMetricsMutex.Lock()
   276  	numServerConnections++
   277  	serverMetricsMutex.Unlock()
   278  	if doTls && serverTlsConfig == nil {
   279  		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   280  		w.WriteHeader(http.StatusNotFound)
   281  		return
   282  	}
   283  	if (tlsRequired && !doTls) || req.Method != "CONNECT" {
   284  		serverMetricsMutex.Lock()
   285  		numRejectedServerConnections++
   286  		serverMetricsMutex.Unlock()
   287  		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   288  		w.WriteHeader(http.StatusMethodNotAllowed)
   289  		return
   290  	}
   291  	if tlsRequired && req.TLS != nil {
   292  		if serverTlsConfig == nil ||
   293  			!checkVerifiedChains(req.TLS.VerifiedChains,
   294  				serverTlsConfig.ClientCAs) {
   295  			serverMetricsMutex.Lock()
   296  			numRejectedServerConnections++
   297  			serverMetricsMutex.Unlock()
   298  			w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   299  			w.WriteHeader(http.StatusUnauthorized)
   300  			return
   301  		}
   302  	}
   303  	hijacker, ok := w.(http.Hijacker)
   304  	if !ok {
   305  		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   306  		w.WriteHeader(http.StatusInternalServerError)
   307  		log.Println("not a hijacker ", req.RemoteAddr)
   308  		return
   309  	}
   310  	unsecuredConn, bufrw, err := hijacker.Hijack()
   311  	if err != nil {
   312  		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   313  		w.WriteHeader(http.StatusInternalServerError)
   314  		log.Println("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
   315  		return
   316  	}
   317  	connToClose := unsecuredConn
   318  	defer func() {
   319  		connToClose.Close()
   320  	}()
   321  	if tcpConn, ok := unsecuredConn.(net.TCPConn); ok {
   322  		if err := tcpConn.SetKeepAlive(true); err != nil {
   323  			log.Println("error setting keepalive: ", err.Error())
   324  			return
   325  		}
   326  		if err := tcpConn.SetKeepAlivePeriod(time.Minute * 5); err != nil {
   327  			log.Println("error setting keepalive period: ", err.Error())
   328  			return
   329  		}
   330  	} else {
   331  		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   332  		w.WriteHeader(http.StatusNotAcceptable)
   333  		log.Println("non-TCP connection")
   334  		return
   335  	}
   336  	_, err = io.WriteString(unsecuredConn, "HTTP/1.0 "+connectString+"\n\n")
   337  	if err != nil {
   338  		log.Println("error writing connect message: ", err.Error())
   339  		return
   340  	}
   341  	myConn := &Conn{remoteAddr: req.RemoteAddr}
   342  	if doTls {
   343  		var tlsConn *tls.Conn
   344  		if req.TLS == nil {
   345  			tlsConn = tls.Server(unsecuredConn, serverTlsConfig)
   346  			connToClose = tlsConn
   347  			if err := tlsConn.Handshake(); err != nil {
   348  				serverMetricsMutex.Lock()
   349  				numRejectedServerConnections++
   350  				serverMetricsMutex.Unlock()
   351  				log.Println(err)
   352  				return
   353  			}
   354  		} else {
   355  			if tlsConn, ok = unsecuredConn.(*tls.Conn); !ok {
   356  				log.Println("not really a TLS connection")
   357  				return
   358  			}
   359  		}
   360  		myConn.isEncrypted = true
   361  		myConn.username, myConn.permittedMethods, myConn.groupList, err =
   362  			getAuth(tlsConn.ConnectionState())
   363  		if err != nil {
   364  			log.Println(err)
   365  			return
   366  		}
   367  		myConn.ReadWriter = bufio.NewReadWriter(bufio.NewReader(tlsConn),
   368  			bufio.NewWriter(tlsConn))
   369  	} else {
   370  		myConn.ReadWriter = bufrw
   371  	}
   372  	serverMetricsMutex.Lock()
   373  	numOpenServerConnections++
   374  	serverMetricsMutex.Unlock()
   375  	handleConnection(myConn, makeCoder)
   376  	serverMetricsMutex.Lock()
   377  	numOpenServerConnections--
   378  	serverMetricsMutex.Unlock()
   379  }
   380  
   381  func checkVerifiedChains(verifiedChains [][]*x509.Certificate,
   382  	certPool *x509.CertPool) bool {
   383  	for _, vChain := range verifiedChains {
   384  		vSubject := vChain[0].RawIssuer
   385  		for _, cSubject := range certPool.Subjects() {
   386  			if bytes.Compare(vSubject, cSubject) == 0 {
   387  				return true
   388  			}
   389  		}
   390  	}
   391  	return false
   392  }
   393  
   394  func getAuth(state tls.ConnectionState) (string, map[string]struct{},
   395  	map[string]struct{}, error) {
   396  	var username string
   397  	permittedMethods := make(map[string]struct{})
   398  	trustCertMethods := false
   399  	if fullAuthCaCertPool == nil ||
   400  		checkVerifiedChains(state.VerifiedChains, fullAuthCaCertPool) {
   401  		trustCertMethods = true
   402  	}
   403  	var groupList map[string]struct{}
   404  	for _, certChain := range state.VerifiedChains {
   405  		for _, cert := range certChain {
   406  			var err error
   407  			if username == "" {
   408  				username, err = x509util.GetUsername(cert)
   409  				if err != nil {
   410  					return "", nil, nil, err
   411  				}
   412  			}
   413  			if len(groupList) < 1 {
   414  				groupList, err = x509util.GetGroupList(cert)
   415  				if err != nil {
   416  					return "", nil, nil, err
   417  				}
   418  			}
   419  			if trustCertMethods {
   420  				pms, err := x509util.GetPermittedMethods(cert)
   421  				if err != nil {
   422  					return "", nil, nil, err
   423  				}
   424  				for method := range pms {
   425  					permittedMethods[method] = struct{}{}
   426  				}
   427  			}
   428  		}
   429  	}
   430  	return username, permittedMethods, groupList, nil
   431  }
   432  
   433  func handleConnection(conn *Conn, makeCoder coderMaker) {
   434  	defer conn.callReleaseNotifier()
   435  	defer conn.Flush()
   436  	for ; ; conn.Flush() {
   437  		conn.callReleaseNotifier()
   438  		serviceMethod, err := conn.ReadString('\n')
   439  		if err == io.EOF || err == io.ErrUnexpectedEOF {
   440  			return
   441  		}
   442  		if err != nil {
   443  			log.Println(err)
   444  			if _, err := conn.WriteString(err.Error() + "\n"); err != nil {
   445  				log.Println(err)
   446  				return
   447  			}
   448  			continue
   449  		}
   450  		serviceMethod = strings.TrimSpace(serviceMethod)
   451  		if serviceMethod == "" {
   452  			// Received a "ping" request, send response.
   453  			if _, err := conn.WriteString("\n"); err != nil {
   454  				log.Println(err)
   455  				return
   456  			}
   457  			continue
   458  		}
   459  		method, err := conn.findMethod(serviceMethod)
   460  		if err != nil {
   461  			if _, err := conn.WriteString(err.Error() + "\n"); err != nil {
   462  				log.Println(err)
   463  				return
   464  			}
   465  			continue
   466  		}
   467  		// Method is OK to call. Tell client and then call method handler.
   468  		if _, err := conn.WriteString("\n"); err != nil {
   469  			log.Println(err)
   470  			return
   471  		}
   472  		if err := conn.Flush(); err != nil {
   473  			log.Println(err)
   474  			return
   475  		}
   476  		if err := method.call(conn, makeCoder); err != nil {
   477  			if err != ErrorCloseClient {
   478  				log.Println(err)
   479  			}
   480  			return
   481  		}
   482  	}
   483  }
   484  
   485  func (conn *Conn) callReleaseNotifier() {
   486  	if releaseNotifier := conn.releaseNotifier; releaseNotifier != nil {
   487  		releaseNotifier()
   488  	}
   489  	conn.releaseNotifier = nil
   490  }
   491  
   492  func (conn *Conn) findMethod(serviceMethod string) (*methodWrapper, error) {
   493  	splitServiceMethod := strings.Split(serviceMethod, ".")
   494  	if len(splitServiceMethod) != 2 {
   495  		return nil, errors.New("malformed Service.Method: " + serviceMethod)
   496  	}
   497  	serviceName := splitServiceMethod[0]
   498  	receiver, ok := receivers[serviceName]
   499  	if !ok {
   500  		return nil, errors.New("unknown service: " + serviceName)
   501  	}
   502  	methodName := splitServiceMethod[1]
   503  	method, ok := receiver.methods[methodName]
   504  	if !ok {
   505  		return nil, errors.New(serviceName + ": unknown method: " + methodName)
   506  	}
   507  	if conn.checkMethodAccess(serviceMethod) {
   508  		conn.haveMethodAccess = true
   509  	} else if receiver.grantMethod(serviceName, conn.GetAuthInformation()) {
   510  		conn.haveMethodAccess = true
   511  	} else {
   512  		conn.haveMethodAccess = false
   513  		if !method.public {
   514  			method.numDeniedCalls++
   515  			return nil, ErrorAccessToMethodDenied
   516  		}
   517  	}
   518  	authInfo := conn.GetAuthInformation()
   519  	if rn, err := receiver.blockMethod(methodName, authInfo); err != nil {
   520  		return nil, err
   521  	} else {
   522  		conn.releaseNotifier = rn
   523  	}
   524  	return method, nil
   525  }
   526  
   527  // Returns true if the method is permitted, else false if denied.
   528  func (conn *Conn) checkMethodAccess(serviceMethod string) bool {
   529  	if conn.permittedMethods == nil {
   530  		return true
   531  	}
   532  	for sm := range conn.permittedMethods {
   533  		if matched, _ := filepath.Match(sm, serviceMethod); matched {
   534  			return true
   535  		}
   536  	}
   537  	return false
   538  }
   539  
   540  func listMethodsHttpHandler(w http.ResponseWriter, req *http.Request) {
   541  	writer := bufio.NewWriter(w)
   542  	defer writer.Flush()
   543  	methods := make([]string, len(receivers))
   544  	for receiverName, receiver := range receivers {
   545  		for method := range receiver.methods {
   546  			methods = append(methods, receiverName+"."+method+"\n")
   547  		}
   548  	}
   549  	sort.Strings(methods)
   550  	for _, method := range methods {
   551  		writer.WriteString(method)
   552  	}
   553  }
   554  
   555  func (m *methodWrapper) call(conn *Conn, makeCoder coderMaker) error {
   556  	m.numPermittedCalls++
   557  	startTime := time.Now()
   558  	err := m._call(conn, makeCoder)
   559  	timeTaken := time.Since(startTime)
   560  	if err == nil {
   561  		m.successfulCallsDistribution.Add(timeTaken)
   562  	} else {
   563  		m.failedCallsDistribution.Add(timeTaken)
   564  	}
   565  	return err
   566  }
   567  
   568  func (m *methodWrapper) _call(conn *Conn, makeCoder coderMaker) error {
   569  	connValue := reflect.ValueOf(conn)
   570  	conn.Decoder = makeCoder.MakeDecoder(conn)
   571  	conn.Encoder = makeCoder.MakeEncoder(conn)
   572  	switch m.methodType {
   573  	case methodTypeRaw:
   574  		returnValues := m.fn.Call([]reflect.Value{connValue})
   575  		errInter := returnValues[0].Interface()
   576  		if errInter != nil {
   577  			return errInter.(error)
   578  		}
   579  		return nil
   580  	case methodTypeCoder:
   581  		returnValues := m.fn.Call([]reflect.Value{
   582  			connValue,
   583  			reflect.ValueOf(conn.Decoder),
   584  			reflect.ValueOf(conn.Encoder),
   585  		})
   586  		errInter := returnValues[0].Interface()
   587  		if errInter != nil {
   588  			return errInter.(error)
   589  		}
   590  		return nil
   591  	case methodTypeRequestReply:
   592  		request := reflect.New(m.requestType)
   593  		response := reflect.New(m.responseType)
   594  		if err := conn.Decode(request.Interface()); err != nil {
   595  			_, err = conn.WriteString(err.Error() + "\n")
   596  			return err
   597  		}
   598  		startTime := time.Now()
   599  		returnValues := m.fn.Call([]reflect.Value{connValue, request.Elem(),
   600  			response})
   601  		timeTaken := time.Since(startTime)
   602  		errInter := returnValues[0].Interface()
   603  		if errInter != nil {
   604  			m.failedRRCallsDistribution.Add(timeTaken)
   605  			err := errInter.(error)
   606  			_, err = conn.WriteString(err.Error() + "\n")
   607  			return err
   608  		}
   609  		m.successfulRRCallsDistribution.Add(timeTaken)
   610  		if _, err := conn.WriteString("\n"); err != nil {
   611  			return err
   612  		}
   613  		return conn.Encode(response.Interface())
   614  	}
   615  	return errors.New("unknown method type")
   616  }