github.com/stripe/stripe-go/v76@v76.25.0/webhook/client_test.go (about)

     1  package webhook
     2  
     3  import (
     4  	"encoding/hex"
     5  	"fmt"
     6  	"strings"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/stripe/stripe-go/v76"
    11  )
    12  
    13  var testPayload = []byte(fmt.Sprintf(`{
    14    "id": "evt_test_webhook",
    15    "object": "event",
    16    "api_version": "%s"
    17  }`, stripe.APIVersion))
    18  var testPayloadWithAPIVersionMismatch = []byte(`{
    19  	"id": "evt_test_webhook",
    20  	"object": "event",
    21  	"api_version": "2020-01-01"
    22    }`)
    23  var testSecret = "whsec_test_secret"
    24  
    25  func newSignedPayload(options ...func(*SignedPayload)) *SignedPayload {
    26  	signedPayload := &SignedPayload{}
    27  	signedPayload.Timestamp = time.Now()
    28  	signedPayload.Payload = testPayload
    29  	signedPayload.Secret = testSecret
    30  	signedPayload.Scheme = "v1"
    31  
    32  	for _, opt := range options {
    33  		opt(signedPayload)
    34  	}
    35  
    36  	if signedPayload.Signature == nil {
    37  		signedPayload.Signature = ComputeSignature(signedPayload.Timestamp, signedPayload.Payload, signedPayload.Secret)
    38  	}
    39  	signedPayload.Header = generateHeader(*signedPayload)
    40  	return signedPayload
    41  }
    42  
    43  func (p *SignedPayload) hexSignature() string {
    44  	return hex.EncodeToString(p.Signature)
    45  }
    46  
    47  func TestTokenNew(t *testing.T) {
    48  	p := newSignedPayload()
    49  
    50  	evt, err := ConstructEvent(p.Payload, p.Header, p.Secret)
    51  	if err != nil {
    52  		t.Errorf("Error validating signature: %v", err)
    53  	} else if evt.ID != "evt_test_webhook" {
    54  		t.Errorf("Expected a parsed event matching the test Payload, got %v", evt)
    55  	}
    56  
    57  	p = newSignedPayload(func(p *SignedPayload) {
    58  		p.Payload = append(p.Payload, byte('['))
    59  	})
    60  	evt, err = ConstructEvent(p.Payload, p.Header, p.Secret)
    61  	if err == nil {
    62  		t.Errorf("Invalid JSON did not cause a parse error")
    63  	}
    64  
    65  	p = newSignedPayload()
    66  	err = ValidatePayload(p.Payload, "", p.Secret)
    67  	if err != ErrNotSigned {
    68  		t.Errorf("Expected ErrNotSigned from missing signature, got %v", err)
    69  	}
    70  	evt, err = ConstructEvent(p.Payload, "", p.Secret)
    71  	if err != ErrNotSigned {
    72  		t.Errorf("Expected ErrNotSigned from missing signature, got %v", err)
    73  	}
    74  
    75  	evt, err = ConstructEvent(p.Payload, "v1,t=1", p.Secret)
    76  	if err != ErrInvalidHeader {
    77  		t.Errorf("Expected ErrInvalidHeader from bad header format, got %v", err)
    78  	}
    79  
    80  	err = ValidatePayload(p.Payload, "t=", p.Secret)
    81  	if err != ErrInvalidHeader {
    82  		t.Errorf("Expected ErrInvalidHeader from bad header format, got %v", err)
    83  	}
    84  	evt, err = ConstructEvent(p.Payload, "t=", p.Secret)
    85  	if err != ErrInvalidHeader {
    86  		t.Errorf("Expected ErrInvalidHeader from bad header format, got %v", err)
    87  	}
    88  
    89  	err = ValidatePayload(p.Payload, p.Header+",v1=bad_signature", p.Secret)
    90  	if err != nil {
    91  		t.Errorf("Received unexpected %v error with an unreadable signature in the header (should be ignored)", err)
    92  	}
    93  	evt, err = ConstructEvent(p.Payload, p.Header+",v1=bad_signature", p.Secret)
    94  	if err != nil {
    95  		t.Errorf("Received unexpected %v error with an unreadable signature in the header (should be ignored)", err)
    96  	}
    97  
    98  	p = newSignedPayload(func(p *SignedPayload) {
    99  		p.Scheme = "v0"
   100  	})
   101  	err = ValidatePayload(p.Payload, p.Header, p.Secret)
   102  	if err != ErrNoValidSignature {
   103  		t.Errorf("Expected error from mismatched schema, got %v", err)
   104  	}
   105  	evt, err = ConstructEvent(p.Payload, p.Header, p.Secret)
   106  	if err != ErrNoValidSignature {
   107  		t.Errorf("Expected error from mismatched schema, got %v", err)
   108  	}
   109  
   110  	p = newSignedPayload(func(p *SignedPayload) {
   111  		p.Signature = []byte("deadbeef")
   112  	})
   113  	err = ValidatePayload(p.Payload, p.Header, p.Secret)
   114  	if err != ErrNoValidSignature {
   115  		t.Errorf("Expected error from fake signature, got %v", err)
   116  	}
   117  	evt, err = ConstructEvent(p.Payload, p.Header, p.Secret)
   118  	if err != ErrNoValidSignature {
   119  		t.Errorf("Expected error from fake signature, got %v", err)
   120  	}
   121  
   122  	p = newSignedPayload()
   123  	p2 := newSignedPayload(func(p *SignedPayload) {
   124  		p.Secret = testSecret + "_rolled_key"
   125  	})
   126  	headerWithRolledKey := p.Header + ",v1=" + p2.hexSignature()
   127  	if p.hexSignature() == p2.hexSignature() {
   128  		t.Errorf("Got the same signature with two different secret keys")
   129  	}
   130  
   131  	err = ValidatePayload(p.Payload, headerWithRolledKey, p.Secret)
   132  	if err != nil {
   133  		t.Errorf("Expected to be able to decode webhook with old key after rolling key, but got %v", err)
   134  	}
   135  	evt, err = ConstructEvent(p.Payload, headerWithRolledKey, p.Secret)
   136  	if err != nil {
   137  		t.Errorf("Expected to be able to decode webhook with old key after rolling key, but got %v", err)
   138  	}
   139  	err = ValidatePayload(p.Payload, headerWithRolledKey, p2.Secret)
   140  	if err != nil {
   141  		t.Errorf("Expected to be able to decode webhook with new key after rolling key, but got %v", err)
   142  	}
   143  	evt, err = ConstructEvent(p.Payload, headerWithRolledKey, p2.Secret)
   144  	if err != nil {
   145  		t.Errorf("Expected to be able to decode webhook with new key after rolling key, but got %v", err)
   146  	}
   147  
   148  	p = newSignedPayload(func(p *SignedPayload) {
   149  		p.Timestamp = time.Now().Add(-15 * time.Second)
   150  	})
   151  	err = ValidatePayloadWithTolerance(p.Payload, p.Header, p.Secret, 10*time.Second)
   152  	if err != ErrTooOld {
   153  		t.Errorf("Received %v error when validating timestamp outside of allowed timing window", err)
   154  	}
   155  	evt, err = ConstructEventWithTolerance(p.Payload, p.Header, p.Secret, 10*time.Second)
   156  	if err != ErrTooOld {
   157  		t.Errorf("Received %v error when validating timestamp outside of allowed timing window", err)
   158  	}
   159  
   160  	err = ValidatePayloadWithTolerance(p.Payload, p.Header, p.Secret, 20*time.Second)
   161  	if err != nil {
   162  		t.Errorf("Received %v error when validating timestamp inside allowed timing window", err)
   163  	}
   164  	evt, err = ConstructEventWithTolerance(p.Payload, p.Header, p.Secret, 20*time.Second)
   165  	if err != nil {
   166  		t.Errorf("Received %v error when validating timestamp inside allowed timing window", err)
   167  	}
   168  
   169  	p = newSignedPayload(func(p *SignedPayload) {
   170  		p.Timestamp = time.Unix(12345, 0)
   171  	})
   172  	err = ValidatePayloadIgnoringTolerance(p.Payload, p.Header, p.Secret)
   173  	if err != nil {
   174  		t.Errorf("Received %v error when timestamp outside window but no tolerance specified", err)
   175  	}
   176  	evt, err = ConstructEventIgnoringTolerance(p.Payload, p.Header, p.Secret)
   177  	if err != nil {
   178  		t.Errorf("Received %v error when timestamp outside window but no tolerance specified", err)
   179  	}
   180  }
   181  
   182  func TestConstructEvent_ErrorOnAPIVersionMismatch(t *testing.T) {
   183  	p := newSignedPayload(func(p *SignedPayload) {
   184  		p.Payload = testPayloadWithAPIVersionMismatch
   185  	})
   186  
   187  	_, err := ConstructEvent(p.Payload, p.Header, p.Secret)
   188  
   189  	if err == nil {
   190  		t.Errorf("Expected error due to API version mismatch.")
   191  	}
   192  
   193  	if !strings.Contains(err.Error(), "Received event with API version") {
   194  		t.Errorf("Expected API version mismatch error but received %v", err)
   195  	}
   196  }
   197  
   198  func TestConstructEventWithOptions_IgnoreAPIVersionMismatch(t *testing.T) {
   199  
   200  	p := newSignedPayload(func(p *SignedPayload) {
   201  		p.Payload = testPayloadWithAPIVersionMismatch
   202  	})
   203  
   204  	evt, err := ConstructEventWithOptions(p.Payload, p.Header, p.Secret, ConstructEventOptions{IgnoreAPIVersionMismatch: true})
   205  
   206  	if err != nil {
   207  		t.Errorf("Expected no error due ignoreAPIVersionMismatch.")
   208  	}
   209  
   210  	if evt.ID != "evt_test_webhook" {
   211  		t.Errorf("Expected a parsed event matching the test Payload, got %v", evt)
   212  	}
   213  }
   214  
   215  func TestConstructEventWithOptions_UsesDefaultToleranceWhenNoneProvided(t *testing.T) {
   216  
   217  	p := newSignedPayload(func(p *SignedPayload) {
   218  		// Get close to the default tolerance, but give wiggle room to avoid
   219  		// a flaky test.
   220  		p.Timestamp = time.Now().Add(-DefaultTolerance).Add(1 * time.Second)
   221  	})
   222  
   223  	_, err := ConstructEventWithOptions(p.Payload, p.Header, p.Secret, ConstructEventOptions{})
   224  
   225  	if err != nil {
   226  		t.Errorf("Expected no error due tolerance, but got %v.", err)
   227  	}
   228  
   229  	p = newSignedPayload(func(p *SignedPayload) {
   230  		p.Timestamp = time.Now().Add(-DefaultTolerance).Add(-1 * time.Millisecond)
   231  	})
   232  
   233  	_, err = ConstructEventWithOptions(p.Payload, p.Header, p.Secret, ConstructEventOptions{})
   234  
   235  	if err != ErrTooOld {
   236  		t.Errorf("Expected error due to being too old, but got %v.", err)
   237  	}
   238  }