github.com/google/trillian-examples@v0.0.0-20240520080811-0d40d35cef0e/binary_transparency/firmware/cmd/ft_personality/internal/http/server_test.go (about)

     1  // Copyright 2020 Google LLC. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package http
    16  
    17  import (
    18  	"encoding/base64"
    19  	"encoding/json"
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"net/http"
    24  	"net/http/httptest"
    25  	"strings"
    26  	"testing"
    27  
    28  	gomock "github.com/golang/mock/gomock"
    29  	"github.com/google/go-cmp/cmp"
    30  	"github.com/google/trillian-examples/binary_transparency/firmware/api"
    31  	"github.com/google/trillian-examples/binary_transparency/firmware/internal/crypto"
    32  	"github.com/google/trillian/types"
    33  	"github.com/gorilla/mux"
    34  	"golang.org/x/mod/sumdb/note"
    35  	"google.golang.org/grpc/codes"
    36  	"google.golang.org/grpc/status"
    37  )
    38  
    39  func TestRoot(t *testing.T) {
    40  	testSigner, _ := note.NewSigner(crypto.TestFTPersonalityPriv)
    41  	testVerifier, _ := note.NewVerifier(crypto.TestFTPersonalityPub)
    42  
    43  	for _, test := range []struct {
    44  		desc     string
    45  		root     types.LogRootV1
    46  		wantBody string
    47  	}{
    48  		{
    49  			desc:     "valid 1",
    50  			root:     types.LogRootV1{TreeSize: 1, TimestampNanos: 123, RootHash: []byte{0x12, 0x34}},
    51  			wantBody: "Firmware Transparency Log\n1\nEjQ=\n123\n",
    52  		}, {
    53  			desc:     "valid 2",
    54  			root:     types.LogRootV1{TreeSize: 10, TimestampNanos: 1230, RootHash: []byte{0x34, 0x12}},
    55  			wantBody: "Firmware Transparency Log\n10\nNBI=\n1230\n",
    56  		},
    57  	} {
    58  		t.Run(test.desc, func(t *testing.T) {
    59  			ctrl := gomock.NewController(t)
    60  			mt := NewMockTrillian(ctrl)
    61  			server := NewServer(mt, FakeCAS{}, testSigner)
    62  
    63  			mt.EXPECT().Root().Return(&test.root)
    64  
    65  			ts := httptest.NewServer(http.HandlerFunc(server.getRoot))
    66  			defer ts.Close()
    67  
    68  			client := ts.Client()
    69  			resp, err := client.Get(ts.URL)
    70  			if err != nil {
    71  				t.Errorf("error response: %v", err)
    72  			}
    73  			if resp.StatusCode != http.StatusOK {
    74  				t.Errorf("status code not OK: %v", resp.StatusCode)
    75  			}
    76  			body, err := io.ReadAll(resp.Body)
    77  			if err != nil {
    78  				t.Errorf("failed to read body: %v", err)
    79  			}
    80  			got, err := note.Open(body, note.VerifierList(testVerifier))
    81  			if err != nil {
    82  				t.Fatalf("Failed to open returned checkpoint: %q", err)
    83  			}
    84  			if got := got.Text; got != test.wantBody {
    85  				t.Errorf("got '%s' want '%s'", got, test.wantBody)
    86  			}
    87  		})
    88  	}
    89  }
    90  
    91  func b64Decode(t *testing.T, b64 string) []byte {
    92  	t.Helper()
    93  	st, err := base64.StdEncoding.DecodeString(b64)
    94  	if err != nil {
    95  		t.Fatalf("b64 decoding failed: %v", err)
    96  	}
    97  	return st
    98  }
    99  func TestAddFirmware(t *testing.T) {
   100  	testSigner, _ := note.NewSigner(crypto.TestFTPersonalityPriv)
   101  	st := b64Decode(t, "eyJEZXZpY2VJRCI6IlRhbGtpZVRvYXN0ZXIiLCJGaXJtd2FyZVJldmlzaW9uIjoxLCJGaXJtd2FyZUltYWdlU0hBNTEyIjoiMTRxN0JVSnphR1g1UndSU0ZnbkNNTnJBT2k4Mm5RUTZ3aExXa3p1UlFRNEdPWjQzK2NYTWlFTnFNWE56TU1ISTdNc3NMNTgzVFdMM0ZrTXFNdFVQckE9PSIsIkV4cGVjdGVkRmlybXdhcmVNZWFzdXJlbWVudCI6IiIsIkJ1aWxkVGltZXN0YW1wIjoiMjAyMC0xMS0xN1QxMzozMDoxNFoifQ==")
   102  	sign, err := crypto.Publisher.SignMessage(api.FirmwareMetadataType, st)
   103  	if err != nil {
   104  		t.Fatalf("signing failed, bailing out!: %v", err)
   105  	}
   106  	statement := api.SignedStatement{Type: api.FirmwareMetadataType, Statement: st, Signature: sign}
   107  	js, err := json.Marshal(statement)
   108  	if err != nil {
   109  		t.Fatalf("marshaling failed, bailing out!: %v", err)
   110  	}
   111  
   112  	s := string(js)
   113  	for _, test := range []struct {
   114  		desc             string
   115  		body             string
   116  		trillianErr      error
   117  		wantTrillianCall bool
   118  		wantManifest     string
   119  		wantStatus       int
   120  	}{
   121  		{
   122  			desc:       "malformed request",
   123  			body:       "garbage",
   124  			wantStatus: http.StatusBadRequest,
   125  		}, {
   126  			desc: "valid request",
   127  			body: strings.Join([]string{"--mimeisfunlolol",
   128  				"Content-Type: application/json",
   129  				"",
   130  				s,
   131  				"--mimeisfunlolol",
   132  				"Content-Type: application/octet-stream",
   133  				"",
   134  				"hi",
   135  				"",
   136  				"--mimeisfunlolol--",
   137  				"",
   138  			}, "\n"),
   139  			wantTrillianCall: true,
   140  			wantManifest:     s,
   141  			wantStatus:       http.StatusOK,
   142  		}, {
   143  			desc: "firmware image does not match manifest",
   144  			body: strings.Join([]string{"--mimeisfunlolol",
   145  				"Content-Type: application/json",
   146  				"",
   147  				s,
   148  				"--mimeisfunlolol",
   149  				"Content-Type: application/octet-stream",
   150  				"",
   151  				"THIS HAS A DIFFERENT HASH THAN EXPECTED",
   152  				"",
   153  				"--mimeisfunlolol--",
   154  				"",
   155  			}, "\n"),
   156  			wantManifest: s,
   157  			wantStatus:   http.StatusBadRequest,
   158  		}, {
   159  			desc: "valid request but trillian failure",
   160  			body: strings.Join([]string{"--mimeisfunlolol",
   161  				"Content-Type: application/json",
   162  				"",
   163  				s,
   164  				"--mimeisfunlolol",
   165  				"Content-Type: application/octet-stream",
   166  				"",
   167  				"hi",
   168  				"",
   169  				"--mimeisfunlolol--",
   170  				"",
   171  			}, "\n"),
   172  			wantTrillianCall: true,
   173  			wantManifest:     s,
   174  			trillianErr:      errors.New("boom"),
   175  			wantStatus:       http.StatusInternalServerError,
   176  		},
   177  	} {
   178  		t.Run(test.desc, func(t *testing.T) {
   179  			ctrl := gomock.NewController(t)
   180  			mt := NewMockTrillian(ctrl)
   181  			server := NewServer(mt, FakeCAS{}, testSigner)
   182  
   183  			if test.wantTrillianCall {
   184  				mt.EXPECT().AddSignedStatement(gomock.Any(), gomock.Eq([]byte(test.wantManifest))).
   185  					Return(test.trillianErr)
   186  			}
   187  
   188  			r := mux.NewRouter()
   189  			server.RegisterHandlers(r)
   190  			ts := httptest.NewServer(r)
   191  			defer ts.Close()
   192  
   193  			client := ts.Client()
   194  			url := fmt.Sprintf("%s/%s", ts.URL, api.HTTPAddFirmware)
   195  			resp, err := client.Post(url, "multipart/form-data; boundary=mimeisfunlolol", strings.NewReader(test.body))
   196  			if err != nil {
   197  				t.Errorf("error response: %v", err)
   198  			}
   199  			if got, want := resp.StatusCode, test.wantStatus; got != want {
   200  				body, _ := io.ReadAll(resp.Body)
   201  				t.Errorf("status code got != want (%d, %d): %q", got, want, body)
   202  			}
   203  		})
   204  	}
   205  }
   206  
   207  func TestGetConsistency(t *testing.T) {
   208  	testSigner, _ := note.NewSigner(crypto.TestFTPersonalityPriv)
   209  	root := types.LogRootV1{TreeSize: 24, TimestampNanos: 123, RootHash: []byte{0x12, 0x34}}
   210  	for _, test := range []struct {
   211  		desc             string
   212  		from, to         int
   213  		wantFrom, wantTo uint64
   214  		trillianProof    [][]byte
   215  		trillianErr      error
   216  		wantStatus       int
   217  		wantBody         string
   218  	}{
   219  		{
   220  			desc:          "valid request",
   221  			from:          1,
   222  			to:            24,
   223  			wantFrom:      1,
   224  			wantTo:        24,
   225  			trillianProof: [][]byte{[]byte("pr"), []byte("oo"), []byte("f!")},
   226  			wantStatus:    http.StatusOK,
   227  			wantBody:      `{"Proof":["cHI=","b28=","ZiE="]}`,
   228  		}, {
   229  			desc:       "ToSize bigger than tree size",
   230  			from:       1,
   231  			to:         25,
   232  			wantStatus: http.StatusBadRequest,
   233  		}, {
   234  			desc:       "FromSize too large",
   235  			from:       15,
   236  			to:         12,
   237  			wantStatus: http.StatusBadRequest,
   238  		}, {
   239  			desc:        "valid request but trillian failure",
   240  			from:        11,
   241  			to:          15,
   242  			wantFrom:    11,
   243  			wantTo:      15,
   244  			trillianErr: errors.New("boom"),
   245  			wantStatus:  http.StatusInternalServerError,
   246  		},
   247  	} {
   248  		t.Run(test.desc, func(t *testing.T) {
   249  			ctrl := gomock.NewController(t)
   250  			mt := NewMockTrillian(ctrl)
   251  			server := NewServer(mt, FakeCAS{}, testSigner)
   252  			mt.EXPECT().Root().AnyTimes().
   253  				Return(&root)
   254  
   255  			if test.trillianProof != nil || test.trillianErr != nil {
   256  				mt.EXPECT().ConsistencyProof(gomock.Any(), gomock.Eq(test.wantFrom), gomock.Eq(test.wantTo)).
   257  					Return(test.trillianProof, test.trillianErr)
   258  			}
   259  
   260  			r := mux.NewRouter()
   261  			server.RegisterHandlers(r)
   262  			ts := httptest.NewServer(r)
   263  			defer ts.Close()
   264  			url := fmt.Sprintf("%s/%s/from/%d/to/%d", ts.URL, api.HTTPGetConsistency, test.from, test.to)
   265  
   266  			client := ts.Client()
   267  			resp, err := client.Get(url)
   268  			if err != nil {
   269  				t.Errorf("error response: %v", err)
   270  			}
   271  			if got, want := resp.StatusCode, test.wantStatus; got != want {
   272  				t.Errorf("status code got != want (%d, %d)", got, want)
   273  			}
   274  			if len(test.wantBody) > 0 {
   275  				body, err := io.ReadAll(resp.Body)
   276  				if err != nil {
   277  					t.Errorf("failed to read body: %v", err)
   278  				}
   279  				if got, want := string(body), test.wantBody; got != test.wantBody {
   280  					t.Errorf("got '%s' want '%s'", got, want)
   281  				}
   282  			}
   283  		})
   284  	}
   285  }
   286  
   287  func TestGetManifestEntries(t *testing.T) {
   288  	testSigner, _ := note.NewSigner(crypto.TestFTPersonalityPriv)
   289  	root := types.LogRootV1{TreeSize: 24, TimestampNanos: 123, RootHash: []byte{0x12, 0x34}}
   290  	for _, test := range []struct {
   291  		desc                string
   292  		index               int
   293  		treeSize            int
   294  		wantIndex, wantSize uint64
   295  		trillianData        []byte
   296  		trillianProof       [][]byte
   297  		trillianErr         error
   298  		wantStatus          int
   299  		wantBody            string
   300  	}{
   301  		{
   302  			desc:          "valid request",
   303  			index:         1,
   304  			treeSize:      24,
   305  			wantIndex:     1,
   306  			wantSize:      24,
   307  			trillianData:  []byte("leafdata"),
   308  			trillianProof: [][]byte{[]byte("pr"), []byte("oo"), []byte("f!")},
   309  			wantStatus:    http.StatusOK,
   310  			wantBody:      `{"Value":"bGVhZmRhdGE=","LeafIndex":1,"Proof":["cHI=","b28=","ZiE="]}`,
   311  		}, {
   312  			desc:       "TreeSize bigger than golden tree size",
   313  			index:      1,
   314  			treeSize:   29,
   315  			wantStatus: http.StatusBadRequest,
   316  		}, {
   317  			desc:       "LeafIndex larger than tree size",
   318  			index:      1,
   319  			treeSize:   0,
   320  			wantStatus: http.StatusBadRequest,
   321  		}, {
   322  			desc:       "LeafIndex equal to tree size",
   323  			index:      4,
   324  			treeSize:   4,
   325  			wantStatus: http.StatusBadRequest,
   326  		}, {
   327  			desc:        "valid request but trillian failure",
   328  			index:       1,
   329  			treeSize:    24,
   330  			wantIndex:   1,
   331  			wantSize:    24,
   332  			trillianErr: errors.New("boom"),
   333  			wantStatus:  http.StatusInternalServerError,
   334  		},
   335  	} {
   336  		t.Run(test.desc, func(t *testing.T) {
   337  			ctrl := gomock.NewController(t)
   338  			mt := NewMockTrillian(ctrl)
   339  			server := NewServer(mt, FakeCAS{}, testSigner)
   340  
   341  			mt.EXPECT().Root().AnyTimes().
   342  				Return(&root)
   343  
   344  			if test.trillianData != nil || test.trillianErr != nil {
   345  				mt.EXPECT().FirmwareManifestAtIndex(gomock.Any(), gomock.Eq(test.wantIndex), gomock.Eq(test.wantSize)).
   346  					Return(test.trillianData, test.trillianProof, test.trillianErr)
   347  			}
   348  
   349  			r := mux.NewRouter()
   350  			server.RegisterHandlers(r)
   351  			ts := httptest.NewServer(r)
   352  			defer ts.Close()
   353  			url := fmt.Sprintf("%s/%s/at/%d/in-tree-of/%d", ts.URL, api.HTTPGetManifestEntryAndProof, test.index, test.treeSize)
   354  
   355  			client := ts.Client()
   356  			resp, err := client.Get(url)
   357  			if err != nil {
   358  				t.Errorf("error response: %v", err)
   359  			}
   360  			if got, want := resp.StatusCode, test.wantStatus; got != want {
   361  				t.Errorf("status code got != want (%d, %d)", got, want)
   362  			}
   363  			if len(test.wantBody) > 0 {
   364  				body, err := io.ReadAll(resp.Body)
   365  				if err != nil {
   366  					t.Errorf("failed to read body: %v", err)
   367  				}
   368  				if got, want := string(body), test.wantBody; got != test.wantBody {
   369  					t.Errorf("got '%s' want '%s'", got, want)
   370  				}
   371  			}
   372  		})
   373  	}
   374  }
   375  
   376  func TestGetInclusionProofByHash(t *testing.T) {
   377  	testSigner, _ := note.NewSigner(crypto.TestFTPersonalityPriv)
   378  	root := types.LogRootV1{TreeSize: 24, TimestampNanos: 123, RootHash: []byte{0x12, 0x34}}
   379  	for _, test := range []struct {
   380  		desc          string
   381  		hash          []byte
   382  		treeSize      int
   383  		trillianIndex uint64
   384  		trillianProof [][]byte
   385  		trillianErr   error
   386  		wantStatus    int
   387  	}{
   388  		{
   389  			desc:          "valid request",
   390  			hash:          []byte("a good leaf hash"),
   391  			treeSize:      24,
   392  			trillianProof: [][]byte{[]byte("pr"), []byte("oo"), []byte("f!")},
   393  			trillianIndex: 4,
   394  			wantStatus:    http.StatusOK,
   395  		}, {
   396  			desc:       "TreeSize bigger than golden tree size",
   397  			hash:       []byte("a good leaf hash"),
   398  			treeSize:   29,
   399  			wantStatus: http.StatusBadRequest,
   400  		}, {
   401  			desc:        "unknown leafhash",
   402  			hash:        []byte("made up leaf hash"),
   403  			treeSize:    24,
   404  			trillianErr: status.Error(codes.NotFound, "never heard of it, mate"),
   405  			wantStatus:  http.StatusNotFound,
   406  		}, {
   407  			desc:        "valid request but trillian failure",
   408  			hash:        []byte("a good leaf hash"),
   409  			treeSize:    24,
   410  			trillianErr: errors.New("boom"),
   411  			wantStatus:  http.StatusInternalServerError,
   412  		},
   413  	} {
   414  		t.Run(test.desc, func(t *testing.T) {
   415  			ctrl := gomock.NewController(t)
   416  			mt := NewMockTrillian(ctrl)
   417  			server := NewServer(mt, FakeCAS{}, testSigner)
   418  
   419  			mt.EXPECT().Root().AnyTimes().
   420  				Return(&root)
   421  
   422  			if test.trillianProof != nil || test.trillianErr != nil {
   423  				mt.EXPECT().InclusionProofByHash(gomock.Any(), gomock.Eq(test.hash), gomock.Eq(root.TreeSize)).
   424  					Return(test.trillianIndex, test.trillianProof, test.trillianErr)
   425  			}
   426  
   427  			r := mux.NewRouter()
   428  			server.RegisterHandlers(r)
   429  			ts := httptest.NewServer(r)
   430  			defer ts.Close()
   431  			url := fmt.Sprintf("%s/%s/for-leaf-hash/%s/in-tree-of/%d", ts.URL, api.HTTPGetInclusion, base64.URLEncoding.EncodeToString(test.hash), test.treeSize)
   432  
   433  			client := ts.Client()
   434  			resp, err := client.Get(url)
   435  			if err != nil {
   436  				t.Errorf("error response: %v", err)
   437  			}
   438  			if got, want := resp.StatusCode, test.wantStatus; got != want {
   439  				t.Errorf("status code got != want (%d, %d)", got, want)
   440  			}
   441  			if test.wantStatus == http.StatusOK {
   442  				// If we're expecting a good response then check that all values got passed through ok
   443  				body, err := io.ReadAll(resp.Body)
   444  				if err != nil {
   445  					t.Fatalf("failed to read body: %v", err)
   446  				}
   447  				wantBody := api.InclusionProof{
   448  					LeafIndex: test.trillianIndex,
   449  					Proof:     test.trillianProof,
   450  				}
   451  				var gotBody api.InclusionProof
   452  				if err := json.Unmarshal(body, &gotBody); err != nil {
   453  					t.Fatalf("got invalid json response: %q", err)
   454  				}
   455  				if diff := cmp.Diff(gotBody, wantBody); len(diff) > 0 {
   456  					t.Errorf("got response with diff %q", diff)
   457  				}
   458  			}
   459  		})
   460  	}
   461  }
   462  
   463  type FakeCAS map[string][]byte
   464  
   465  func (f FakeCAS) Store(key, image []byte) error {
   466  	f[string(key)] = image
   467  	return nil
   468  }
   469  
   470  func (f FakeCAS) Retrieve(key []byte) ([]byte, error) {
   471  	if image, ok := f[string(key)]; ok {
   472  		return image, nil
   473  	}
   474  	return nil, errors.New("nope")
   475  }