github.com/cloud-foundations/dominator@v0.0.0-20221004181915-6e4fee580046/lib/srpc/server.go (about) 1 package srpc 2 3 import ( 4 "bufio" 5 "bytes" 6 "crypto/tls" 7 "crypto/x509" 8 "errors" 9 "io" 10 "log" 11 "net/http" 12 "path/filepath" 13 "reflect" 14 "sort" 15 "strings" 16 "sync" 17 "time" 18 19 "github.com/Cloud-Foundations/Dominator/lib/net" 20 "github.com/Cloud-Foundations/Dominator/lib/x509util" 21 "github.com/Cloud-Foundations/tricorder/go/tricorder" 22 "github.com/Cloud-Foundations/tricorder/go/tricorder/units" 23 ) 24 25 const ( 26 connectString = "200 Connected to Go SRPC" 27 rpcPath = "/_goSRPC_/" // Legacy endpoint. GOB coder. 28 tlsRpcPath = "/_go_TLS_SRPC_/" // Legacy endpoint. GOB coder. 29 jsonRpcPath = "/_SRPC_/unsecured/JSON" 30 jsonTlsRpcPath = "/_SRPC_/TLS/JSON" 31 listMethodsPath = rpcPath + "listMethods" 32 33 methodTypeRaw = iota 34 methodTypeCoder 35 methodTypeRequestReply 36 ) 37 38 type methodWrapper struct { 39 methodType int 40 public bool 41 fn reflect.Value 42 requestType reflect.Type 43 responseType reflect.Type 44 failedCallsDistribution *tricorder.CumulativeDistribution 45 failedRRCallsDistribution *tricorder.CumulativeDistribution 46 numDeniedCalls uint64 47 numPermittedCalls uint64 48 successfulCallsDistribution *tricorder.CumulativeDistribution 49 successfulRRCallsDistribution *tricorder.CumulativeDistribution 50 } 51 52 type receiverType struct { 53 methods map[string]*methodWrapper 54 blockMethod func(methodName string, 55 authInfo *AuthInformation) (func(), error) 56 grantMethod func(serviceMethod string, authInfo *AuthInformation) bool 57 } 58 59 var ( 60 defaultGrantMethod = func(serviceMethod string, 61 authInfo *AuthInformation) bool { 62 return false 63 } 64 receivers map[string]receiverType = make(map[string]receiverType) 65 serverMetricsDir *tricorder.DirectorySpec 66 bucketer *tricorder.Bucketer 67 serverMetricsMutex sync.Mutex 68 numServerConnections uint64 69 numOpenServerConnections uint64 70 numRejectedServerConnections uint64 71 ) 72 73 // Precompute some reflect types. Can't use the types directly because Typeof 74 // takes an empty interface value. This is annoying. 75 var typeOfConn = reflect.TypeOf((**Conn)(nil)).Elem() 76 var typeOfDecoder = reflect.TypeOf((*Decoder)(nil)).Elem() 77 var typeOfEncoder = reflect.TypeOf((*Encoder)(nil)).Elem() 78 var typeOfError = reflect.TypeOf((*error)(nil)).Elem() 79 80 func init() { 81 http.HandleFunc(rpcPath, gobUnsecuredHttpHandler) 82 http.HandleFunc(tlsRpcPath, gobTlsHttpHandler) 83 http.HandleFunc(jsonRpcPath, jsonUnsecuredHttpHandler) 84 http.HandleFunc(jsonTlsRpcPath, jsonTlsHttpHandler) 85 http.HandleFunc(listMethodsPath, listMethodsHttpHandler) 86 registerServerMetrics() 87 } 88 89 func registerServerMetrics() { 90 var err error 91 serverMetricsDir, err = tricorder.RegisterDirectory("srpc/server") 92 if err != nil { 93 panic(err) 94 } 95 err = serverMetricsDir.RegisterMetric("num-connections", 96 &numServerConnections, units.None, "number of connection attempts") 97 if err != nil { 98 panic(err) 99 } 100 err = serverMetricsDir.RegisterMetric("num-open-connections", 101 &numOpenServerConnections, units.None, "number of open connections") 102 if err != nil { 103 panic(err) 104 } 105 err = serverMetricsDir.RegisterMetric("num-rejected-connections", 106 &numRejectedServerConnections, units.None, 107 "number of rejected connections") 108 if err != nil { 109 panic(err) 110 } 111 bucketer = tricorder.NewGeometricBucketer(0.1, 1e5) 112 } 113 114 func defaultMethodBlocker(methodName string, 115 authInfo *AuthInformation) (func(), error) { 116 return nil, nil 117 } 118 119 func defaultMethodGranter(serviceMethod string, 120 authInfo *AuthInformation) bool { 121 return defaultGrantMethod(serviceMethod, authInfo) 122 } 123 124 func registerName(name string, rcvr interface{}, 125 options ReceiverOptions) error { 126 receiver := receiverType{methods: make(map[string]*methodWrapper)} 127 typeOfReceiver := reflect.TypeOf(rcvr) 128 valueOfReceiver := reflect.ValueOf(rcvr) 129 receiverMetricsDir, err := serverMetricsDir.RegisterDirectory(name) 130 if err != nil { 131 return err 132 } 133 publicMethods := make(map[string]struct{}, len(options.PublicMethods)) 134 for _, methodName := range options.PublicMethods { 135 publicMethods[methodName] = struct{}{} 136 } 137 for index := 0; index < typeOfReceiver.NumMethod(); index++ { 138 method := typeOfReceiver.Method(index) 139 if method.PkgPath != "" { // Method must be exported. 140 continue 141 } 142 methodType := method.Type 143 mVal := getMethod(methodType, valueOfReceiver.Method(index)) 144 if mVal == nil { 145 continue 146 } 147 receiver.methods[method.Name] = mVal 148 if _, ok := publicMethods[method.Name]; ok { 149 mVal.public = true 150 } 151 dir, err := receiverMetricsDir.RegisterDirectory(method.Name) 152 if err != nil { 153 return err 154 } 155 if err := mVal.registerMetrics(dir); err != nil { 156 return err 157 } 158 } 159 if blocker, ok := rcvr.(MethodBlocker); ok { 160 receiver.blockMethod = blocker.BlockMethod 161 } else { 162 receiver.blockMethod = defaultMethodBlocker 163 } 164 if granter, ok := rcvr.(MethodGranter); ok { 165 receiver.grantMethod = granter.GrantMethod 166 } else { 167 receiver.grantMethod = defaultMethodGranter 168 } 169 receivers[name] = receiver 170 return nil 171 } 172 173 func getMethod(methodType reflect.Type, fn reflect.Value) *methodWrapper { 174 if methodType.NumOut() != 1 { 175 return nil 176 } 177 if methodType.Out(0) != typeOfError { 178 return nil 179 } 180 if methodType.NumIn() == 2 { 181 // Method needs two ins: receiver, *Conn. 182 if methodType.In(1) != typeOfConn { 183 return nil 184 } 185 return &methodWrapper{methodType: methodTypeRaw, fn: fn} 186 } 187 if methodType.NumIn() == 4 { 188 if methodType.In(1) != typeOfConn { 189 return nil 190 } 191 // Coder Method needs four ins: receiver, *Conn, Decoder, Encoder. 192 if methodType.In(2) == typeOfDecoder && 193 methodType.In(3) == typeOfEncoder { 194 return &methodWrapper{ 195 methodType: methodTypeCoder, 196 fn: fn, 197 } 198 } 199 // RequestReply Method needs four ins: receiver, *Conn, request, *reply. 200 if methodType.In(3).Kind() == reflect.Ptr { 201 return &methodWrapper{ 202 methodType: methodTypeRequestReply, 203 fn: fn, 204 requestType: methodType.In(2), 205 responseType: methodType.In(3).Elem(), 206 } 207 } 208 } 209 return nil 210 } 211 212 func (m *methodWrapper) registerMetrics(dir *tricorder.DirectorySpec) error { 213 m.failedCallsDistribution = bucketer.NewCumulativeDistribution() 214 err := dir.RegisterMetric("failed-call-durations", 215 m.failedCallsDistribution, units.Millisecond, 216 "duration of failed calls") 217 if err != nil { 218 return err 219 } 220 err = dir.RegisterMetric("num-denied-calls", &m.numDeniedCalls, 221 units.None, "number of denied calls to method") 222 if err != nil { 223 return err 224 } 225 err = dir.RegisterMetric("num-permitted-calls", &m.numPermittedCalls, 226 units.None, "number of permitted calls to method") 227 if err != nil { 228 return err 229 } 230 m.successfulCallsDistribution = bucketer.NewCumulativeDistribution() 231 err = dir.RegisterMetric("successful-call-durations", 232 m.successfulCallsDistribution, units.Millisecond, 233 "duration of successful calls") 234 if err != nil { 235 return err 236 } 237 if m.methodType != methodTypeRequestReply { 238 return nil 239 } 240 m.failedRRCallsDistribution = bucketer.NewCumulativeDistribution() 241 err = dir.RegisterMetric("failed-request-reply-call-durations", 242 m.failedRRCallsDistribution, units.Millisecond, 243 "duration of failed request-reply calls") 244 if err != nil { 245 return err 246 } 247 m.successfulRRCallsDistribution = bucketer.NewCumulativeDistribution() 248 err = dir.RegisterMetric("successful-request-reply-call-durations", 249 m.successfulRRCallsDistribution, units.Millisecond, 250 "duration of successful request-reply calls") 251 if err != nil { 252 return err 253 } 254 return nil 255 } 256 257 func gobTlsHttpHandler(w http.ResponseWriter, req *http.Request) { 258 httpHandler(w, req, true, &gobCoder{}) 259 } 260 261 func gobUnsecuredHttpHandler(w http.ResponseWriter, req *http.Request) { 262 httpHandler(w, req, false, &gobCoder{}) 263 } 264 265 func jsonTlsHttpHandler(w http.ResponseWriter, req *http.Request) { 266 httpHandler(w, req, true, &jsonCoder{}) 267 } 268 269 func jsonUnsecuredHttpHandler(w http.ResponseWriter, req *http.Request) { 270 httpHandler(w, req, false, &jsonCoder{}) 271 } 272 273 func httpHandler(w http.ResponseWriter, req *http.Request, doTls bool, 274 makeCoder coderMaker) { 275 serverMetricsMutex.Lock() 276 numServerConnections++ 277 serverMetricsMutex.Unlock() 278 if doTls && serverTlsConfig == nil { 279 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 280 w.WriteHeader(http.StatusNotFound) 281 return 282 } 283 if (tlsRequired && !doTls) || req.Method != "CONNECT" { 284 serverMetricsMutex.Lock() 285 numRejectedServerConnections++ 286 serverMetricsMutex.Unlock() 287 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 288 w.WriteHeader(http.StatusMethodNotAllowed) 289 return 290 } 291 if tlsRequired && req.TLS != nil { 292 if serverTlsConfig == nil || 293 !checkVerifiedChains(req.TLS.VerifiedChains, 294 serverTlsConfig.ClientCAs) { 295 serverMetricsMutex.Lock() 296 numRejectedServerConnections++ 297 serverMetricsMutex.Unlock() 298 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 299 w.WriteHeader(http.StatusUnauthorized) 300 return 301 } 302 } 303 hijacker, ok := w.(http.Hijacker) 304 if !ok { 305 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 306 w.WriteHeader(http.StatusInternalServerError) 307 log.Println("not a hijacker ", req.RemoteAddr) 308 return 309 } 310 unsecuredConn, bufrw, err := hijacker.Hijack() 311 if err != nil { 312 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 313 w.WriteHeader(http.StatusInternalServerError) 314 log.Println("rpc hijacking ", req.RemoteAddr, ": ", err.Error()) 315 return 316 } 317 connToClose := unsecuredConn 318 defer func() { 319 connToClose.Close() 320 }() 321 if tcpConn, ok := unsecuredConn.(net.TCPConn); ok { 322 if err := tcpConn.SetKeepAlive(true); err != nil { 323 log.Println("error setting keepalive: ", err.Error()) 324 return 325 } 326 if err := tcpConn.SetKeepAlivePeriod(time.Minute * 5); err != nil { 327 log.Println("error setting keepalive period: ", err.Error()) 328 return 329 } 330 } else { 331 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 332 w.WriteHeader(http.StatusNotAcceptable) 333 log.Println("non-TCP connection") 334 return 335 } 336 _, err = io.WriteString(unsecuredConn, "HTTP/1.0 "+connectString+"\n\n") 337 if err != nil { 338 log.Println("error writing connect message: ", err.Error()) 339 return 340 } 341 myConn := &Conn{remoteAddr: req.RemoteAddr} 342 if doTls { 343 var tlsConn *tls.Conn 344 if req.TLS == nil { 345 tlsConn = tls.Server(unsecuredConn, serverTlsConfig) 346 connToClose = tlsConn 347 if err := tlsConn.Handshake(); err != nil { 348 serverMetricsMutex.Lock() 349 numRejectedServerConnections++ 350 serverMetricsMutex.Unlock() 351 log.Println(err) 352 return 353 } 354 } else { 355 if tlsConn, ok = unsecuredConn.(*tls.Conn); !ok { 356 log.Println("not really a TLS connection") 357 return 358 } 359 } 360 myConn.isEncrypted = true 361 myConn.username, myConn.permittedMethods, myConn.groupList, err = 362 getAuth(tlsConn.ConnectionState()) 363 if err != nil { 364 log.Println(err) 365 return 366 } 367 myConn.ReadWriter = bufio.NewReadWriter(bufio.NewReader(tlsConn), 368 bufio.NewWriter(tlsConn)) 369 } else { 370 myConn.ReadWriter = bufrw 371 } 372 serverMetricsMutex.Lock() 373 numOpenServerConnections++ 374 serverMetricsMutex.Unlock() 375 handleConnection(myConn, makeCoder) 376 serverMetricsMutex.Lock() 377 numOpenServerConnections-- 378 serverMetricsMutex.Unlock() 379 } 380 381 func checkVerifiedChains(verifiedChains [][]*x509.Certificate, 382 certPool *x509.CertPool) bool { 383 for _, vChain := range verifiedChains { 384 vSubject := vChain[0].RawIssuer 385 for _, cSubject := range certPool.Subjects() { 386 if bytes.Compare(vSubject, cSubject) == 0 { 387 return true 388 } 389 } 390 } 391 return false 392 } 393 394 func getAuth(state tls.ConnectionState) (string, map[string]struct{}, 395 map[string]struct{}, error) { 396 var username string 397 permittedMethods := make(map[string]struct{}) 398 trustCertMethods := false 399 if fullAuthCaCertPool == nil || 400 checkVerifiedChains(state.VerifiedChains, fullAuthCaCertPool) { 401 trustCertMethods = true 402 } 403 var groupList map[string]struct{} 404 for _, certChain := range state.VerifiedChains { 405 for _, cert := range certChain { 406 var err error 407 if username == "" { 408 username, err = x509util.GetUsername(cert) 409 if err != nil { 410 return "", nil, nil, err 411 } 412 } 413 if len(groupList) < 1 { 414 groupList, err = x509util.GetGroupList(cert) 415 if err != nil { 416 return "", nil, nil, err 417 } 418 } 419 if trustCertMethods { 420 pms, err := x509util.GetPermittedMethods(cert) 421 if err != nil { 422 return "", nil, nil, err 423 } 424 for method := range pms { 425 permittedMethods[method] = struct{}{} 426 } 427 } 428 } 429 } 430 return username, permittedMethods, groupList, nil 431 } 432 433 func handleConnection(conn *Conn, makeCoder coderMaker) { 434 defer conn.callReleaseNotifier() 435 defer conn.Flush() 436 for ; ; conn.Flush() { 437 conn.callReleaseNotifier() 438 serviceMethod, err := conn.ReadString('\n') 439 if err == io.EOF || err == io.ErrUnexpectedEOF { 440 return 441 } 442 if err != nil { 443 log.Println(err) 444 if _, err := conn.WriteString(err.Error() + "\n"); err != nil { 445 log.Println(err) 446 return 447 } 448 continue 449 } 450 serviceMethod = strings.TrimSpace(serviceMethod) 451 if serviceMethod == "" { 452 // Received a "ping" request, send response. 453 if _, err := conn.WriteString("\n"); err != nil { 454 log.Println(err) 455 return 456 } 457 continue 458 } 459 method, err := conn.findMethod(serviceMethod) 460 if err != nil { 461 if _, err := conn.WriteString(err.Error() + "\n"); err != nil { 462 log.Println(err) 463 return 464 } 465 continue 466 } 467 // Method is OK to call. Tell client and then call method handler. 468 if _, err := conn.WriteString("\n"); err != nil { 469 log.Println(err) 470 return 471 } 472 if err := conn.Flush(); err != nil { 473 log.Println(err) 474 return 475 } 476 if err := method.call(conn, makeCoder); err != nil { 477 if err != ErrorCloseClient { 478 log.Println(err) 479 } 480 return 481 } 482 } 483 } 484 485 func (conn *Conn) callReleaseNotifier() { 486 if releaseNotifier := conn.releaseNotifier; releaseNotifier != nil { 487 releaseNotifier() 488 } 489 conn.releaseNotifier = nil 490 } 491 492 func (conn *Conn) findMethod(serviceMethod string) (*methodWrapper, error) { 493 splitServiceMethod := strings.Split(serviceMethod, ".") 494 if len(splitServiceMethod) != 2 { 495 return nil, errors.New("malformed Service.Method: " + serviceMethod) 496 } 497 serviceName := splitServiceMethod[0] 498 receiver, ok := receivers[serviceName] 499 if !ok { 500 return nil, errors.New("unknown service: " + serviceName) 501 } 502 methodName := splitServiceMethod[1] 503 method, ok := receiver.methods[methodName] 504 if !ok { 505 return nil, errors.New(serviceName + ": unknown method: " + methodName) 506 } 507 if conn.checkMethodAccess(serviceMethod) { 508 conn.haveMethodAccess = true 509 } else if receiver.grantMethod(serviceName, conn.GetAuthInformation()) { 510 conn.haveMethodAccess = true 511 } else { 512 conn.haveMethodAccess = false 513 if !method.public { 514 method.numDeniedCalls++ 515 return nil, ErrorAccessToMethodDenied 516 } 517 } 518 authInfo := conn.GetAuthInformation() 519 if rn, err := receiver.blockMethod(methodName, authInfo); err != nil { 520 return nil, err 521 } else { 522 conn.releaseNotifier = rn 523 } 524 return method, nil 525 } 526 527 // Returns true if the method is permitted, else false if denied. 528 func (conn *Conn) checkMethodAccess(serviceMethod string) bool { 529 if conn.permittedMethods == nil { 530 return true 531 } 532 for sm := range conn.permittedMethods { 533 if matched, _ := filepath.Match(sm, serviceMethod); matched { 534 return true 535 } 536 } 537 return false 538 } 539 540 func listMethodsHttpHandler(w http.ResponseWriter, req *http.Request) { 541 writer := bufio.NewWriter(w) 542 defer writer.Flush() 543 methods := make([]string, len(receivers)) 544 for receiverName, receiver := range receivers { 545 for method := range receiver.methods { 546 methods = append(methods, receiverName+"."+method+"\n") 547 } 548 } 549 sort.Strings(methods) 550 for _, method := range methods { 551 writer.WriteString(method) 552 } 553 } 554 555 func (m *methodWrapper) call(conn *Conn, makeCoder coderMaker) error { 556 m.numPermittedCalls++ 557 startTime := time.Now() 558 err := m._call(conn, makeCoder) 559 timeTaken := time.Since(startTime) 560 if err == nil { 561 m.successfulCallsDistribution.Add(timeTaken) 562 } else { 563 m.failedCallsDistribution.Add(timeTaken) 564 } 565 return err 566 } 567 568 func (m *methodWrapper) _call(conn *Conn, makeCoder coderMaker) error { 569 connValue := reflect.ValueOf(conn) 570 conn.Decoder = makeCoder.MakeDecoder(conn) 571 conn.Encoder = makeCoder.MakeEncoder(conn) 572 switch m.methodType { 573 case methodTypeRaw: 574 returnValues := m.fn.Call([]reflect.Value{connValue}) 575 errInter := returnValues[0].Interface() 576 if errInter != nil { 577 return errInter.(error) 578 } 579 return nil 580 case methodTypeCoder: 581 returnValues := m.fn.Call([]reflect.Value{ 582 connValue, 583 reflect.ValueOf(conn.Decoder), 584 reflect.ValueOf(conn.Encoder), 585 }) 586 errInter := returnValues[0].Interface() 587 if errInter != nil { 588 return errInter.(error) 589 } 590 return nil 591 case methodTypeRequestReply: 592 request := reflect.New(m.requestType) 593 response := reflect.New(m.responseType) 594 if err := conn.Decode(request.Interface()); err != nil { 595 _, err = conn.WriteString(err.Error() + "\n") 596 return err 597 } 598 startTime := time.Now() 599 returnValues := m.fn.Call([]reflect.Value{connValue, request.Elem(), 600 response}) 601 timeTaken := time.Since(startTime) 602 errInter := returnValues[0].Interface() 603 if errInter != nil { 604 m.failedRRCallsDistribution.Add(timeTaken) 605 err := errInter.(error) 606 _, err = conn.WriteString(err.Error() + "\n") 607 return err 608 } 609 m.successfulRRCallsDistribution.Add(timeTaken) 610 if _, err := conn.WriteString("\n"); err != nil { 611 return err 612 } 613 return conn.Encode(response.Interface()) 614 } 615 return errors.New("unknown method type") 616 }