github.com/cilki/sh@v2.6.4+incompatible/shell/source_test.go (about)

     1  // Copyright (c) 2018, Daniel Martí <mvdan@mvdan.cc>
     2  // See LICENSE for licensing information
     3  
     4  package shell
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"io/ioutil"
    10  	"os"
    11  	"reflect"
    12  	"strings"
    13  	"testing"
    14  
    15  	"mvdan.cc/sh/expand"
    16  	"mvdan.cc/sh/syntax"
    17  
    18  	"github.com/kr/pretty"
    19  )
    20  
    21  var mapTests = []struct {
    22  	in   string
    23  	want map[string]expand.Variable
    24  }{
    25  	{
    26  		"a=x; b=y",
    27  		map[string]expand.Variable{
    28  			"a": {Value: "x"},
    29  			"b": {Value: "y"},
    30  		},
    31  	},
    32  	{
    33  		"a=x; a=y; X=(a b c)",
    34  		map[string]expand.Variable{
    35  			"a": {Value: "y"},
    36  			"X": {Value: []string{"a", "b", "c"}},
    37  		},
    38  	},
    39  	{
    40  		"a=$(echo foo | sed 's/o/a/g')",
    41  		map[string]expand.Variable{
    42  			"a": {Value: "faa"},
    43  		},
    44  	},
    45  }
    46  
    47  var errTests = []struct {
    48  	in   string
    49  	want string
    50  }{
    51  	{
    52  		"a=b; exit 1",
    53  		"exit status 1",
    54  	},
    55  }
    56  
    57  func TestSourceNode(t *testing.T) {
    58  	for i := range mapTests {
    59  		t.Run(fmt.Sprintf("%02d", i), func(t *testing.T) {
    60  			tc := mapTests[i]
    61  			t.Parallel()
    62  			p := syntax.NewParser()
    63  			file, err := p.Parse(strings.NewReader(tc.in), "")
    64  			if err != nil {
    65  				t.Fatal(err)
    66  			}
    67  			got, err := SourceNode(context.Background(), file)
    68  			if err != nil {
    69  				t.Fatal(err)
    70  			}
    71  			if !reflect.DeepEqual(tc.want, got) {
    72  				t.Fatal(strings.Join(pretty.Diff(tc.want, got), "\n"))
    73  			}
    74  		})
    75  	}
    76  }
    77  
    78  func TestSourceNodeErr(t *testing.T) {
    79  	for i := range errTests {
    80  		t.Run(fmt.Sprintf("%02d", i), func(t *testing.T) {
    81  			tc := errTests[i]
    82  			t.Parallel()
    83  			p := syntax.NewParser()
    84  			file, err := p.Parse(strings.NewReader(tc.in), "")
    85  			if err != nil {
    86  				t.Fatal(err)
    87  			}
    88  			_, err = SourceNode(context.Background(), file)
    89  			if err == nil {
    90  				t.Fatal("wanted non-nil error")
    91  			}
    92  			if !strings.Contains(err.Error(), tc.want) {
    93  				t.Fatalf("error %q does not match %q", err, tc.want)
    94  			}
    95  		})
    96  	}
    97  }
    98  
    99  func TestSourceFileContext(t *testing.T) {
   100  	t.Parallel()
   101  	tf, err := ioutil.TempFile("", "sh-shell")
   102  	if err != nil {
   103  		t.Fatal(err)
   104  	}
   105  	defer os.Remove(tf.Name())
   106  	const src = "cat" // block forever
   107  	if _, err := tf.WriteString(src); err != nil {
   108  		t.Fatal(err)
   109  	}
   110  	if err := tf.Close(); err != nil {
   111  		t.Fatal(err)
   112  	}
   113  
   114  	ctx, cancel := context.WithCancel(context.Background())
   115  	cancel()
   116  	errc := make(chan error, 1)
   117  	go func() {
   118  		_, err := SourceFile(ctx, tf.Name())
   119  		errc <- err
   120  	}()
   121  	err = <-errc
   122  	want := "context canceled"
   123  	if err == nil || !strings.Contains(err.Error(), want) {
   124  		t.Fatalf("error %q does not match %q", err, want)
   125  	}
   126  }