github.com/apache/arrow/go/v7@v7.0.1/arrow/flight/flight_middleware_test.go (about) 1 // Licensed to the Apache Software Foundation (ASF) under one 2 // or more contributor license agreements. See the NOTICE file 3 // distributed with this work for additional information 4 // regarding copyright ownership. The ASF licenses this file 5 // to you under the Apache License, Version 2.0 (the 6 // "License"); you may not use this file except in compliance 7 // with the License. You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 package flight_test 18 19 import ( 20 "context" 21 "io" 22 sync "sync" 23 "testing" 24 25 "github.com/apache/arrow/go/v7/arrow/flight" 26 "github.com/apache/arrow/go/v7/arrow/internal/arrdata" 27 "github.com/stretchr/testify/assert" 28 "github.com/stretchr/testify/require" 29 "google.golang.org/grpc" 30 "google.golang.org/grpc/metadata" 31 ) 32 33 type ServerMiddlewareAddHeader struct { 34 ctx context.Context 35 } 36 37 func (s *ServerMiddlewareAddHeader) StartCall(ctx context.Context) context.Context { 38 grpc.SetHeader(ctx, metadata.Pairs("foo", "bar")) 39 s.ctx = ctx 40 41 return nil 42 } 43 44 func (s *ServerMiddlewareAddHeader) CallCompleted(ctx context.Context, err error) { 45 if s.ctx != ctx { 46 panic("invalid context") 47 } 48 49 grpc.SetTrailer(ctx, metadata.Pairs("super", "duper")) 50 51 if err != nil { 52 panic("got error") 53 } 54 } 55 56 type ServerTraceMiddleware struct{} 57 58 type tracetestKey struct{} 59 60 func (s ServerTraceMiddleware) StartCall(ctx context.Context) context.Context { 61 return context.WithValue(ctx, tracetestKey{}, "foobar") 62 } 63 64 func (s ServerTraceMiddleware) CallCompleted(ctx context.Context, _ error) { 65 v := ctx.Value(tracetestKey{}).(string) 66 if v != "foobar" { 67 panic("missing value from context in middleware test") 68 } 69 } 70 71 type ServerExpectHeaderMiddleware struct{} 72 73 func (s ServerExpectHeaderMiddleware) StartCall(ctx context.Context) context.Context { 74 md, ok := metadata.FromIncomingContext(ctx) 75 if !ok { 76 panic("missing metadata headers") 77 } 78 79 bar := md.Get("foo") 80 if len(bar) != 1 || bar[0] != "bar" { 81 panic("incorrect header received: " + bar[0]) 82 } 83 84 return nil 85 } 86 87 func (s ServerExpectHeaderMiddleware) CallCompleted(context.Context, error) {} 88 89 func TestServerStreamMiddleware(t *testing.T) { 90 s := flight.NewServerWithMiddleware(nil, []flight.ServerMiddleware{ 91 flight.CreateServerMiddleware(&ServerMiddlewareAddHeader{}), 92 flight.CreateServerMiddleware(ServerTraceMiddleware{}), 93 }) 94 s.Init("localhost:0") 95 f := &flightServer{} 96 s.RegisterFlightService(&flight.FlightServiceService{ 97 ListFlights: f.ListFlights, 98 }) 99 100 go s.Serve() 101 defer s.Shutdown() 102 103 client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, nil, grpc.WithInsecure()) 104 require.NoError(t, err) 105 defer client.Close() 106 107 flightStream, err := client.ListFlights(context.Background(), &flight.Criteria{}) 108 require.NoError(t, err) 109 110 md, err := flightStream.Header() 111 assert.NoError(t, err) 112 assert.Equal(t, []string{"bar"}, md.Get("foo")) 113 114 for { 115 info, err := flightStream.Recv() 116 if err != nil { 117 if err == io.EOF { 118 break 119 } 120 assert.NoError(t, err) 121 } 122 123 fname := info.GetFlightDescriptor().GetPath()[0] 124 recs, ok := arrdata.Records[fname] 125 assert.True(t, ok) 126 127 sc, err := flight.DeserializeSchema(info.GetSchema(), f.mem) 128 assert.NoError(t, err) 129 130 assert.True(t, recs[0].Schema().Equal(sc)) 131 } 132 133 md = flightStream.Trailer() 134 assert.Equal(t, []string{"duper"}, md.Get("super")) 135 } 136 137 func TestServerUnaryMiddleware(t *testing.T) { 138 s := flight.NewServerWithMiddleware(nil, []flight.ServerMiddleware{ 139 flight.CreateServerMiddleware(&ServerMiddlewareAddHeader{}), 140 flight.CreateServerMiddleware(ServerTraceMiddleware{}), 141 }) 142 s.Init("localhost:0") 143 f := &flightServer{} 144 s.RegisterFlightService(&flight.FlightServiceService{ 145 GetSchema: f.GetSchema, 146 }) 147 148 go s.Serve() 149 defer s.Shutdown() 150 151 client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, nil, grpc.WithInsecure()) 152 require.NoError(t, err) 153 defer client.Close() 154 155 for name, testrecs := range arrdata.Records { 156 t.Run("flight get schema: "+name, func(t *testing.T) { 157 var ( 158 hdrMD metadata.MD 159 trailerMD metadata.MD 160 ) 161 res, err := client.GetSchema(context.Background(), &flight.FlightDescriptor{Path: []string{name}}, grpc.Header(&hdrMD), grpc.Trailer(&trailerMD)) 162 if err != nil { 163 t.Fatal(err) 164 } 165 166 schema, err := flight.DeserializeSchema(res.GetSchema(), f.getmem()) 167 if err != nil { 168 t.Fatal(err) 169 } 170 171 if !testrecs[0].Schema().Equal(schema) { 172 t.Fatalf("schema not match: \ngot = %#v\nwant = %#v\n", schema, testrecs[0].Schema()) 173 } 174 175 assert.Equal(t, []string{"bar"}, hdrMD.Get("foo")) 176 assert.Equal(t, []string{"duper"}, trailerMD.Get("super")) 177 }) 178 } 179 } 180 181 type ClientTestSendHeaderMiddleware struct { 182 ctx context.Context 183 md metadata.MD 184 mx sync.Mutex 185 } 186 187 func (c *ClientTestSendHeaderMiddleware) StartCall(ctx context.Context) context.Context { 188 c.ctx = context.WithValue(metadata.AppendToOutgoingContext(ctx, "foo", "bar"), tracetestKey{}, "super") 189 return c.ctx 190 } 191 192 func (c *ClientTestSendHeaderMiddleware) CallCompleted(ctx context.Context, err error) { 193 val := ctx.Value(tracetestKey{}).(string) 194 if val != "super" { 195 panic("invalid context client middleware") 196 } 197 } 198 199 func (c *ClientTestSendHeaderMiddleware) HeadersReceived(ctx context.Context, md metadata.MD) { 200 val := ctx.Value(tracetestKey{}).(string) 201 if val != "super" { 202 panic("invalid context client middleware") 203 } 204 205 c.mx.Lock() 206 defer c.mx.Unlock() 207 c.md = md 208 } 209 210 func TestClientStreamMiddleware(t *testing.T) { 211 s := flight.NewServerWithMiddleware(nil, []flight.ServerMiddleware{ 212 flight.CreateServerMiddleware(&ServerExpectHeaderMiddleware{}), 213 flight.CreateServerMiddleware(&ServerMiddlewareAddHeader{}), 214 }) 215 s.Init("localhost:0") 216 f := &flightServer{} 217 s.RegisterFlightService(&flight.FlightServiceService{ 218 ListFlights: f.ListFlights, 219 }) 220 221 go s.Serve() 222 defer s.Shutdown() 223 224 middleware := &ClientTestSendHeaderMiddleware{} 225 client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, []flight.ClientMiddleware{ 226 flight.CreateClientMiddleware(middleware), 227 }, grpc.WithInsecure()) 228 require.NoError(t, err) 229 defer client.Close() 230 231 flightStream, err := client.ListFlights(context.Background(), &flight.Criteria{}) 232 require.NoError(t, err) 233 234 for { 235 info, err := flightStream.Recv() 236 if err != nil { 237 if err == io.EOF { 238 break 239 } 240 assert.NoError(t, err) 241 } 242 243 fname := info.GetFlightDescriptor().GetPath()[0] 244 recs, ok := arrdata.Records[fname] 245 assert.True(t, ok) 246 247 sc, err := flight.DeserializeSchema(info.GetSchema(), f.mem) 248 assert.NoError(t, err) 249 250 assert.True(t, recs[0].Schema().Equal(sc)) 251 } 252 253 middleware.mx.Lock() 254 defer middleware.mx.Unlock() 255 assert.Equal(t, []string{"bar"}, middleware.md.Get("foo")) 256 assert.Equal(t, []string{"duper"}, middleware.md.Get("super")) 257 } 258 259 func TestClientUnaryMiddleware(t *testing.T) { 260 s := flight.NewServerWithMiddleware(nil, []flight.ServerMiddleware{ 261 flight.CreateServerMiddleware(&ServerMiddlewareAddHeader{}), 262 flight.CreateServerMiddleware(ServerExpectHeaderMiddleware{}), 263 }) 264 s.Init("localhost:0") 265 f := &flightServer{} 266 s.RegisterFlightService(&flight.FlightServiceService{ 267 GetSchema: f.GetSchema, 268 }) 269 270 go s.Serve() 271 defer s.Shutdown() 272 273 middle := &ClientTestSendHeaderMiddleware{} 274 client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, []flight.ClientMiddleware{ 275 flight.CreateClientMiddleware(middle), 276 }, grpc.WithInsecure()) 277 278 require.NoError(t, err) 279 defer client.Close() 280 281 for name, testrecs := range arrdata.Records { 282 t.Run("flight get schema: "+name, func(t *testing.T) { 283 res, err := client.GetSchema(context.Background(), &flight.FlightDescriptor{Path: []string{name}}) 284 if err != nil { 285 t.Fatal(err) 286 } 287 288 schema, err := flight.DeserializeSchema(res.GetSchema(), f.getmem()) 289 if err != nil { 290 t.Fatal(err) 291 } 292 293 if !testrecs[0].Schema().Equal(schema) { 294 t.Fatalf("schema not match: \ngot = %#v\nwant = %#v\n", schema, testrecs[0].Schema()) 295 } 296 297 assert.Equal(t, []string{"bar"}, middle.md.Get("foo")) 298 assert.Equal(t, []string{"duper"}, middle.md.Get("super")) 299 300 middle.md = metadata.MD{} 301 }) 302 } 303 }