gitee.com/mirrors_u-root/u-root@v7.0.0+incompatible/cmds/core/wget/wget_test.go (about)

     1  // Copyright 2017 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // A parity test can be run:
     6  //     go test
     7  //     EXECPATH="wget -O -" go test
     8  package main
     9  
    10  import (
    11  	"fmt"
    12  	"io/ioutil"
    13  	"log"
    14  	"net"
    15  	"net/http"
    16  	"path/filepath"
    17  	"testing"
    18  
    19  	"github.com/u-root/u-root/pkg/testutil"
    20  )
    21  
    22  const content = "Very simple web server"
    23  
    24  type handler struct{}
    25  
    26  func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    27  	switch r.URL.Path {
    28  	case "/200":
    29  		w.WriteHeader(200)
    30  		w.Write([]byte(content))
    31  	case "/302":
    32  		http.Redirect(w, r, "/200", http.StatusFound /* 302 */)
    33  	case "/500":
    34  		w.WriteHeader(500)
    35  		w.Write([]byte(content))
    36  	default:
    37  		w.WriteHeader(404)
    38  		w.Write([]byte(content))
    39  	}
    40  }
    41  
    42  var tests = []struct {
    43  	flags   []string // in, %[1]d is the server's port, %[2] is an unopen port
    44  	url     string   // in
    45  	content string   // out
    46  	retCode int      // out
    47  }{
    48  	{
    49  		// basic
    50  		flags:   []string{},
    51  		url:     "http://localhost:%[1]d/200",
    52  		content: content,
    53  		retCode: 0,
    54  	}, {
    55  		// ipv4
    56  		flags:   []string{},
    57  		url:     "http://127.0.0.1:%[1]d/200",
    58  		content: content,
    59  		retCode: 0,
    60  	}, /*{ TODO: travis does not support ipv6
    61  		// ipv6
    62  		flags:   []string{},
    63  		url:     "http://[::1]:%[1]d/200",
    64  		content:  content,
    65  		retCode: 0,
    66  	},*/{
    67  		// redirect
    68  		flags:   []string{},
    69  		url:     "http://localhost:%[1]d/302",
    70  		content: "",
    71  		retCode: 0,
    72  	}, {
    73  		// 4xx error
    74  		flags:   []string{},
    75  		url:     "http://localhost:%[1]d/404",
    76  		content: "",
    77  		retCode: 1,
    78  	}, {
    79  		// 5xx error
    80  		flags:   []string{},
    81  		url:     "http://localhost:%[1]d/500",
    82  		content: "",
    83  		retCode: 1,
    84  	}, {
    85  		// no server
    86  		flags:   []string{},
    87  		url:     "http://localhost:%[2]d/200",
    88  		content: "",
    89  		retCode: 1,
    90  	}, {
    91  		// output file
    92  		flags:   []string{"-O", "/dev/null"},
    93  		url:     "http://localhost:%[1]d/200",
    94  		content: "",
    95  		retCode: 0,
    96  	},
    97  }
    98  
    99  func getFreePort(t *testing.T) int {
   100  	l, err := net.Listen("tcp", ":0")
   101  	if err != nil {
   102  		t.Fatalf("Cannot create free port: %v", err)
   103  	}
   104  	l.Close()
   105  	return l.Addr().(*net.TCPAddr).Port
   106  }
   107  
   108  // TestWget implements a table-driven test.
   109  func TestWget(t *testing.T) {
   110  	// Start a webserver on a free port.
   111  	unusedPort := getFreePort(t)
   112  
   113  	l, err := net.Listen("tcp", ":0")
   114  	if err != nil {
   115  		t.Fatalf("Cannot create free port: %v", err)
   116  	}
   117  	port := l.Addr().(*net.TCPAddr).Port
   118  
   119  	h := handler{}
   120  	go func() {
   121  		log.Fatal(http.Serve(l, h))
   122  	}()
   123  
   124  	for i, tt := range tests {
   125  		args := append(tt.flags, fmt.Sprintf(tt.url, port, unusedPort))
   126  		output, err := testutil.Command(t, args...).CombinedOutput()
   127  
   128  		// Check return code.
   129  		if err := testutil.IsExitCode(err, tt.retCode); err != nil {
   130  			t.Errorf("exit code: %v, output: %s", err, string(output))
   131  		}
   132  
   133  		if tt.content != "" {
   134  			fileName := filepath.Base(tt.url)
   135  			content, err := ioutil.ReadFile(fileName)
   136  			if err != nil {
   137  				t.Errorf("%d. File %s was not created: %v", i, fileName, err)
   138  			}
   139  
   140  			// Check content.
   141  			if string(content) != tt.content {
   142  				t.Errorf("%d. Want:\n%#v\nGot:\n%#v", i, tt.content, string(content))
   143  			}
   144  		}
   145  	}
   146  }
   147  
   148  func TestMain(m *testing.M) {
   149  	testutil.Run(m, main)
   150  }