github.com/cloudflare/circl@v1.5.0/abe/cpabe/tkn20/tkn20_test.go (about)

     1  package tkn20
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"encoding/json"
     7  	"fmt"
     8  	"os"
     9  	"testing"
    10  )
    11  
    12  type TestCase struct {
    13  	Policy  string
    14  	Success bool
    15  	Attrs   map[string]string `json:"attributes"`
    16  }
    17  
    18  func TestConcurrentDecryption(t *testing.T) {
    19  	var tests []TestCase
    20  	buf, _ := os.ReadFile("testdata/policies.json")
    21  	err := json.Unmarshal(buf, &tests)
    22  	if err != nil {
    23  		t.Fatal(err)
    24  	}
    25  	msg := []byte("must have the precious")
    26  	for i, test := range tests {
    27  		t.Run(fmt.Sprintf("TestConcurrentDecryption:#%d", i), func(t *testing.T) {
    28  			pk, msk, err := Setup(rand.Reader)
    29  			if err != nil {
    30  				t.Fatal(err)
    31  			}
    32  			policy := Policy{}
    33  			err = policy.FromString(test.Policy)
    34  			if err != nil {
    35  				t.Fatal(err)
    36  			}
    37  			ct, err := pk.Encrypt(rand.Reader, policy, msg)
    38  			if err != nil {
    39  				t.Fatalf("encryption failed: %s", err)
    40  			}
    41  			attrs := Attributes{}
    42  			attrs.FromMap(test.Attrs)
    43  			sk, err := msk.KeyGen(rand.Reader, attrs)
    44  			if err != nil {
    45  				t.Fatalf("key generation failed: %s", err)
    46  			}
    47  			checkResults := func(ct []byte, sk AttributeKey, i int) {
    48  				pt, err := sk.Decrypt(ct)
    49  				if tests[i].Success {
    50  					if err != nil {
    51  						t.Errorf("decryption failed: %s", err)
    52  					}
    53  					if !bytes.Equal(pt, msg) {
    54  						t.Errorf("expected %v, received %v", pt, msg)
    55  					}
    56  				} else {
    57  					if err == nil {
    58  						t.Errorf("decryption should have failed")
    59  					}
    60  				}
    61  			}
    62  			go checkResults(ct, sk, i)
    63  			go checkResults(ct, sk, i)
    64  		})
    65  	}
    66  }
    67  
    68  func TestEndToEndEncryption(t *testing.T) {
    69  	var tests []TestCase
    70  	buf, _ := os.ReadFile("testdata/policies.json")
    71  	err := json.Unmarshal(buf, &tests)
    72  	if err != nil {
    73  		t.Fatal(err)
    74  	}
    75  	msg := []byte("must have the precious")
    76  	for i, test := range tests {
    77  		t.Run(fmt.Sprintf("TestEndToEndEncryption:#%d", i), func(t *testing.T) {
    78  			pk, msk, err := Setup(rand.Reader)
    79  			if err != nil {
    80  				t.Fatal(err)
    81  			}
    82  			policy := Policy{}
    83  			err = policy.FromString(test.Policy)
    84  			if err != nil {
    85  				t.Fatal(err)
    86  			}
    87  			ct, err := pk.Encrypt(rand.Reader, policy, msg)
    88  			if err != nil {
    89  				t.Fatalf("encryption failed: %s", err)
    90  			}
    91  			attrs := Attributes{}
    92  			attrs.FromMap(test.Attrs)
    93  			sk, err := msk.KeyGen(rand.Reader, attrs)
    94  			if err != nil {
    95  				t.Fatalf("key generation failed: %s", err)
    96  			}
    97  			npol := &Policy{}
    98  			if err = npol.ExtractFromCiphertext(ct); err != nil {
    99  				t.Fatalf("extraction failed: %s", err)
   100  			}
   101  			strpol := npol.String()
   102  			npol2 := &Policy{}
   103  			if err = npol2.FromString(strpol); err != nil {
   104  				t.Fatalf("string %s didn't parse: %s", strpol, err)
   105  			}
   106  			sat := policy.Satisfaction(attrs)
   107  			if sat != npol.Satisfaction(attrs) {
   108  				t.Fatalf("extracted policy doesn't match original")
   109  			}
   110  			if sat != npol2.Satisfaction(attrs) {
   111  				t.Fatalf("round tripped policy doesn't match original")
   112  			}
   113  			ctSat := attrs.CouldDecrypt(ct)
   114  			pt, err := sk.Decrypt(ct)
   115  			if test.Success {
   116  				// test case should succeed
   117  				if !sat {
   118  					t.Fatalf("satisfaction failed")
   119  				}
   120  				if !ctSat {
   121  					t.Fatalf("ciphertext satisfaction failed")
   122  				}
   123  				if err != nil {
   124  					t.Fatalf("decryption failed: %s", err)
   125  				}
   126  				if !bytes.Equal(pt, msg) {
   127  					t.Fatalf("expected %v, received %v", pt, msg)
   128  				}
   129  			} else {
   130  				// test case should fail
   131  				if sat {
   132  					t.Fatal("satisfaction should have failed")
   133  				}
   134  				if ctSat {
   135  					t.Fatal("ciphertext satisfaction should have failed")
   136  				}
   137  				if err == nil {
   138  					t.Fatal("decryption should have failed")
   139  				}
   140  			}
   141  		})
   142  	}
   143  }
   144  
   145  func TestMarshal(t *testing.T) {
   146  	pk, msk, err := Setup(rand.Reader)
   147  	if err != nil {
   148  		t.Fatal(err)
   149  	}
   150  
   151  	data, err := pk.MarshalBinary()
   152  	if err != nil {
   153  		t.Fatal(err)
   154  	}
   155  	b := &PublicKey{}
   156  	err = b.UnmarshalBinary(data)
   157  	if err != nil {
   158  		t.Fatal(err)
   159  	}
   160  	if !pk.Equal(b) {
   161  		t.Fatal("PublicKey: failure to roundtrip")
   162  	}
   163  
   164  	data, err = msk.MarshalBinary()
   165  	if err != nil {
   166  		t.Fatal(err)
   167  	}
   168  	c := &SystemSecretKey{}
   169  	err = c.UnmarshalBinary(data)
   170  	if err != nil {
   171  		t.Fatal(err)
   172  	}
   173  	if !msk.Equal(c) {
   174  		t.Fatal("MasterSecretKey: failure to roundtrip")
   175  	}
   176  
   177  	attrs := Attributes{}
   178  	attrs.FromMap(map[string]string{"occupation": "doctor", "country": "US", "age": "16"})
   179  	sk, err := msk.KeyGen(rand.Reader, attrs)
   180  	if err != nil {
   181  		t.Fatal(err)
   182  	}
   183  
   184  	data, err = sk.MarshalBinary()
   185  	if err != nil {
   186  		t.Fatal(err)
   187  	}
   188  	d := AttributeKey{} // don't use pointer to verify unmarshal works with both pointer and not
   189  	err = d.UnmarshalBinary(data)
   190  	if err != nil {
   191  		t.Fatal(err)
   192  	}
   193  	if !sk.Equal(&d) {
   194  		t.Fatal("SecretKey: failure to roundtrip")
   195  	}
   196  }
   197  
   198  func TestPolicyMethods(t *testing.T) {
   199  	policyStr := "(season: fall or season: winter) or (region: alaska and season: summer)"
   200  	policy := Policy{}
   201  	err := policy.FromString(policyStr)
   202  	if err != nil {
   203  		t.Fatal(err)
   204  	}
   205  	expected := map[string][]string{
   206  		"season": {"fall", "winter", "summer"},
   207  		"region": {"alaska"},
   208  	}
   209  	received := policy.ExtractAttributeValuePairs()
   210  	if len(expected) != len(received) {
   211  		t.Fatal("diff lengths")
   212  	}
   213  	for k, vs := range expected {
   214  		vs2, ok := received[k]
   215  		if !ok {
   216  			t.Fatalf("key %s not found in received map", k)
   217  		}
   218  		if len(vs) != len(vs2) {
   219  			t.Fatalf("expected len: %d, received len: %d, for key %s", len(vs), len(vs2), k)
   220  		}
   221  		// compare each value for given key, order doesn't matter
   222  		for _, v := range vs {
   223  			flag := false
   224  			for _, v2 := range vs2 {
   225  				if v == v2 {
   226  					flag = true
   227  					break
   228  				}
   229  			}
   230  			if !flag {
   231  				t.Fatalf("expected and received values differ")
   232  			}
   233  		}
   234  	}
   235  }