github.com/creativeprojects/go-selfupdate@v1.2.0/validate_test.go (about)

     1  package selfupdate
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/ecdsa"
     6  	"crypto/x509"
     7  	"encoding/hex"
     8  	"encoding/pem"
     9  	"fmt"
    10  	"golang.org/x/crypto/openpgp"
    11  	"golang.org/x/crypto/openpgp/armor"
    12  	"os"
    13  	"testing"
    14  
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/require"
    17  )
    18  
    19  func TestValidatorAssetNames(t *testing.T) {
    20  	filename := "asset"
    21  	for _, test := range []struct {
    22  		validator      Validator
    23  		validationName string
    24  	}{
    25  		{
    26  			validator:      &SHAValidator{},
    27  			validationName: filename + ".sha256",
    28  		},
    29  		{
    30  			validator:      &ECDSAValidator{},
    31  			validationName: filename + ".sig",
    32  		},
    33  		{
    34  			validator:      &PGPValidator{},
    35  			validationName: filename + ".asc",
    36  		},
    37  		{
    38  			validator:      &PGPValidator{Binary: true},
    39  			validationName: filename + ".sig",
    40  		},
    41  		{
    42  			validator:      &ChecksumValidator{"funny_sha256"},
    43  			validationName: "funny_sha256",
    44  		},
    45  	} {
    46  		want := test.validationName
    47  		got := test.validator.GetValidationAssetName(filename)
    48  		if want != got {
    49  			t.Errorf("Wanted %q but got %q", want, got)
    50  		}
    51  	}
    52  }
    53  
    54  // ======= PatternValidator ================================================
    55  
    56  func TestPatternValidator(t *testing.T) {
    57  	data, err := os.ReadFile("testdata/foo.zip")
    58  	require.NoError(t, err)
    59  
    60  	hashData, err := os.ReadFile("testdata/foo.zip.sha256")
    61  	require.NoError(t, err)
    62  
    63  	t.Run("Mapping", func(t *testing.T) {
    64  		validator := new(PatternValidator).Add("foo.*", new(SHAValidator))
    65  		{
    66  			v, _ := validator.findValidator("foo.ext")
    67  			assert.IsType(t, &SHAValidator{}, v)
    68  		}
    69  
    70  		assert.True(t, validator.MustContinueValidation("foo.zip"))
    71  		assert.NoError(t, validator.Validate("foo.zip", data, hashData))
    72  		assert.Equal(t, "foo.zip.sha256", validator.GetValidationAssetName("foo.zip"))
    73  
    74  		assert.Error(t, validator.Validate("foo.zip", data, data))
    75  		assert.Error(t, validator.Validate("unmapped", data, hashData))
    76  	})
    77  
    78  	t.Run("MappingInvalidPanics", func(t *testing.T) {
    79  		assert.PanicsWithError(t, "failed adding \"\\\\\": syntax error in pattern", func() {
    80  			new(PatternValidator).Add("\\", new(SHAValidator))
    81  		})
    82  	})
    83  
    84  	t.Run("Skip", func(t *testing.T) {
    85  		validator := new(PatternValidator).SkipValidation("*.skipped")
    86  
    87  		assert.False(t, validator.MustContinueValidation("foo.skipped"))
    88  		assert.NoError(t, validator.Validate("foo.skipped", nil, nil))
    89  		assert.Equal(t, "foo.skipped", validator.GetValidationAssetName("foo.skipped"))
    90  	})
    91  
    92  	t.Run("Unmapped", func(t *testing.T) {
    93  		validator := new(PatternValidator)
    94  
    95  		assert.False(t, validator.MustContinueValidation("foo.zip"))
    96  		assert.ErrorIs(t, ErrValidatorNotFound, validator.Validate("foo.zip", data, hashData))
    97  		assert.Equal(t, "foo.zip", validator.GetValidationAssetName("foo.zip"))
    98  	})
    99  
   100  	t.Run("SupportsNesting", func(t *testing.T) {
   101  		nested := new(PatternValidator).Add("**/*.zip", new(SHAValidator))
   102  		validator := new(PatternValidator).Add("path/**", nested)
   103  		{
   104  			v, _ := validator.findValidator("path/foo")
   105  			assert.Equal(t, nested, v)
   106  		}
   107  
   108  		assert.True(t, validator.MustContinueValidation("path/foo.zip"))
   109  		assert.False(t, validator.MustContinueValidation("path/other"))
   110  		assert.NoError(t, validator.Validate("path/foo.zip", data, hashData))
   111  		assert.Error(t, validator.Validate("foo.zip", data, hashData))
   112  	})
   113  }
   114  
   115  // ======= SHAValidator ====================================================
   116  
   117  func TestSHAValidatorEmptyFile(t *testing.T) {
   118  	validator := &SHAValidator{}
   119  	data, err := os.ReadFile("testdata/foo.zip")
   120  	require.NoError(t, err)
   121  	err = validator.Validate("foo.zip", data, nil)
   122  	assert.EqualError(t, err, ErrIncorrectChecksumFile.Error())
   123  }
   124  
   125  func TestSHAValidatorInvalidFile(t *testing.T) {
   126  	validator := &SHAValidator{}
   127  	data, err := os.ReadFile("testdata/foo.zip")
   128  	require.NoError(t, err)
   129  	err = validator.Validate("foo.zip", data, []byte("blahblahblah\n"))
   130  	assert.EqualError(t, err, ErrIncorrectChecksumFile.Error())
   131  }
   132  
   133  func TestSHAValidator(t *testing.T) {
   134  	validator := &SHAValidator{}
   135  	data, err := os.ReadFile("testdata/foo.zip")
   136  	require.NoError(t, err)
   137  
   138  	hashData, err := os.ReadFile("testdata/foo.zip.sha256")
   139  	require.NoError(t, err)
   140  
   141  	err = validator.Validate("foo.zip", data, hashData)
   142  	assert.NoError(t, err)
   143  }
   144  
   145  func TestSHAValidatorFail(t *testing.T) {
   146  	validator := &SHAValidator{}
   147  	data, err := os.ReadFile("testdata/foo.zip")
   148  	require.NoError(t, err)
   149  
   150  	hashData, err := os.ReadFile("testdata/foo.zip.sha256")
   151  	require.NoError(t, err)
   152  
   153  	hashData[0] = '0'
   154  	err = validator.Validate("foo.zip", data, hashData)
   155  	assert.ErrorIs(t, err, ErrChecksumValidationFailed)
   156  }
   157  
   158  // ======= ECDSAValidator ====================================================
   159  
   160  func TestECDSAValidatorNoPublicKey(t *testing.T) {
   161  	validator := &ECDSAValidator{
   162  		PublicKey: nil,
   163  	}
   164  	data, err := os.ReadFile("testdata/foo.zip")
   165  	require.NoError(t, err)
   166  
   167  	signatureData, err := os.ReadFile("testdata/foo.zip.sig")
   168  	require.NoError(t, err)
   169  
   170  	err = validator.Validate("foo.zip", data, signatureData)
   171  	assert.EqualError(t, err, ErrECDSAValidationFailed.Error())
   172  }
   173  
   174  func TestECDSAValidatorEmptySignature(t *testing.T) {
   175  	validator := &ECDSAValidator{
   176  		PublicKey: getTestPublicKey(t),
   177  	}
   178  	data, err := os.ReadFile("testdata/foo.zip")
   179  	require.NoError(t, err)
   180  
   181  	err = validator.Validate("foo.zip", data, nil)
   182  	assert.EqualError(t, err, ErrInvalidECDSASignature.Error())
   183  }
   184  
   185  func TestECDSAValidator(t *testing.T) {
   186  	validator := &ECDSAValidator{
   187  		PublicKey: getTestPublicKey(t),
   188  	}
   189  	data, err := os.ReadFile("testdata/foo.zip")
   190  	require.NoError(t, err)
   191  
   192  	signatureData, err := os.ReadFile("testdata/foo.zip.sig")
   193  	require.NoError(t, err)
   194  
   195  	err = validator.Validate("foo.zip", data, signatureData)
   196  	assert.NoError(t, err)
   197  }
   198  
   199  func TestECDSAValidatorWithKeyFromPem(t *testing.T) {
   200  	pemData, err := os.ReadFile("testdata/Test.crt")
   201  	require.NoError(t, err)
   202  
   203  	validator := new(ECDSAValidator).WithPublicKey(pemData)
   204  	assert.True(t, getTestPublicKey(t).Equal(validator.PublicKey))
   205  
   206  	assert.PanicsWithError(t, "failed to decode PEM block", func() {
   207  		new(ECDSAValidator).WithPublicKey([]byte{})
   208  	})
   209  
   210  	assert.PanicsWithError(t, "failed to parse certificate in PEM block: x509: malformed certificate", func() {
   211  		new(ECDSAValidator).WithPublicKey([]byte(`
   212  -----BEGIN CERTIFICATE-----
   213  
   214  -----END CERTIFICATE-----
   215  `))
   216  	})
   217  }
   218  
   219  func TestECDSAValidatorFail(t *testing.T) {
   220  	validator := &ECDSAValidator{
   221  		PublicKey: getTestPublicKey(t),
   222  	}
   223  	data, err := os.ReadFile("testdata/foo.tar.xz")
   224  	require.NoError(t, err)
   225  
   226  	signatureData, err := os.ReadFile("testdata/foo.zip.sig")
   227  	require.NoError(t, err)
   228  
   229  	err = validator.Validate("foo.tar.xz", data, signatureData)
   230  	assert.EqualError(t, err, ErrECDSAValidationFailed.Error())
   231  }
   232  
   233  func getTestPublicKey(t *testing.T) *ecdsa.PublicKey {
   234  	pemData, err := os.ReadFile("testdata/Test.crt")
   235  	require.NoError(t, err)
   236  
   237  	block, _ := pem.Decode(pemData)
   238  	if block == nil || block.Type != "CERTIFICATE" {
   239  		t.Fatalf("failed to decode PEM block")
   240  	}
   241  
   242  	cert, err := x509.ParseCertificate(block.Bytes)
   243  	require.NoError(t, err)
   244  
   245  	pubKey, ok := cert.PublicKey.(*ecdsa.PublicKey)
   246  	if !ok {
   247  		t.Errorf("PublicKey is not ECDSA")
   248  	}
   249  	return pubKey
   250  }
   251  
   252  // ======= PGPValidator ======================================================
   253  
   254  func TestPGPValidator(t *testing.T) {
   255  	data, err := os.ReadFile("testdata/foo.zip")
   256  	require.NoError(t, err)
   257  
   258  	otherData, err := os.ReadFile("testdata/foo.tar.xz")
   259  	require.NoError(t, err)
   260  
   261  	keyRing, entity := getTestPGPKeyRing(t)
   262  	require.NotNil(t, keyRing)
   263  	require.NotNil(t, entity)
   264  
   265  	var signatureData []byte
   266  	{
   267  		signature := &bytes.Buffer{}
   268  		err = openpgp.ArmoredDetachSign(signature, entity, bytes.NewReader(data), nil)
   269  		require.NoError(t, err)
   270  		signatureData = signature.Bytes()
   271  	}
   272  
   273  	t.Run("NoPublicKey", func(t *testing.T) {
   274  		validator := new(PGPValidator)
   275  		err = validator.Validate("foo.zip", data, signatureData)
   276  		assert.ErrorIs(t, err, ErrPGPKeyRingNotSet)
   277  		err = validator.Validate("foo.zip", data, nil)
   278  		assert.ErrorIs(t, err, ErrPGPKeyRingNotSet)
   279  		err = validator.Validate("foo.zip", data, []byte{})
   280  		assert.ErrorIs(t, err, ErrPGPKeyRingNotSet)
   281  	})
   282  
   283  	t.Run("EmptySignature", func(t *testing.T) {
   284  		validator := new(PGPValidator).WithArmoredKeyRing(keyRing)
   285  		err = validator.Validate("foo.zip", data, nil)
   286  		assert.ErrorIs(t, err, ErrInvalidPGPSignature)
   287  		err = validator.Validate("foo.zip", data, []byte{})
   288  		assert.ErrorIs(t, err, ErrInvalidPGPSignature)
   289  	})
   290  
   291  	t.Run("InvalidSignature", func(t *testing.T) {
   292  		validator := new(PGPValidator).WithArmoredKeyRing(keyRing)
   293  		err = validator.Validate("foo.zip", data, []byte{0, 1, 2})
   294  		assert.ErrorIs(t, err, ErrInvalidPGPSignature)
   295  		err = validator.Validate("foo.zip", data, data)
   296  		assert.ErrorIs(t, err, ErrInvalidPGPSignature)
   297  	})
   298  
   299  	t.Run("ValidSignature", func(t *testing.T) {
   300  		validator := new(PGPValidator).WithArmoredKeyRing(keyRing)
   301  		err = validator.Validate("foo.zip", data, signatureData)
   302  		assert.NoError(t, err)
   303  	})
   304  
   305  	t.Run("Fail", func(t *testing.T) {
   306  		validator := new(PGPValidator).WithArmoredKeyRing(keyRing)
   307  		err = validator.Validate("foo.tar.xz", otherData, signatureData)
   308  		assert.EqualError(t, err, "openpgp: invalid signature: hash tag doesn't match")
   309  	})
   310  }
   311  
   312  func TestPGPValidatorWithArmoredKeyRing(t *testing.T) {
   313  	keyRing, entity := getTestPGPKeyRing(t)
   314  	validator := new(PGPValidator).WithArmoredKeyRing(keyRing)
   315  	assert.Equal(t, entity.PrimaryKey.KeyIdString(), validator.KeyRing[0].PrimaryKey.KeyIdString())
   316  
   317  	assert.PanicsWithError(t, "failed setting armored public key ring: openpgp: invalid argument: no armored data found", func() {
   318  		new(PGPValidator).WithArmoredKeyRing([]byte{})
   319  	})
   320  }
   321  
   322  func getTestPGPKeyRing(t *testing.T) (PGPKeyRing []byte, entity *openpgp.Entity) {
   323  	var err error
   324  	entity, err = openpgp.NewEntity("go-selfupdate", "", "info@go-selfupdate.local", nil)
   325  
   326  	buffer := &bytes.Buffer{}
   327  	if armoredWriter, err := armor.Encode(buffer, openpgp.PublicKeyType, nil); err == nil {
   328  		if err = entity.Serialize(armoredWriter); err == nil {
   329  			err = armoredWriter.Close()
   330  		}
   331  	}
   332  	require.NoError(t, err)
   333  	PGPKeyRing = buffer.Bytes()
   334  	return
   335  }
   336  
   337  // ======= ChecksumValidator ====================================================
   338  
   339  func TestChecksumValidatorEmptyFile(t *testing.T) {
   340  	data, err := os.ReadFile("testdata/foo.zip")
   341  	require.NoError(t, err)
   342  
   343  	validator := &ChecksumValidator{}
   344  	err = validator.Validate("foo.zip", data, nil)
   345  	assert.EqualError(t, err, ErrHashNotFound.Error())
   346  }
   347  
   348  func TestChecksumValidatorInvalidChecksumFile(t *testing.T) {
   349  	data, err := os.ReadFile("testdata/foo.zip")
   350  	require.NoError(t, err)
   351  
   352  	validator := &ChecksumValidator{}
   353  	err = validator.Validate("foo.zip", data, []byte("blahblahblah"))
   354  	assert.EqualError(t, err, ErrIncorrectChecksumFile.Error())
   355  }
   356  
   357  func TestChecksumValidatorWithUniqueLine(t *testing.T) {
   358  	data, err := os.ReadFile("testdata/foo.zip")
   359  	require.NoError(t, err)
   360  
   361  	hashData, err := os.ReadFile("testdata/foo.zip.sha256")
   362  	require.NoError(t, err)
   363  
   364  	validator := &ChecksumValidator{}
   365  	err = validator.Validate("foo.zip", data, hashData)
   366  	require.NoError(t, err)
   367  }
   368  
   369  func TestChecksumValidatorWillFailWithWrongHash(t *testing.T) {
   370  	data, err := os.ReadFile("testdata/foo.tar.xz")
   371  	require.NoError(t, err)
   372  
   373  	hashData, err := os.ReadFile("testdata/foo.zip.sha256")
   374  	require.NoError(t, err)
   375  
   376  	validator := &ChecksumValidator{}
   377  	err = validator.Validate("foo.zip", data, hashData)
   378  	assert.ErrorIs(t, err, ErrChecksumValidationFailed)
   379  }
   380  
   381  func TestChecksumNotFound(t *testing.T) {
   382  	data, err := os.ReadFile("testdata/bar-not-found.zip")
   383  	require.NoError(t, err)
   384  
   385  	hashData, err := os.ReadFile("testdata/SHA256SUM")
   386  	require.NoError(t, err)
   387  
   388  	validator := &ChecksumValidator{}
   389  	err = validator.Validate("bar-not-found.zip", data, hashData)
   390  	assert.EqualError(t, err, ErrHashNotFound.Error())
   391  }
   392  
   393  func TestChecksumValidatorSuccess(t *testing.T) {
   394  	data, err := os.ReadFile("testdata/foo.tar.xz")
   395  	require.NoError(t, err)
   396  
   397  	hashData, err := os.ReadFile("testdata/SHA256SUM")
   398  	require.NoError(t, err)
   399  
   400  	validator := &ChecksumValidator{"SHA256SUM"}
   401  	err = validator.Validate("foo.tar.xz", data, hashData)
   402  	assert.NoError(t, err)
   403  }
   404  
   405  // ======= Utilities =========================================================
   406  
   407  func TestHexStringEquals(t *testing.T) {
   408  	tests := []struct {
   409  		equal bool
   410  		size  int
   411  		a, b  string
   412  		err   error
   413  	}{
   414  		{true, 0, "", "", nil},
   415  		{true, 1, "b1", "b1", nil},
   416  		{true, 1, "b1", "B1", nil},
   417  		{true, 2, "b1AA", "B1aa", nil},
   418  		{false, 1, "", "", nil},
   419  		{false, 0, "b1", "b1", nil},
   420  		{false, 0, "b", "b", nil},
   421  		{false, 1, "b", "b", nil},
   422  		{false, 1, "b2", "b1", nil},
   423  		{false, 2, "b1", "b1", nil},
   424  		{false, 2, "b1AA", "aab1", nil},
   425  		{false, 3, "b1", "b1", nil},
   426  		{false, 2, "aaXX", "aaXX", hex.InvalidByteError('X')},
   427  	}
   428  	for i, test := range tests {
   429  		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   430  			equal, err := hexStringEquals(test.size, test.a, test.b)
   431  			assert.Equal(t, test.equal, equal)
   432  			assert.ErrorIs(t, err, test.err)
   433  		})
   434  	}
   435  }
   436  
   437  func TestNewChecksumWithECDSAValidator(t *testing.T) {
   438  	pemData, err := os.ReadFile("testdata/Test.crt")
   439  	require.NoError(t, err)
   440  
   441  	validator := NewChecksumWithECDSAValidator("checksums", pemData)
   442  	assert.Implements(t, (*RecursiveValidator)(nil), validator)
   443  	assert.Equal(t, "checksums", validator.GetValidationAssetName("anything"))
   444  	assert.Equal(t, "checksums.sig", validator.GetValidationAssetName("checksums"))
   445  }
   446  
   447  func TestNewChecksumWithPGPValidator(t *testing.T) {
   448  	keyRing, _ := getTestPGPKeyRing(t)
   449  
   450  	validator := NewChecksumWithPGPValidator("checksums", keyRing)
   451  	assert.Implements(t, (*RecursiveValidator)(nil), validator)
   452  	assert.Equal(t, "checksums", validator.GetValidationAssetName("anything"))
   453  	assert.Equal(t, "checksums.asc", validator.GetValidationAssetName("checksums"))
   454  }