github.com/profzone/eden-framework@v1.0.10/pkg/courier/transport_grpc/encode_decode.go (about)

     1  package transport_grpc
     2  
     3  import (
     4  	"context"
     5  	"reflect"
     6  
     7  	"github.com/google/uuid"
     8  	"github.com/sirupsen/logrus"
     9  	"github.com/vmihailenco/msgpack"
    10  	"google.golang.org/grpc"
    11  	"google.golang.org/grpc/metadata"
    12  	"google.golang.org/grpc/status"
    13  
    14  	logContext "github.com/profzone/eden-framework/pkg/context"
    15  	"github.com/profzone/eden-framework/pkg/courier"
    16  	"github.com/profzone/eden-framework/pkg/courier/httpx"
    17  	"github.com/profzone/eden-framework/pkg/courier/status_error"
    18  	"github.com/profzone/eden-framework/pkg/duration"
    19  )
    20  
    21  type DecodeStreamFunc func(c context.Context, data []byte) (request interface{}, err error)
    22  
    23  var (
    24  	ContextKeyServerName = uuid.New().String()
    25  )
    26  
    27  func ContextWithServiceName(ctx context.Context, serverName string) context.Context {
    28  	return context.WithValue(ctx, ContextKeyServerName, serverName)
    29  }
    30  
    31  func CreateStreamHandler(s *ServeGRPC, ops ...courier.IOperator) grpc.StreamHandler {
    32  	opMetas := courier.ToOperatorMetaList(ops...)
    33  
    34  	return func(_ interface{}, stream grpc.ServerStream) (err error) {
    35  		ctx := stream.Context()
    36  		ctx = ContextWithServiceName(ctx, s.Name)
    37  
    38  		reqID := getRequestID(ctx)
    39  
    40  		if reqID == "" {
    41  			reqID = uuid.New().String()
    42  		}
    43  
    44  		logContext.SetLogID(reqID)
    45  		defer logContext.Close()
    46  
    47  		d := duration.NewDuration()
    48  
    49  		defer func() {
    50  			fields := logrus.Fields{
    51  				"tag":       "access",
    52  				"log_id":    reqID,
    53  				"remote_ip": ClientIP(ctx),
    54  				"method":    "/" + opMetas[len(opMetas)-1].Type.Name(),
    55  			}
    56  
    57  			fields["request_time"] = d.Get()
    58  
    59  			logger := logrus.WithFields(fields)
    60  
    61  			if err != nil {
    62  				statusErr := status_error.FromError(err)
    63  				if statusErr.Status() >= 500 {
    64  					logger.Errorf(err.Error())
    65  				} else {
    66  					logger.Warnf(err.Error())
    67  				}
    68  			} else {
    69  				logger.Infof("")
    70  			}
    71  		}()
    72  
    73  		opDecode := createGRPCStreamDecoder(receiveMsgData(stream))
    74  
    75  		for _, opMeta := range opMetas {
    76  			op, decodeErr := courier.NewOperatorBy(opMeta.Type, opMeta.Operator, opDecode)
    77  			if decodeErr != nil {
    78  				err = passErr(ctx, decodeErr)
    79  				return
    80  			}
    81  
    82  			response, endpointErr := op.Output(ctx)
    83  			if endpointErr != nil {
    84  				err = passErr(ctx, endpointErr)
    85  				return
    86  			}
    87  
    88  			if !opMeta.IsLast {
    89  				// set result in context with key of operator name
    90  				ctx = context.WithValue(ctx, opMeta.ContextKey, response)
    91  				continue
    92  			}
    93  
    94  			encodeErr := sendMsg(ctx, stream, response)
    95  			if encodeErr != nil {
    96  				err = passErr(ctx, encodeErr)
    97  				return
    98  			}
    99  		}
   100  		return
   101  	}
   102  }
   103  
   104  func createGRPCStreamDecoder(data []byte) courier.OperatorDecoder {
   105  	return func(op courier.IOperator, rv reflect.Value) (err error) {
   106  		err = msgpack.Unmarshal(data, op)
   107  		if err != nil {
   108  			err = status_error.InvalidStruct.StatusError().WithDesc(err.Error())
   109  			return
   110  		}
   111  		return
   112  	}
   113  }
   114  
   115  func getRequestID(ctx context.Context) string {
   116  	md, ok := metadata.FromIncomingContext(ctx)
   117  	if ok {
   118  		if values, ok := md[httpx.HeaderRequestID]; ok {
   119  			if len(values) > 0 {
   120  				return values[0]
   121  			}
   122  		}
   123  	}
   124  	return ""
   125  }
   126  
   127  func GetFieldDisplayName(field reflect.StructField) string {
   128  	pathName := field.Name
   129  	jsonName, _ := field.Tag.Lookup("json")
   130  	if jsonName != "" {
   131  		pathName = jsonName
   132  	}
   133  	return pathName
   134  }
   135  
   136  func sendMsg(_ context.Context, stream grpc.ServerStream, response interface{}) (err error) {
   137  	md := metadata.Pairs(httpx.HeaderRequestID, logContext.GetLogID())
   138  	if canMeta, ok := response.(courier.IMeta); ok {
   139  		md = metadata.Join(md, metadata.MD(canMeta.Meta()))
   140  	}
   141  	if err = stream.SetHeader(md); err != nil {
   142  		return
   143  	}
   144  	return stream.SendMsg(response)
   145  }
   146  
   147  func passErr(ctx context.Context, err error) error {
   148  	if err == nil {
   149  		return err
   150  	}
   151  	if _, ok := status.FromError(err); !ok {
   152  		finalStatusErr := status_error.FromError(err)
   153  		err = status.Error(CodeFromHTTPStatus(finalStatusErr.Status()), finalStatusErr.WithSource(ctx.Value(ContextKeyServerName).(string)).String())
   154  	}
   155  
   156  	return err
   157  }
   158  
   159  func receiveMsgData(stream grpc.ServerStream) (data []byte) {
   160  	stream.RecvMsg(&data)
   161  	return
   162  }
   163  
   164  func MarshalOperator(stream grpc.ServerStream, operator courier.IOperator) error {
   165  	opDecode := createGRPCStreamDecoder(receiveMsgData(stream))
   166  	rv := reflect.Indirect(reflect.ValueOf(operator))
   167  	op, err := courier.NewOperatorBy(rv.Type(), operator, opDecode)
   168  	if err != nil {
   169  		return err
   170  	}
   171  	rv.Set(reflect.ValueOf(op).Elem())
   172  	return nil
   173  }