github.com/anycable/anycable-go@v1.5.1/utils/message_verifier.go (about)

     1  package utils
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/sha256"
     6  	"crypto/subtle"
     7  	"encoding/base64"
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"strings"
    12  
    13  	"github.com/joomcode/errorx"
    14  )
    15  
    16  type MessageVerifier struct {
    17  	key []byte
    18  }
    19  
    20  func NewMessageVerifier(key string) *MessageVerifier {
    21  	return &MessageVerifier{key: []byte(key)}
    22  }
    23  
    24  func (m *MessageVerifier) Generate(payload interface{}) (string, error) {
    25  	payloadJson, err := json.Marshal(payload)
    26  
    27  	if err != nil {
    28  		return "", err
    29  	}
    30  
    31  	encoded := base64.StdEncoding.EncodeToString(payloadJson)
    32  
    33  	signature, err := m.Sign([]byte(encoded))
    34  
    35  	if err != nil {
    36  		return "", err
    37  	}
    38  
    39  	signed := encoded + "--" + string(signature)
    40  	return signed, nil
    41  }
    42  
    43  func (m *MessageVerifier) Verified(msg string) (interface{}, error) {
    44  	if err := m.Validate(msg); err != nil {
    45  		return "", errorx.Decorate(err, "failed to verify message")
    46  	}
    47  
    48  	parts := strings.Split(msg, "--")
    49  	data := parts[0]
    50  
    51  	jsonStr, err := base64.StdEncoding.DecodeString(data)
    52  
    53  	if err != nil {
    54  		return "", err
    55  	}
    56  
    57  	var result interface{}
    58  
    59  	if err = json.Unmarshal(jsonStr, &result); err != nil {
    60  		return "", err
    61  	}
    62  
    63  	return result, nil
    64  }
    65  
    66  // https://github.com/rails/rails/blob/061bf3156fb90ac6b8ec255dfa39492cf22d7b13/activesupport/lib/active_support/message_verifier.rb#L122
    67  func (m *MessageVerifier) Validate(msg string) error {
    68  	if msg == "" {
    69  		return errors.New("message is empty")
    70  	}
    71  
    72  	parts := strings.Split(msg, "--")
    73  
    74  	if len(parts) != 2 {
    75  		return fmt.Errorf("message must contain 2 parts, got %d", len(parts))
    76  	}
    77  
    78  	data := []byte(parts[0])
    79  	digest := []byte(parts[1])
    80  
    81  	if m.VerifySignature(data, digest) {
    82  		return nil
    83  	} else {
    84  		return errors.New("invalid signature")
    85  	}
    86  }
    87  
    88  func (m *MessageVerifier) Sign(payload []byte) ([]byte, error) {
    89  	digest := hmac.New(sha256.New, m.key)
    90  	_, err := digest.Write(payload)
    91  
    92  	if err != nil {
    93  		return nil, errorx.Decorate(err, "failed to sign payload")
    94  	}
    95  
    96  	return []byte(fmt.Sprintf("%x", digest.Sum(nil))), nil
    97  }
    98  
    99  func (m *MessageVerifier) VerifySignature(payload []byte, digest []byte) bool {
   100  	h := hmac.New(sha256.New, m.key)
   101  	h.Write(payload)
   102  
   103  	actual := []byte(fmt.Sprintf("%x", h.Sum(nil)))
   104  
   105  	return subtle.ConstantTimeEq(int32(len(actual)), int32(len(digest))) == 1 &&
   106  		subtle.ConstantTimeCompare(actual, digest) == 1
   107  }