git.zd.zone/hrpc/hrpc@v0.0.12/tracer/tracer.go (about)

     1  package tracer
     2  
     3  import (
     4  	"context"
     5  	"crypto/sha256"
     6  	"fmt"
     7  	"math/rand"
     8  	"runtime"
     9  	"time"
    10  
    11  	"git.zd.zone/hrpc/hrpc/codec"
    12  	"git.zd.zone/hrpc/hrpc/log"
    13  	"git.zd.zone/hrpc/hrpc/utils/uniqueid"
    14  	"google.golang.org/grpc"
    15  	"google.golang.org/grpc/codes"
    16  	"google.golang.org/grpc/metadata"
    17  	"google.golang.org/grpc/status"
    18  )
    19  
    20  func prefix(s string) string {
    21  	seed := time.Now().UnixNano()
    22  	rand.Seed(seed)
    23  	h := sha256.New()
    24  	h.Write([]byte(fmt.Sprintf("%s-%v-%v", s, rand.Intn(99999), seed)))
    25  	return fmt.Sprintf("%x", h.Sum(nil))
    26  }
    27  
    28  // NewID generates a random trace id in string
    29  func NewID(serverName string) string {
    30  	return prefix(serverName) + "." + uniqueid.String()
    31  }
    32  
    33  // AddTraceID will add an unique id to the ctx
    34  func AddTraceID(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    35  	msg := codec.Message(ctx)
    36  	traceid := NewID(msg.ServerName())
    37  	if msg.TraceID() == "" {
    38  		msg.WithTraceID(traceid)
    39  	}
    40  	// TEMP!
    41  	// ctx, cancel := context.WithTimeout(msg.Context(), msg.RequestTimeout())
    42  	// defer cancel()
    43  
    44  	v := msg.Metadata()
    45  	ctx = metadata.NewOutgoingContext(
    46  		ctx, v,
    47  	)
    48  
    49  	v1 := make(chan interface{}, 1)
    50  	v2 := make(chan error, 1)
    51  	v3 := make(chan time.Duration, 1)
    52  
    53  	go func(ctx context.Context) {
    54  		defer func() {
    55  			if err := recover(); err != nil {
    56  				stackSlice := make([]byte, 512)
    57  				s := runtime.Stack(stackSlice, false)
    58  				log.WithFields(
    59  					ctx,
    60  					"stack", string(stackSlice[0:s]),
    61  				).Error(err)
    62  			}
    63  		}()
    64  		now := time.Now()
    65  		resp, err := handler(ctx, req)
    66  		v1 <- resp
    67  		v2 <- err
    68  		v3 <- time.Since(now)
    69  	}(ctx)
    70  
    71  	select {
    72  	case <-ctx.Done():
    73  		msg.WithRequestTimeout(
    74  			time.Duration(0),
    75  		)
    76  		log.WithFields(ctx).Info("timeout")
    77  		return nil, status.Error(codes.DeadlineExceeded, "deadline exceeded")
    78  	case d := <-v3:
    79  		left := time.Duration(
    80  			(msg.RequestTimeout().Milliseconds() - d.Milliseconds()) * int64(time.Millisecond),
    81  		)
    82  		msg.WithRequestTimeout(left)
    83  		log.WithFields(ctx, "used", d.String(), "left", left.String()).Info()
    84  		return <-v1, <-v2
    85  	}
    86  }