github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/examples/trustedtypes/server_test.go (about)

     1  // Copyright 2022 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  import (
    18  	"fmt"
    19  	"net/http/httptest"
    20  	"testing"
    21  
    22  	"github.com/google/safehtml/template"
    23  
    24  	"github.com/google/go-safeweb/safehttp"
    25  	"github.com/google/safehtml"
    26  )
    27  
    28  func TestNewMuxConfig(t *testing.T) {
    29  	cf, addr := newServeMuxConfig()
    30  	mux := cf.Mux()
    31  	h := safehttp.HandlerFunc(func(w safehttp.ResponseWriter, r *safehttp.IncomingRequest) safehttp.Result {
    32  		return w.Write(safehtml.HTMLEscaped("<h1>Hello World!</h1>"))
    33  	})
    34  	mux.Handle("/", safehttp.MethodGet, h)
    35  
    36  	req := httptest.NewRequest(safehttp.MethodGet, fmt.Sprintf("http://%s/", addr), nil)
    37  	rw := httptest.NewRecorder()
    38  	mux.ServeHTTP(rw, req)
    39  
    40  	body := rw.Body.String()
    41  	if want := "&lt;h1&gt;Hello World!&lt;/h1&gt;"; body != want {
    42  		t.Errorf("body got: %q want: %q", body, want)
    43  	}
    44  }
    45  
    46  func Test_loadTemplate(t *testing.T) {
    47  	tests := []struct {
    48  		name    string
    49  		src     string
    50  		wantErr bool
    51  	}{
    52  		{
    53  			name:    "existing template, no error",
    54  			src:     "safe.html",
    55  			wantErr: false,
    56  		},
    57  		{
    58  			name:    "missing template, error",
    59  			src:     "not_existing.html",
    60  			wantErr: true,
    61  		},
    62  		{
    63  			name:    "invalid source name, error",
    64  			src:     "../../hidden.html",
    65  			wantErr: true,
    66  		},
    67  	}
    68  	for _, tt := range tests {
    69  		t.Run(tt.name, func(t *testing.T) {
    70  			defer func() {
    71  				if r := recover(); r != nil && !tt.wantErr {
    72  					t.Errorf("unexpected panic: %v", r)
    73  				}
    74  			}()
    75  			_, err := loadTemplate(tt.src)
    76  			if err != nil && !tt.wantErr {
    77  				t.Errorf("loadTemplate() error = %v, wantErr %v", err, tt.wantErr)
    78  				return
    79  			}
    80  		})
    81  	}
    82  }
    83  
    84  func TestHandleTemplate(t *testing.T) {
    85  	mux := safehttp.NewServeMuxConfig(nil).Mux()
    86  	safeTmpl := template.Must(template.New("standard").Parse(`<h1>Hi there!</h1>`))
    87  	mux.Handle("/", safehttp.MethodGet, safehttp.HandlerFunc(handleTemplate(safeTmpl)))
    88  
    89  	req := httptest.NewRequest(safehttp.MethodGet, "/spaghetti", nil)
    90  	rw := httptest.NewRecorder()
    91  	mux.ServeHTTP(rw, req)
    92  
    93  	if body, want := rw.Body.String(), "<h1>Hi there!</h1>"; body != want {
    94  		t.Errorf("handleTemplate() got %q, want %q", body, want)
    95  	}
    96  }