github.com/argoproj/argo-cd@v1.8.7/util/grpc/grpc.go (about) 1 package grpc 2 3 import ( 4 "crypto/tls" 5 "net" 6 "runtime/debug" 7 "strings" 8 "time" 9 10 "github.com/sirupsen/logrus" 11 "golang.org/x/net/context" 12 "google.golang.org/grpc" 13 "google.golang.org/grpc/codes" 14 "google.golang.org/grpc/credentials" 15 "google.golang.org/grpc/keepalive" 16 "google.golang.org/grpc/status" 17 ) 18 19 // PanicLoggerUnaryServerInterceptor returns a new unary server interceptor for recovering from panics and returning error 20 func PanicLoggerUnaryServerInterceptor(log *logrus.Entry) grpc.UnaryServerInterceptor { 21 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) { 22 defer func() { 23 if r := recover(); r != nil { 24 log.Errorf("Recovered from panic: %+v\n%s", r, debug.Stack()) 25 err = status.Errorf(codes.Internal, "%s", r) 26 } 27 }() 28 return handler(ctx, req) 29 } 30 } 31 32 // PanicLoggerStreamServerInterceptor returns a new streaming server interceptor for recovering from panics and returning error 33 func PanicLoggerStreamServerInterceptor(log *logrus.Entry) grpc.StreamServerInterceptor { 34 return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { 35 defer func() { 36 if r := recover(); r != nil { 37 log.Errorf("Recovered from panic: %+v\n%s", r, debug.Stack()) 38 err = status.Errorf(codes.Internal, "%s", r) 39 } 40 }() 41 return handler(srv, stream) 42 } 43 } 44 45 // BlockingDial is a helper method to dial the given address, using optional TLS credentials, 46 // and blocking until the returned connection is ready. If the given credentials are nil, the 47 // connection will be insecure (plain-text). 48 // Lifted from: https://github.com/fullstorydev/grpcurl/blob/master/grpcurl.go 49 func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) { 50 // grpc.Dial doesn't provide any information on permanent connection errors (like 51 // TLS handshake failures). So in order to provide good error messages, we need a 52 // custom dialer that can provide that info. That means we manage the TLS handshake. 53 result := make(chan interface{}, 1) 54 writeResult := func(res interface{}) { 55 // non-blocking write: we only need the first result 56 select { 57 case result <- res: 58 default: 59 } 60 } 61 62 dialer := func(address string, timeout time.Duration) (net.Conn, error) { 63 ctx, cancel := context.WithTimeout(ctx, timeout) 64 defer cancel() 65 66 conn, err := (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address) 67 if err != nil { 68 writeResult(err) 69 return nil, err 70 } 71 if creds != nil { 72 conn, _, err = creds.ClientHandshake(ctx, address, conn) 73 if err != nil { 74 writeResult(err) 75 return nil, err 76 } 77 } 78 return conn, nil 79 } 80 81 // Even with grpc.FailOnNonTempDialError, this call will usually timeout in 82 // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to 83 // know when we're done. So we run it in a goroutine and then use result 84 // channel to either get the channel or fail-fast. 85 go func() { 86 opts = append(opts, 87 grpc.WithBlock(), 88 grpc.FailOnNonTempDialError(true), 89 grpc.WithDialer(dialer), 90 grpc.WithInsecure(), // we are handling TLS, so tell grpc not to 91 grpc.WithKeepaliveParams(keepalive.ClientParameters{Time: 10 * time.Second}), 92 ) 93 conn, err := grpc.DialContext(ctx, address, opts...) 94 var res interface{} 95 if err != nil { 96 res = err 97 } else { 98 res = conn 99 } 100 writeResult(res) 101 }() 102 103 select { 104 case res := <-result: 105 if conn, ok := res.(*grpc.ClientConn); ok { 106 return conn, nil 107 } 108 return nil, res.(error) 109 case <-ctx.Done(): 110 return nil, ctx.Err() 111 } 112 } 113 114 type TLSTestResult struct { 115 TLS bool 116 InsecureErr error 117 } 118 119 func TestTLS(address string) (*TLSTestResult, error) { 120 if parts := strings.Split(address, ":"); len(parts) == 1 { 121 // If port is unspecified, assume the most likely port 122 address += ":443" 123 } 124 var testResult TLSTestResult 125 var tlsConfig tls.Config 126 tlsConfig.InsecureSkipVerify = true 127 creds := credentials.NewTLS(&tlsConfig) 128 conn, err := BlockingDial(context.Background(), "tcp", address, creds) 129 if err == nil { 130 _ = conn.Close() 131 testResult.TLS = true 132 creds := credentials.NewTLS(&tls.Config{}) 133 conn, err := BlockingDial(context.Background(), "tcp", address, creds) 134 if err == nil { 135 _ = conn.Close() 136 } else { 137 // if connection was successful with InsecureSkipVerify true, but unsuccessful with 138 // InsecureSkipVerify false, it means server is not configured securely 139 testResult.InsecureErr = err 140 } 141 return &testResult, nil 142 } 143 // If we get here, we were unable to connect via TLS (even with InsecureSkipVerify: true) 144 // It may be because server is running without TLS, or because of real issues (e.g. connection 145 // refused). Test if server accepts plain-text connections 146 conn, err = BlockingDial(context.Background(), "tcp", address, nil) 147 if err == nil { 148 _ = conn.Close() 149 testResult.TLS = false 150 return &testResult, nil 151 } 152 return nil, err 153 } 154 155 func WithTimeout(duration time.Duration) grpc.UnaryClientInterceptor { 156 return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 157 clientDeadline := time.Now().Add(duration) 158 ctx, cancel := context.WithDeadline(ctx, clientDeadline) 159 defer cancel() 160 return invoker(ctx, method, req, reply, cc, opts...) 161 } 162 }