google.golang.org/grpc@v1.62.1/test/servertester.go (about) 1 /* 2 * Copyright 2016 gRPC authors. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Package test contains tests. 18 package test 19 20 import ( 21 "bytes" 22 "errors" 23 "io" 24 "strings" 25 "testing" 26 "time" 27 28 "golang.org/x/net/http2" 29 "golang.org/x/net/http2/hpack" 30 ) 31 32 // This is a subset of http2's serverTester type. 33 // 34 // serverTester wraps a io.ReadWriter (acting like the underlying 35 // network connection) and provides utility methods to read and write 36 // http2 frames. 37 // 38 // NOTE(bradfitz): this could eventually be exported somewhere. Others 39 // have asked for it too. For now I'm still experimenting with the 40 // API and don't feel like maintaining a stable testing API. 41 42 type serverTester struct { 43 cc io.ReadWriteCloser // client conn 44 t testing.TB 45 fr *http2.Framer 46 47 // writing headers: 48 headerBuf bytes.Buffer 49 hpackEnc *hpack.Encoder 50 51 // reading frames: 52 frc chan http2.Frame 53 frErrc chan error 54 } 55 56 func newServerTesterFromConn(t testing.TB, cc io.ReadWriteCloser) *serverTester { 57 st := &serverTester{ 58 t: t, 59 cc: cc, 60 frc: make(chan http2.Frame, 1), 61 frErrc: make(chan error, 1), 62 } 63 st.hpackEnc = hpack.NewEncoder(&st.headerBuf) 64 st.fr = http2.NewFramer(cc, cc) 65 st.fr.ReadMetaHeaders = hpack.NewDecoder(4096 /*initialHeaderTableSize*/, nil) 66 67 return st 68 } 69 70 func (st *serverTester) readFrame() (http2.Frame, error) { 71 go func() { 72 fr, err := st.fr.ReadFrame() 73 if err != nil { 74 st.frErrc <- err 75 } else { 76 st.frc <- fr 77 } 78 }() 79 t := time.NewTimer(2 * time.Second) 80 defer t.Stop() 81 select { 82 case f := <-st.frc: 83 return f, nil 84 case err := <-st.frErrc: 85 return nil, err 86 case <-t.C: 87 return nil, errors.New("timeout waiting for frame") 88 } 89 } 90 91 // greet initiates the client's HTTP/2 connection into a state where 92 // frames may be sent. 93 func (st *serverTester) greet() { 94 st.writePreface() 95 st.writeInitialSettings() 96 st.wantSettings() 97 st.writeSettingsAck() 98 for { 99 f, err := st.readFrame() 100 if err != nil { 101 st.t.Fatal(err) 102 } 103 switch f := f.(type) { 104 case *http2.WindowUpdateFrame: 105 // grpc's transport/http2_server sends this 106 // before the settings ack. The Go http2 107 // server uses a setting instead. 108 case *http2.SettingsFrame: 109 if f.IsAck() { 110 return 111 } 112 st.t.Fatalf("during greet, got non-ACK settings frame") 113 default: 114 st.t.Fatalf("during greet, unexpected frame type %T", f) 115 } 116 } 117 } 118 119 func (st *serverTester) writePreface() { 120 n, err := st.cc.Write([]byte(http2.ClientPreface)) 121 if err != nil { 122 st.t.Fatalf("Error writing client preface: %v", err) 123 } 124 if n != len(http2.ClientPreface) { 125 st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(http2.ClientPreface)) 126 } 127 } 128 129 func (st *serverTester) writeInitialSettings() { 130 if err := st.fr.WriteSettings(); err != nil { 131 st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err) 132 } 133 } 134 135 func (st *serverTester) writeSettingsAck() { 136 if err := st.fr.WriteSettingsAck(); err != nil { 137 st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err) 138 } 139 } 140 141 func (st *serverTester) wantGoAway(errCode http2.ErrCode) *http2.GoAwayFrame { 142 f, err := st.readFrame() 143 if err != nil { 144 st.t.Fatalf("Error while expecting an RST frame: %v", err) 145 } 146 gaf, ok := f.(*http2.GoAwayFrame) 147 if !ok { 148 st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f) 149 } 150 if gaf.ErrCode != errCode { 151 st.t.Fatalf("expected GOAWAY error code '%v', got '%v'", errCode.String(), gaf.ErrCode.String()) 152 } 153 return gaf 154 } 155 156 func (st *serverTester) wantPing() *http2.PingFrame { 157 f, err := st.readFrame() 158 if err != nil { 159 st.t.Fatalf("Error while expecting an RST frame: %v", err) 160 } 161 pf, ok := f.(*http2.PingFrame) 162 if !ok { 163 st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f) 164 } 165 return pf 166 } 167 168 func (st *serverTester) wantRSTStream(errCode http2.ErrCode) *http2.RSTStreamFrame { 169 f, err := st.readFrame() 170 if err != nil { 171 st.t.Fatalf("Error while expecting an RST frame: %v", err) 172 } 173 rf, ok := f.(*http2.RSTStreamFrame) 174 if !ok { 175 st.t.Fatalf("got a %T; want *http2.RSTStreamFrame", f) 176 } 177 if rf.ErrCode != errCode { 178 st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), rf.ErrCode.String()) 179 } 180 return rf 181 } 182 183 func (st *serverTester) wantSettings() *http2.SettingsFrame { 184 f, err := st.readFrame() 185 if err != nil { 186 st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err) 187 } 188 sf, ok := f.(*http2.SettingsFrame) 189 if !ok { 190 st.t.Fatalf("got a %T; want *SettingsFrame", f) 191 } 192 return sf 193 } 194 195 // wait for any activity from the server 196 func (st *serverTester) wantAnyFrame() http2.Frame { 197 f, err := st.fr.ReadFrame() 198 if err != nil { 199 st.t.Fatal(err) 200 } 201 return f 202 } 203 204 func (st *serverTester) encodeHeaderField(k, v string) { 205 err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) 206 if err != nil { 207 st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) 208 } 209 } 210 211 // encodeHeader encodes headers and returns their HPACK bytes. headers 212 // must contain an even number of key/value pairs. There may be 213 // multiple pairs for keys (e.g. "cookie"). The :method, :path, and 214 // :scheme headers default to GET, / and https. 215 func (st *serverTester) encodeHeader(headers ...string) []byte { 216 if len(headers)%2 == 1 { 217 panic("odd number of kv args") 218 } 219 220 st.headerBuf.Reset() 221 222 if len(headers) == 0 { 223 // Fast path, mostly for benchmarks, so test code doesn't pollute 224 // profiles when we're looking to improve server allocations. 225 st.encodeHeaderField(":method", "GET") 226 st.encodeHeaderField(":path", "/") 227 st.encodeHeaderField(":scheme", "https") 228 return st.headerBuf.Bytes() 229 } 230 231 if len(headers) == 2 && headers[0] == ":method" { 232 // Another fast path for benchmarks. 233 st.encodeHeaderField(":method", headers[1]) 234 st.encodeHeaderField(":path", "/") 235 st.encodeHeaderField(":scheme", "https") 236 return st.headerBuf.Bytes() 237 } 238 239 pseudoCount := map[string]int{} 240 keys := []string{":method", ":path", ":scheme"} 241 vals := map[string][]string{ 242 ":method": {"GET"}, 243 ":path": {"/"}, 244 ":scheme": {"https"}, 245 } 246 for len(headers) > 0 { 247 k, v := headers[0], headers[1] 248 headers = headers[2:] 249 if _, ok := vals[k]; !ok { 250 keys = append(keys, k) 251 } 252 if strings.HasPrefix(k, ":") { 253 pseudoCount[k]++ 254 if pseudoCount[k] == 1 { 255 vals[k] = []string{v} 256 } else { 257 // Allows testing of invalid headers w/ dup pseudo fields. 258 vals[k] = append(vals[k], v) 259 } 260 } else { 261 vals[k] = append(vals[k], v) 262 } 263 } 264 for _, k := range keys { 265 for _, v := range vals[k] { 266 st.encodeHeaderField(k, v) 267 } 268 } 269 return st.headerBuf.Bytes() 270 } 271 272 func (st *serverTester) writeHeadersGRPC(streamID uint32, path string, endStream bool) { 273 st.writeHeaders(http2.HeadersFrameParam{ 274 StreamID: streamID, 275 BlockFragment: st.encodeHeader( 276 ":method", "POST", 277 ":path", path, 278 "content-type", "application/grpc", 279 "te", "trailers", 280 ), 281 EndStream: endStream, 282 EndHeaders: true, 283 }) 284 } 285 286 func (st *serverTester) writeHeaders(p http2.HeadersFrameParam) { 287 if err := st.fr.WriteHeaders(p); err != nil { 288 st.t.Fatalf("Error writing HEADERS: %v", err) 289 } 290 } 291 292 func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) { 293 if err := st.fr.WriteData(streamID, endStream, data); err != nil { 294 st.t.Fatalf("Error writing DATA: %v", err) 295 } 296 } 297 298 func (st *serverTester) writeRSTStream(streamID uint32, code http2.ErrCode) { 299 if err := st.fr.WriteRSTStream(streamID, code); err != nil { 300 st.t.Fatalf("Error writing RST_STREAM: %v", err) 301 } 302 } 303 304 func (st *serverTester) writePing(ack bool, data [8]byte) { 305 if err := st.fr.WritePing(ack, data); err != nil { 306 st.t.Fatalf("Error writing PING: %v", err) 307 } 308 }