github.com/google/trillian-examples@v0.0.0-20240520080811-0d40d35cef0e/binary_transparency/firmware/internal/client/client_test.go (about)

     1  // Copyright 2020 Google LLC. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package client_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"io"
    21  	"mime"
    22  	"mime/multipart"
    23  	"net/http"
    24  	"net/http/httptest"
    25  	"net/url"
    26  	"strings"
    27  	"testing"
    28  
    29  	"github.com/google/go-cmp/cmp"
    30  	"github.com/google/trillian-examples/binary_transparency/firmware/api"
    31  	"github.com/google/trillian-examples/binary_transparency/firmware/internal/client"
    32  	"github.com/google/trillian-examples/binary_transparency/firmware/internal/crypto"
    33  	"github.com/transparency-dev/formats/log"
    34  	"golang.org/x/mod/sumdb/note"
    35  )
    36  
    37  func mustSignCPNote(t *testing.T, b string) []byte {
    38  	t.Helper()
    39  	s, err := note.NewSigner(crypto.TestFTPersonalityPriv)
    40  	if err != nil {
    41  		t.Fatalf("failed to create signer: %q", err)
    42  	}
    43  	n, err := note.Sign(&note.Note{Text: b}, s)
    44  	if err != nil {
    45  		t.Fatalf("failed to sign note: %q", err)
    46  	}
    47  	return n
    48  }
    49  
    50  func mustGetLogSigVerifier(t *testing.T) note.Verifier {
    51  	t.Helper()
    52  	v, err := note.NewVerifier(crypto.TestFTPersonalityPub)
    53  	if err != nil {
    54  		t.Fatalf("failed to create verifier: %q", err)
    55  	}
    56  	return v
    57  }
    58  
    59  func TestPublish(t *testing.T) {
    60  	for _, test := range []struct {
    61  		desc     string
    62  		manifest []byte
    63  		image    []byte
    64  		wantErr  bool
    65  	}{
    66  		{
    67  			desc:     "valid",
    68  			manifest: []byte("Boo!"),
    69  		}, {
    70  			desc:     "log server fails",
    71  			manifest: []byte("Boo!"),
    72  			wantErr:  true,
    73  		},
    74  	} {
    75  		t.Run(test.desc, func(t *testing.T) {
    76  			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    77  				// Check for path prefix, trimming off leading / since it's not present in the
    78  				// const.
    79  				// TODO Add an index to the test case to improve coverage
    80  				if !strings.HasPrefix(r.URL.Path[1:], api.HTTPAddFirmware) {
    81  					t.Fatalf("Got unexpected HTTP request on %q", r.URL.Path)
    82  				}
    83  
    84  				if test.wantErr {
    85  					http.Error(w, "BOOM", http.StatusInternalServerError)
    86  					return
    87  				}
    88  
    89  				meta, _, err := parseAddFirmwareRequest(r)
    90  				if err != nil {
    91  					t.Fatalf("Failed to read multipart body: %v", err)
    92  				}
    93  				if diff := cmp.Diff(meta, test.manifest); len(diff) != 0 {
    94  					t.Errorf("POSTed body with unexpected diff: %v", diff)
    95  				}
    96  			}))
    97  			defer ts.Close()
    98  
    99  			tsURL, err := url.Parse((ts.URL))
   100  			if err != nil {
   101  				t.Fatalf("Failed to parse test server URL: %v", err)
   102  			}
   103  			c := client.SubmitClient{ReadonlyClient: &client.ReadonlyClient{LogURL: tsURL}}
   104  			err = c.PublishFirmware(test.manifest, test.image)
   105  			switch {
   106  			case err != nil && !test.wantErr:
   107  				t.Fatalf("Got unexpected error %q", err)
   108  			case err == nil && test.wantErr:
   109  				t.Fatal("Got no error, but wanted error")
   110  			case err != nil && test.wantErr:
   111  				// expected error
   112  			default:
   113  			}
   114  		})
   115  	}
   116  }
   117  
   118  // parseAddFirmwareRequest returns the bytes for the SignedStatement, and the firmware image respectively.
   119  // TODO(mhutchinson): For now this is a copy of the server code. de-dupe this.
   120  func parseAddFirmwareRequest(r *http.Request) ([]byte, []byte, error) {
   121  	h := r.Header["Content-Type"]
   122  	if len(h) == 0 {
   123  		return nil, nil, fmt.Errorf("no content-type header")
   124  	}
   125  
   126  	mediaType, mediaParams, err := mime.ParseMediaType(h[0])
   127  	if err != nil {
   128  		return nil, nil, err
   129  	}
   130  	if !strings.HasPrefix(mediaType, "multipart/") {
   131  		return nil, nil, fmt.Errorf("expecting mime multipart body")
   132  	}
   133  	boundary := mediaParams["boundary"]
   134  	if len(boundary) == 0 {
   135  		return nil, nil, fmt.Errorf("invalid mime multipart header - no boundary specified")
   136  	}
   137  	mr := multipart.NewReader(r.Body, boundary)
   138  
   139  	// Get firmware statement (JSON)
   140  	p, err := mr.NextPart()
   141  	if err != nil {
   142  		return nil, nil, fmt.Errorf("failed to find firmware statement in request body: %v", err)
   143  	}
   144  	rawJSON, err := io.ReadAll(p)
   145  	if err != nil {
   146  		return nil, nil, fmt.Errorf("failed to read body of firmware statement: %v", err)
   147  	}
   148  
   149  	// Get firmware binary image
   150  	p, err = mr.NextPart()
   151  	if err != nil {
   152  		return nil, nil, fmt.Errorf("failed to find firmware image in request body: %v", err)
   153  	}
   154  	image, err := io.ReadAll(p)
   155  	if err != nil {
   156  		return nil, nil, fmt.Errorf("failed to read body of firmware image: %v", err)
   157  	}
   158  	return rawJSON, image, nil
   159  }
   160  
   161  func TestGetCheckpoint(t *testing.T) {
   162  	for _, test := range []struct {
   163  		desc    string
   164  		body    []byte
   165  		want    api.LogCheckpoint
   166  		wantErr bool
   167  	}{
   168  		{
   169  			desc: "valid 1",
   170  			body: mustSignCPNote(t, "Firmware Transparency Log\n1\nEjQ=\n123\n"),
   171  			want: api.LogCheckpoint{
   172  				Checkpoint: log.Checkpoint{
   173  					Origin: "Firmware Transparency Log",
   174  					Size:   1,
   175  					Hash:   []byte{0x12, 0x34},
   176  				},
   177  				TimestampNanos: 123,
   178  			},
   179  		}, {
   180  			desc: "valid 2",
   181  			body: mustSignCPNote(t, "Firmware Transparency Log\n10\nNBI=\n1230\n"),
   182  			want: api.LogCheckpoint{
   183  				Checkpoint: log.Checkpoint{
   184  					Origin: "Firmware Transparency Log",
   185  					Size:   10,
   186  					Hash:   []byte{0x34, 0x12},
   187  				},
   188  				TimestampNanos: 1230,
   189  			},
   190  		}, {
   191  			desc:    "garbage",
   192  			body:    []byte(`garbage`),
   193  			wantErr: true,
   194  		},
   195  	} {
   196  		t.Run(test.desc, func(t *testing.T) {
   197  			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   198  				if !strings.HasSuffix(r.URL.Path, api.HTTPGetRoot) {
   199  					t.Fatalf("Got unexpected HTTP request on %q", r.URL.Path)
   200  				}
   201  				fmt.Fprint(w, string(test.body))
   202  			}))
   203  			defer ts.Close()
   204  
   205  			tsURL, err := url.Parse((ts.URL))
   206  			if err != nil {
   207  				t.Fatalf("Failed to parse test server URL: %v", err)
   208  			}
   209  			c := client.ReadonlyClient{
   210  				LogURL:         tsURL,
   211  				LogSigVerifier: mustGetLogSigVerifier(t),
   212  			}
   213  			cp, err := c.GetCheckpoint()
   214  			switch {
   215  			case err != nil && !test.wantErr:
   216  				t.Fatalf("Got unexpected error %q", err)
   217  			case err == nil && test.wantErr:
   218  				t.Fatal("Got no error, but wanted error")
   219  			case err != nil && test.wantErr:
   220  				// expected error
   221  			default:
   222  				// Ignore the envelope data:
   223  				cp.Envelope = nil
   224  				if d := cmp.Diff(*cp, test.want); len(d) != 0 {
   225  					t.Fatalf("Got checkpoint with diff: %s", d)
   226  				}
   227  			}
   228  		})
   229  	}
   230  }
   231  
   232  func TestGetInclusion(t *testing.T) {
   233  	cp := api.LogCheckpoint{
   234  		Checkpoint: log.Checkpoint{
   235  			Size: 30,
   236  		},
   237  	}
   238  	for _, test := range []struct {
   239  		desc    string
   240  		body    string
   241  		want    api.InclusionProof
   242  		wantErr bool
   243  	}{
   244  		{
   245  			desc: "valid 1",
   246  			body: `{ "LeafIndex": 2, "Proof": ["qg==", "uw==", "zA=="]}`,
   247  			want: api.InclusionProof{LeafIndex: 2, Proof: [][]byte{{0xAA}, {0xBB}, {0xCC}}},
   248  		}, {
   249  			desc: "valid 2",
   250  			body: `{ "LeafIndex": 20, "Proof": ["3Q==", "7g=="]}`,
   251  			want: api.InclusionProof{LeafIndex: 20, Proof: [][]byte{{0xDD}, {0xEE}}},
   252  		}, {
   253  			desc:    "garbage",
   254  			body:    `garbage`,
   255  			wantErr: true,
   256  		},
   257  	} {
   258  		t.Run(test.desc, func(t *testing.T) {
   259  			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   260  				// Check for path prefix, trimming off leading / since it's not present in the
   261  				// const.
   262  				// TODO Add an index to the test case to improve coverage
   263  				if !strings.HasPrefix(r.URL.Path[1:], api.HTTPGetInclusion) {
   264  					t.Fatalf("Got unexpected HTTP request on %q", r.URL.Path)
   265  				}
   266  				fmt.Fprintln(w, test.body)
   267  			}))
   268  			defer ts.Close()
   269  
   270  			tsURL, err := url.Parse((ts.URL))
   271  			if err != nil {
   272  				t.Fatalf("Failed to parse test server URL: %v", err)
   273  			}
   274  			c := client.ReadonlyClient{LogURL: tsURL}
   275  			ip, err := c.GetInclusion([]byte{}, cp)
   276  			switch {
   277  			case err != nil && !test.wantErr:
   278  				t.Fatalf("Got unexpected error %q", err)
   279  			case err == nil && test.wantErr:
   280  				t.Fatal("Got no error, but wanted error")
   281  			case err != nil && test.wantErr:
   282  				// expected error
   283  			default:
   284  				if d := cmp.Diff(ip, test.want); len(d) != 0 {
   285  					t.Fatalf("Got checkpoint with diff: %s", d)
   286  				}
   287  			}
   288  		})
   289  	}
   290  }
   291  
   292  func TestGetManifestAndProof(t *testing.T) {
   293  	for _, test := range []struct {
   294  		desc    string
   295  		body    string
   296  		want    api.InclusionProof
   297  		wantErr bool
   298  	}{
   299  		{
   300  			desc: "valid 1",
   301  			body: `{ "Value":"EjQ=", "Proof": ["qg==", "uw==", "zA=="]}`,
   302  			want: api.InclusionProof{Value: []byte{0x12, 0x34}, Proof: [][]byte{{0xAA}, {0xBB}, {0xCC}}},
   303  		}, {
   304  			desc: "valid 2",
   305  			body: `{ "Value":"NBI=","Proof": ["3Q==", "7g=="]}`,
   306  			want: api.InclusionProof{Value: []byte{0x34, 0x12}, Proof: [][]byte{{0xDD}, {0xEE}}},
   307  		}, {
   308  			desc:    "garbage",
   309  			body:    `garbage`,
   310  			wantErr: true,
   311  		},
   312  	} {
   313  		t.Run(test.desc, func(t *testing.T) {
   314  			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   315  				// Check for path prefix, trimming off leading / since it's not present in the
   316  				// const.
   317  				// TODO Add an index to the test case to improve coverage
   318  				if !strings.HasPrefix(r.URL.Path[1:], api.HTTPGetManifestEntryAndProof) {
   319  					t.Fatalf("Got unexpected HTTP request on %q", r.URL.Path)
   320  				}
   321  				fmt.Fprintln(w, test.body)
   322  			}))
   323  			defer ts.Close()
   324  
   325  			tsURL, err := url.Parse((ts.URL))
   326  			if err != nil {
   327  				t.Fatalf("Failed to parse test server URL: %v", err)
   328  			}
   329  			c := client.ReadonlyClient{LogURL: tsURL}
   330  			ip, err := c.GetManifestEntryAndProof(api.GetFirmwareManifestRequest{Index: 0, TreeSize: 0})
   331  			switch {
   332  			case err != nil && !test.wantErr:
   333  				t.Fatalf("Got unexpected error %q", err)
   334  			case err == nil && test.wantErr:
   335  				t.Fatal("Got no error, but wanted error")
   336  			case err != nil && test.wantErr:
   337  				// expected error
   338  			default:
   339  				if d := cmp.Diff(*ip, test.want); len(d) != 0 {
   340  					t.Fatalf("Got response with diff: %s", d)
   341  				}
   342  			}
   343  		})
   344  	}
   345  }
   346  
   347  func TestGetConsistency(t *testing.T) {
   348  	for _, test := range []struct {
   349  		desc    string
   350  		body    string
   351  		From    uint64
   352  		To      uint64
   353  		want    api.ConsistencyProof
   354  		wantErr bool
   355  	}{
   356  		{
   357  			desc: "valid 1",
   358  			body: `{"Proof": ["qg==", "uw==", "zA=="]}`,
   359  			From: 0,
   360  			To:   1,
   361  			want: api.ConsistencyProof{Proof: [][]byte{{0xAA}, {0xBB}, {0xCC}}},
   362  		}, {
   363  			desc: "valid 2",
   364  			body: `{"Proof": ["3Q==", "7g=="]}`,
   365  			From: 1,
   366  			To:   2,
   367  			want: api.ConsistencyProof{Proof: [][]byte{{0xDD}, {0xEE}}},
   368  		}, {
   369  			desc:    "garbage",
   370  			body:    `garbage`,
   371  			wantErr: true,
   372  		},
   373  	} {
   374  		t.Run(test.desc, func(t *testing.T) {
   375  			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   376  				// Check for path prefix, trimming off leading / since it's not present in the
   377  				// const.
   378  				// TODO Add an index to the test case to improve coverage
   379  				if !strings.HasPrefix(r.URL.Path[1:], api.HTTPGetConsistency) {
   380  					t.Fatalf("Got unexpected HTTP request on %q", r.URL.Path)
   381  				}
   382  				fmt.Fprintln(w, test.body)
   383  			}))
   384  			defer ts.Close()
   385  
   386  			tsURL, err := url.Parse((ts.URL))
   387  			if err != nil {
   388  				t.Fatalf("Failed to parse test server URL: %v", err)
   389  			}
   390  			c := client.ReadonlyClient{LogURL: tsURL}
   391  			cp, err := c.GetConsistencyProof(api.GetConsistencyRequest{From: test.From, To: test.To})
   392  			switch {
   393  			case err != nil && !test.wantErr:
   394  				t.Fatalf("Got unexpected error %q", err)
   395  			case err == nil && test.wantErr:
   396  				t.Fatal("Got no error, but wanted error")
   397  			case err != nil && test.wantErr:
   398  				// expected error
   399  			default:
   400  				if d := cmp.Diff(*cp, test.want); len(d) != 0 {
   401  					t.Fatalf("Got response with diff: %s", d)
   402  				}
   403  			}
   404  		})
   405  	}
   406  }
   407  
   408  func TestGetFirmwareImage(t *testing.T) {
   409  	knownHash := []byte("knownhash")
   410  	for _, test := range []struct {
   411  		desc      string
   412  		hash      []byte
   413  		body      []byte
   414  		isUnknown bool
   415  		want      []byte
   416  		wantErr   bool
   417  	}{
   418  		{
   419  			desc: "valid",
   420  			hash: knownHash,
   421  			body: []byte("body"),
   422  			want: []byte("body"),
   423  		}, {
   424  			desc:      "not found",
   425  			hash:      []byte("never heard of it"),
   426  			isUnknown: true,
   427  			wantErr:   true,
   428  		},
   429  	} {
   430  		t.Run(test.desc, func(t *testing.T) {
   431  			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   432  				// Check for path prefix, trimming off leading / since it's not present in the
   433  				// const.
   434  				// TODO Add an index to the test case to improve coverage
   435  				if !strings.HasPrefix(r.URL.Path[1:], api.HTTPGetFirmwareImage) {
   436  					t.Fatalf("Got unexpected HTTP request on %q", r.URL.Path)
   437  				}
   438  				if test.isUnknown {
   439  					http.Error(w, "unknown", http.StatusNotFound)
   440  					return
   441  				}
   442  				if _, err := w.Write(test.body); err != nil {
   443  					t.Errorf("w.Write(): %v", err)
   444  				}
   445  			}))
   446  			defer ts.Close()
   447  
   448  			tsURL, err := url.Parse((ts.URL))
   449  			if err != nil {
   450  				t.Fatalf("Failed to parse test server URL: %v", err)
   451  			}
   452  			c := client.ReadonlyClient{LogURL: tsURL}
   453  			img, err := c.GetFirmwareImage(test.hash)
   454  			switch {
   455  			case err != nil && !test.wantErr:
   456  				t.Fatalf("Got unexpected error %q", err)
   457  			case err == nil && test.wantErr:
   458  				t.Fatal("Got no error, but wanted error")
   459  			case err != nil && test.wantErr:
   460  				// expected error
   461  			default:
   462  				if got, want := img, test.want; !bytes.Equal(got, want) {
   463  					t.Fatalf("Got body %q, want %q", got, want)
   464  				}
   465  			}
   466  		})
   467  	}
   468  }