github.com/pion/dtls/v2@v2.2.12/certificate_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"crypto/tls"
     8  	"reflect"
     9  	"testing"
    10  
    11  	"github.com/pion/dtls/v2/pkg/crypto/selfsign"
    12  )
    13  
    14  func TestGetCertificate(t *testing.T) {
    15  	certificateWildcard, err := selfsign.GenerateSelfSignedWithDNS("*.test.test")
    16  	if err != nil {
    17  		t.Fatal(err)
    18  	}
    19  
    20  	certificateTest, err := selfsign.GenerateSelfSignedWithDNS("test.test", "www.test.test", "pop.test.test")
    21  	if err != nil {
    22  		t.Fatal(err)
    23  	}
    24  
    25  	certificateRandom, err := selfsign.GenerateSelfSigned()
    26  	if err != nil {
    27  		t.Fatal(err)
    28  	}
    29  
    30  	testCases := []struct {
    31  		localCertificates   []tls.Certificate
    32  		desc                string
    33  		serverName          string
    34  		expectedCertificate tls.Certificate
    35  		getCertificate      func(info *ClientHelloInfo) (*tls.Certificate, error)
    36  	}{
    37  		{
    38  			desc: "Simple match in CN",
    39  			localCertificates: []tls.Certificate{
    40  				certificateRandom,
    41  				certificateTest,
    42  				certificateWildcard,
    43  			},
    44  			serverName:          "test.test",
    45  			expectedCertificate: certificateTest,
    46  		},
    47  		{
    48  			desc: "Simple match in SANs",
    49  			localCertificates: []tls.Certificate{
    50  				certificateRandom,
    51  				certificateTest,
    52  				certificateWildcard,
    53  			},
    54  			serverName:          "www.test.test",
    55  			expectedCertificate: certificateTest,
    56  		},
    57  
    58  		{
    59  			desc: "Wildcard match",
    60  			localCertificates: []tls.Certificate{
    61  				certificateRandom,
    62  				certificateTest,
    63  				certificateWildcard,
    64  			},
    65  			serverName:          "foo.test.test",
    66  			expectedCertificate: certificateWildcard,
    67  		},
    68  		{
    69  			desc: "No match return first",
    70  			localCertificates: []tls.Certificate{
    71  				certificateRandom,
    72  				certificateTest,
    73  				certificateWildcard,
    74  			},
    75  			serverName:          "foo.bar",
    76  			expectedCertificate: certificateRandom,
    77  		},
    78  		{
    79  			desc: "Get certificate from callback",
    80  			getCertificate: func(info *ClientHelloInfo) (*tls.Certificate, error) {
    81  				return &certificateTest, nil
    82  			},
    83  			expectedCertificate: certificateTest,
    84  		},
    85  	}
    86  
    87  	for _, test := range testCases {
    88  		test := test
    89  		t.Run(test.desc, func(t *testing.T) {
    90  			cfg := &handshakeConfig{
    91  				localCertificates:   test.localCertificates,
    92  				localGetCertificate: test.getCertificate,
    93  			}
    94  			cert, err := cfg.getCertificate(&ClientHelloInfo{ServerName: test.serverName})
    95  			if err != nil {
    96  				t.Fatal(err)
    97  			}
    98  
    99  			if !reflect.DeepEqual(cert.Leaf, test.expectedCertificate.Leaf) {
   100  				t.Fatalf("Certificate does not match: expected(%v) actual(%v)", test.expectedCertificate.Leaf, cert.Leaf)
   101  			}
   102  		})
   103  	}
   104  }