github.com/webdestroya/awsmocker@v0.2.6/mocker.go (about)

     1  package awsmocker
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"os"
     9  	"path"
    10  	"time"
    11  )
    12  
    13  const (
    14  	envAwsCaBundle       = "AWS_CA_BUNDLE"
    15  	envAwsAccessKey      = "AWS_ACCESS_KEY_ID"
    16  	envAwsSecretKey      = "AWS_SECRET_ACCESS_KEY"
    17  	envAwsSessionToken   = "AWS_SESSION_TOKEN"
    18  	envAwsEc2MetaDisable = "AWS_EC2_METADATA_DISABLED"
    19  	envAwsContCredUri    = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
    20  	envAwsContCredRelUri = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
    21  	envAwsContAuthToken  = "AWS_CONTAINER_AUTHORIZATION_TOKEN"
    22  	envAwsConfigFile     = "AWS_CONFIG_FILE"
    23  	envAwsSharedCredFile = "AWS_SHARED_CREDENTIALS_FILE"
    24  	envAwsWebIdentTFile  = "AWS_WEB_IDENTITY_TOKEN_FILE"
    25  	envAwsDefaultRegion  = "AWS_DEFAULT_REGION"
    26  
    27  	// AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE
    28  	// AWS_EC2_METADATA_SERVICE_ENDPOINT
    29  )
    30  
    31  type mocker struct {
    32  	t          TestingT
    33  	timeout    time.Duration
    34  	httpServer *httptest.Server
    35  
    36  	verbose      bool
    37  	debugTraffic bool
    38  
    39  	usingAwsConfig     bool
    40  	doNotOverrideCreds bool
    41  	doNotFailUnhandled bool
    42  
    43  	originalEnv map[string]*string
    44  
    45  	mocks []*MockedEndpoint
    46  }
    47  
    48  func (m *mocker) init() {
    49  	m.originalEnv = make(map[string]*string, 10)
    50  }
    51  
    52  // Overrides an environment variable and then adds it to the stack to undo later
    53  func (m *mocker) setEnv(k string, v any) {
    54  	val, ok := os.LookupEnv(k)
    55  	if ok {
    56  		m.originalEnv[k] = &val
    57  	} else {
    58  		m.originalEnv[k] = nil
    59  	}
    60  
    61  	switch nval := v.(type) {
    62  	case string:
    63  		err := os.Setenv(k, nval)
    64  		if err != nil {
    65  			m.t.Errorf("Unable to set env var '%s': %s", k, err)
    66  		}
    67  	case nil:
    68  		err := os.Unsetenv(k)
    69  		if err != nil {
    70  			m.t.Errorf("Unable to unset env var '%s': %s", k, err)
    71  		}
    72  	default:
    73  		panic("WRONG ENV VAR VALUE TYPE: must be nil or a string")
    74  	}
    75  }
    76  
    77  func (m *mocker) revertEnv() {
    78  	for k, v := range m.originalEnv {
    79  		if v == nil {
    80  			_ = os.Unsetenv(k)
    81  		} else {
    82  			_ = os.Setenv(k, *v)
    83  		}
    84  	}
    85  }
    86  
    87  func (m *mocker) Start() {
    88  	// reset Go's proxy cache
    89  	resetProxyConfig()
    90  
    91  	m.init()
    92  
    93  	m.t.Cleanup(m.Shutdown)
    94  
    95  	for i := range m.mocks {
    96  		m.mocks[i].prep()
    97  	}
    98  
    99  	// if we are using aws config, then we don't need this
   100  	if !m.usingAwsConfig {
   101  		caBundlePath := path.Join(m.t.TempDir(), "awsmockcabundle.pem")
   102  		err := writeCABundle(caBundlePath)
   103  		if err != nil {
   104  			m.t.Errorf("Failed to write CA Bundle: %s", err)
   105  		}
   106  		m.setEnv(envAwsCaBundle, caBundlePath)
   107  	}
   108  
   109  	ts := httptest.NewServer(m)
   110  	m.httpServer = ts
   111  
   112  	m.setEnv("HTTP_PROXY", ts.URL)
   113  	m.setEnv("http_proxy", ts.URL)
   114  	m.setEnv("HTTPS_PROXY", ts.URL)
   115  	m.setEnv("https_proxy", ts.URL)
   116  
   117  	// m.setEnv(envAwsEc2MetaDisable, "true")
   118  	m.setEnv(envAwsDefaultRegion, DefaultRegion)
   119  
   120  	if !m.doNotOverrideCreds {
   121  		m.setEnv(envAwsAccessKey, "fakekey")
   122  		m.setEnv(envAwsSecretKey, "fakesecret")
   123  		m.setEnv(envAwsSessionToken, "faketoken")
   124  		m.setEnv(envAwsConfigFile, "fakeconffile")
   125  		m.setEnv(envAwsSharedCredFile, "fakesharedfile")
   126  	}
   127  
   128  }
   129  
   130  func (m *mocker) Shutdown() {
   131  	m.httpServer.Close()
   132  
   133  	m.revertEnv()
   134  
   135  	// reset Go's proxy cache
   136  	if !m.usingAwsConfig {
   137  		resetProxyConfig()
   138  	}
   139  }
   140  
   141  func (m *mocker) Logf(format string, args ...any) {
   142  	if !m.verbose {
   143  		return
   144  	}
   145  	m.printf(format, args...)
   146  }
   147  func (m *mocker) Warnf(format string, args ...any) {
   148  	m.printf("WARN: "+format, args...)
   149  }
   150  
   151  func (m *mocker) printf(format string, args ...any) {
   152  	m.t.Logf("[AWSMOCKER] "+format, args...)
   153  }
   154  
   155  func (m *mocker) handleRequest(req *http.Request) (*http.Request, *http.Response) {
   156  	recvReq := newReceivedRequest(req)
   157  
   158  	// if recvReq.invalid {
   159  	// 	recvReq.DebugDump()
   160  	// 	m.t.Errorf("You provided an invalid request")
   161  	// 	return req, generateErrorStruct(http.StatusNotImplemented, "AccessDenied", "You provided a bad or invalid request").getResponse(recvReq).toHttpResponse(req)
   162  	// }
   163  
   164  	if m.debugTraffic {
   165  		recvReq.DebugDump()
   166  	}
   167  
   168  	for _, mockEndpoint := range m.mocks {
   169  		if mockEndpoint.matchRequest(recvReq) {
   170  			// increment it's matcher count
   171  			mockEndpoint.Request.incMatchCount()
   172  
   173  			// build the response
   174  			return req, mockEndpoint.getResponse(recvReq).toHttpResponse(req)
   175  		}
   176  	}
   177  
   178  	if !m.doNotFailUnhandled {
   179  		m.t.Errorf("No matching request mock was found for this request: %s", recvReq.Inspect())
   180  	}
   181  
   182  	return req, generateErrorStruct(http.StatusNotImplemented, "AccessDenied", "No matching request mock was found for this").getResponse(recvReq).toHttpResponse(req)
   183  }
   184  
   185  func (m *mocker) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   186  	hostname := r.URL.Hostname()
   187  
   188  	if m.verbose {
   189  		buf := new(bytes.Buffer)
   190  		fmt.Fprintln(buf, "AWSMocker Proxy Request:")
   191  		fmt.Fprintf(buf, "%s %s [%s]\n", r.Method, r.RequestURI, r.Proto)
   192  		fmt.Fprintf(buf, "Host: %s --- Raw: %s\n", hostname, r.Host)
   193  		for k, vlist := range r.Header {
   194  			for _, v := range vlist {
   195  				fmt.Fprintf(buf, "%s: %s\n", k, v)
   196  			}
   197  		}
   198  		m.Logf(buf.String())
   199  	}
   200  
   201  	if r.Method == "CONNECT" {
   202  		m.handleHttps(w, r)
   203  		return
   204  	}
   205  
   206  	if !r.URL.IsAbs() {
   207  		handleNonProxyRequest.ServeHTTP(w, r)
   208  		return
   209  	}
   210  
   211  	// Must be an HTTP call
   212  	m.handleHttp(w, r)
   213  
   214  }