github.com/avenga/couper@v1.12.2/server/tls_test.go (about)

     1  package server_test
     2  
     3  import (
     4  	"crypto/tls"
     5  	"os"
     6  	"reflect"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/avenga/couper/config"
    11  	"github.com/avenga/couper/errors"
    12  	"github.com/avenga/couper/internal/test"
    13  	"github.com/avenga/couper/server"
    14  )
    15  
    16  func Test_LoadClientCertificate(t *testing.T) {
    17  	helper := test.New(t)
    18  
    19  	now := time.Now()
    20  	selfSignedCert, serr := server.NewCertificate(time.Hour, nil, &now)
    21  	helper.Must(serr)
    22  	t.Logf("generated certificates in %s", time.Since(now).String())
    23  
    24  	// obsolete for this test
    25  	selfSignedCert.ClientIntermediate.PrivateKey = nil
    26  	selfSignedCert.Client.PrivateKey = nil
    27  
    28  	for _, format := range []string{"DER", "PEM"} {
    29  		var caCertBytes, certBytes []byte
    30  		switch format {
    31  		case "DER":
    32  			caCertBytes = selfSignedCert.ClientIntermediate.Certificate[0]
    33  			certBytes = selfSignedCert.Client.Certificate[0]
    34  		case "PEM":
    35  			caCertBytes = selfSignedCert.ClientIntermediateCertificate.Certificate
    36  			certBytes = selfSignedCert.ClientCertificate.Certificate
    37  		}
    38  
    39  		pattern := "couper_test_tls_" + format
    40  		tmpCertFile, err := os.CreateTemp("", pattern)
    41  		helper.Must(err)
    42  		_, err = tmpCertFile.Write(caCertBytes)
    43  		helper.Must(err)
    44  		helper.Must(tmpCertFile.Close())
    45  		defer os.Remove(tmpCertFile.Name())
    46  
    47  		tmpLeafCertFile, err := os.CreateTemp("", pattern)
    48  		helper.Must(err)
    49  		_, err = tmpLeafCertFile.Write(certBytes)
    50  		helper.Must(err)
    51  		helper.Must(tmpLeafCertFile.Close())
    52  		defer os.Remove(tmpLeafCertFile.Name())
    53  
    54  		tests := []struct {
    55  			name         string
    56  			config       *config.ClientCertificate
    57  			wantCaCert   tls.Certificate
    58  			wantLeafCert tls.Certificate
    59  			wantErr      bool
    60  		}{
    61  			{"nil clientCertificate", nil, tls.Certificate{}, tls.Certificate{}, false},
    62  			{"empty clientCertificate", &config.ClientCertificate{}, tls.Certificate{}, tls.Certificate{}, true},
    63  			{"malformed clientCertificate value", &config.ClientCertificate{CA: "asdf"}, tls.Certificate{}, tls.Certificate{}, true},
    64  			{"clientCertificate CA value", &config.ClientCertificate{
    65  				CA: string(caCertBytes),
    66  			}, *selfSignedCert.ClientIntermediate, tls.Certificate{}, false},
    67  			{"clientCertificate CA /w Leaf value", &config.ClientCertificate{
    68  				CA:   string(caCertBytes),
    69  				Leaf: string(certBytes),
    70  			}, *selfSignedCert.ClientIntermediate, *selfSignedCert.Client, false},
    71  			{"clientCertificate /w Leaf value", &config.ClientCertificate{
    72  				Leaf: string(certBytes),
    73  			}, tls.Certificate{}, *selfSignedCert.Client, false},
    74  		}
    75  		for _, tt := range tests {
    76  			t.Run(format+"/"+tt.name, func(t *testing.T) {
    77  				gotCaCert, gotLeafCert, err := server.LoadClientCertificate(tt.config)
    78  				if (err != nil) != tt.wantErr {
    79  					msg := err.Error()
    80  					if cerr, ok := err.(errors.GoError); ok {
    81  						msg = cerr.LogError()
    82  					}
    83  					t.Errorf("LoadClientCertificate() error = %v, wantErr %v", msg, tt.wantErr)
    84  					return
    85  				}
    86  				if !reflect.DeepEqual(gotCaCert, tt.wantCaCert) {
    87  					t.Errorf("LoadClientCertificate() CA\n\tgot:\t%v\n\twant:\t%v\n", gotCaCert, tt.wantCaCert)
    88  				}
    89  				if !reflect.DeepEqual(gotLeafCert, tt.wantLeafCert) {
    90  					t.Errorf("LoadClientCertificate() Leaf\n\tgot:\t%v\n\twant:\t%v\n", gotLeafCert, tt.wantLeafCert)
    91  				}
    92  			})
    93  		}
    94  	}
    95  }
    96  
    97  func Test_LoadServerCertificate(t *testing.T) {
    98  	helper := test.New(t)
    99  
   100  	now := time.Now()
   101  	selfSignedCert, serr := server.NewCertificate(time.Hour, nil, &now)
   102  	helper.Must(serr)
   103  	t.Logf("generated certificates in %s", time.Since(now).String())
   104  
   105  	for _, format := range []string{"DER", "PEM"} {
   106  		var certBytes, privateKeyBytes []byte
   107  		switch format {
   108  		case "DER":
   109  			certBytes = selfSignedCert.Server.Certificate[0]
   110  			privateKeyBytes = selfSignedCert.ServerPrivateKey
   111  		case "PEM":
   112  			certBytes = selfSignedCert.ServerCertificate.Certificate
   113  			privateKeyBytes = selfSignedCert.ServerCertificate.PrivateKey
   114  		}
   115  
   116  		pattern := "couper_test_tls_" + format
   117  		tmpCertFile, err := os.CreateTemp("", pattern)
   118  		helper.Must(err)
   119  		_, err = tmpCertFile.Write(certBytes)
   120  		helper.Must(err)
   121  		helper.Must(tmpCertFile.Close())
   122  		defer os.Remove(tmpCertFile.Name())
   123  
   124  		tmpKeyFile, err := os.CreateTemp("", pattern)
   125  		helper.Must(err)
   126  		_, err = tmpKeyFile.Write(privateKeyBytes)
   127  		helper.Must(err)
   128  		helper.Must(tmpKeyFile.Close())
   129  		defer os.Remove(tmpKeyFile.Name())
   130  
   131  		tests := []struct {
   132  			name    string
   133  			config  *config.ServerCertificate
   134  			want    tls.Certificate
   135  			wantErr bool
   136  		}{
   137  			{"nil serverCertificate", nil, tls.Certificate{}, false},
   138  			{"empty serverCertificate", &config.ServerCertificate{}, tls.Certificate{}, true},
   139  			{"with serverCertificateValue", &config.ServerCertificate{
   140  				PublicKey:  string(certBytes),
   141  				PrivateKey: string(privateKeyBytes),
   142  			}, *selfSignedCert.Server, false},
   143  			{"with serverCertificateFile", &config.ServerCertificate{
   144  				PublicKeyFile:  tmpCertFile.Name(),
   145  				PrivateKeyFile: tmpKeyFile.Name(),
   146  			}, *selfSignedCert.Server, false},
   147  		}
   148  		for _, tt := range tests {
   149  			t.Run(format+"/"+tt.name, func(t *testing.T) {
   150  				got, terr := server.LoadServerCertificate(tt.config)
   151  				if (terr != nil) != tt.wantErr {
   152  					t.Errorf("LoadServerCertificate() error = %v, wantErr %v", terr, tt.wantErr)
   153  					return
   154  				}
   155  				if !reflect.DeepEqual(got, tt.want) {
   156  					t.Errorf("LoadServerCertificate()\n\tgot:\t%v\n\twant:\t%v\n", got, tt.want)
   157  				}
   158  			})
   159  		}
   160  	}
   161  }