github.com/cornelk/go-cloud@v0.17.1/internal/testing/setup/setup.go (about)

     1  // Copyright 2019 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package setup // import "github.com/cornelk/go-cloud/internal/testing/setup"
    16  
    17  import (
    18  	"context"
    19  	"flag"
    20  	"io/ioutil"
    21  	"net/http"
    22  	"os"
    23  	"path/filepath"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/aws/aws-sdk-go/aws"
    28  	awscreds "github.com/aws/aws-sdk-go/aws/credentials"
    29  	"github.com/aws/aws-sdk-go/aws/session"
    30  	"github.com/cornelk/go-cloud/gcp"
    31  	"github.com/cornelk/go-cloud/internal/useragent"
    32  
    33  	"github.com/google/go-replayers/grpcreplay"
    34  	"github.com/google/go-replayers/httpreplay"
    35  	hrgoog "github.com/google/go-replayers/httpreplay/google"
    36  	"golang.org/x/oauth2/google"
    37  	"google.golang.org/api/option"
    38  	"google.golang.org/grpc"
    39  	grpccreds "google.golang.org/grpc/credentials"
    40  	"google.golang.org/grpc/credentials/oauth"
    41  
    42  	"github.com/Azure/azure-pipeline-go/pipeline"
    43  	"github.com/Azure/azure-storage-blob-go/azblob"
    44  )
    45  
    46  // Record is true iff the tests are being run in "record" mode.
    47  var Record = flag.Bool("record", false, "whether to run tests against cloud resources and record the interactions")
    48  
    49  // FakeGCPCredentials gets fake GCP credentials.
    50  func FakeGCPCredentials(ctx context.Context) (*google.Credentials, error) {
    51  	return google.CredentialsFromJSON(ctx, []byte(`{"type": "service_account", "project_id": "my-project-id"}`))
    52  }
    53  
    54  func awsSession(region string, client *http.Client) (*session.Session, error) {
    55  	// Provide fake creds if running in replay mode.
    56  	var creds *awscreds.Credentials
    57  	if !*Record {
    58  		creds = awscreds.NewStaticCredentials("FAKE_ID", "FAKE_SECRET", "FAKE_TOKEN")
    59  	}
    60  	return session.NewSession(&aws.Config{
    61  		HTTPClient:  client,
    62  		Region:      aws.String(region),
    63  		Credentials: creds,
    64  		MaxRetries:  aws.Int(0),
    65  	})
    66  }
    67  
    68  // NewRecordReplayClient creates a new http.Client for tests. This client's
    69  // activity is being either recorded to files (when *Record is set) or replayed
    70  // from files. rf is a modifier function that will be invoked with the address
    71  // of the httpreplay.Recorder object used to obtain the client; this function
    72  // can mutate the recorder to add service-specific header filters, for example.
    73  // An initState is returned for tests that need a state to have deterministic
    74  // results, for example, a seed to generate random sequences.
    75  func NewRecordReplayClient(ctx context.Context, t *testing.T, rf func(r *httpreplay.Recorder)) (c *http.Client, cleanup func(), initState int64) {
    76  	httpreplay.DebugHeaders()
    77  	path := filepath.Join("testdata", t.Name()+".replay")
    78  	if *Record {
    79  		t.Logf("Recording into golden file %s", path)
    80  		if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
    81  			t.Fatal(err)
    82  		}
    83  		state := time.Now()
    84  		b, _ := state.MarshalBinary()
    85  		rec, err := httpreplay.NewRecorder(path, b)
    86  		if err != nil {
    87  			t.Fatal(err)
    88  		}
    89  		rf(rec)
    90  		cleanup = func() {
    91  			if err := rec.Close(); err != nil {
    92  				t.Fatal(err)
    93  			}
    94  		}
    95  
    96  		return rec.Client(), cleanup, state.UnixNano()
    97  	}
    98  	t.Logf("Replaying from golden file %s", path)
    99  	rep, err := httpreplay.NewReplayer(path)
   100  	if err != nil {
   101  		t.Fatal(err)
   102  	}
   103  	recState := new(time.Time)
   104  	if err := recState.UnmarshalBinary(rep.Initial()); err != nil {
   105  		t.Fatal(err)
   106  	}
   107  	return rep.Client(), func() { rep.Close() }, recState.UnixNano()
   108  }
   109  
   110  // NewAWSSession creates a new session for testing against AWS.
   111  // If the test is in --record mode, the test will call out to AWS, and the
   112  // results are recorded in a replay file.
   113  // Otherwise, the session reads a replay file and runs the test as a replay,
   114  // which never makes an outgoing HTTP call and uses fake credentials.
   115  // An initState is returned for tests that need a state to have deterministic
   116  // results, for example, a seed to generate random sequences.
   117  func NewAWSSession(ctx context.Context, t *testing.T, region string) (sess *session.Session,
   118  	rt http.RoundTripper, cleanup func(), initState int64) {
   119  	client, cleanup, state := NewRecordReplayClient(ctx, t, func(r *httpreplay.Recorder) {
   120  		r.RemoveQueryParams("X-Amz-Credential", "X-Amz-Signature", "X-Amz-Security-Token")
   121  		r.RemoveRequestHeaders("Authorization", "Duration", "X-Amz-Security-Token")
   122  		r.ClearHeaders("X-Amz-Date")
   123  		r.ClearQueryParams("X-Amz-Date")
   124  		r.ClearHeaders("User-Agent") // AWS includes the Go version
   125  	})
   126  	sess, err := awsSession(region, client)
   127  	if err != nil {
   128  		t.Fatal(err)
   129  	}
   130  	return sess, client.Transport, cleanup, state
   131  }
   132  
   133  // NewGCPClient creates a new HTTPClient for testing against GCP.
   134  // If the test is in --record mode, the client will call out to GCP, and the
   135  // results are recorded in a replay file.
   136  // Otherwise, the session reads a replay file and runs the test as a replay,
   137  // which never makes an outgoing HTTP call and uses fake credentials.
   138  func NewGCPClient(ctx context.Context, t *testing.T) (client *gcp.HTTPClient, rt http.RoundTripper, done func()) {
   139  	c, cleanup, _ := NewRecordReplayClient(ctx, t, func(r *httpreplay.Recorder) {
   140  		r.ClearQueryParams("Expires")
   141  		r.ClearQueryParams("Signature")
   142  		r.ClearHeaders("Expires")
   143  		r.ClearHeaders("Signature")
   144  	})
   145  	transport := c.Transport
   146  	if *Record {
   147  		creds, err := gcp.DefaultCredentials(ctx)
   148  		if err != nil {
   149  			t.Fatalf("failed to get default credentials: %v", err)
   150  		}
   151  		c, err = hrgoog.RecordClient(ctx, c, option.WithTokenSource(gcp.CredentialsTokenSource(creds)))
   152  		if err != nil {
   153  			t.Fatal(err)
   154  		}
   155  	}
   156  	return &gcp.HTTPClient{Client: *c}, transport, cleanup
   157  }
   158  
   159  // NewGCPgRPCConn creates a new connection for testing against GCP via gRPC.
   160  // If the test is in --record mode, the client will call out to GCP, and the
   161  // results are recorded in a replay file.
   162  // Otherwise, the session reads a replay file and runs the test as a replay,
   163  // which never makes an outgoing RPC and uses fake credentials.
   164  func NewGCPgRPCConn(ctx context.Context, t *testing.T, endPoint, api string) (*grpc.ClientConn, func()) {
   165  	filename := t.Name() + ".replay"
   166  	if *Record {
   167  		opts, done := newGCPRecordDialOptions(t, filename)
   168  		opts = append(opts, useragent.GRPCDialOption(api))
   169  		// Add credentials for real RPCs.
   170  		creds, err := gcp.DefaultCredentials(ctx)
   171  		if err != nil {
   172  			t.Fatal(err)
   173  		}
   174  		opts = append(opts, grpc.WithTransportCredentials(grpccreds.NewClientTLSFromCert(nil, "")))
   175  		opts = append(opts, grpc.WithPerRPCCredentials(oauth.TokenSource{TokenSource: gcp.CredentialsTokenSource(creds)}))
   176  		conn, err := grpc.DialContext(ctx, endPoint, opts...)
   177  		if err != nil {
   178  			t.Fatal(err)
   179  		}
   180  		return conn, done
   181  	}
   182  	rep, done := newGCPReplayer(t, filename)
   183  	conn, err := rep.Connection()
   184  	if err != nil {
   185  		t.Fatal(err)
   186  	}
   187  	return conn, done
   188  }
   189  
   190  // contentTypeInjectPolicy and contentTypeInjector are somewhat of a hack to
   191  // overcome an impedance mismatch between the Azure pipeline library and
   192  // httpreplay - the tool we use to record/replay HTTP traffic for tests.
   193  // azure-pipeline-go does not set the Content-Type header in its requests,
   194  // setting X-Ms-Blob-Content-Type instead; however, httpreplay expects
   195  // Content-Type to be non-empty in some cases. This injector makes sure that
   196  // the content type is copied into the right header when that is originally
   197  // empty. It's only used for testing.
   198  type contentTypeInjectPolicy struct {
   199  	node pipeline.Policy
   200  }
   201  
   202  func (p *contentTypeInjectPolicy) Do(ctx context.Context, request pipeline.Request) (pipeline.Response, error) {
   203  	if len(request.Header.Get("Content-Type")) == 0 {
   204  		cType := request.Header.Get("X-Ms-Blob-Content-Type")
   205  		request.Header.Set("Content-Type", cType)
   206  	}
   207  	response, err := p.node.Do(ctx, request)
   208  	return response, err
   209  }
   210  
   211  type contentTypeInjector struct {
   212  }
   213  
   214  func (f contentTypeInjector) New(node pipeline.Policy, opts *pipeline.PolicyOptions) pipeline.Policy {
   215  	return &contentTypeInjectPolicy{node: node}
   216  }
   217  
   218  // NewAzureTestPipeline creates a new connection for testing against Azure Blob.
   219  func NewAzureTestPipeline(ctx context.Context, t *testing.T, api string, credential azblob.Credential, accountName string) (pipeline.Pipeline, func(), *http.Client) {
   220  	client, done, _ := NewRecordReplayClient(ctx, t, func(r *httpreplay.Recorder) {
   221  		r.RemoveQueryParams("se", "sig")
   222  		r.RemoveQueryParams("X-Ms-Date")
   223  		r.ClearHeaders("X-Ms-Date")
   224  		r.ClearHeaders("User-Agent") // includes the full Go version
   225  	})
   226  	f := []pipeline.Factory{
   227  		// Sets User-Agent for recorder.
   228  		azblob.NewTelemetryPolicyFactory(azblob.TelemetryOptions{
   229  			Value: useragent.AzureUserAgentPrefix(api),
   230  		}),
   231  		contentTypeInjector{},
   232  		credential,
   233  		pipeline.MethodFactoryMarker(),
   234  	}
   235  	// Create a pipeline that uses client to make requests.
   236  	p := pipeline.NewPipeline(f, pipeline.Options{
   237  		HTTPSender: pipeline.FactoryFunc(func(next pipeline.Policy, po *pipeline.PolicyOptions) pipeline.PolicyFunc {
   238  			return func(ctx context.Context, request pipeline.Request) (pipeline.Response, error) {
   239  				r, err := client.Do(request.WithContext(ctx))
   240  				if err != nil {
   241  					err = pipeline.NewError(err, "HTTP request failed")
   242  				}
   243  				return pipeline.NewHTTPResponse(r), err
   244  			}
   245  		}),
   246  	})
   247  
   248  	return p, done, client
   249  }
   250  
   251  // NewAzureKeyVaultTestClient creates a *http.Client for Azure KeyVault test
   252  // recordings.
   253  func NewAzureKeyVaultTestClient(ctx context.Context, t *testing.T) (*http.Client, func()) {
   254  	client, cleanup, _ := NewRecordReplayClient(ctx, t, func(r *httpreplay.Recorder) {
   255  		r.RemoveQueryParams("se", "sig")
   256  		r.RemoveQueryParams("X-Ms-Date")
   257  		r.ClearHeaders("X-Ms-Date")
   258  		r.ClearHeaders("User-Agent") // includes the full Go version
   259  	})
   260  	return client, cleanup
   261  }
   262  
   263  // FakeGCPDefaultCredentials sets up the environment with fake GCP credentials.
   264  // It returns a cleanup function.
   265  func FakeGCPDefaultCredentials(t *testing.T) func() {
   266  	const envVar = "GOOGLE_APPLICATION_CREDENTIALS"
   267  	jsonCred := []byte(`{"client_id": "foo.apps.googleusercontent.com", "client_secret": "bar", "refresh_token": "baz", "type": "authorized_user"}`)
   268  	f, err := ioutil.TempFile("", "fake-gcp-creds")
   269  	if err != nil {
   270  		t.Fatal(err)
   271  	}
   272  	if err := ioutil.WriteFile(f.Name(), jsonCred, 0666); err != nil {
   273  		t.Fatal(err)
   274  	}
   275  	oldEnvVal := os.Getenv(envVar)
   276  	os.Setenv(envVar, f.Name())
   277  	return func() {
   278  		os.Remove(f.Name())
   279  		os.Setenv(envVar, oldEnvVal)
   280  	}
   281  }
   282  
   283  // newGCPRecordDialOptions return grpc.DialOptions that are to be appended to a
   284  // GRPC dial request. These options allow a recorder to intercept RPCs and save
   285  // RPCs to the file at filename, or read the RPCs from the file and return them.
   286  func newGCPRecordDialOptions(t *testing.T, filename string) (opts []grpc.DialOption, done func()) {
   287  	path := filepath.Join("testdata", filename)
   288  	t.Logf("Recording into golden file %s", path)
   289  	r, err := grpcreplay.NewRecorder(path, nil)
   290  	if err != nil {
   291  		t.Fatal(err)
   292  	}
   293  	opts = r.DialOptions()
   294  	done = func() {
   295  		if err := r.Close(); err != nil {
   296  			t.Errorf("unable to close recorder: %v", err)
   297  		}
   298  	}
   299  	return opts, done
   300  }
   301  
   302  // newGCPReplayer returns a Replayer for GCP gRPC connections, as well as a function
   303  // to call when done with the Replayer.
   304  func newGCPReplayer(t *testing.T, filename string) (*grpcreplay.Replayer, func()) {
   305  	path := filepath.Join("testdata", filename)
   306  	t.Logf("Replaying from golden file %s", path)
   307  	r, err := grpcreplay.NewReplayer(path, nil)
   308  	if err != nil {
   309  		t.Fatal(err)
   310  	}
   311  	done := func() {
   312  		if err := r.Close(); err != nil {
   313  			t.Errorf("unable to close recorder: %v", err)
   314  		}
   315  	}
   316  	return r, done
   317  }
   318  
   319  // HasDockerTestEnvironment returns true when either:
   320  // 1) Not on Travis.
   321  // 2) On Travis Linux environment, where Docker is available.
   322  func HasDockerTestEnvironment() bool {
   323  	s := os.Getenv("TRAVIS_OS_NAME")
   324  	return s == "" || s == "linux"
   325  }