github.com/avenga/couper@v1.12.2/handler/transport/tls_test.go (about)

     1  package transport_test
     2  
     3  import (
     4  	"crypto/tls"
     5  	"os"
     6  	"reflect"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/avenga/couper/config"
    11  	"github.com/avenga/couper/errors"
    12  	"github.com/avenga/couper/handler/transport"
    13  	"github.com/avenga/couper/internal/test"
    14  	"github.com/avenga/couper/server"
    15  )
    16  
    17  func TestReadCertificates(t *testing.T) {
    18  	helper := test.New(t)
    19  
    20  	now := time.Now()
    21  	selfSignedCert, serr := server.NewCertificate(time.Hour, nil, &now)
    22  	helper.Must(serr)
    23  	t.Logf("generated certificates in %s", time.Since(now).String())
    24  
    25  	// obsolete for this test
    26  	selfSignedCert.CA.PrivateKey = nil
    27  	clientCertOnly := *selfSignedCert.Client
    28  	clientCertOnly.PrivateKey = nil
    29  
    30  	for _, format := range []string{"DER", "PEM"} {
    31  		var caCertBytes, certBytes []byte
    32  		switch format {
    33  		case "DER":
    34  			caCertBytes = selfSignedCert.CA.Certificate[0]
    35  			certBytes = selfSignedCert.Client.Certificate[0]
    36  		case "PEM":
    37  			caCertBytes = selfSignedCert.CACertificate.Certificate
    38  			certBytes = selfSignedCert.ClientCertificate.Certificate
    39  		}
    40  
    41  		pattern := "couper_test_tls_read_certs_" + format
    42  		tmpCaCertFile, ferr := os.CreateTemp("", pattern)
    43  		helper.Must(ferr)
    44  		_, ferr = tmpCaCertFile.Write(caCertBytes)
    45  		helper.Must(ferr)
    46  		helper.Must(tmpCaCertFile.Close())
    47  		defer os.Remove(tmpCaCertFile.Name())
    48  
    49  		tmpCertFile, ferr := os.CreateTemp("", pattern)
    50  		helper.Must(ferr)
    51  		_, ferr = tmpCertFile.Write(certBytes)
    52  		helper.Must(ferr)
    53  		helper.Must(tmpCertFile.Close())
    54  		defer os.Remove(tmpCertFile.Name())
    55  
    56  		tests := []struct {
    57  			name       string
    58  			conf       config.BackendTLS
    59  			wantSrv    tls.Certificate
    60  			wantClient tls.Certificate
    61  			wantErr    bool
    62  		}{
    63  			{"empty attributes", config.BackendTLS{}, tls.Certificate{}, tls.Certificate{}, false},
    64  			{"server ca file", config.BackendTLS{ServerCertificateFile: tmpCaCertFile.Name()}, *selfSignedCert.CA, tls.Certificate{}, false},
    65  			{"server ca value", config.BackendTLS{ServerCertificate: string(caCertBytes)}, *selfSignedCert.CA, tls.Certificate{}, false},
    66  			{"server ca file + value", config.BackendTLS{ServerCertificateFile: tmpCaCertFile.Name(), ServerCertificate: string(caCertBytes)}, tls.Certificate{}, tls.Certificate{}, true},
    67  			// TODO: testCase with combined crt+key PEM file
    68  			{"client ca file /w malformed key", config.BackendTLS{ClientCertificateFile: tmpCertFile.Name(), ClientPrivateKey: "malformed"}, tls.Certificate{}, clientCertOnly, true},
    69  			{"client ca file /w key", config.BackendTLS{ClientCertificateFile: tmpCertFile.Name(), ClientPrivateKey: string(selfSignedCert.ClientPrivateKey)}, tls.Certificate{}, *selfSignedCert.Client, false},
    70  			{"client ca value /w key", config.BackendTLS{ClientCertificate: string(certBytes), ClientPrivateKey: string(selfSignedCert.ClientPrivateKey)}, tls.Certificate{}, *selfSignedCert.Client, false},
    71  			{"client ca file /wo key", config.BackendTLS{ClientCertificateFile: tmpCertFile.Name()}, tls.Certificate{}, tls.Certificate{}, true},
    72  			{"client ca value /wo key", config.BackendTLS{ClientCertificate: string(certBytes)}, tls.Certificate{}, tls.Certificate{}, true},
    73  			{"client ca file + value", config.BackendTLS{ClientCertificateFile: tmpCertFile.Name(), ClientCertificate: string(certBytes)}, tls.Certificate{}, tls.Certificate{}, true},
    74  		}
    75  		for _, tt := range tests {
    76  			t.Run(format+"/"+tt.name, func(t *testing.T) {
    77  				gotSrv, gotClient, err := transport.ReadCertificates(&tt.conf)
    78  				if (err != nil) != tt.wantErr {
    79  					msg := "<nil>"
    80  					if lerr, ok := err.(errors.GoError); ok {
    81  						msg = lerr.LogError()
    82  					} else if err != nil {
    83  						msg = err.Error()
    84  					}
    85  					t.Errorf("ReadCertificates() error = %v, wantErr %v", msg, tt.wantErr)
    86  					return
    87  				}
    88  				if !reflect.DeepEqual(gotSrv, tt.wantSrv) {
    89  					t.Errorf("ReadCertificates():\n\tgotSrv  %v\n\t\twant %v", gotSrv, tt.wantSrv)
    90  				}
    91  				if !reflect.DeepEqual(gotClient, tt.wantClient) {
    92  					t.Errorf("ReadCertificates():\n\tgotClient %v\n\t\t want %v", gotClient, tt.wantClient)
    93  				}
    94  			})
    95  		}
    96  	}
    97  }