github.com/hduhelp/go-zero@v1.4.3/gateway/server.go (about) 1 package gateway 2 3 import ( 4 "context" 5 "fmt" 6 "net/http" 7 "strings" 8 "time" 9 10 "github.com/fullstorydev/grpcurl" 11 "github.com/golang/protobuf/jsonpb" 12 "github.com/hduhelp/go-zero/core/logx" 13 "github.com/hduhelp/go-zero/core/mr" 14 "github.com/hduhelp/go-zero/gateway/internal" 15 "github.com/hduhelp/go-zero/rest" 16 "github.com/hduhelp/go-zero/rest/httpx" 17 "github.com/hduhelp/go-zero/zrpc" 18 "github.com/jhump/protoreflect/grpcreflect" 19 "google.golang.org/grpc/codes" 20 "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" 21 ) 22 23 type ( 24 // Server is a gateway server. 25 Server struct { 26 *rest.Server 27 upstreams []Upstream 28 timeout time.Duration 29 processHeader func(http.Header) []string 30 } 31 32 // Option defines the method to customize Server. 33 Option func(svr *Server) 34 ) 35 36 // MustNewServer creates a new gateway server. 37 func MustNewServer(c GatewayConf, opts ...Option) *Server { 38 svr := &Server{ 39 Server: rest.MustNewServer(c.RestConf), 40 upstreams: c.Upstreams, 41 timeout: c.Timeout, 42 } 43 for _, opt := range opts { 44 opt(svr) 45 } 46 47 return svr 48 } 49 50 // Start starts the gateway server. 51 func (s *Server) Start() { 52 logx.Must(s.build()) 53 s.Server.Start() 54 } 55 56 // Stop stops the gateway server. 57 func (s *Server) Stop() { 58 s.Server.Stop() 59 } 60 61 func (s *Server) build() error { 62 if err := s.ensureUpstreamNames(); err != nil { 63 return err 64 } 65 66 return mr.MapReduceVoid(func(source chan<- interface{}) { 67 for _, up := range s.upstreams { 68 source <- up 69 } 70 }, func(item interface{}, writer mr.Writer, cancel func(error)) { 71 up := item.(Upstream) 72 cli := zrpc.MustNewClient(up.Grpc) 73 source, err := s.createDescriptorSource(cli, up) 74 if err != nil { 75 cancel(fmt.Errorf("%s: %w", up.Name, err)) 76 return 77 } 78 79 methods, err := internal.GetMethods(source) 80 if err != nil { 81 cancel(fmt.Errorf("%s: %w", up.Name, err)) 82 return 83 } 84 85 resolver := grpcurl.AnyResolverFromDescriptorSource(source) 86 for _, m := range methods { 87 if len(m.HttpMethod) > 0 && len(m.HttpPath) > 0 { 88 writer.Write(rest.Route{ 89 Method: m.HttpMethod, 90 Path: m.HttpPath, 91 Handler: s.buildHandler(source, resolver, cli, m.RpcPath), 92 }) 93 } 94 } 95 96 methodSet := make(map[string]struct{}) 97 for _, m := range methods { 98 methodSet[m.RpcPath] = struct{}{} 99 } 100 for _, m := range up.Mappings { 101 if _, ok := methodSet[m.RpcPath]; !ok { 102 cancel(fmt.Errorf("%s: rpc method %s not found", up.Name, m.RpcPath)) 103 return 104 } 105 106 writer.Write(rest.Route{ 107 Method: strings.ToUpper(m.Method), 108 Path: m.Path, 109 Handler: s.buildHandler(source, resolver, cli, m.RpcPath), 110 }) 111 } 112 }, func(pipe <-chan interface{}, cancel func(error)) { 113 for item := range pipe { 114 route := item.(rest.Route) 115 s.Server.AddRoute(route) 116 } 117 }) 118 } 119 120 func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver, 121 cli zrpc.Client, rpcPath string) func(http.ResponseWriter, *http.Request) { 122 return func(w http.ResponseWriter, r *http.Request) { 123 parser, err := internal.NewRequestParser(r, resolver) 124 if err != nil { 125 httpx.Error(w, err) 126 return 127 } 128 129 timeout := internal.GetTimeout(r.Header, s.timeout) 130 ctx, can := context.WithTimeout(r.Context(), timeout) 131 defer can() 132 133 w.Header().Set(httpx.ContentType, httpx.JsonContentType) 134 handler := internal.NewEventHandler(w, resolver) 135 if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), rpcPath, s.prepareMetadata(r.Header), 136 handler, parser.Next); err != nil { 137 httpx.Error(w, err) 138 } 139 140 st := handler.Status 141 if st.Code() != codes.OK { 142 httpx.Error(w, st.Err()) 143 } 144 } 145 } 146 147 func (s *Server) createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.DescriptorSource, error) { 148 var source grpcurl.DescriptorSource 149 var err error 150 151 if len(up.ProtoSets) > 0 { 152 source, err = grpcurl.DescriptorSourceFromProtoSets(up.ProtoSets...) 153 if err != nil { 154 return nil, err 155 } 156 } else { 157 refCli := grpc_reflection_v1alpha.NewServerReflectionClient(cli.Conn()) 158 client := grpcreflect.NewClient(context.Background(), refCli) 159 source = grpcurl.DescriptorSourceFromServer(context.Background(), client) 160 } 161 162 return source, nil 163 } 164 165 func (s *Server) ensureUpstreamNames() error { 166 for _, up := range s.upstreams { 167 target, err := up.Grpc.BuildTarget() 168 if err != nil { 169 return err 170 } 171 172 up.Name = target 173 } 174 175 return nil 176 } 177 178 func (s *Server) prepareMetadata(header http.Header) []string { 179 vals := internal.ProcessHeaders(header) 180 if s.processHeader != nil { 181 vals = append(vals, s.processHeader(header)...) 182 } 183 184 return vals 185 } 186 187 // WithHeaderProcessor sets a processor to process request headers. 188 // The returned headers are used as metadata to invoke the RPC. 189 func WithHeaderProcessor(processHeader func(http.Header) []string) func(*Server) { 190 return func(s *Server) { 191 s.processHeader = processHeader 192 } 193 }