
     1  package grpc
     3  import (
     4  	"crypto/tls"
     5  	"net"
     6  	"runtime/debug"
     7  	"strings"
     8  	"time"
    10  	""
    11  	""
    12  	""
    13  	""
    14  	""
    15  	""
    16  	""
    17  )
    19  // PanicLoggerUnaryServerInterceptor returns a new unary server interceptor for recovering from panics and returning error
    20  func PanicLoggerUnaryServerInterceptor(log *logrus.Entry) grpc.UnaryServerInterceptor {
    21  	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
    22  		defer func() {
    23  			if r := recover(); r != nil {
    24  				log.Errorf("Recovered from panic: %+v\n%s", r, debug.Stack())
    25  				err = status.Errorf(codes.Internal, "%s", r)
    26  			}
    27  		}()
    28  		return handler(ctx, req)
    29  	}
    30  }
    32  // PanicLoggerStreamServerInterceptor returns a new streaming server interceptor for recovering from panics and returning error
    33  func PanicLoggerStreamServerInterceptor(log *logrus.Entry) grpc.StreamServerInterceptor {
    34  	return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
    35  		defer func() {
    36  			if r := recover(); r != nil {
    37  				log.Errorf("Recovered from panic: %+v\n%s", r, debug.Stack())
    38  				err = status.Errorf(codes.Internal, "%s", r)
    39  			}
    40  		}()
    41  		return handler(srv, stream)
    42  	}
    43  }
    45  // BlockingDial is a helper method to dial the given address, using optional TLS credentials,
    46  // and blocking until the returned connection is ready. If the given credentials are nil, the
    47  // connection will be insecure (plain-text).
    48  // Lifted from:
    49  func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
    50  	// grpc.Dial doesn't provide any information on permanent connection errors (like
    51  	// TLS handshake failures). So in order to provide good error messages, we need a
    52  	// custom dialer that can provide that info. That means we manage the TLS handshake.
    53  	result := make(chan interface{}, 1)
    54  	writeResult := func(res interface{}) {
    55  		// non-blocking write: we only need the first result
    56  		select {
    57  		case result <- res:
    58  		default:
    59  		}
    60  	}
    62  	dialer := func(address string, timeout time.Duration) (net.Conn, error) {
    63  		ctx, cancel := context.WithTimeout(ctx, timeout)
    64  		defer cancel()
    66  		conn, err := (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
    67  		if err != nil {
    68  			writeResult(err)
    69  			return nil, err
    70  		}
    71  		if creds != nil {
    72  			conn, _, err = creds.ClientHandshake(ctx, address, conn)
    73  			if err != nil {
    74  				writeResult(err)
    75  				return nil, err
    76  			}
    77  		}
    78  		return conn, nil
    79  	}
    81  	// Even with grpc.FailOnNonTempDialError, this call will usually timeout in
    82  	// the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
    83  	// know when we're done. So we run it in a goroutine and then use result
    84  	// channel to either get the channel or fail-fast.
    85  	go func() {
    86  		opts = append(opts,
    87  			grpc.WithBlock(),
    88  			grpc.FailOnNonTempDialError(true),
    89  			grpc.WithDialer(dialer),
    90  			grpc.WithInsecure(), // we are handling TLS, so tell grpc not to
    91  			grpc.WithKeepaliveParams(keepalive.ClientParameters{Time: 10 * time.Second}),
    92  		)
    93  		conn, err := grpc.DialContext(ctx, address, opts...)
    94  		var res interface{}
    95  		if err != nil {
    96  			res = err
    97  		} else {
    98  			res = conn
    99  		}
   100  		writeResult(res)
   101  	}()
   103  	select {
   104  	case res := <-result:
   105  		if conn, ok := res.(*grpc.ClientConn); ok {
   106  			return conn, nil
   107  		}
   108  		return nil, res.(error)
   109  	case <-ctx.Done():
   110  		return nil, ctx.Err()
   111  	}
   112  }
   114  type TLSTestResult struct {
   115  	TLS         bool
   116  	InsecureErr error
   117  }
   119  func TestTLS(address string) (*TLSTestResult, error) {
   120  	if parts := strings.Split(address, ":"); len(parts) == 1 {
   121  		// If port is unspecified, assume the most likely port
   122  		address += ":443"
   123  	}
   124  	var testResult TLSTestResult
   125  	var tlsConfig tls.Config
   126  	tlsConfig.InsecureSkipVerify = true
   127  	creds := credentials.NewTLS(&tlsConfig)
   128  	conn, err := BlockingDial(context.Background(), "tcp", address, creds)
   129  	if err == nil {
   130  		_ = conn.Close()
   131  		testResult.TLS = true
   132  		creds := credentials.NewTLS(&tls.Config{})
   133  		conn, err := BlockingDial(context.Background(), "tcp", address, creds)
   134  		if err == nil {
   135  			_ = conn.Close()
   136  		} else {
   137  			// if connection was successful with InsecureSkipVerify true, but unsuccessful with
   138  			// InsecureSkipVerify false, it means server is not configured securely
   139  			testResult.InsecureErr = err
   140  		}
   141  		return &testResult, nil
   142  	}
   143  	// If we get here, we were unable to connect via TLS (even with InsecureSkipVerify: true)
   144  	// It may be because server is running without TLS, or because of real issues (e.g. connection
   145  	// refused). Test if server accepts plain-text connections
   146  	conn, err = BlockingDial(context.Background(), "tcp", address, nil)
   147  	if err == nil {
   148  		_ = conn.Close()
   149  		testResult.TLS = false
   150  		return &testResult, nil
   151  	}
   152  	return nil, err
   153  }
   155  func WithTimeout(duration time.Duration) grpc.UnaryClientInterceptor {
   156  	return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   157  		clientDeadline := time.Now().Add(duration)
   158  		ctx, cancel := context.WithDeadline(ctx, clientDeadline)
   159  		defer cancel()
   160  		return invoker(ctx, method, req, reply, cc, opts...)
   161  	}
   162  }