go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/swarming/server/hmactoken/hmactoken_test.go (about)

     1  // Copyright 2023 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package hmactoken
    16  
    17  import (
    18  	"crypto/hmac"
    19  	"crypto/sha256"
    20  	"fmt"
    21  	"testing"
    22  
    23  	"google.golang.org/protobuf/proto"
    24  
    25  	"go.chromium.org/luci/server/secrets"
    26  
    27  	internalspb "go.chromium.org/luci/swarming/proto/internals"
    28  
    29  	. "github.com/smartystreets/goconvey/convey"
    30  	. "go.chromium.org/luci/common/testing/assertions"
    31  )
    32  
    33  func TestValidateToken(t *testing.T) {
    34  	t.Parallel()
    35  
    36  	Convey("With secret", t, func() {
    37  		s := NewStaticSecret(secrets.Secret{
    38  			Active:  []byte("secret"),
    39  			Passive: [][]byte{[]byte("also-secret")},
    40  		})
    41  
    42  		Convey("Good token", func() {
    43  			original := &internalspb.PollState{Id: "some-id"}
    44  
    45  			extracted := &internalspb.PollState{}
    46  			err := s.ValidateToken(genPollToken(
    47  				original,
    48  				internalspb.TaggedMessage_POLL_STATE,
    49  				[]byte("secret"),
    50  			), extracted)
    51  			So(err, ShouldBeNil)
    52  			So(extracted, ShouldResembleProto, original)
    53  
    54  			// Non-active secret is also OK.
    55  			extracted = &internalspb.PollState{}
    56  			err = s.ValidateToken(genPollToken(
    57  				original,
    58  				internalspb.TaggedMessage_POLL_STATE,
    59  				[]byte("also-secret"),
    60  			), extracted)
    61  			So(err, ShouldBeNil)
    62  			So(extracted, ShouldResembleProto, original)
    63  		})
    64  
    65  		Convey("Bad TaggedMessage proto", func() {
    66  			err := s.ValidateToken([]byte("what is this"), &internalspb.PollState{})
    67  			So(err, ShouldErrLike, "failed to deserialize TaggedMessage")
    68  		})
    69  
    70  		Convey("Wrong type", func() {
    71  			err := s.ValidateToken(genPollToken(
    72  				&internalspb.PollState{Id: "some-id"},
    73  				123,
    74  				[]byte("secret"),
    75  			), &internalspb.PollState{})
    76  			So(err, ShouldErrLike, "invalid payload type")
    77  		})
    78  
    79  		Convey("Bad MAC", func() {
    80  			err := s.ValidateToken(genPollToken(
    81  				&internalspb.PollState{Id: "some-id"},
    82  				internalspb.TaggedMessage_POLL_STATE,
    83  				[]byte("some-other-secret"),
    84  			), &internalspb.PollState{})
    85  			So(err, ShouldErrLike, "bad token HMAC")
    86  		})
    87  	})
    88  }
    89  
    90  func TestGenerateToken(t *testing.T) {
    91  	t.Parallel()
    92  
    93  	Convey("With secret", t, func() {
    94  		s := NewStaticSecret(secrets.Secret{
    95  			Active:  []byte("secret"),
    96  			Passive: [][]byte{[]byte("also-secret")},
    97  		})
    98  
    99  		Convey("PollState", func() {
   100  			original := &internalspb.PollState{Id: "testing"}
   101  			tok, err := s.GenerateToken(original)
   102  			So(err, ShouldBeNil)
   103  
   104  			decoded := &internalspb.PollState{}
   105  			So(s.ValidateToken(tok, decoded), ShouldBeNil)
   106  
   107  			So(decoded, ShouldResembleProto, original)
   108  		})
   109  
   110  		Convey("BotSession", func() {
   111  			original := &internalspb.BotSession{RbeBotSessionId: "testing"}
   112  			tok, err := s.GenerateToken(original)
   113  			So(err, ShouldBeNil)
   114  
   115  			decoded := &internalspb.BotSession{}
   116  			So(s.ValidateToken(tok, decoded), ShouldBeNil)
   117  
   118  			So(decoded, ShouldResembleProto, original)
   119  		})
   120  	})
   121  }
   122  
   123  func genPollToken(state *internalspb.PollState, typ internalspb.TaggedMessage_PayloadType, secret []byte) []byte {
   124  	payload, err := proto.Marshal(state)
   125  	if err != nil {
   126  		panic(err)
   127  	}
   128  
   129  	mac := hmac.New(sha256.New, secret)
   130  	_, _ = fmt.Fprintf(mac, "%d\n", typ)
   131  	_, _ = mac.Write(payload)
   132  	digest := mac.Sum(nil)
   133  
   134  	blob, err := proto.Marshal(&internalspb.TaggedMessage{
   135  		PayloadType: typ,
   136  		Payload:     payload,
   137  		HmacSha256:  digest,
   138  	})
   139  	if err != nil {
   140  		panic(err)
   141  	}
   142  	return blob
   143  }