github.com/blixtra/rkt@v0.8.1-0.20160204105720-ab0d1add1a43/Godeps/_workspace/src/google.golang.org/grpc/server.go (about)

     1  /*
     2   *
     3   * Copyright 2014, Google Inc.
     4   * All rights reserved.
     5   *
     6   * Redistribution and use in source and binary forms, with or without
     7   * modification, are permitted provided that the following conditions are
     8   * met:
     9   *
    10   *     * Redistributions of source code must retain the above copyright
    11   * notice, this list of conditions and the following disclaimer.
    12   *     * Redistributions in binary form must reproduce the above
    13   * copyright notice, this list of conditions and the following disclaimer
    14   * in the documentation and/or other materials provided with the
    15   * distribution.
    16   *     * Neither the name of Google Inc. nor the names of its
    17   * contributors may be used to endorse or promote products derived from
    18   * this software without specific prior written permission.
    19   *
    20   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    21   * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    22   * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    23   * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
    24   * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    25   * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    26   * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
    27   * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
    28   * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    29   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    30   * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    31   *
    32   */
    33  
    34  package grpc
    35  
    36  import (
    37  	"errors"
    38  	"fmt"
    39  	"io"
    40  	"net"
    41  	"reflect"
    42  	"runtime"
    43  	"strings"
    44  	"sync"
    45  	"time"
    46  
    47  	"golang.org/x/net/context"
    48  	"golang.org/x/net/trace"
    49  	"google.golang.org/grpc/codes"
    50  	"google.golang.org/grpc/credentials"
    51  	"google.golang.org/grpc/grpclog"
    52  	"google.golang.org/grpc/metadata"
    53  	"google.golang.org/grpc/transport"
    54  )
    55  
    56  type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error) (interface{}, error)
    57  
    58  // MethodDesc represents an RPC service's method specification.
    59  type MethodDesc struct {
    60  	MethodName string
    61  	Handler    methodHandler
    62  }
    63  
    64  // ServiceDesc represents an RPC service's specification.
    65  type ServiceDesc struct {
    66  	ServiceName string
    67  	// The pointer to the service interface. Used to check whether the user
    68  	// provided implementation satisfies the interface requirements.
    69  	HandlerType interface{}
    70  	Methods     []MethodDesc
    71  	Streams     []StreamDesc
    72  }
    73  
    74  // service consists of the information of the server serving this service and
    75  // the methods in this service.
    76  type service struct {
    77  	server interface{} // the server for service methods
    78  	md     map[string]*MethodDesc
    79  	sd     map[string]*StreamDesc
    80  }
    81  
    82  // Server is a gRPC server to serve RPC requests.
    83  type Server struct {
    84  	opts   options
    85  	mu     sync.Mutex
    86  	lis    map[net.Listener]bool
    87  	conns  map[transport.ServerTransport]bool
    88  	m      map[string]*service // service name -> service info
    89  	events trace.EventLog
    90  }
    91  
    92  type options struct {
    93  	creds                credentials.Credentials
    94  	codec                Codec
    95  	maxConcurrentStreams uint32
    96  }
    97  
    98  // A ServerOption sets options.
    99  type ServerOption func(*options)
   100  
   101  // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
   102  func CustomCodec(codec Codec) ServerOption {
   103  	return func(o *options) {
   104  		o.codec = codec
   105  	}
   106  }
   107  
   108  // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
   109  // of concurrent streams to each ServerTransport.
   110  func MaxConcurrentStreams(n uint32) ServerOption {
   111  	return func(o *options) {
   112  		o.maxConcurrentStreams = n
   113  	}
   114  }
   115  
   116  // Creds returns a ServerOption that sets credentials for server connections.
   117  func Creds(c credentials.Credentials) ServerOption {
   118  	return func(o *options) {
   119  		o.creds = c
   120  	}
   121  }
   122  
   123  // NewServer creates a gRPC server which has no service registered and has not
   124  // started to accept requests yet.
   125  func NewServer(opt ...ServerOption) *Server {
   126  	var opts options
   127  	for _, o := range opt {
   128  		o(&opts)
   129  	}
   130  	if opts.codec == nil {
   131  		// Set the default codec.
   132  		opts.codec = protoCodec{}
   133  	}
   134  	s := &Server{
   135  		lis:   make(map[net.Listener]bool),
   136  		opts:  opts,
   137  		conns: make(map[transport.ServerTransport]bool),
   138  		m:     make(map[string]*service),
   139  	}
   140  	if EnableTracing {
   141  		_, file, line, _ := runtime.Caller(1)
   142  		s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
   143  	}
   144  	return s
   145  }
   146  
   147  // printf records an event in s's event log, unless s has been stopped.
   148  // REQUIRES s.mu is held.
   149  func (s *Server) printf(format string, a ...interface{}) {
   150  	if s.events != nil {
   151  		s.events.Printf(format, a...)
   152  	}
   153  }
   154  
   155  // errorf records an error in s's event log, unless s has been stopped.
   156  // REQUIRES s.mu is held.
   157  func (s *Server) errorf(format string, a ...interface{}) {
   158  	if s.events != nil {
   159  		s.events.Errorf(format, a...)
   160  	}
   161  }
   162  
   163  // RegisterService register a service and its implementation to the gRPC
   164  // server. Called from the IDL generated code. This must be called before
   165  // invoking Serve.
   166  func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) {
   167  	ht := reflect.TypeOf(sd.HandlerType).Elem()
   168  	st := reflect.TypeOf(ss)
   169  	if !st.Implements(ht) {
   170  		grpclog.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
   171  	}
   172  	s.register(sd, ss)
   173  }
   174  
   175  func (s *Server) register(sd *ServiceDesc, ss interface{}) {
   176  	s.mu.Lock()
   177  	defer s.mu.Unlock()
   178  	s.printf("RegisterService(%q)", sd.ServiceName)
   179  	if _, ok := s.m[sd.ServiceName]; ok {
   180  		grpclog.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)
   181  	}
   182  	srv := &service{
   183  		server: ss,
   184  		md:     make(map[string]*MethodDesc),
   185  		sd:     make(map[string]*StreamDesc),
   186  	}
   187  	for i := range sd.Methods {
   188  		d := &sd.Methods[i]
   189  		srv.md[d.MethodName] = d
   190  	}
   191  	for i := range sd.Streams {
   192  		d := &sd.Streams[i]
   193  		srv.sd[d.StreamName] = d
   194  	}
   195  	s.m[sd.ServiceName] = srv
   196  }
   197  
   198  var (
   199  	// ErrServerStopped indicates that the operation is now illegal because of
   200  	// the server being stopped.
   201  	ErrServerStopped = errors.New("grpc: the server has been stopped")
   202  )
   203  
   204  // Serve accepts incoming connections on the listener lis, creating a new
   205  // ServerTransport and service goroutine for each. The service goroutines
   206  // read gRPC request and then call the registered handlers to reply to them.
   207  // Service returns when lis.Accept fails.
   208  func (s *Server) Serve(lis net.Listener) error {
   209  	s.mu.Lock()
   210  	s.printf("serving")
   211  	if s.lis == nil {
   212  		s.mu.Unlock()
   213  		return ErrServerStopped
   214  	}
   215  	s.lis[lis] = true
   216  	s.mu.Unlock()
   217  	defer func() {
   218  		lis.Close()
   219  		s.mu.Lock()
   220  		delete(s.lis, lis)
   221  		s.mu.Unlock()
   222  	}()
   223  	for {
   224  		c, err := lis.Accept()
   225  		if err != nil {
   226  			s.mu.Lock()
   227  			s.printf("done serving; Accept = %v", err)
   228  			s.mu.Unlock()
   229  			return err
   230  		}
   231  		var authInfo credentials.AuthInfo
   232  		if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok {
   233  			var conn net.Conn
   234  			conn, authInfo, err = creds.ServerHandshake(c)
   235  			if err != nil {
   236  				s.mu.Lock()
   237  				s.errorf("ServerHandshake(%q) failed: %v", c.RemoteAddr(), err)
   238  				s.mu.Unlock()
   239  				grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
   240  				continue
   241  			}
   242  			c = conn
   243  		}
   244  		s.mu.Lock()
   245  		if s.conns == nil {
   246  			s.mu.Unlock()
   247  			c.Close()
   248  			return nil
   249  		}
   250  		st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo)
   251  		if err != nil {
   252  			s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
   253  			s.mu.Unlock()
   254  			c.Close()
   255  			grpclog.Println("grpc: Server.Serve failed to create ServerTransport: ", err)
   256  			continue
   257  		}
   258  		s.conns[st] = true
   259  		s.mu.Unlock()
   260  
   261  		go func() {
   262  			var wg sync.WaitGroup
   263  			st.HandleStreams(func(stream *transport.Stream) {
   264  				var trInfo *traceInfo
   265  				if EnableTracing {
   266  					trInfo = &traceInfo{
   267  						tr: trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()),
   268  					}
   269  					trInfo.firstLine.client = false
   270  					trInfo.firstLine.remoteAddr = st.RemoteAddr()
   271  					stream.TraceContext(trInfo.tr)
   272  					if dl, ok := stream.Context().Deadline(); ok {
   273  						trInfo.firstLine.deadline = dl.Sub(time.Now())
   274  					}
   275  				}
   276  				wg.Add(1)
   277  				go func() {
   278  					s.handleStream(st, stream, trInfo)
   279  					wg.Done()
   280  				}()
   281  			})
   282  			wg.Wait()
   283  			s.mu.Lock()
   284  			delete(s.conns, st)
   285  			s.mu.Unlock()
   286  		}()
   287  	}
   288  }
   289  
   290  func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, pf payloadFormat, opts *transport.Options) error {
   291  	p, err := encode(s.opts.codec, msg, pf)
   292  	if err != nil {
   293  		// This typically indicates a fatal issue (e.g., memory
   294  		// corruption or hardware faults) the application program
   295  		// cannot handle.
   296  		//
   297  		// TODO(zhaoq): There exist other options also such as only closing the
   298  		// faulty stream locally and remotely (Other streams can keep going). Find
   299  		// the optimal option.
   300  		grpclog.Fatalf("grpc: Server failed to encode response %v", err)
   301  	}
   302  	return t.Write(stream, p, opts)
   303  }
   304  
   305  func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
   306  	if trInfo != nil {
   307  		defer trInfo.tr.Finish()
   308  		trInfo.firstLine.client = false
   309  		trInfo.tr.LazyLog(&trInfo.firstLine, false)
   310  		defer func() {
   311  			if err != nil && err != io.EOF {
   312  				trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
   313  				trInfo.tr.SetError()
   314  			}
   315  		}()
   316  	}
   317  	p := &parser{s: stream}
   318  	for {
   319  		pf, req, err := p.recvMsg()
   320  		if err == io.EOF {
   321  			// The entire stream is done (for unary RPC only).
   322  			return err
   323  		}
   324  		if err != nil {
   325  			switch err := err.(type) {
   326  			case transport.ConnectionError:
   327  				// Nothing to do here.
   328  			case transport.StreamError:
   329  				if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil {
   330  					grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
   331  				}
   332  			default:
   333  				panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err))
   334  			}
   335  			return err
   336  		}
   337  		switch pf {
   338  		case compressionNone:
   339  			statusCode := codes.OK
   340  			statusDesc := ""
   341  			df := func(v interface{}) error {
   342  				if err := s.opts.codec.Unmarshal(req, v); err != nil {
   343  					return err
   344  				}
   345  				if trInfo != nil {
   346  					trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
   347  				}
   348  				return nil
   349  			}
   350  			reply, appErr := md.Handler(srv.server, stream.Context(), df)
   351  			if appErr != nil {
   352  				if err, ok := appErr.(rpcError); ok {
   353  					statusCode = err.code
   354  					statusDesc = err.desc
   355  				} else {
   356  					statusCode = convertCode(appErr)
   357  					statusDesc = appErr.Error()
   358  				}
   359  				if trInfo != nil && statusCode != codes.OK {
   360  					trInfo.tr.LazyLog(stringer(statusDesc), true)
   361  					trInfo.tr.SetError()
   362  				}
   363  
   364  				if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil {
   365  					grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
   366  					return err
   367  				}
   368  				return nil
   369  			}
   370  			if trInfo != nil {
   371  				trInfo.tr.LazyLog(stringer("OK"), false)
   372  			}
   373  			opts := &transport.Options{
   374  				Last:  true,
   375  				Delay: false,
   376  			}
   377  			if err := s.sendResponse(t, stream, reply, compressionNone, opts); err != nil {
   378  				switch err := err.(type) {
   379  				case transport.ConnectionError:
   380  					// Nothing to do here.
   381  				case transport.StreamError:
   382  					statusCode = err.Code
   383  					statusDesc = err.Desc
   384  				default:
   385  					statusCode = codes.Unknown
   386  					statusDesc = err.Error()
   387  				}
   388  				return err
   389  			}
   390  			if trInfo != nil {
   391  				trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
   392  			}
   393  			return t.WriteStatus(stream, statusCode, statusDesc)
   394  		default:
   395  			panic(fmt.Sprintf("payload format to be supported: %d", pf))
   396  		}
   397  	}
   398  }
   399  
   400  func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
   401  	ss := &serverStream{
   402  		t:      t,
   403  		s:      stream,
   404  		p:      &parser{s: stream},
   405  		codec:  s.opts.codec,
   406  		trInfo: trInfo,
   407  	}
   408  	if trInfo != nil {
   409  		trInfo.tr.LazyLog(&trInfo.firstLine, false)
   410  		defer func() {
   411  			ss.mu.Lock()
   412  			if err != nil && err != io.EOF {
   413  				ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
   414  				ss.trInfo.tr.SetError()
   415  			}
   416  			ss.trInfo.tr.Finish()
   417  			ss.trInfo.tr = nil
   418  			ss.mu.Unlock()
   419  		}()
   420  	}
   421  	if appErr := sd.Handler(srv.server, ss); appErr != nil {
   422  		if err, ok := appErr.(rpcError); ok {
   423  			ss.statusCode = err.code
   424  			ss.statusDesc = err.desc
   425  		} else {
   426  			ss.statusCode = convertCode(appErr)
   427  			ss.statusDesc = appErr.Error()
   428  		}
   429  	}
   430  	if trInfo != nil {
   431  		ss.mu.Lock()
   432  		if ss.statusCode != codes.OK {
   433  			ss.trInfo.tr.LazyLog(stringer(ss.statusDesc), true)
   434  			ss.trInfo.tr.SetError()
   435  		} else {
   436  			ss.trInfo.tr.LazyLog(stringer("OK"), false)
   437  		}
   438  		ss.mu.Unlock()
   439  	}
   440  	return t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc)
   441  
   442  }
   443  
   444  func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
   445  	sm := stream.Method()
   446  	if sm != "" && sm[0] == '/' {
   447  		sm = sm[1:]
   448  	}
   449  	pos := strings.LastIndex(sm, "/")
   450  	if pos == -1 {
   451  		if trInfo != nil {
   452  			trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []interface{}{sm}}, true)
   453  			trInfo.tr.SetError()
   454  		}
   455  		if err := t.WriteStatus(stream, codes.InvalidArgument, fmt.Sprintf("malformed method name: %q", stream.Method())); err != nil {
   456  			if trInfo != nil {
   457  				trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
   458  				trInfo.tr.SetError()
   459  			}
   460  			grpclog.Printf("grpc: Server.handleStream failed to write status: %v", err)
   461  		}
   462  		if trInfo != nil {
   463  			trInfo.tr.Finish()
   464  		}
   465  		return
   466  	}
   467  	service := sm[:pos]
   468  	method := sm[pos+1:]
   469  	srv, ok := s.m[service]
   470  	if !ok {
   471  		if trInfo != nil {
   472  			trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true)
   473  			trInfo.tr.SetError()
   474  		}
   475  		if err := t.WriteStatus(stream, codes.Unimplemented, fmt.Sprintf("unknown service %v", service)); err != nil {
   476  			if trInfo != nil {
   477  				trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
   478  				trInfo.tr.SetError()
   479  			}
   480  			grpclog.Printf("grpc: Server.handleStream failed to write status: %v", err)
   481  		}
   482  		if trInfo != nil {
   483  			trInfo.tr.Finish()
   484  		}
   485  		return
   486  	}
   487  	// Unary RPC or Streaming RPC?
   488  	if md, ok := srv.md[method]; ok {
   489  		s.processUnaryRPC(t, stream, srv, md, trInfo)
   490  		return
   491  	}
   492  	if sd, ok := srv.sd[method]; ok {
   493  		s.processStreamingRPC(t, stream, srv, sd, trInfo)
   494  		return
   495  	}
   496  	if trInfo != nil {
   497  		trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true)
   498  		trInfo.tr.SetError()
   499  	}
   500  	if err := t.WriteStatus(stream, codes.Unimplemented, fmt.Sprintf("unknown method %v", method)); err != nil {
   501  		if trInfo != nil {
   502  			trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
   503  			trInfo.tr.SetError()
   504  		}
   505  		grpclog.Printf("grpc: Server.handleStream failed to write status: %v", err)
   506  	}
   507  	if trInfo != nil {
   508  		trInfo.tr.Finish()
   509  	}
   510  }
   511  
   512  // Stop stops the gRPC server. Once Stop returns, the server stops accepting
   513  // connection requests and closes all the connected connections.
   514  func (s *Server) Stop() {
   515  	s.mu.Lock()
   516  	listeners := s.lis
   517  	s.lis = nil
   518  	cs := s.conns
   519  	s.conns = nil
   520  	s.mu.Unlock()
   521  	for lis := range listeners {
   522  		lis.Close()
   523  	}
   524  	for c := range cs {
   525  		c.Close()
   526  	}
   527  	s.mu.Lock()
   528  	if s.events != nil {
   529  		s.events.Finish()
   530  		s.events = nil
   531  	}
   532  	s.mu.Unlock()
   533  }
   534  
   535  // TestingCloseConns closes all exiting transports but keeps s.lis accepting new
   536  // connections. This is for test only now.
   537  func (s *Server) TestingCloseConns() {
   538  	s.mu.Lock()
   539  	for c := range s.conns {
   540  		c.Close()
   541  	}
   542  	s.conns = make(map[transport.ServerTransport]bool)
   543  	s.mu.Unlock()
   544  }
   545  
   546  // SendHeader sends header metadata. It may be called at most once from a unary
   547  // RPC handler. The ctx is the RPC handler's Context or one derived from it.
   548  func SendHeader(ctx context.Context, md metadata.MD) error {
   549  	if md.Len() == 0 {
   550  		return nil
   551  	}
   552  	stream, ok := transport.StreamFromContext(ctx)
   553  	if !ok {
   554  		return fmt.Errorf("grpc: failed to fetch the stream from the context %v", ctx)
   555  	}
   556  	t := stream.ServerTransport()
   557  	if t == nil {
   558  		grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream)
   559  	}
   560  	return t.WriteHeader(stream, md)
   561  }
   562  
   563  // SetTrailer sets the trailer metadata that will be sent when an RPC returns.
   564  // It may be called at most once from a unary RPC handler. The ctx is the RPC
   565  // handler's Context or one derived from it.
   566  func SetTrailer(ctx context.Context, md metadata.MD) error {
   567  	if md.Len() == 0 {
   568  		return nil
   569  	}
   570  	stream, ok := transport.StreamFromContext(ctx)
   571  	if !ok {
   572  		return fmt.Errorf("grpc: failed to fetch the stream from the context %v", ctx)
   573  	}
   574  	return stream.SetTrailer(md)
   575  }