github.com/apache/arrow/go/v7@v7.0.1/arrow/flight/flight_middleware_test.go (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one
     2  // or more contributor license agreements.  See the NOTICE file
     3  // distributed with this work for additional information
     4  // regarding copyright ownership.  The ASF licenses this file
     5  // to you under the Apache License, Version 2.0 (the
     6  // "License"); you may not use this file except in compliance
     7  // with the License.  You may obtain a copy of the License at
     8  //
     9  // http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package flight_test
    18  
    19  import (
    20  	"context"
    21  	"io"
    22  	sync "sync"
    23  	"testing"
    24  
    25  	"github.com/apache/arrow/go/v7/arrow/flight"
    26  	"github.com/apache/arrow/go/v7/arrow/internal/arrdata"
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  	"google.golang.org/grpc"
    30  	"google.golang.org/grpc/metadata"
    31  )
    32  
    33  type ServerMiddlewareAddHeader struct {
    34  	ctx context.Context
    35  }
    36  
    37  func (s *ServerMiddlewareAddHeader) StartCall(ctx context.Context) context.Context {
    38  	grpc.SetHeader(ctx, metadata.Pairs("foo", "bar"))
    39  	s.ctx = ctx
    40  
    41  	return nil
    42  }
    43  
    44  func (s *ServerMiddlewareAddHeader) CallCompleted(ctx context.Context, err error) {
    45  	if s.ctx != ctx {
    46  		panic("invalid context")
    47  	}
    48  
    49  	grpc.SetTrailer(ctx, metadata.Pairs("super", "duper"))
    50  
    51  	if err != nil {
    52  		panic("got error")
    53  	}
    54  }
    55  
    56  type ServerTraceMiddleware struct{}
    57  
    58  type tracetestKey struct{}
    59  
    60  func (s ServerTraceMiddleware) StartCall(ctx context.Context) context.Context {
    61  	return context.WithValue(ctx, tracetestKey{}, "foobar")
    62  }
    63  
    64  func (s ServerTraceMiddleware) CallCompleted(ctx context.Context, _ error) {
    65  	v := ctx.Value(tracetestKey{}).(string)
    66  	if v != "foobar" {
    67  		panic("missing value from context in middleware test")
    68  	}
    69  }
    70  
    71  type ServerExpectHeaderMiddleware struct{}
    72  
    73  func (s ServerExpectHeaderMiddleware) StartCall(ctx context.Context) context.Context {
    74  	md, ok := metadata.FromIncomingContext(ctx)
    75  	if !ok {
    76  		panic("missing metadata headers")
    77  	}
    78  
    79  	bar := md.Get("foo")
    80  	if len(bar) != 1 || bar[0] != "bar" {
    81  		panic("incorrect header received: " + bar[0])
    82  	}
    83  
    84  	return nil
    85  }
    86  
    87  func (s ServerExpectHeaderMiddleware) CallCompleted(context.Context, error) {}
    88  
    89  func TestServerStreamMiddleware(t *testing.T) {
    90  	s := flight.NewServerWithMiddleware(nil, []flight.ServerMiddleware{
    91  		flight.CreateServerMiddleware(&ServerMiddlewareAddHeader{}),
    92  		flight.CreateServerMiddleware(ServerTraceMiddleware{}),
    93  	})
    94  	s.Init("localhost:0")
    95  	f := &flightServer{}
    96  	s.RegisterFlightService(&flight.FlightServiceService{
    97  		ListFlights: f.ListFlights,
    98  	})
    99  
   100  	go s.Serve()
   101  	defer s.Shutdown()
   102  
   103  	client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, nil, grpc.WithInsecure())
   104  	require.NoError(t, err)
   105  	defer client.Close()
   106  
   107  	flightStream, err := client.ListFlights(context.Background(), &flight.Criteria{})
   108  	require.NoError(t, err)
   109  
   110  	md, err := flightStream.Header()
   111  	assert.NoError(t, err)
   112  	assert.Equal(t, []string{"bar"}, md.Get("foo"))
   113  
   114  	for {
   115  		info, err := flightStream.Recv()
   116  		if err != nil {
   117  			if err == io.EOF {
   118  				break
   119  			}
   120  			assert.NoError(t, err)
   121  		}
   122  
   123  		fname := info.GetFlightDescriptor().GetPath()[0]
   124  		recs, ok := arrdata.Records[fname]
   125  		assert.True(t, ok)
   126  
   127  		sc, err := flight.DeserializeSchema(info.GetSchema(), f.mem)
   128  		assert.NoError(t, err)
   129  
   130  		assert.True(t, recs[0].Schema().Equal(sc))
   131  	}
   132  
   133  	md = flightStream.Trailer()
   134  	assert.Equal(t, []string{"duper"}, md.Get("super"))
   135  }
   136  
   137  func TestServerUnaryMiddleware(t *testing.T) {
   138  	s := flight.NewServerWithMiddleware(nil, []flight.ServerMiddleware{
   139  		flight.CreateServerMiddleware(&ServerMiddlewareAddHeader{}),
   140  		flight.CreateServerMiddleware(ServerTraceMiddleware{}),
   141  	})
   142  	s.Init("localhost:0")
   143  	f := &flightServer{}
   144  	s.RegisterFlightService(&flight.FlightServiceService{
   145  		GetSchema: f.GetSchema,
   146  	})
   147  
   148  	go s.Serve()
   149  	defer s.Shutdown()
   150  
   151  	client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, nil, grpc.WithInsecure())
   152  	require.NoError(t, err)
   153  	defer client.Close()
   154  
   155  	for name, testrecs := range arrdata.Records {
   156  		t.Run("flight get schema: "+name, func(t *testing.T) {
   157  			var (
   158  				hdrMD     metadata.MD
   159  				trailerMD metadata.MD
   160  			)
   161  			res, err := client.GetSchema(context.Background(), &flight.FlightDescriptor{Path: []string{name}}, grpc.Header(&hdrMD), grpc.Trailer(&trailerMD))
   162  			if err != nil {
   163  				t.Fatal(err)
   164  			}
   165  
   166  			schema, err := flight.DeserializeSchema(res.GetSchema(), f.getmem())
   167  			if err != nil {
   168  				t.Fatal(err)
   169  			}
   170  
   171  			if !testrecs[0].Schema().Equal(schema) {
   172  				t.Fatalf("schema not match: \ngot = %#v\nwant = %#v\n", schema, testrecs[0].Schema())
   173  			}
   174  
   175  			assert.Equal(t, []string{"bar"}, hdrMD.Get("foo"))
   176  			assert.Equal(t, []string{"duper"}, trailerMD.Get("super"))
   177  		})
   178  	}
   179  }
   180  
   181  type ClientTestSendHeaderMiddleware struct {
   182  	ctx context.Context
   183  	md  metadata.MD
   184  	mx  sync.Mutex
   185  }
   186  
   187  func (c *ClientTestSendHeaderMiddleware) StartCall(ctx context.Context) context.Context {
   188  	c.ctx = context.WithValue(metadata.AppendToOutgoingContext(ctx, "foo", "bar"), tracetestKey{}, "super")
   189  	return c.ctx
   190  }
   191  
   192  func (c *ClientTestSendHeaderMiddleware) CallCompleted(ctx context.Context, err error) {
   193  	val := ctx.Value(tracetestKey{}).(string)
   194  	if val != "super" {
   195  		panic("invalid context client middleware")
   196  	}
   197  }
   198  
   199  func (c *ClientTestSendHeaderMiddleware) HeadersReceived(ctx context.Context, md metadata.MD) {
   200  	val := ctx.Value(tracetestKey{}).(string)
   201  	if val != "super" {
   202  		panic("invalid context client middleware")
   203  	}
   204  
   205  	c.mx.Lock()
   206  	defer c.mx.Unlock()
   207  	c.md = md
   208  }
   209  
   210  func TestClientStreamMiddleware(t *testing.T) {
   211  	s := flight.NewServerWithMiddleware(nil, []flight.ServerMiddleware{
   212  		flight.CreateServerMiddleware(&ServerExpectHeaderMiddleware{}),
   213  		flight.CreateServerMiddleware(&ServerMiddlewareAddHeader{}),
   214  	})
   215  	s.Init("localhost:0")
   216  	f := &flightServer{}
   217  	s.RegisterFlightService(&flight.FlightServiceService{
   218  		ListFlights: f.ListFlights,
   219  	})
   220  
   221  	go s.Serve()
   222  	defer s.Shutdown()
   223  
   224  	middleware := &ClientTestSendHeaderMiddleware{}
   225  	client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, []flight.ClientMiddleware{
   226  		flight.CreateClientMiddleware(middleware),
   227  	}, grpc.WithInsecure())
   228  	require.NoError(t, err)
   229  	defer client.Close()
   230  
   231  	flightStream, err := client.ListFlights(context.Background(), &flight.Criteria{})
   232  	require.NoError(t, err)
   233  
   234  	for {
   235  		info, err := flightStream.Recv()
   236  		if err != nil {
   237  			if err == io.EOF {
   238  				break
   239  			}
   240  			assert.NoError(t, err)
   241  		}
   242  
   243  		fname := info.GetFlightDescriptor().GetPath()[0]
   244  		recs, ok := arrdata.Records[fname]
   245  		assert.True(t, ok)
   246  
   247  		sc, err := flight.DeserializeSchema(info.GetSchema(), f.mem)
   248  		assert.NoError(t, err)
   249  
   250  		assert.True(t, recs[0].Schema().Equal(sc))
   251  	}
   252  
   253  	middleware.mx.Lock()
   254  	defer middleware.mx.Unlock()
   255  	assert.Equal(t, []string{"bar"}, middleware.md.Get("foo"))
   256  	assert.Equal(t, []string{"duper"}, middleware.md.Get("super"))
   257  }
   258  
   259  func TestClientUnaryMiddleware(t *testing.T) {
   260  	s := flight.NewServerWithMiddleware(nil, []flight.ServerMiddleware{
   261  		flight.CreateServerMiddleware(&ServerMiddlewareAddHeader{}),
   262  		flight.CreateServerMiddleware(ServerExpectHeaderMiddleware{}),
   263  	})
   264  	s.Init("localhost:0")
   265  	f := &flightServer{}
   266  	s.RegisterFlightService(&flight.FlightServiceService{
   267  		GetSchema: f.GetSchema,
   268  	})
   269  
   270  	go s.Serve()
   271  	defer s.Shutdown()
   272  
   273  	middle := &ClientTestSendHeaderMiddleware{}
   274  	client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, []flight.ClientMiddleware{
   275  		flight.CreateClientMiddleware(middle),
   276  	}, grpc.WithInsecure())
   277  
   278  	require.NoError(t, err)
   279  	defer client.Close()
   280  
   281  	for name, testrecs := range arrdata.Records {
   282  		t.Run("flight get schema: "+name, func(t *testing.T) {
   283  			res, err := client.GetSchema(context.Background(), &flight.FlightDescriptor{Path: []string{name}})
   284  			if err != nil {
   285  				t.Fatal(err)
   286  			}
   287  
   288  			schema, err := flight.DeserializeSchema(res.GetSchema(), f.getmem())
   289  			if err != nil {
   290  				t.Fatal(err)
   291  			}
   292  
   293  			if !testrecs[0].Schema().Equal(schema) {
   294  				t.Fatalf("schema not match: \ngot = %#v\nwant = %#v\n", schema, testrecs[0].Schema())
   295  			}
   296  
   297  			assert.Equal(t, []string{"bar"}, middle.md.Get("foo"))
   298  			assert.Equal(t, []string{"duper"}, middle.md.Get("super"))
   299  
   300  			middle.md = metadata.MD{}
   301  		})
   302  	}
   303  }