github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/inttesting/integrationtest/sigs.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package integrationtest
    16  
    17  import (
    18  	"crypto"
    19  	"crypto/ecdsa"
    20  	"crypto/elliptic"
    21  	"crypto/rand"
    22  	"crypto/rsa"
    23  	"crypto/x509"
    24  	"crypto/x509/pkix"
    25  	"math/big"
    26  	"net"
    27  	"sync/atomic"
    28  	"testing"
    29  
    30  	log "github.com/golang/glog"
    31  	"github.com/google/fleetspeak/fleetspeak/src/client/signer"
    32  	"github.com/google/fleetspeak/fleetspeak/src/server/authorizer"
    33  
    34  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    35  )
    36  
    37  type testAuthorizer struct {
    38  	authorizer.PermissiveAuthorizer
    39  	authCount   int64
    40  	t           *testing.T
    41  	expectCount int
    42  	expectValid map[int64]bool
    43  	expectType  map[int64]x509.SignatureAlgorithm
    44  }
    45  
    46  func (a *testAuthorizer) Allow4(_ net.Addr, _ authorizer.ContactInfo, _ authorizer.ClientInfo, sigs []authorizer.SignatureInfo) (accept bool, validationInfo *fspb.ValidationInfo) {
    47  	atomic.AddInt64(&a.authCount, 1)
    48  	if len(sigs) != a.expectCount {
    49  		a.t.Errorf("expected %d sigs, got %d", a.expectCount, len(sigs))
    50  	}
    51  	for _, s := range sigs {
    52  		serial := s.Certificate[0].SerialNumber.Int64()
    53  		if a.expectType[serial] != s.Algorithm {
    54  			a.t.Errorf("Expected %v to have signature of type %v, got %v", serial, a.expectType[serial], s.Algorithm)
    55  		}
    56  		if a.expectValid[serial] != s.Valid {
    57  			a.t.Errorf("Expected %v to have valid=%t got valid=%t", serial, a.expectValid[serial], s.Valid)
    58  		}
    59  	}
    60  	return true, &fspb.ValidationInfo{Tags: map[string]string{"result": "Valid"}}
    61  }
    62  
    63  type testSigner struct {
    64  	cert   *x509.Certificate
    65  	alg    x509.SignatureAlgorithm
    66  	hash   crypto.Hash
    67  	signer crypto.Signer
    68  	t      *testing.T
    69  }
    70  
    71  func (s *testSigner) SignContact(data []byte) *fspb.Signature {
    72  	h := s.hash.New()
    73  	h.Write(data)
    74  	hashed := h.Sum(nil)
    75  	sig, err := s.signer.Sign(rand.Reader, hashed, s.hash)
    76  	if err != nil {
    77  		log.Exitf("Unable to sign hashed of length %d with (%v): %v", len(hashed), s.hash, err)
    78  	}
    79  	return &fspb.Signature{
    80  		Certificate: [][]byte{s.cert.Raw},
    81  		Algorithm:   int32(s.alg),
    82  		Signature:   sig,
    83  	}
    84  }
    85  
    86  func makeAuthorizerSigners(t *testing.T) (*testAuthorizer, []signer.Signer) {
    87  	var sigs []signer.Signer
    88  	cases := []struct {
    89  		serial  int64
    90  		ka      x509.PublicKeyAlgorithm
    91  		hash    crypto.Hash
    92  		sa      x509.SignatureAlgorithm
    93  		invalid bool
    94  	}{
    95  		{
    96  			serial: 42,
    97  			ka:     x509.RSA,
    98  			hash:   crypto.SHA256,
    99  			sa:     x509.SHA256WithRSA,
   100  		},
   101  		{
   102  			serial: 49,
   103  			ka:     x509.ECDSA,
   104  			hash:   crypto.SHA256,
   105  			sa:     x509.ECDSAWithSHA256,
   106  		},
   107  		{
   108  			serial:  50,
   109  			ka:      x509.ECDSA,
   110  			hash:    crypto.SHA1, // broken sig - wrong hash used to generate it
   111  			sa:      x509.ECDSAWithSHA256,
   112  			invalid: true,
   113  		},
   114  	}
   115  	auth := testAuthorizer{
   116  		t:           t,
   117  		expectCount: len(cases),
   118  		expectValid: make(map[int64]bool),
   119  		expectType:  make(map[int64]x509.SignatureAlgorithm),
   120  	}
   121  	for _, c := range cases {
   122  		cert, key := makeCert(c.serial, c.ka)
   123  		sigs = append(sigs, &testSigner{
   124  			cert:   cert,
   125  			alg:    c.sa,
   126  			hash:   c.hash,
   127  			signer: key.(crypto.Signer),
   128  			t:      t,
   129  		})
   130  		auth.expectValid[c.serial] = !c.invalid
   131  		auth.expectType[c.serial] = c.sa
   132  	}
   133  	return &auth, sigs
   134  }
   135  
   136  func makeCert(serial int64, alg x509.PublicKeyAlgorithm) (*x509.Certificate, crypto.PrivateKey) {
   137  	var k crypto.PrivateKey
   138  	var pk crypto.PublicKey
   139  	switch alg {
   140  	case x509.RSA:
   141  		rk, err := rsa.GenerateKey(rand.Reader, 4096)
   142  		if err != nil {
   143  			log.Fatalf("Unable to generate RSA key: %v", err)
   144  		}
   145  		k = rk
   146  		pk = rk.Public()
   147  	case x509.ECDSA:
   148  		ek, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
   149  		if err != nil {
   150  			log.Fatalf("Unable to generate ECDSA key: %v", err)
   151  		}
   152  		k = ek
   153  		pk = ek.Public()
   154  	default:
   155  		log.Fatalf("Unknown public key algorithm type: %v", alg)
   156  	}
   157  
   158  	tmpl := x509.Certificate{
   159  		Subject:               pkix.Name{CommonName: "Test client"},
   160  		SerialNumber:          big.NewInt(serial),
   161  		BasicConstraintsValid: true,
   162  	}
   163  
   164  	b, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, pk, k)
   165  	if err != nil {
   166  		log.Fatalf("Unable to create x509 cert: %v", err)
   167  	}
   168  	res, err := x509.ParseCertificate(b)
   169  	if err != nil {
   170  		log.Fatalf("Unable to parse newly created cert: %v", err)
   171  	}
   172  	return res, k
   173  }