github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/netpoll/http_client_handler.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package netpoll
    18  
    19  import (
    20  	"bufio"
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"io/ioutil"
    25  	"net"
    26  	"net/http"
    27  	"path"
    28  	"strconv"
    29  	"strings"
    30  
    31  	"github.com/cloudwego/netpoll"
    32  
    33  	"github.com/cloudwego/kitex/pkg/kerrors"
    34  	"github.com/cloudwego/kitex/pkg/klog"
    35  	"github.com/cloudwego/kitex/pkg/remote"
    36  	"github.com/cloudwego/kitex/pkg/remote/trans"
    37  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    38  	"github.com/cloudwego/kitex/pkg/stats"
    39  )
    40  
    41  func newHTTPCliTransHandler(opt *remote.ClientOption, ext trans.Extension) (remote.ClientTransHandler, error) {
    42  	return &httpCliTransHandler{
    43  		opt:   opt,
    44  		codec: opt.Codec,
    45  		ext:   ext,
    46  	}, nil
    47  }
    48  
    49  type httpCliTransHandler struct {
    50  	opt       *remote.ClientOption
    51  	codec     remote.Codec
    52  	transPipe *remote.TransPipeline
    53  	ext       trans.Extension
    54  }
    55  
    56  // Write implements the remote.ClientTransHandler interface.
    57  func (t *httpCliTransHandler) Write(ctx context.Context, conn net.Conn, sendMsg remote.Message) (nctx context.Context, err error) {
    58  	var bufWriter remote.ByteBuffer
    59  	ri := sendMsg.RPCInfo()
    60  	rpcinfo.Record(ctx, ri, stats.WriteStart, nil)
    61  	defer func() {
    62  		t.ext.ReleaseBuffer(bufWriter, err)
    63  		rpcinfo.Record(ctx, ri, stats.WriteFinish, err)
    64  	}()
    65  
    66  	bufWriter = t.ext.NewWriteByteBuffer(ctx, conn, sendMsg)
    67  	buffer := netpoll.NewLinkBuffer(0)
    68  	bodyReaderWriter := NewReaderWriterByteBuffer(buffer)
    69  	defer bodyReaderWriter.Release(err)
    70  	if err != nil {
    71  		return ctx, err
    72  	}
    73  	err = t.codec.Encode(ctx, sendMsg, bodyReaderWriter)
    74  	if err != nil {
    75  		return ctx, err
    76  	}
    77  	err = bodyReaderWriter.Flush()
    78  	if err != nil {
    79  		return ctx, err
    80  	}
    81  	var url string
    82  	if hu, ok := ri.To().Tag(rpcinfo.HTTPURL); ok {
    83  		url = hu
    84  	} else {
    85  		url = path.Join("/", ri.Invocation().MethodName())
    86  	}
    87  	req, err := http.NewRequest("POST", url, netpoll.NewIOReader(buffer))
    88  	if err != nil {
    89  		return ctx, err
    90  	}
    91  	err = addMetaInfo(sendMsg, req.Header)
    92  	if err != nil {
    93  		return ctx, err
    94  	}
    95  	err = req.Write(bufWriter)
    96  	if err != nil {
    97  		return ctx, err
    98  	}
    99  	return ctx, bufWriter.Flush()
   100  }
   101  
   102  // Read implements the remote.ClientTransHandler interface. Read is blocked.
   103  func (t *httpCliTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) {
   104  	var bufReader remote.ByteBuffer
   105  	rpcinfo.Record(ctx, msg.RPCInfo(), stats.ReadStart, nil)
   106  	defer func() {
   107  		t.ext.ReleaseBuffer(bufReader, err)
   108  		rpcinfo.Record(ctx, msg.RPCInfo(), stats.ReadFinish, err)
   109  	}()
   110  
   111  	t.ext.SetReadTimeout(ctx, conn, msg.RPCInfo().Config(), remote.Client)
   112  	bufReader = t.ext.NewReadByteBuffer(ctx, conn, msg)
   113  	bodyReader, err := getBodyBufReader(bufReader)
   114  	if err != nil {
   115  		return ctx, fmt.Errorf("get body bufreader error:%w", err)
   116  	}
   117  	err = t.codec.Decode(ctx, msg, bodyReader)
   118  	if err != nil {
   119  		return ctx, err
   120  	}
   121  	if left := bufReader.ReadableLen(); left > 0 {
   122  		bufReader.Skip(left)
   123  	}
   124  	return ctx, nil
   125  }
   126  
   127  // OnMessage implements the remote.ClientTransHandler interface.
   128  func (t *httpCliTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) {
   129  	// do nothing
   130  	return ctx, nil
   131  }
   132  
   133  // OnInactive implements the remote.ClientTransHandler interface.
   134  // This is called when connection is closed.
   135  func (t *httpCliTransHandler) OnInactive(ctx context.Context, conn net.Conn) {
   136  	// ineffective now and do nothing
   137  }
   138  
   139  // OnError implements the remote.ClientTransHandler interface.
   140  // This is called when panic happens.
   141  func (t *httpCliTransHandler) OnError(ctx context.Context, err error, conn net.Conn) {
   142  	if pe, ok := err.(*kerrors.DetailedError); ok {
   143  		klog.CtxErrorf(ctx, "KITEX: send http request error, remote=%s, error=%s\nstack=%s", conn.RemoteAddr(), err.Error(), pe.Stack())
   144  	} else {
   145  		klog.CtxErrorf(ctx, "KITEX: send http request error, remote=%s, error=%s", conn.RemoteAddr(), err.Error())
   146  	}
   147  }
   148  
   149  // SetPipeline implements the remote.ClientTransHandler interface.
   150  func (t *httpCliTransHandler) SetPipeline(p *remote.TransPipeline) {
   151  	t.transPipe = p
   152  }
   153  
   154  func addMetaInfo(msg remote.Message, h http.Header) error {
   155  	meta, ok := msg.Tags()[rpcinfo.HTTPHeader]
   156  	if !ok {
   157  		return nil
   158  	}
   159  	if header, ok := meta.(http.Header); ok {
   160  		for k, v := range header {
   161  			h[k] = v
   162  		}
   163  	} else {
   164  		return errors.New("http header in rpcinfo type assertion failed")
   165  	}
   166  	return nil
   167  }
   168  
   169  func readLine(buffer remote.ByteBuffer) ([]byte, error) {
   170  	var buf []byte
   171  	for {
   172  		buf0, err := buffer.Next(1)
   173  		if err != nil {
   174  			return nil, err
   175  		}
   176  		if buf0[0] == '\r' {
   177  			buf1, err := buffer.Peek(1)
   178  			if err != nil {
   179  				return nil, err
   180  			}
   181  			if buf1[0] == '\n' {
   182  				err = buffer.Skip(1)
   183  				if err != nil {
   184  					return nil, err
   185  				}
   186  				return buf, nil
   187  			}
   188  		} else {
   189  			buf = append(buf, buf0[0])
   190  		}
   191  	}
   192  }
   193  
   194  // return n bytes skipped ('\r\n' not included)
   195  func skipLine(buffer remote.ByteBuffer) (n int, err error) {
   196  	for {
   197  		buf0, err := buffer.Next(1)
   198  		if err != nil {
   199  			return n, err
   200  		}
   201  		if buf0[0] == '\r' {
   202  			buf1, err := buffer.Peek(1)
   203  			if err != nil {
   204  				return n, err
   205  			}
   206  			if buf1[0] == '\n' {
   207  				err = buffer.Skip(1)
   208  				if err != nil {
   209  					return n, err
   210  				}
   211  				return n, err
   212  			}
   213  		} else {
   214  			n++
   215  		}
   216  	}
   217  }
   218  
   219  func parseHTTPResponseHead(line string) (protoMajor, protoMinor, statusCodeInt int, err error) {
   220  	var proto, status, statusCode string
   221  	i := strings.IndexByte(line, ' ')
   222  	if i == -1 {
   223  		return 0, 0, 0, errors.New("malformed HTTP response: " + line)
   224  	}
   225  	proto = line[:i]
   226  	status = strings.TrimLeft(line[i+1:], " ")
   227  	statusCode = status
   228  	if i := strings.IndexByte(status, ' '); i != -1 {
   229  		statusCode = status[:i]
   230  	}
   231  	if len(statusCode) != 3 {
   232  		return 0, 0, 0, errors.New("malformed HTTP status code: " + statusCode)
   233  	}
   234  	statusCodeInt, err = strconv.Atoi(statusCode)
   235  	if err != nil || statusCodeInt < 0 {
   236  		return 0, 0, 0, errors.New("malformed HTTP status code: " + statusCode)
   237  	}
   238  	var ok bool
   239  	if protoMajor, protoMinor, ok = http.ParseHTTPVersion(proto); !ok {
   240  		return 0, 0, 0, errors.New("malformed HTTP version: " + proto)
   241  	}
   242  	return
   243  }
   244  
   245  func skipToBody(buffer remote.ByteBuffer) error {
   246  	head, err := readLine(buffer)
   247  	if err != nil {
   248  		return err
   249  	}
   250  	_, _, statusCode, err := parseHTTPResponseHead(string(head))
   251  	if err != nil {
   252  		return err
   253  	}
   254  	if statusCode != 200 {
   255  		s, err := buffer.ReadString(buffer.ReadableLen())
   256  		if err != nil {
   257  			return fmt.Errorf("http code: %d, read request error:\n%w", statusCode, err)
   258  		}
   259  		return fmt.Errorf("http code: %d, error:\n%s", statusCode, s)
   260  	}
   261  	for {
   262  		n, err := skipLine(buffer)
   263  		if err != nil {
   264  			return err
   265  		}
   266  		if n == 0 {
   267  			return nil
   268  		}
   269  	}
   270  }
   271  
   272  func getBodyBufReader(buf remote.ByteBuffer) (remote.ByteBuffer, error) {
   273  	br := bufio.NewReader(buf)
   274  	hr, err := http.ReadResponse(br, nil)
   275  	if err != nil {
   276  		return nil, fmt.Errorf("read http response error:%w", err)
   277  	}
   278  	if hr.StatusCode != http.StatusOK {
   279  		return nil, fmt.Errorf("http response not OK, StatusCode: %d", hr.StatusCode)
   280  	}
   281  	b, err := ioutil.ReadAll(hr.Body)
   282  	hr.Body.Close()
   283  	if err != nil {
   284  		return nil, fmt.Errorf("read http response body error:%w", err)
   285  	}
   286  	return remote.NewReaderBuffer(b), nil
   287  }