github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/test/servertester.go (about)

     1  /*
     2   * Copyright 2016 gRPC authors.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  // Package test contains tests.
    18  package test
    19  
    20  import (
    21  	"bytes"
    22  	"errors"
    23  	"io"
    24  	"strings"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/hxx258456/ccgo/net/http2"
    29  	"github.com/hxx258456/ccgo/net/http2/hpack"
    30  )
    31  
    32  // This is a subset of http2's serverTester type.
    33  //
    34  // serverTester wraps a io.ReadWriter (acting like the underlying
    35  // network connection) and provides utility methods to read and write
    36  // http2 frames.
    37  //
    38  // NOTE(bradfitz): this could eventually be exported somewhere. Others
    39  // have asked for it too. For now I'm still experimenting with the
    40  // API and don't feel like maintaining a stable testing API.
    41  
    42  type serverTester struct {
    43  	cc io.ReadWriteCloser // client conn
    44  	t  testing.TB
    45  	fr *http2.Framer
    46  
    47  	// writing headers:
    48  	headerBuf bytes.Buffer
    49  	hpackEnc  *hpack.Encoder
    50  
    51  	// reading frames:
    52  	frc    chan http2.Frame
    53  	frErrc chan error
    54  }
    55  
    56  func newServerTesterFromConn(t testing.TB, cc io.ReadWriteCloser) *serverTester {
    57  	st := &serverTester{
    58  		t:      t,
    59  		cc:     cc,
    60  		frc:    make(chan http2.Frame, 1),
    61  		frErrc: make(chan error, 1),
    62  	}
    63  	st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
    64  	st.fr = http2.NewFramer(cc, cc)
    65  	st.fr.ReadMetaHeaders = hpack.NewDecoder(4096 /*initialHeaderTableSize*/, nil)
    66  
    67  	return st
    68  }
    69  
    70  func (st *serverTester) readFrame() (http2.Frame, error) {
    71  	go func() {
    72  		fr, err := st.fr.ReadFrame()
    73  		if err != nil {
    74  			st.frErrc <- err
    75  		} else {
    76  			st.frc <- fr
    77  		}
    78  	}()
    79  	t := time.NewTimer(2 * time.Second)
    80  	defer t.Stop()
    81  	select {
    82  	case f := <-st.frc:
    83  		return f, nil
    84  	case err := <-st.frErrc:
    85  		return nil, err
    86  	case <-t.C:
    87  		return nil, errors.New("timeout waiting for frame")
    88  	}
    89  }
    90  
    91  // greet initiates the client's HTTP/2 connection into a state where
    92  // frames may be sent.
    93  func (st *serverTester) greet() {
    94  	st.writePreface()
    95  	st.writeInitialSettings()
    96  	st.wantSettings()
    97  	st.writeSettingsAck()
    98  	for {
    99  		f, err := st.readFrame()
   100  		if err != nil {
   101  			st.t.Fatal(err)
   102  		}
   103  		switch f := f.(type) {
   104  		case *http2.WindowUpdateFrame:
   105  			// grpc's transport/http2_server sends this
   106  			// before the settings ack. The Go http2
   107  			// server uses a setting instead.
   108  		case *http2.SettingsFrame:
   109  			if f.IsAck() {
   110  				return
   111  			}
   112  			st.t.Fatalf("during greet, got non-ACK settings frame")
   113  		default:
   114  			st.t.Fatalf("during greet, unexpected frame type %T", f)
   115  		}
   116  	}
   117  }
   118  
   119  func (st *serverTester) writePreface() {
   120  	n, err := st.cc.Write([]byte(http2.ClientPreface))
   121  	if err != nil {
   122  		st.t.Fatalf("Error writing client preface: %v", err)
   123  	}
   124  	if n != len(http2.ClientPreface) {
   125  		st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(http2.ClientPreface))
   126  	}
   127  }
   128  
   129  func (st *serverTester) writeInitialSettings() {
   130  	if err := st.fr.WriteSettings(); err != nil {
   131  		st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err)
   132  	}
   133  }
   134  
   135  func (st *serverTester) writeSettingsAck() {
   136  	if err := st.fr.WriteSettingsAck(); err != nil {
   137  		st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err)
   138  	}
   139  }
   140  
   141  func (st *serverTester) wantRSTStream(errCode http2.ErrCode) *http2.RSTStreamFrame {
   142  	f, err := st.readFrame()
   143  	if err != nil {
   144  		st.t.Fatalf("Error while expecting an RST frame: %v", err)
   145  	}
   146  	sf, ok := f.(*http2.RSTStreamFrame)
   147  	if !ok {
   148  		st.t.Fatalf("got a %T; want *http2.RSTStreamFrame", f)
   149  	}
   150  	if sf.ErrCode != errCode {
   151  		st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), sf.ErrCode.String())
   152  	}
   153  	return sf
   154  }
   155  
   156  func (st *serverTester) wantSettings() *http2.SettingsFrame {
   157  	f, err := st.readFrame()
   158  	if err != nil {
   159  		st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
   160  	}
   161  	sf, ok := f.(*http2.SettingsFrame)
   162  	if !ok {
   163  		st.t.Fatalf("got a %T; want *SettingsFrame", f)
   164  	}
   165  	return sf
   166  }
   167  
   168  // wait for any activity from the server
   169  func (st *serverTester) wantAnyFrame() http2.Frame {
   170  	f, err := st.fr.ReadFrame()
   171  	if err != nil {
   172  		st.t.Fatal(err)
   173  	}
   174  	return f
   175  }
   176  
   177  func (st *serverTester) encodeHeaderField(k, v string) {
   178  	err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
   179  	if err != nil {
   180  		st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
   181  	}
   182  }
   183  
   184  // encodeHeader encodes headers and returns their HPACK bytes. headers
   185  // must contain an even number of key/value pairs.  There may be
   186  // multiple pairs for keys (e.g. "cookie").  The :method, :path, and
   187  // :scheme headers default to GET, / and https.
   188  func (st *serverTester) encodeHeader(headers ...string) []byte {
   189  	if len(headers)%2 == 1 {
   190  		panic("odd number of kv args")
   191  	}
   192  
   193  	st.headerBuf.Reset()
   194  
   195  	if len(headers) == 0 {
   196  		// Fast path, mostly for benchmarks, so test code doesn't pollute
   197  		// profiles when we're looking to improve server allocations.
   198  		st.encodeHeaderField(":method", "GET")
   199  		st.encodeHeaderField(":path", "/")
   200  		st.encodeHeaderField(":scheme", "https")
   201  		return st.headerBuf.Bytes()
   202  	}
   203  
   204  	if len(headers) == 2 && headers[0] == ":method" {
   205  		// Another fast path for benchmarks.
   206  		st.encodeHeaderField(":method", headers[1])
   207  		st.encodeHeaderField(":path", "/")
   208  		st.encodeHeaderField(":scheme", "https")
   209  		return st.headerBuf.Bytes()
   210  	}
   211  
   212  	pseudoCount := map[string]int{}
   213  	keys := []string{":method", ":path", ":scheme"}
   214  	vals := map[string][]string{
   215  		":method": {"GET"},
   216  		":path":   {"/"},
   217  		":scheme": {"https"},
   218  	}
   219  	for len(headers) > 0 {
   220  		k, v := headers[0], headers[1]
   221  		headers = headers[2:]
   222  		if _, ok := vals[k]; !ok {
   223  			keys = append(keys, k)
   224  		}
   225  		if strings.HasPrefix(k, ":") {
   226  			pseudoCount[k]++
   227  			if pseudoCount[k] == 1 {
   228  				vals[k] = []string{v}
   229  			} else {
   230  				// Allows testing of invalid headers w/ dup pseudo fields.
   231  				vals[k] = append(vals[k], v)
   232  			}
   233  		} else {
   234  			vals[k] = append(vals[k], v)
   235  		}
   236  	}
   237  	for _, k := range keys {
   238  		for _, v := range vals[k] {
   239  			st.encodeHeaderField(k, v)
   240  		}
   241  	}
   242  	return st.headerBuf.Bytes()
   243  }
   244  
   245  func (st *serverTester) writeHeadersGRPC(streamID uint32, path string, endStream bool) {
   246  	st.writeHeaders(http2.HeadersFrameParam{
   247  		StreamID: streamID,
   248  		BlockFragment: st.encodeHeader(
   249  			":method", "POST",
   250  			":path", path,
   251  			"content-type", "application/grpc",
   252  			"te", "trailers",
   253  		),
   254  		EndStream:  endStream,
   255  		EndHeaders: true,
   256  	})
   257  }
   258  
   259  func (st *serverTester) writeHeaders(p http2.HeadersFrameParam) {
   260  	if err := st.fr.WriteHeaders(p); err != nil {
   261  		st.t.Fatalf("Error writing HEADERS: %v", err)
   262  	}
   263  }
   264  
   265  func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) {
   266  	if err := st.fr.WriteData(streamID, endStream, data); err != nil {
   267  		st.t.Fatalf("Error writing DATA: %v", err)
   268  	}
   269  }
   270  
   271  func (st *serverTester) writeRSTStream(streamID uint32, code http2.ErrCode) {
   272  	if err := st.fr.WriteRSTStream(streamID, code); err != nil {
   273  		st.t.Fatalf("Error writing RST_STREAM: %v", err)
   274  	}
   275  }