bitbucket.org/Aishee/synsec@v0.0.0-20210414005726-236fc01a153d/pkg/apiclient/auth.go (about)

     1  package apiclient
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"time"
     7  
     8  	//"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"net/http/httputil"
    13  	"net/url"
    14  
    15  	"bitbucket.org/Aishee/synsec/pkg/models"
    16  	"github.com/go-openapi/strfmt"
    17  	"github.com/pkg/errors"
    18  	log "github.com/sirupsen/logrus"
    19  	//"google.golang.org/appengine/log"
    20  )
    21  
    22  type APIKeyTransport struct {
    23  	APIKey string
    24  	// Transport is the underlying HTTP transport to use when making requests.
    25  	// It will default to http.DefaultTransport if nil.
    26  	Transport     http.RoundTripper
    27  	URL           *url.URL
    28  	VersionPrefix string
    29  	UserAgent     string
    30  }
    31  
    32  // RoundTrip implements the RoundTripper interface.
    33  func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    34  	if t.APIKey == "" {
    35  		return nil, errors.New("APIKey is empty")
    36  	}
    37  
    38  	// We must make a copy of the Request so
    39  	// that we don't modify the Request we were given. This is required by the
    40  	// specification of http.RoundTripper.
    41  	req = cloneRequest(req)
    42  	req.Header.Add("X-Api-Key", t.APIKey)
    43  	if t.UserAgent != "" {
    44  		req.Header.Add("User-Agent", t.UserAgent)
    45  	}
    46  	log.Debugf("req-api: %s %s", req.Method, req.URL.String())
    47  	if log.GetLevel() >= log.TraceLevel {
    48  		dump, _ := httputil.DumpRequest(req, true)
    49  		log.Tracef("auth-api request: %s", string(dump))
    50  	}
    51  	// Make the HTTP request.
    52  	resp, err := t.transport().RoundTrip(req)
    53  	if err != nil {
    54  		log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err)
    55  		return resp, err
    56  	}
    57  	if log.GetLevel() >= log.TraceLevel {
    58  		dump, _ := httputil.DumpResponse(resp, true)
    59  		log.Tracef("auth-api response: %s", string(dump))
    60  	}
    61  
    62  	log.Debugf("resp-api: http %d", resp.StatusCode)
    63  
    64  	return resp, err
    65  }
    66  
    67  func (t *APIKeyTransport) Client() *http.Client {
    68  	return &http.Client{Transport: t}
    69  }
    70  
    71  func (t *APIKeyTransport) transport() http.RoundTripper {
    72  	if t.Transport != nil {
    73  		return t.Transport
    74  	}
    75  	return http.DefaultTransport
    76  }
    77  
    78  type JWTTransport struct {
    79  	MachineID     *string
    80  	Password      *strfmt.Password
    81  	token         string
    82  	Expiration    time.Time
    83  	Scenarios     []string
    84  	URL           *url.URL
    85  	VersionPrefix string
    86  	UserAgent     string
    87  	// Transport is the underlying HTTP transport to use when making requests.
    88  	// It will default to http.DefaultTransport if nil.
    89  	Transport      http.RoundTripper
    90  	UpdateScenario func() ([]string, error)
    91  }
    92  
    93  func (t *JWTTransport) refreshJwtToken() error {
    94  	var err error
    95  	if t.UpdateScenario != nil {
    96  		t.Scenarios, err = t.UpdateScenario()
    97  		if err != nil {
    98  			return fmt.Errorf("can't update scenario list: %s", err)
    99  		}
   100  		log.Debugf("scenarios list updated for '%s'", *t.MachineID)
   101  	}
   102  
   103  	var auth = models.WatcherAuthRequest{
   104  		MachineID: t.MachineID,
   105  		Password:  t.Password,
   106  		Scenarios: t.Scenarios,
   107  	}
   108  
   109  	var response models.WatcherAuthResponse
   110  
   111  	/*
   112  		we don't use the main client, so let's build the body
   113  	*/
   114  	var buf io.ReadWriter
   115  	buf = &bytes.Buffer{}
   116  	enc := json.NewEncoder(buf)
   117  	enc.SetEscapeHTML(false)
   118  	err = enc.Encode(auth)
   119  	if err != nil {
   120  		return errors.Wrap(err, "could not encode jwt auth body")
   121  	}
   122  	req, err := http.NewRequest("POST", fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf)
   123  	if err != nil {
   124  		return errors.Wrap(err, "could not create request")
   125  	}
   126  	req.Header.Add("Content-Type", "application/json")
   127  	client := &http.Client{}
   128  	if t.UserAgent != "" {
   129  		req.Header.Add("User-Agent", t.UserAgent)
   130  	}
   131  	if log.GetLevel() >= log.TraceLevel {
   132  		dump, _ := httputil.DumpRequest(req, true)
   133  		log.Tracef("auth-jwt request: %s", string(dump))
   134  	}
   135  
   136  	log.Debugf("auth-jwt(auth): %s %s", req.Method, req.URL.String())
   137  
   138  	resp, err := client.Do(req)
   139  	if err != nil {
   140  		return errors.Wrap(err, "could not get jwt token")
   141  	}
   142  	log.Debugf("auth-jwt : http %d", resp.StatusCode)
   143  
   144  	if log.GetLevel() >= log.TraceLevel {
   145  		dump, _ := httputil.DumpResponse(resp, true)
   146  		log.Tracef("auth-jwt response: %s", string(dump))
   147  	}
   148  
   149  	defer resp.Body.Close()
   150  
   151  	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
   152  		return fmt.Errorf("received response status %q when fetching %v", resp.Status, req.URL)
   153  	}
   154  
   155  	if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
   156  		return errors.Wrap(err, "unable to decode response")
   157  	}
   158  	if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
   159  		return errors.Wrap(err, "unable to parse jwt expiration")
   160  	}
   161  	t.token = response.Token
   162  
   163  	log.Debugf("token %s will expire on %s", t.token, t.Expiration.String())
   164  	return nil
   165  }
   166  
   167  // RoundTrip implements the RoundTripper interface.
   168  func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   169  	if t.token == "" || t.Expiration.Add(-time.Minute).Before(time.Now()) {
   170  		if err := t.refreshJwtToken(); err != nil {
   171  			return nil, err
   172  		}
   173  	}
   174  
   175  	// We must make a copy of the Request so
   176  	// that we don't modify the Request we were given. This is required by the
   177  	// specification of http.RoundTripper.
   178  	req = cloneRequest(req)
   179  	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token))
   180  	log.Debugf("req-jwt: %s %s", req.Method, req.URL.String())
   181  	if log.GetLevel() >= log.TraceLevel {
   182  		dump, _ := httputil.DumpRequest(req, true)
   183  		log.Tracef("req-jwt: %s", string(dump))
   184  	}
   185  	if t.UserAgent != "" {
   186  		req.Header.Add("User-Agent", t.UserAgent)
   187  	}
   188  	// Make the HTTP request.
   189  	resp, err := t.transport().RoundTrip(req)
   190  	if log.GetLevel() >= log.TraceLevel {
   191  		dump, _ := httputil.DumpResponse(resp, true)
   192  		log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
   193  	}
   194  	if err != nil || resp.StatusCode == 401 {
   195  		/*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
   196  		t.token = ""
   197  		return resp, errors.Wrapf(err, "performing jwt auth")
   198  	}
   199  	log.Debugf("resp-jwt: %d", resp.StatusCode)
   200  	return resp, nil
   201  }
   202  
   203  func (t *JWTTransport) Client() *http.Client {
   204  	return &http.Client{Transport: t}
   205  }
   206  
   207  func (t *JWTTransport) transport() http.RoundTripper {
   208  	if t.Transport != nil {
   209  		return t.Transport
   210  	}
   211  	return http.DefaultTransport
   212  }
   213  
   214  // cloneRequest returns a clone of the provided *http.Request. The clone is a
   215  // shallow copy of the struct and its Header map.
   216  func cloneRequest(r *http.Request) *http.Request {
   217  	// shallow copy of the struct
   218  	r2 := new(http.Request)
   219  	*r2 = *r
   220  	// deep copy of the Header
   221  	r2.Header = make(http.Header, len(r.Header))
   222  	for k, s := range r.Header {
   223  		r2.Header[k] = append([]string(nil), s...)
   224  	}
   225  	return r2
   226  }