github.com/hduhelp/go-zero@v1.4.3/gateway/server.go (about)

     1  package gateway
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/fullstorydev/grpcurl"
    11  	"github.com/golang/protobuf/jsonpb"
    12  	"github.com/hduhelp/go-zero/core/logx"
    13  	"github.com/hduhelp/go-zero/core/mr"
    14  	"github.com/hduhelp/go-zero/gateway/internal"
    15  	"github.com/hduhelp/go-zero/rest"
    16  	"github.com/hduhelp/go-zero/rest/httpx"
    17  	"github.com/hduhelp/go-zero/zrpc"
    18  	"github.com/jhump/protoreflect/grpcreflect"
    19  	"google.golang.org/grpc/codes"
    20  	"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
    21  )
    22  
    23  type (
    24  	// Server is a gateway server.
    25  	Server struct {
    26  		*rest.Server
    27  		upstreams     []Upstream
    28  		timeout       time.Duration
    29  		processHeader func(http.Header) []string
    30  	}
    31  
    32  	// Option defines the method to customize Server.
    33  	Option func(svr *Server)
    34  )
    35  
    36  // MustNewServer creates a new gateway server.
    37  func MustNewServer(c GatewayConf, opts ...Option) *Server {
    38  	svr := &Server{
    39  		Server:    rest.MustNewServer(c.RestConf),
    40  		upstreams: c.Upstreams,
    41  		timeout:   c.Timeout,
    42  	}
    43  	for _, opt := range opts {
    44  		opt(svr)
    45  	}
    46  
    47  	return svr
    48  }
    49  
    50  // Start starts the gateway server.
    51  func (s *Server) Start() {
    52  	logx.Must(s.build())
    53  	s.Server.Start()
    54  }
    55  
    56  // Stop stops the gateway server.
    57  func (s *Server) Stop() {
    58  	s.Server.Stop()
    59  }
    60  
    61  func (s *Server) build() error {
    62  	if err := s.ensureUpstreamNames(); err != nil {
    63  		return err
    64  	}
    65  
    66  	return mr.MapReduceVoid(func(source chan<- interface{}) {
    67  		for _, up := range s.upstreams {
    68  			source <- up
    69  		}
    70  	}, func(item interface{}, writer mr.Writer, cancel func(error)) {
    71  		up := item.(Upstream)
    72  		cli := zrpc.MustNewClient(up.Grpc)
    73  		source, err := s.createDescriptorSource(cli, up)
    74  		if err != nil {
    75  			cancel(fmt.Errorf("%s: %w", up.Name, err))
    76  			return
    77  		}
    78  
    79  		methods, err := internal.GetMethods(source)
    80  		if err != nil {
    81  			cancel(fmt.Errorf("%s: %w", up.Name, err))
    82  			return
    83  		}
    84  
    85  		resolver := grpcurl.AnyResolverFromDescriptorSource(source)
    86  		for _, m := range methods {
    87  			if len(m.HttpMethod) > 0 && len(m.HttpPath) > 0 {
    88  				writer.Write(rest.Route{
    89  					Method:  m.HttpMethod,
    90  					Path:    m.HttpPath,
    91  					Handler: s.buildHandler(source, resolver, cli, m.RpcPath),
    92  				})
    93  			}
    94  		}
    95  
    96  		methodSet := make(map[string]struct{})
    97  		for _, m := range methods {
    98  			methodSet[m.RpcPath] = struct{}{}
    99  		}
   100  		for _, m := range up.Mappings {
   101  			if _, ok := methodSet[m.RpcPath]; !ok {
   102  				cancel(fmt.Errorf("%s: rpc method %s not found", up.Name, m.RpcPath))
   103  				return
   104  			}
   105  
   106  			writer.Write(rest.Route{
   107  				Method:  strings.ToUpper(m.Method),
   108  				Path:    m.Path,
   109  				Handler: s.buildHandler(source, resolver, cli, m.RpcPath),
   110  			})
   111  		}
   112  	}, func(pipe <-chan interface{}, cancel func(error)) {
   113  		for item := range pipe {
   114  			route := item.(rest.Route)
   115  			s.Server.AddRoute(route)
   116  		}
   117  	})
   118  }
   119  
   120  func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver,
   121  	cli zrpc.Client, rpcPath string) func(http.ResponseWriter, *http.Request) {
   122  	return func(w http.ResponseWriter, r *http.Request) {
   123  		parser, err := internal.NewRequestParser(r, resolver)
   124  		if err != nil {
   125  			httpx.Error(w, err)
   126  			return
   127  		}
   128  
   129  		timeout := internal.GetTimeout(r.Header, s.timeout)
   130  		ctx, can := context.WithTimeout(r.Context(), timeout)
   131  		defer can()
   132  
   133  		w.Header().Set(httpx.ContentType, httpx.JsonContentType)
   134  		handler := internal.NewEventHandler(w, resolver)
   135  		if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), rpcPath, s.prepareMetadata(r.Header),
   136  			handler, parser.Next); err != nil {
   137  			httpx.Error(w, err)
   138  		}
   139  
   140  		st := handler.Status
   141  		if st.Code() != codes.OK {
   142  			httpx.Error(w, st.Err())
   143  		}
   144  	}
   145  }
   146  
   147  func (s *Server) createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.DescriptorSource, error) {
   148  	var source grpcurl.DescriptorSource
   149  	var err error
   150  
   151  	if len(up.ProtoSets) > 0 {
   152  		source, err = grpcurl.DescriptorSourceFromProtoSets(up.ProtoSets...)
   153  		if err != nil {
   154  			return nil, err
   155  		}
   156  	} else {
   157  		refCli := grpc_reflection_v1alpha.NewServerReflectionClient(cli.Conn())
   158  		client := grpcreflect.NewClient(context.Background(), refCli)
   159  		source = grpcurl.DescriptorSourceFromServer(context.Background(), client)
   160  	}
   161  
   162  	return source, nil
   163  }
   164  
   165  func (s *Server) ensureUpstreamNames() error {
   166  	for _, up := range s.upstreams {
   167  		target, err := up.Grpc.BuildTarget()
   168  		if err != nil {
   169  			return err
   170  		}
   171  
   172  		up.Name = target
   173  	}
   174  
   175  	return nil
   176  }
   177  
   178  func (s *Server) prepareMetadata(header http.Header) []string {
   179  	vals := internal.ProcessHeaders(header)
   180  	if s.processHeader != nil {
   181  		vals = append(vals, s.processHeader(header)...)
   182  	}
   183  
   184  	return vals
   185  }
   186  
   187  // WithHeaderProcessor sets a processor to process request headers.
   188  // The returned headers are used as metadata to invoke the RPC.
   189  func WithHeaderProcessor(processHeader func(http.Header) []string) func(*Server) {
   190  	return func(s *Server) {
   191  		s.processHeader = processHeader
   192  	}
   193  }