github.com/lingyao2333/mo-zero@v1.4.1/rest/internal/security/contentsecurity_test.go (about)

     1  package security
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/md5"
     6  	"crypto/sha256"
     7  	"encoding/base64"
     8  	"fmt"
     9  	"io"
    10  	"log"
    11  	"net/http"
    12  	"os"
    13  	"strconv"
    14  	"strings"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/lingyao2333/mo-zero/core/codec"
    19  	"github.com/lingyao2333/mo-zero/core/fs"
    20  	"github.com/lingyao2333/mo-zero/rest/httpx"
    21  	"github.com/stretchr/testify/assert"
    22  )
    23  
    24  const (
    25  	pubKey = `-----BEGIN PUBLIC KEY-----
    26  MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCyeDYV2ieOtNDi6tuNtAbmUjN9
    27  pTHluAU5yiKEz8826QohcxqUKP3hybZBcm60p+rUxMAJFBJ8Dt+UJ6sEMzrf1rOF
    28  YOImVvORkXjpFU7sCJkhnLMs/kxtRzcZJG6ADUlG4GDCNcZpY/qELEvwgm2kCcHi
    29  tGC2mO8opFFFHTR0aQIDAQAB
    30  -----END PUBLIC KEY-----`
    31  	priKey = `-----BEGIN RSA PRIVATE KEY-----
    32  MIICXQIBAAKBgQCyeDYV2ieOtNDi6tuNtAbmUjN9pTHluAU5yiKEz8826QohcxqU
    33  KP3hybZBcm60p+rUxMAJFBJ8Dt+UJ6sEMzrf1rOFYOImVvORkXjpFU7sCJkhnLMs
    34  /kxtRzcZJG6ADUlG4GDCNcZpY/qELEvwgm2kCcHitGC2mO8opFFFHTR0aQIDAQAB
    35  AoGAcENv+jT9VyZkk6karLuG75DbtPiaN5+XIfAF4Ld76FWVOs9V88cJVON20xpx
    36  ixBphqexCMToj8MnXuHJEN5M9H15XXx/9IuiMm3FOw0i6o0+4V8XwHr47siT6T+r
    37  HuZEyXER/2qrm0nxyC17TXtd/+TtpfQWSbivl6xcAEo9RRECQQDj6OR6AbMQAIDn
    38  v+AhP/y7duDZimWJIuMwhigA1T2qDbtOoAEcjv3DB1dAswJ7clcnkxI9a6/0RDF9
    39  0IEHUcX9AkEAyHdcegWiayEnbatxWcNWm1/5jFnCN+GTRRFrOhBCyFr2ZdjFV4T+
    40  acGtG6omXWaZJy1GZz6pybOGy93NwLB93QJARKMJ0/iZDbOpHqI5hKn5mhd2Je25
    41  IHDCTQXKHF4cAQ+7njUvwIMLx2V5kIGYuMa5mrB/KMI6rmyvHv3hLewhnQJBAMMb
    42  cPUOENMllINnzk2oEd3tXiscnSvYL4aUeoErnGP2LERZ40/YD+mMZ9g6FVboaX04
    43  0oHf+k5mnXZD7WJyJD0CQQDJ2HyFbNaUUHK+lcifCibfzKTgmnNh9ZpePFumgJzI
    44  EfFE5H+nzsbbry2XgJbWzRNvuFTOLWn4zM+aFyy9WvbO
    45  -----END RSA PRIVATE KEY-----`
    46  	body = "hello world!"
    47  )
    48  
    49  var key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
    50  
    51  func TestContentSecurity(t *testing.T) {
    52  	tests := []struct {
    53  		name        string
    54  		mode        string
    55  		extraKey    string
    56  		extraSecret string
    57  		extraTime   string
    58  		err         error
    59  		code        int
    60  	}{
    61  		{
    62  			name: "encrypted",
    63  			mode: "1",
    64  		},
    65  		{
    66  			name: "unencrypted",
    67  			mode: "0",
    68  		},
    69  		{
    70  			name: "bad content type",
    71  			mode: "a",
    72  			err:  ErrInvalidContentType,
    73  		},
    74  		{
    75  			name:        "bad secret",
    76  			mode:        "1",
    77  			extraSecret: "any",
    78  			err:         ErrInvalidSecret,
    79  		},
    80  		{
    81  			name:     "bad key",
    82  			mode:     "1",
    83  			extraKey: "any",
    84  			err:      ErrInvalidKey,
    85  		},
    86  		{
    87  			name:      "bad time",
    88  			mode:      "1",
    89  			extraTime: "any",
    90  			code:      httpx.CodeSignatureInvalidHeader,
    91  		},
    92  	}
    93  
    94  	for _, test := range tests {
    95  		test := test
    96  		t.Run(test.name, func(t *testing.T) {
    97  			t.Parallel()
    98  
    99  			r, err := http.NewRequest(http.MethodPost, "http://localhost:3333/a/b?c=first&d=second",
   100  				strings.NewReader(body))
   101  			assert.Nil(t, err)
   102  
   103  			timestamp := time.Now().Unix()
   104  			sha := sha256.New()
   105  			sha.Write([]byte(body))
   106  			bodySign := fmt.Sprintf("%x", sha.Sum(nil))
   107  			contentOfSign := strings.Join([]string{
   108  				strconv.FormatInt(timestamp, 10),
   109  				http.MethodPost,
   110  				r.URL.Path,
   111  				r.URL.RawQuery,
   112  				bodySign,
   113  			}, "\n")
   114  			sign := hs256(key, contentOfSign)
   115  			content := strings.Join([]string{
   116  				"version=v1",
   117  				"type=" + test.mode,
   118  				fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)) + test.extraKey,
   119  				"time=" + strconv.FormatInt(timestamp, 10) + test.extraTime,
   120  			}, "; ")
   121  
   122  			encrypter, err := codec.NewRsaEncrypter([]byte(pubKey))
   123  			if err != nil {
   124  				log.Fatal(err)
   125  			}
   126  
   127  			output, err := encrypter.Encrypt([]byte(content))
   128  			if err != nil {
   129  				log.Fatal(err)
   130  			}
   131  
   132  			encryptedContent := base64.StdEncoding.EncodeToString(output)
   133  			r.Header.Set("X-Content-Security", strings.Join([]string{
   134  				fmt.Sprintf("key=%s", fingerprint(pubKey)),
   135  				"secret=" + encryptedContent + test.extraSecret,
   136  				"signature=" + sign,
   137  			}, "; "))
   138  
   139  			file, err := fs.TempFilenameWithText(priKey)
   140  			assert.Nil(t, err)
   141  			defer os.Remove(file)
   142  
   143  			dec, err := codec.NewRsaDecrypter(file)
   144  			assert.Nil(t, err)
   145  
   146  			header, err := ParseContentSecurity(map[string]codec.RsaDecrypter{
   147  				fingerprint(pubKey): dec,
   148  			}, r)
   149  			assert.Equal(t, test.err, err)
   150  			if err != nil {
   151  				return
   152  			}
   153  
   154  			encrypted := test.mode != "0"
   155  			assert.Equal(t, encrypted, header.Encrypted())
   156  			assert.Equal(t, test.code, VerifySignature(r, header, time.Minute))
   157  		})
   158  	}
   159  }
   160  
   161  func fingerprint(key string) string {
   162  	h := md5.New()
   163  	io.WriteString(h, key)
   164  	return base64.StdEncoding.EncodeToString(h.Sum(nil))
   165  }
   166  
   167  func hs256(key []byte, body string) string {
   168  	h := hmac.New(sha256.New, key)
   169  	io.WriteString(h, body)
   170  	return base64.StdEncoding.EncodeToString(h.Sum(nil))
   171  }