dubbo.apache.org/dubbo-go/v3@v3.1.1/protocol/jsonrpc/server.go (about)

     1  /*
     2   * Licensed to the Apache Software Foundation (ASF) under one or more
     3   * contributor license agreements.  See the NOTICE file distributed with
     4   * this work for additional information regarding copyright ownership.
     5   * The ASF licenses this file to You under the Apache License, Version 2.0
     6   * (the "License"); you may not use this file except in compliance with
     7   * the License.  You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   */
    17  
    18  package jsonrpc
    19  
    20  import (
    21  	"bufio"
    22  	"bytes"
    23  	"context"
    24  	"io"
    25  	"net"
    26  	"net/http"
    27  	"runtime"
    28  	"runtime/debug"
    29  	"sync"
    30  	"time"
    31  
    32  	"github.com/dubbogo/gost/log/logger"
    33  	"github.com/opentracing/opentracing-go"
    34  
    35  	"dubbo.apache.org/dubbo-go/v3/common"
    36  	"dubbo.apache.org/dubbo-go/v3/common/constant"
    37  	"dubbo.apache.org/dubbo-go/v3/protocol/invocation"
    38  	perrors "github.com/pkg/errors"
    39  )
    40  
    41  // A value sent as a placeholder for the server's response value when the server
    42  // receives an invalid request. It is never decoded by the client since the Response
    43  // contains an error when it is used.
    44  var invalidRequest = struct{}{}
    45  
    46  const (
    47  	// DefaultMaxSleepTime max sleep interval in accept
    48  	DefaultMaxSleepTime = 1 * time.Second
    49  	// DefaultHTTPRspBufferSize ...
    50  	DefaultHTTPRspBufferSize = 1024
    51  	// PathPrefix ...
    52  	PathPrefix = byte('/')
    53  	// Max HTTP header size in Mib
    54  	MaxHeaderSize = 8 * 1024 * 1024
    55  )
    56  
    57  // Server is JSON RPC server wrapper
    58  type Server struct {
    59  	done chan struct{}
    60  	once sync.Once
    61  
    62  	sync.RWMutex
    63  	wg      sync.WaitGroup
    64  	timeout time.Duration
    65  }
    66  
    67  // NewServer creates new JSON RPC server.
    68  func NewServer() *Server {
    69  	return &Server{
    70  		done: make(chan struct{}),
    71  	}
    72  }
    73  
    74  func (s *Server) handlePkg(conn net.Conn) {
    75  	defer func() {
    76  		if r := recover(); r != nil {
    77  			logger.Warnf("connection{local:%v, remote:%v} panic error:%#v, debug stack:%s",
    78  				conn.LocalAddr(), conn.RemoteAddr(), r, string(debug.Stack()))
    79  		}
    80  
    81  		conn.Close()
    82  	}()
    83  
    84  	setTimeout := func(conn net.Conn, timeout time.Duration) {
    85  		t := time.Time{}
    86  		if timeout > time.Duration(0) {
    87  			t = time.Now().Add(timeout)
    88  		}
    89  
    90  		if err := conn.SetDeadline(t); err != nil {
    91  			logger.Error("connection.SetDeadline(t:%v) = error:%v", t, err)
    92  		}
    93  	}
    94  
    95  	sendErrorResp := func(header http.Header, body []byte) error {
    96  		rsp := &http.Response{
    97  			Header:        header,
    98  			StatusCode:    500,
    99  			ProtoMajor:    1,
   100  			ProtoMinor:    1,
   101  			ContentLength: int64(len(body)),
   102  			Body:          io.NopCloser(bytes.NewReader(body)),
   103  		}
   104  		rsp.Header.Del("Content-Type")
   105  		rsp.Header.Del("Content-Length")
   106  		rsp.Header.Del("Timeout")
   107  
   108  		rspBuf := bytes.NewBuffer(make([]byte, DefaultHTTPRspBufferSize))
   109  		rspBuf.Reset()
   110  		err := rsp.Write(rspBuf)
   111  		if err != nil {
   112  			return perrors.WithStack(err)
   113  		}
   114  		_, err = rspBuf.WriteTo(conn)
   115  		return perrors.WithStack(err)
   116  	}
   117  
   118  	for {
   119  		bufReader := bufio.NewReader(io.LimitReader(conn, MaxHeaderSize))
   120  		r, err := http.ReadRequest(bufReader)
   121  		if err != nil {
   122  			logger.Warnf("[ReadRequest] error: %v", err)
   123  			return
   124  		}
   125  
   126  		reqBody, err := io.ReadAll(r.Body)
   127  		r.Body.Close()
   128  		if err != nil {
   129  			return
   130  		}
   131  
   132  		reqHeader := make(map[string]string)
   133  		for k := range r.Header {
   134  			reqHeader[k] = r.Header.Get(k)
   135  		}
   136  		reqHeader["Path"] = r.URL.Path[1:] // to get service name
   137  		if r.URL.Path[0] != PathPrefix {
   138  			reqHeader["Path"] = r.URL.Path
   139  		}
   140  		reqHeader["HttpMethod"] = r.Method
   141  
   142  		httpTimeout := s.timeout
   143  		contentType := reqHeader["Content-Type"]
   144  		if contentType != "application/json" && contentType != "application/json-rpc" {
   145  			setTimeout(conn, httpTimeout)
   146  			r.Header.Set("Content-Type", "text/plain")
   147  			if errRsp := sendErrorResp(r.Header, []byte(perrors.WithStack(err).Error())); errRsp != nil {
   148  				logger.Warnf("sendErrorResp(header:%#v, error:%v) = error:%s",
   149  					r.Header, perrors.WithStack(err), errRsp)
   150  			}
   151  			return
   152  		}
   153  
   154  		ctx := context.Background()
   155  
   156  		spanCtx, err := opentracing.GlobalTracer().Extract(opentracing.HTTPHeaders,
   157  			opentracing.HTTPHeadersCarrier(r.Header))
   158  		if err == nil {
   159  			ctx = context.WithValue(ctx, constant.TracingRemoteSpanCtx, spanCtx)
   160  		}
   161  
   162  		if len(reqHeader["Timeout"]) > 0 {
   163  			timeout, err := time.ParseDuration(reqHeader["Timeout"])
   164  			if err == nil {
   165  				httpTimeout = timeout
   166  				var cancel context.CancelFunc
   167  				ctx, cancel = context.WithTimeout(ctx, httpTimeout)
   168  				defer cancel()
   169  			}
   170  			delete(reqHeader, "Timeout")
   171  		}
   172  		setTimeout(conn, httpTimeout)
   173  
   174  		if err := serveRequest(ctx, reqHeader, reqBody, conn); err != nil {
   175  			if errRsp := sendErrorResp(r.Header, []byte(perrors.WithStack(err).Error())); errRsp != nil {
   176  				logger.Warnf("sendErrorResp(header:%#v, error:%v) = error:%s",
   177  					r.Header, perrors.WithStack(err), errRsp)
   178  			}
   179  
   180  			logger.Infof("Unexpected error serving request, closing socket: %v", err)
   181  			return
   182  		}
   183  	}
   184  }
   185  
   186  func accept(listener net.Listener, fn func(net.Conn)) error {
   187  	var (
   188  		ok       bool
   189  		ne       net.Error
   190  		tmpDelay time.Duration
   191  	)
   192  
   193  	for {
   194  		c, err := listener.Accept()
   195  		if err != nil {
   196  			if ne, ok = err.(net.Error); ok && ne.Temporary() {
   197  				if tmpDelay != 0 {
   198  					tmpDelay <<= 1
   199  				} else {
   200  					tmpDelay = 5 * time.Millisecond
   201  				}
   202  				if tmpDelay > DefaultMaxSleepTime {
   203  					tmpDelay = DefaultMaxSleepTime
   204  				}
   205  				logger.Infof("http: Accept error: %v; retrying in %v\n", err, tmpDelay)
   206  				time.Sleep(tmpDelay)
   207  				continue
   208  			}
   209  			return perrors.WithStack(err)
   210  		}
   211  
   212  		go func() {
   213  			defer func() {
   214  				if r := recover(); r != nil {
   215  					const size = 64 << 10
   216  					buf := make([]byte, size)
   217  					buf = buf[:runtime.Stack(buf, false)]
   218  					logger.Errorf("http: panic serving %v: %v\n%s", c.RemoteAddr(), r, buf)
   219  					c.Close()
   220  				}
   221  			}()
   222  
   223  			fn(c)
   224  		}()
   225  	}
   226  }
   227  
   228  // Start JSON RPC server then ready for accept request.
   229  func (s *Server) Start(url *common.URL) {
   230  	listener, err := net.Listen("tcp", url.Location)
   231  	if err != nil {
   232  		logger.Errorf("jsonrpc server [%s] start failed: %v", url.Path, err)
   233  		return
   234  	}
   235  	logger.Infof("rpc server start to listen on %s", listener.Addr())
   236  
   237  	s.wg.Add(1)
   238  	go func() {
   239  		if err := accept(listener, func(conn net.Conn) { s.handlePkg(conn) }); err != nil {
   240  			logger.Error("accept() = error:%v", err)
   241  		}
   242  		s.wg.Done()
   243  	}()
   244  
   245  	s.wg.Add(1)
   246  	go func() { // Server done goroutine
   247  		var err error
   248  		<-s.done               // step1: block to wait for done channel(wait Server.Stop step2)
   249  		err = listener.Close() // step2: and then close listener
   250  		if err != nil {
   251  			logger.Warnf("listener{addr:%s}.Close() = error{%#v}", listener.Addr(), err)
   252  		}
   253  		s.wg.Done()
   254  	}()
   255  }
   256  
   257  // Stop JSON RPC server, just can be call once.
   258  func (s *Server) Stop() {
   259  	s.once.Do(func() {
   260  		close(s.done)
   261  		s.wg.Wait()
   262  	})
   263  }
   264  
   265  func serveRequest(ctx context.Context, header map[string]string, body []byte, conn net.Conn) error {
   266  	sendErrorResp := func(header map[string]string, body []byte) error {
   267  		rsp := &http.Response{
   268  			Header:        make(http.Header),
   269  			StatusCode:    500,
   270  			ProtoMajor:    1,
   271  			ProtoMinor:    1,
   272  			ContentLength: int64(len(body)),
   273  			Body:          io.NopCloser(bytes.NewReader(body)),
   274  		}
   275  		rsp.Header.Del("Content-Type")
   276  		rsp.Header.Del("Content-Length")
   277  		rsp.Header.Del("Timeout")
   278  		for k, v := range header {
   279  			rsp.Header.Set(k, v)
   280  		}
   281  
   282  		rspBuf := bytes.NewBuffer(make([]byte, DefaultHTTPRspBufferSize))
   283  		rspBuf.Reset()
   284  		err := rsp.Write(rspBuf)
   285  		if err != nil {
   286  			return perrors.WithStack(err)
   287  		}
   288  		_, err = rspBuf.WriteTo(conn)
   289  		return perrors.WithStack(err)
   290  	}
   291  
   292  	sendResp := func(header map[string]string, body []byte) error {
   293  		rsp := &http.Response{
   294  			Header:        make(http.Header),
   295  			StatusCode:    200,
   296  			ProtoMajor:    1,
   297  			ProtoMinor:    1,
   298  			ContentLength: int64(len(body)),
   299  			Body:          io.NopCloser(bytes.NewReader(body)),
   300  		}
   301  		rsp.Header.Del("Content-Type")
   302  		rsp.Header.Del("Content-Length")
   303  		rsp.Header.Del("Timeout")
   304  		for k, v := range header {
   305  			rsp.Header.Set(k, v)
   306  		}
   307  
   308  		rspBuf := bytes.NewBuffer(make([]byte, DefaultHTTPRspBufferSize))
   309  		rspBuf.Reset()
   310  		err := rsp.Write(rspBuf)
   311  		if err != nil {
   312  			return perrors.WithStack(err)
   313  		}
   314  		_, err = rspBuf.WriteTo(conn)
   315  		return perrors.WithStack(err)
   316  	}
   317  
   318  	// read request header
   319  	codec := newServerCodec()
   320  	err := codec.ReadHeader(header, body)
   321  	if err != nil {
   322  		if err == io.EOF || err == io.ErrUnexpectedEOF {
   323  			return perrors.WithStack(err)
   324  		}
   325  		return perrors.New("server cannot decode request: " + err.Error())
   326  	}
   327  
   328  	path := header["Path"]
   329  	methodName := codec.req.Method
   330  	if len(path) == 0 || len(methodName) == 0 {
   331  		return perrors.New("service/method request ill-formed: " + path + "/" + methodName)
   332  	}
   333  
   334  	// read body
   335  	var args []interface{}
   336  	if err = codec.ReadBody(&args); err != nil {
   337  		return perrors.WithStack(err)
   338  	}
   339  	logger.Debugf("args: %v", args)
   340  
   341  	// exporter invoke
   342  	exporter, _ := jsonrpcProtocol.ExporterMap().Load(path)
   343  	invoker := exporter.(*JsonrpcExporter).GetInvoker()
   344  	if invoker != nil {
   345  		result := invoker.Invoke(ctx, invocation.NewRPCInvocation(methodName, args, map[string]interface{}{
   346  			constant.PathKey:    path,
   347  			constant.VersionKey: codec.req.Version,
   348  		}))
   349  		if err := result.Error(); err != nil {
   350  			rspStream, codecErr := codec.Write(err.Error(), invalidRequest)
   351  			if codecErr != nil {
   352  				return perrors.WithStack(codecErr)
   353  			}
   354  			if errRsp := sendErrorResp(header, rspStream); errRsp != nil {
   355  				logger.Warnf("Exporter: sendErrorResp(header:%#v, error:%v) = error:%s",
   356  					header, err, errRsp)
   357  			}
   358  		} else {
   359  			res := result.Result()
   360  			rspStream, err := codec.Write("", res)
   361  			if err != nil {
   362  				return perrors.WithStack(err)
   363  			}
   364  			if errRsp := sendResp(header, rspStream); errRsp != nil {
   365  				logger.Warnf("Exporter: sendResp(header:%#v, error:%v) = error:%s",
   366  					header, err, errRsp)
   367  			}
   368  		}
   369  	}
   370  
   371  	return nil
   372  }