github.com/99designs/gqlgen@v0.17.45/graphql/playground/helper_test.go (about)

     1  package playground
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"encoding/base64"
     6  	"io"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/PuerkitoBio/goquery"
    13  	"github.com/stretchr/testify/assert"
    14  )
    15  
    16  func testResourceIntegrity(t *testing.T, handler func(title, endpoint string) http.HandlerFunc) {
    17  	recorder := httptest.NewRecorder()
    18  	request := httptest.NewRequest(http.MethodGet, "http://localhost:8080/", nil)
    19  	handler("example.org API", "/query").ServeHTTP(recorder, request)
    20  
    21  	res := recorder.Result()
    22  	defer assert.NoError(t, res.Body.Close())
    23  
    24  	assert.Equal(t, http.StatusOK, res.StatusCode)
    25  	assert.True(t, strings.HasPrefix(res.Header.Get("Content-Type"), "text/html"))
    26  
    27  	doc, err := goquery.NewDocumentFromReader(res.Body)
    28  	assert.NoError(t, err)
    29  	assert.NotNil(t, doc)
    30  
    31  	var baseUrl string
    32  	if base := doc.Find("base"); len(base.Nodes) != 0 {
    33  		if value, exists := base.Attr("href"); exists {
    34  			baseUrl = value
    35  		}
    36  	}
    37  
    38  	assertNodesIntegrity(t, baseUrl, doc, "script", "src", "integrity")
    39  	assertNodesIntegrity(t, baseUrl, doc, "link", "href", "integrity")
    40  }
    41  
    42  func assertNodesIntegrity(t *testing.T, baseUrl string, doc *goquery.Document, selector string, urlAttrKey, integrityAttrKey string) {
    43  	selection := doc.Find(selector)
    44  	for _, node := range selection.Nodes {
    45  		var url string
    46  		var integrity string
    47  		for _, attribute := range node.Attr {
    48  			if attribute.Key == urlAttrKey {
    49  				url = attribute.Val
    50  			} else if attribute.Key == integrityAttrKey {
    51  				integrity = attribute.Val
    52  			}
    53  		}
    54  
    55  		if len(integrity) != 0 {
    56  			assert.NotEmpty(t, url)
    57  		}
    58  
    59  		if len(url) != 0 && len(integrity) != 0 {
    60  			resp, err := http.Get(baseUrl + url)
    61  			assert.NoError(t, err)
    62  			hasher := sha256.New()
    63  			_, err = io.Copy(hasher, resp.Body)
    64  			assert.NoError(t, err)
    65  			assert.NoError(t, resp.Body.Close())
    66  			actual := "sha256-" + base64.StdEncoding.EncodeToString(hasher.Sum(nil))
    67  			assert.Equal(t, integrity, actual)
    68  		}
    69  	}
    70  }