github.com/cilki/sh@v2.6.4+incompatible/interp/module_test.go (about)

     1  // Copyright (c) 2017, Daniel Martí <mvdan@mvdan.cc>
     2  // See LICENSE for licensing information
     3  
     4  package interp
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"fmt"
    10  	"io"
    11  	"os"
    12  	"runtime"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  
    18  	"mvdan.cc/sh/syntax"
    19  )
    20  
    21  var modCases = []struct {
    22  	name string
    23  	exec ModuleExec
    24  	open ModuleOpen
    25  	src  string
    26  	want string
    27  }{
    28  	{
    29  		name: "ExecBlacklist",
    30  		exec: func(ctx context.Context, path string, args []string) error {
    31  			if args[0] == "sleep" {
    32  				return fmt.Errorf("blacklisted: %s", args[0])
    33  			}
    34  			return DefaultExec(ctx, path, args)
    35  		},
    36  		src:  "echo foo; sleep 1",
    37  		want: "foo\nblacklisted: sleep",
    38  	},
    39  	{
    40  		name: "ExecWhitelist",
    41  		exec: func(ctx context.Context, path string, args []string) error {
    42  			switch args[0] {
    43  			case "sed", "grep":
    44  			default:
    45  				return fmt.Errorf("blacklisted: %s", args[0])
    46  			}
    47  			return DefaultExec(ctx, path, args)
    48  		},
    49  		src:  "a=$(echo foo | sed 's/o/a/g'); echo $a; $a args",
    50  		want: "faa\nblacklisted: faa",
    51  	},
    52  	{
    53  		name: "ExecSubshell",
    54  		exec: func(ctx context.Context, path string, args []string) error {
    55  			return fmt.Errorf("blacklisted: %s", args[0])
    56  		},
    57  		src:  "(malicious)",
    58  		want: "blacklisted: malicious",
    59  	},
    60  	{
    61  		name: "ExecPipe",
    62  		exec: func(ctx context.Context, path string, args []string) error {
    63  			return fmt.Errorf("blacklisted: %s", args[0])
    64  		},
    65  		src:  "malicious | echo foo",
    66  		want: "foo\nblacklisted: malicious",
    67  	},
    68  	{
    69  		name: "ExecCmdSubst",
    70  		exec: func(ctx context.Context, path string, args []string) error {
    71  			return fmt.Errorf("blacklisted: %s", args[0])
    72  		},
    73  		src:  "a=$(malicious)",
    74  		want: "blacklisted: malicious",
    75  	},
    76  	{
    77  		name: "ExecBackground",
    78  		exec: func(ctx context.Context, path string, args []string) error {
    79  			return fmt.Errorf("blacklisted: %s", args[0])
    80  		},
    81  		src:  "{ malicious; true; } & { malicious; true; } & wait",
    82  		want: "blacklisted: malicious",
    83  	},
    84  	{
    85  		name: "OpenForbidNonDev",
    86  		open: OpenDevImpls(func(ctx context.Context, path string, flags int, mode os.FileMode) (io.ReadWriteCloser, error) {
    87  			mc, _ := FromModuleContext(ctx)
    88  			return nil, fmt.Errorf("non-dev: %s", mc.UnixPath(path))
    89  		}),
    90  		src:  "echo foo >/dev/null; echo bar >/tmp/x",
    91  		want: "non-dev: /tmp/x",
    92  	},
    93  }
    94  
    95  func TestRunnerModules(t *testing.T) {
    96  	t.Parallel()
    97  	p := syntax.NewParser()
    98  	for _, tc := range modCases {
    99  		t.Run(tc.name, func(t *testing.T) {
   100  			file, err := p.Parse(strings.NewReader(tc.src), "")
   101  			if err != nil {
   102  				t.Fatalf("could not parse: %v", err)
   103  			}
   104  			var cb concBuffer
   105  			r, err := New(StdIO(nil, &cb, &cb),
   106  				Module(tc.exec), Module(tc.open))
   107  			if err != nil {
   108  				t.Fatal(err)
   109  			}
   110  			ctx := context.Background()
   111  			if err := r.Run(ctx, file); err != nil {
   112  				cb.WriteString(err.Error())
   113  			}
   114  			got := cb.String()
   115  			if got != tc.want {
   116  				t.Fatalf("want:\n%s\ngot:\n%s", tc.want, got)
   117  			}
   118  		})
   119  	}
   120  }
   121  
   122  func TestRunnerDefaultModules(t *testing.T) {
   123  	t.Parallel()
   124  	_, err := New(Module(DefaultOpen), Module(DefaultExec))
   125  	if err != nil {
   126  		t.Fatal(err)
   127  	}
   128  }
   129  
   130  type readyBuffer struct {
   131  	buf       bytes.Buffer
   132  	seenReady sync.WaitGroup
   133  }
   134  
   135  func (b *readyBuffer) Write(p []byte) (n int, err error) {
   136  	if string(p) == "ready\n" {
   137  		b.seenReady.Done()
   138  		return len(p), nil
   139  	}
   140  	return b.buf.Write(p)
   141  }
   142  
   143  func TestKillTimeout(t *testing.T) {
   144  	if testing.Short() {
   145  		t.Skip("sleeps and timeouts are slow")
   146  	}
   147  	if runtime.GOOS == "windows" {
   148  		t.Skip("skipping trap tests on windows")
   149  	}
   150  	tests := []struct {
   151  		src         string
   152  		want        string
   153  		killTimeout time.Duration
   154  		forcedKill  bool
   155  	}{
   156  		// killed immediately
   157  		{
   158  			`bash -c "trap 'echo trapped; exit 0' INT; echo ready; for i in {1..100}; do sleep 0.01; done"`,
   159  			"",
   160  			-1,
   161  			true,
   162  		},
   163  		// interrupted first, and stops itself in time
   164  		{
   165  			`bash -c "trap 'echo trapped; exit 0' INT; echo ready; for i in {1..100}; do sleep 0.01; done"`,
   166  			"trapped\n",
   167  			time.Second,
   168  			false,
   169  		},
   170  		// interrupted first, but does not stop itself in time
   171  		{
   172  			`bash -c "trap 'echo trapped; for i in {1..100}; do sleep 0.01; done' INT; echo ready; for i in {1..100}; do sleep 0.01; done"`,
   173  			"trapped\n",
   174  			20 * time.Millisecond,
   175  			true,
   176  		},
   177  	}
   178  
   179  	for i := range tests {
   180  		test := tests[i]
   181  		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   182  			t.Parallel()
   183  			p := syntax.NewParser()
   184  			file, err := p.Parse(strings.NewReader(test.src), "")
   185  			if err != nil {
   186  				t.Errorf("could not parse: %v", err)
   187  			}
   188  			attempt := 0
   189  			for {
   190  				var rbuf readyBuffer
   191  				rbuf.seenReady.Add(1)
   192  				ctx, cancel := context.WithCancel(context.Background())
   193  				r, err := New(StdIO(nil, &rbuf, &rbuf))
   194  				if err != nil {
   195  					t.Fatal(err)
   196  				}
   197  				r.KillTimeout = test.killTimeout
   198  				go func() {
   199  					rbuf.seenReady.Wait()
   200  					cancel()
   201  				}()
   202  				err = r.Run(ctx, file)
   203  				if test.forcedKill {
   204  					if _, ok := err.(ExitStatus); ok || err == nil {
   205  						t.Error("command was not force-killed")
   206  					}
   207  				} else {
   208  					if err != nil && err != context.Canceled && err != context.DeadlineExceeded {
   209  						t.Errorf("execution errored: %v", err)
   210  					}
   211  				}
   212  				got := rbuf.buf.String()
   213  				if got != test.want {
   214  					if attempt < 3 && got == "" && test.killTimeout > 0 {
   215  						attempt++
   216  						test.killTimeout *= 2
   217  						continue
   218  					}
   219  					t.Fatalf("want:\n%s\ngot:\n%s", test.want, got)
   220  				}
   221  				break
   222  			}
   223  		})
   224  	}
   225  }