github.com/lingyao2333/mo-zero@v1.4.1/rest/handler/contentsecurityhandler_test.go (about)

     1  package handler
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"net/url"
    13  	"os"
    14  	"strconv"
    15  	"strings"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/lingyao2333/mo-zero/core/codec"
    20  	"github.com/lingyao2333/mo-zero/rest/httpx"
    21  	"github.com/stretchr/testify/assert"
    22  )
    23  
    24  const timeDiff = time.Hour * 2 * 24
    25  
    26  var (
    27  	fingerprint = "12345"
    28  	pubKey      = []byte(`-----BEGIN PUBLIC KEY-----
    29  MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE
    30  eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH
    31  miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR
    32  my47YlhspwszKdRP+wIDAQAB
    33  -----END PUBLIC KEY-----`)
    34  	priKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
    35  MIICXAIBAAKBgQD7bq4FLG0ctccbEFEsUBuRxkjEeJ5U+0CAEjJk20V9/u2Fu76i
    36  1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVHmiYbRgh5Fy6336KepLCtCmV/
    37  r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwRmy47YlhspwszKdRP+wIDAQAB
    38  AoGBANs1qf7UtuSbD1ZnKX5K8V5s07CHwPMygw+lzc3k5ndtNUStZQ2vnAaBXHyH
    39  Nm4lJ4AI2mhQ39jQB/1TyP1uAzvpLhT60fRybEq9zgJ/81Gm9bnaEpFJ9bP2bBrY
    40  J0jbaTMfbzL/PJFl3J3RGMR40C76h5yRYSnOpMoMiKWnJqrhAkEA/zCOkR+34Pk0
    41  Yo3sIP4ranY6AAvwacgNaui4ll5xeYwv3iLOQvPlpxIxFHKXEY0klNNyjjXqgYjP
    42  cOenqtt6UwJBAPw7EYuteVHvHvQVuTbKAaYHcOrp4nFeZF3ndFfl0w2dwGhfzcXO
    43  ROyd5dNQCuCWRo8JBpjG6PFyzezayF4KLrkCQCGditoxHG7FRRJKcbVy5dMzWbaR
    44  3AyDLslLeK1OKZKCVffkC9mj+TeF3PM9mQrV1eDI7ckv7wE7PWA5E8wc90MCQEOV
    45  MCZU3OTvRUPxbicYCUkLRV4sPNhTimD+21WR5vMHCb7trJ0Ln7wmsqXkFIYIve8l
    46  Y/cblN7c/AAyvu0znUECQA318nPldsxR6+H8HTS3uEbkL4UJdjQJHsvTwKxAw5qc
    47  moKExvRlN0zmGGuArKcqS38KG7PXZMrUv3FXPdp6BDQ=
    48  -----END RSA PRIVATE KEY-----`)
    49  	key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
    50  )
    51  
    52  type requestSettings struct {
    53  	method      string
    54  	url         string
    55  	body        io.Reader
    56  	strict      bool
    57  	crypt       bool
    58  	requestUri  string
    59  	timestamp   int64
    60  	fingerprint string
    61  	missHeader  bool
    62  	signature   string
    63  }
    64  
    65  func init() {
    66  	log.SetOutput(io.Discard)
    67  }
    68  
    69  func TestContentSecurityHandler(t *testing.T) {
    70  	tests := []struct {
    71  		method      string
    72  		url         string
    73  		body        string
    74  		strict      bool
    75  		crypt       bool
    76  		requestUri  string
    77  		timestamp   int64
    78  		fingerprint string
    79  		missHeader  bool
    80  		signature   string
    81  		statusCode  int
    82  	}{
    83  		{
    84  			method: http.MethodGet,
    85  			url:    "http://localhost/a/b?c=d&e=f",
    86  			strict: true,
    87  			crypt:  false,
    88  		},
    89  		{
    90  			method: http.MethodPost,
    91  			url:    "http://localhost/a/b?c=d&e=f",
    92  			body:   "hello",
    93  			strict: true,
    94  			crypt:  false,
    95  		},
    96  		{
    97  			method: http.MethodGet,
    98  			url:    "http://localhost/a/b?c=d&e=f",
    99  			strict: true,
   100  			crypt:  true,
   101  		},
   102  		{
   103  			method: http.MethodPost,
   104  			url:    "http://localhost/a/b?c=d&e=f",
   105  			body:   "hello",
   106  			strict: true,
   107  			crypt:  true,
   108  		},
   109  		{
   110  			method:     http.MethodGet,
   111  			url:        "http://localhost/a/b?c=d&e=f",
   112  			strict:     true,
   113  			crypt:      true,
   114  			timestamp:  time.Now().Add(timeDiff).Unix(),
   115  			statusCode: http.StatusForbidden,
   116  		},
   117  		{
   118  			method:     http.MethodPost,
   119  			url:        "http://localhost/a/b?c=d&e=f",
   120  			body:       "hello",
   121  			strict:     true,
   122  			crypt:      true,
   123  			timestamp:  time.Now().Add(-timeDiff).Unix(),
   124  			statusCode: http.StatusForbidden,
   125  		},
   126  		{
   127  			method:     http.MethodPost,
   128  			url:        "http://remotehost/",
   129  			body:       "hello",
   130  			strict:     true,
   131  			crypt:      true,
   132  			requestUri: "http://localhost/a/b?c=d&e=f",
   133  		},
   134  		{
   135  			method:      http.MethodPost,
   136  			url:         "http://localhost/a/b?c=d&e=f",
   137  			body:        "hello",
   138  			strict:      false,
   139  			crypt:       true,
   140  			fingerprint: "badone",
   141  		},
   142  		{
   143  			method:      http.MethodPost,
   144  			url:         "http://localhost/a/b?c=d&e=f",
   145  			body:        "hello",
   146  			strict:      true,
   147  			crypt:       true,
   148  			timestamp:   time.Now().Add(-timeDiff).Unix(),
   149  			fingerprint: "badone",
   150  			statusCode:  http.StatusForbidden,
   151  		},
   152  		{
   153  			method:     http.MethodPost,
   154  			url:        "http://localhost/a/b?c=d&e=f",
   155  			body:       "hello",
   156  			strict:     true,
   157  			crypt:      true,
   158  			missHeader: true,
   159  			statusCode: http.StatusForbidden,
   160  		},
   161  		{
   162  			method: http.MethodHead,
   163  			url:    "http://localhost/a/b?c=d&e=f",
   164  			strict: true,
   165  			crypt:  false,
   166  		},
   167  		{
   168  			method:     http.MethodGet,
   169  			url:        "http://localhost/a/b?c=d&e=f",
   170  			strict:     true,
   171  			crypt:      false,
   172  			signature:  "badone",
   173  			statusCode: http.StatusForbidden,
   174  		},
   175  	}
   176  
   177  	for _, test := range tests {
   178  		t.Run(test.url, func(t *testing.T) {
   179  			if test.statusCode == 0 {
   180  				test.statusCode = http.StatusOK
   181  			}
   182  			if len(test.fingerprint) == 0 {
   183  				test.fingerprint = fingerprint
   184  			}
   185  			if test.timestamp == 0 {
   186  				test.timestamp = time.Now().Unix()
   187  			}
   188  
   189  			func() {
   190  				keyFile, err := createTempFile(priKey)
   191  				defer os.Remove(keyFile)
   192  
   193  				assert.Nil(t, err)
   194  				decrypter, err := codec.NewRsaDecrypter(keyFile)
   195  				assert.Nil(t, err)
   196  				contentSecurityHandler := ContentSecurityHandler(map[string]codec.RsaDecrypter{
   197  					fingerprint: decrypter,
   198  				}, time.Hour, test.strict)
   199  				handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   200  				}))
   201  
   202  				var reader io.Reader
   203  				if len(test.body) > 0 {
   204  					reader = strings.NewReader(test.body)
   205  				}
   206  				setting := requestSettings{
   207  					method:      test.method,
   208  					url:         test.url,
   209  					body:        reader,
   210  					strict:      test.strict,
   211  					crypt:       test.crypt,
   212  					requestUri:  test.requestUri,
   213  					timestamp:   test.timestamp,
   214  					fingerprint: test.fingerprint,
   215  					missHeader:  test.missHeader,
   216  					signature:   test.signature,
   217  				}
   218  				req, err := buildRequest(setting)
   219  				assert.Nil(t, err)
   220  				resp := httptest.NewRecorder()
   221  				handler.ServeHTTP(resp, req)
   222  				assert.Equal(t, test.statusCode, resp.Code)
   223  			}()
   224  		})
   225  	}
   226  }
   227  
   228  func TestContentSecurityHandler_UnsignedCallback(t *testing.T) {
   229  	keyFile, err := createTempFile(priKey)
   230  	defer os.Remove(keyFile)
   231  
   232  	assert.Nil(t, err)
   233  	decrypter, err := codec.NewRsaDecrypter(keyFile)
   234  	assert.Nil(t, err)
   235  	contentSecurityHandler := ContentSecurityHandler(
   236  		map[string]codec.RsaDecrypter{
   237  			fingerprint: decrypter,
   238  		},
   239  		time.Hour,
   240  		true,
   241  		func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
   242  			w.WriteHeader(http.StatusOK)
   243  		})
   244  	handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
   245  
   246  	setting := requestSettings{
   247  		method:    http.MethodGet,
   248  		url:       "http://localhost/a/b?c=d&e=f",
   249  		signature: "badone",
   250  	}
   251  	req, err := buildRequest(setting)
   252  	assert.Nil(t, err)
   253  	resp := httptest.NewRecorder()
   254  	handler.ServeHTTP(resp, req)
   255  	assert.Equal(t, http.StatusOK, resp.Code)
   256  }
   257  
   258  func TestContentSecurityHandler_UnsignedCallback_WrongTime(t *testing.T) {
   259  	keyFile, err := createTempFile(priKey)
   260  	defer os.Remove(keyFile)
   261  
   262  	assert.Nil(t, err)
   263  	decrypter, err := codec.NewRsaDecrypter(keyFile)
   264  	assert.Nil(t, err)
   265  	contentSecurityHandler := ContentSecurityHandler(
   266  		map[string]codec.RsaDecrypter{
   267  			fingerprint: decrypter,
   268  		},
   269  		time.Hour,
   270  		true,
   271  		func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
   272  			assert.Equal(t, httpx.CodeSignatureWrongTime, code)
   273  			w.WriteHeader(http.StatusOK)
   274  		})
   275  	handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
   276  
   277  	reader := strings.NewReader("hello")
   278  	setting := requestSettings{
   279  		method:      http.MethodPost,
   280  		url:         "http://localhost/a/b?c=d&e=f",
   281  		body:        reader,
   282  		strict:      true,
   283  		crypt:       true,
   284  		timestamp:   time.Now().Add(time.Hour * 24 * 365).Unix(),
   285  		fingerprint: fingerprint,
   286  	}
   287  	req, err := buildRequest(setting)
   288  	assert.Nil(t, err)
   289  	resp := httptest.NewRecorder()
   290  	handler.ServeHTTP(resp, req)
   291  	assert.Equal(t, http.StatusOK, resp.Code)
   292  }
   293  
   294  func buildRequest(rs requestSettings) (*http.Request, error) {
   295  	var bodyStr string
   296  	var err error
   297  
   298  	if rs.crypt && rs.body != nil {
   299  		var buf bytes.Buffer
   300  		io.Copy(&buf, rs.body)
   301  		bodyBytes, err := codec.EcbEncrypt(key, buf.Bytes())
   302  		if err != nil {
   303  			return nil, err
   304  		}
   305  		bodyStr = base64.StdEncoding.EncodeToString(bodyBytes)
   306  	}
   307  
   308  	r := httptest.NewRequest(rs.method, rs.url, strings.NewReader(bodyStr))
   309  	if len(rs.signature) == 0 {
   310  		sha := sha256.New()
   311  		sha.Write([]byte(bodyStr))
   312  		bodySign := fmt.Sprintf("%x", sha.Sum(nil))
   313  		var path string
   314  		var query string
   315  		if len(rs.requestUri) > 0 {
   316  			u, err := url.Parse(rs.requestUri)
   317  			if err != nil {
   318  				return nil, err
   319  			}
   320  
   321  			path = u.Path
   322  			query = u.RawQuery
   323  		} else {
   324  			path = r.URL.Path
   325  			query = r.URL.RawQuery
   326  		}
   327  		contentOfSign := strings.Join([]string{
   328  			strconv.FormatInt(rs.timestamp, 10),
   329  			rs.method,
   330  			path,
   331  			query,
   332  			bodySign,
   333  		}, "\n")
   334  		rs.signature = codec.HmacBase64([]byte(key), contentOfSign)
   335  	}
   336  
   337  	var mode string
   338  	if rs.crypt {
   339  		mode = "1"
   340  	} else {
   341  		mode = "0"
   342  	}
   343  	content := strings.Join([]string{
   344  		"version=v1",
   345  		"type=" + mode,
   346  		fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)),
   347  		"time=" + strconv.FormatInt(rs.timestamp, 10),
   348  	}, "; ")
   349  
   350  	encrypter, err := codec.NewRsaEncrypter([]byte(pubKey))
   351  	if err != nil {
   352  		log.Fatal(err)
   353  	}
   354  
   355  	output, err := encrypter.Encrypt([]byte(content))
   356  	if err != nil {
   357  		log.Fatal(err)
   358  	}
   359  
   360  	encryptedContent := base64.StdEncoding.EncodeToString(output)
   361  	if !rs.missHeader {
   362  		r.Header.Set(httpx.ContentSecurity, strings.Join([]string{
   363  			fmt.Sprintf("key=%s", rs.fingerprint),
   364  			"secret=" + encryptedContent,
   365  			"signature=" + rs.signature,
   366  		}, "; "))
   367  	}
   368  	if len(rs.requestUri) > 0 {
   369  		r.Header.Set("X-Request-Uri", rs.requestUri)
   370  	}
   371  
   372  	return r, nil
   373  }
   374  
   375  func createTempFile(body []byte) (string, error) {
   376  	tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
   377  	if err != nil {
   378  		return "", err
   379  	}
   380  
   381  	tmpFile.Close()
   382  	err = os.WriteFile(tmpFile.Name(), body, os.ModePerm)
   383  	if err != nil {
   384  		return "", err
   385  	}
   386  
   387  	return tmpFile.Name(), nil
   388  }