github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/test/rawConnWrapper.go (about)

     1  /*
     2   * Copyright 2018 gRPC authors.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package test
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"io"
    23  	"net"
    24  	"strings"
    25  	"sync"
    26  	"time"
    27  
    28  	"github.com/hxx258456/ccgo/net/http2"
    29  	"github.com/hxx258456/ccgo/net/http2/hpack"
    30  )
    31  
    32  type listenerWrapper struct {
    33  	net.Listener
    34  	mu  sync.Mutex
    35  	rcw *rawConnWrapper
    36  }
    37  
    38  func listenWithConnControl(network, address string) (net.Listener, error) {
    39  	l, err := net.Listen(network, address)
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  	return &listenerWrapper{Listener: l}, nil
    44  }
    45  
    46  // Accept blocks until Dial is called, then returns a net.Conn for the server
    47  // half of the connection.
    48  func (l *listenerWrapper) Accept() (net.Conn, error) {
    49  	c, err := l.Listener.Accept()
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	l.mu.Lock()
    54  	l.rcw = newRawConnWrapperFromConn(c)
    55  	l.mu.Unlock()
    56  	return c, nil
    57  }
    58  
    59  func (l *listenerWrapper) getLastConn() *rawConnWrapper {
    60  	l.mu.Lock()
    61  	defer l.mu.Unlock()
    62  	return l.rcw
    63  }
    64  
    65  type dialerWrapper struct {
    66  	c   net.Conn
    67  	rcw *rawConnWrapper
    68  }
    69  
    70  func (d *dialerWrapper) dialer(target string, t time.Duration) (net.Conn, error) {
    71  	c, err := net.DialTimeout("tcp", target, t)
    72  	d.c = c
    73  	d.rcw = newRawConnWrapperFromConn(c)
    74  	return c, err
    75  }
    76  
    77  func (d *dialerWrapper) getRawConnWrapper() *rawConnWrapper {
    78  	return d.rcw
    79  }
    80  
    81  type rawConnWrapper struct {
    82  	cc io.ReadWriteCloser
    83  	fr *http2.Framer
    84  
    85  	// writing headers:
    86  	headerBuf bytes.Buffer
    87  	hpackEnc  *hpack.Encoder
    88  
    89  	// reading frames:
    90  	frc    chan http2.Frame
    91  	frErrc chan error
    92  }
    93  
    94  func newRawConnWrapperFromConn(cc io.ReadWriteCloser) *rawConnWrapper {
    95  	rcw := &rawConnWrapper{
    96  		cc:     cc,
    97  		frc:    make(chan http2.Frame, 1),
    98  		frErrc: make(chan error, 1),
    99  	}
   100  	rcw.hpackEnc = hpack.NewEncoder(&rcw.headerBuf)
   101  	rcw.fr = http2.NewFramer(cc, cc)
   102  	rcw.fr.ReadMetaHeaders = hpack.NewDecoder(4096 /*initialHeaderTableSize*/, nil)
   103  
   104  	return rcw
   105  }
   106  
   107  func (rcw *rawConnWrapper) Close() error {
   108  	return rcw.cc.Close()
   109  }
   110  
   111  func (rcw *rawConnWrapper) encodeHeaderField(k, v string) error {
   112  	err := rcw.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
   113  	if err != nil {
   114  		return fmt.Errorf("HPACK encoding error for %q/%q: %v", k, v, err)
   115  	}
   116  	return nil
   117  }
   118  
   119  // encodeRawHeader is for usage on both client and server side to construct header based on the input
   120  // key, value pairs.
   121  func (rcw *rawConnWrapper) encodeRawHeader(headers ...string) []byte {
   122  	if len(headers)%2 == 1 {
   123  		panic("odd number of kv args")
   124  	}
   125  
   126  	rcw.headerBuf.Reset()
   127  
   128  	pseudoCount := map[string]int{}
   129  	var keys []string
   130  	vals := map[string][]string{}
   131  
   132  	for len(headers) > 0 {
   133  		k, v := headers[0], headers[1]
   134  		headers = headers[2:]
   135  		if _, ok := vals[k]; !ok {
   136  			keys = append(keys, k)
   137  		}
   138  		if strings.HasPrefix(k, ":") {
   139  			pseudoCount[k]++
   140  			if pseudoCount[k] == 1 {
   141  				vals[k] = []string{v}
   142  			} else {
   143  				// Allows testing of invalid headers w/ dup pseudo fields.
   144  				vals[k] = append(vals[k], v)
   145  			}
   146  		} else {
   147  			vals[k] = append(vals[k], v)
   148  		}
   149  	}
   150  	for _, k := range keys {
   151  		for _, v := range vals[k] {
   152  			rcw.encodeHeaderField(k, v)
   153  		}
   154  	}
   155  	return rcw.headerBuf.Bytes()
   156  }
   157  
   158  // encodeHeader is for usage on client side to write request header.
   159  //
   160  // encodeHeader encodes headers and returns their HPACK bytes. headers
   161  // must contain an even number of key/value pairs.  There may be
   162  // multiple pairs for keys (e.g. "cookie").  The :method, :path, and
   163  // :scheme headers default to GET, / and https.
   164  func (rcw *rawConnWrapper) encodeHeader(headers ...string) []byte {
   165  	if len(headers)%2 == 1 {
   166  		panic("odd number of kv args")
   167  	}
   168  
   169  	rcw.headerBuf.Reset()
   170  
   171  	if len(headers) == 0 {
   172  		// Fast path, mostly for benchmarks, so test code doesn't pollute
   173  		// profiles when we're looking to improve server allocations.
   174  		rcw.encodeHeaderField(":method", "GET")
   175  		rcw.encodeHeaderField(":path", "/")
   176  		rcw.encodeHeaderField(":scheme", "https")
   177  		return rcw.headerBuf.Bytes()
   178  	}
   179  
   180  	if len(headers) == 2 && headers[0] == ":method" {
   181  		// Another fast path for benchmarks.
   182  		rcw.encodeHeaderField(":method", headers[1])
   183  		rcw.encodeHeaderField(":path", "/")
   184  		rcw.encodeHeaderField(":scheme", "https")
   185  		return rcw.headerBuf.Bytes()
   186  	}
   187  
   188  	pseudoCount := map[string]int{}
   189  	keys := []string{":method", ":path", ":scheme"}
   190  	vals := map[string][]string{
   191  		":method": {"GET"},
   192  		":path":   {"/"},
   193  		":scheme": {"https"},
   194  	}
   195  	for len(headers) > 0 {
   196  		k, v := headers[0], headers[1]
   197  		headers = headers[2:]
   198  		if _, ok := vals[k]; !ok {
   199  			keys = append(keys, k)
   200  		}
   201  		if strings.HasPrefix(k, ":") {
   202  			pseudoCount[k]++
   203  			if pseudoCount[k] == 1 {
   204  				vals[k] = []string{v}
   205  			} else {
   206  				// Allows testing of invalid headers w/ dup pseudo fields.
   207  				vals[k] = append(vals[k], v)
   208  			}
   209  		} else {
   210  			vals[k] = append(vals[k], v)
   211  		}
   212  	}
   213  	for _, k := range keys {
   214  		for _, v := range vals[k] {
   215  			rcw.encodeHeaderField(k, v)
   216  		}
   217  	}
   218  	return rcw.headerBuf.Bytes()
   219  }
   220  
   221  func (rcw *rawConnWrapper) writeHeaders(p http2.HeadersFrameParam) error {
   222  	if err := rcw.fr.WriteHeaders(p); err != nil {
   223  		return fmt.Errorf("error writing HEADERS: %v", err)
   224  	}
   225  	return nil
   226  }
   227  
   228  func (rcw *rawConnWrapper) writeRSTStream(streamID uint32, code http2.ErrCode) error {
   229  	if err := rcw.fr.WriteRSTStream(streamID, code); err != nil {
   230  		return fmt.Errorf("error writing RST_STREAM: %v", err)
   231  	}
   232  	return nil
   233  }
   234  
   235  func (rcw *rawConnWrapper) writeGoAway(maxStreamID uint32, code http2.ErrCode, debugData []byte) error {
   236  	if err := rcw.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
   237  		return fmt.Errorf("error writing GoAway: %v", err)
   238  	}
   239  	return nil
   240  }
   241  
   242  func (rcw *rawConnWrapper) writeRawFrame(t http2.FrameType, flags http2.Flags, streamID uint32, payload []byte) error {
   243  	if err := rcw.fr.WriteRawFrame(t, flags, streamID, payload); err != nil {
   244  		return fmt.Errorf("error writing Raw Frame: %v", err)
   245  	}
   246  	return nil
   247  }