github.com/hernad/nomad@v1.6.112/helper/escapingio/reader_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package escapingio
     5  
     6  import (
     7  	"bytes"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"math/rand"
    12  	"reflect"
    13  	"regexp"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  	"testing/iotest"
    18  	"testing/quick"
    19  	"time"
    20  	"unicode"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  )
    25  
    26  func TestEscapingReader_Static(t *testing.T) {
    27  	cases := []struct {
    28  		input    string
    29  		expected string
    30  		escaped  string
    31  	}{
    32  		{"hello", "hello", ""},
    33  		{"he\nllo", "he\nllo", ""},
    34  		{"he~.lo", "he~.lo", ""},
    35  		{"he\n~.rest", "he\nrest", "."},
    36  		{"he\n~.r\n~.est", "he\nr\nest", ".."},
    37  		{"he\n~~r\n~~est", "he\n~r\n~est", ""},
    38  		{"he\n~~r\n~.est", "he\n~r\nest", "."},
    39  		{"he\nr~~est", "he\nr~~est", ""},
    40  		{"he\nr\n~qest", "he\nr\n~qest", "q"},
    41  		{"he\nr\r~qe\r~.st", "he\nr\r~qe\rst", "q."},
    42  		{"~q", "~q", "q"},
    43  		{"~.", "", "."},
    44  		{"m~.", "m~.", ""},
    45  		{"\n~.", "\n", "."},
    46  		{"~", "~", ""},
    47  		{"\r~.", "\r", "."},
    48  		{"b\n~\n~.q", "b\n~\nq", "."},
    49  	}
    50  
    51  	for _, c := range cases {
    52  		t.Run("validate naive implementation", func(t *testing.T) {
    53  			h := &testHandler{}
    54  
    55  			processed := naiveEscapeCharacters(c.input, '~', h.handler)
    56  			require.Equal(t, c.expected, processed)
    57  			require.Equal(t, c.escaped, h.escaped())
    58  		})
    59  
    60  		t.Run("chunks at a time: "+c.input, func(t *testing.T) {
    61  			var found bytes.Buffer
    62  
    63  			input := strings.NewReader(c.input)
    64  
    65  			h := &testHandler{}
    66  
    67  			filter := NewReader(input, '~', h.handler)
    68  
    69  			_, err := io.Copy(&found, filter)
    70  			require.NoError(t, err)
    71  
    72  			require.Equal(t, c.expected, found.String())
    73  			require.Equal(t, c.escaped, h.escaped())
    74  		})
    75  
    76  		t.Run("1 byte at a time: "+c.input, func(t *testing.T) {
    77  			var found bytes.Buffer
    78  
    79  			input := iotest.OneByteReader(strings.NewReader(c.input))
    80  
    81  			h := &testHandler{}
    82  
    83  			filter := NewReader(input, '~', h.handler)
    84  			_, err := io.Copy(&found, filter)
    85  			require.NoError(t, err)
    86  
    87  			require.Equal(t, c.expected, found.String())
    88  			require.Equal(t, c.escaped, h.escaped())
    89  		})
    90  
    91  		t.Run("without reading: "+c.input, func(t *testing.T) {
    92  			input := strings.NewReader(c.input)
    93  
    94  			h := &testHandler{}
    95  
    96  			filter := NewReader(input, '~', h.handler)
    97  
    98  			// don't read to mimic a stalled reader
    99  			_ = filter
   100  
   101  			assertEventually(t, func() (bool, error) {
   102  				escaped := h.escaped()
   103  				if c.escaped == escaped {
   104  					return true, nil
   105  				}
   106  
   107  				return false, fmt.Errorf("expected %v but found %v", c.escaped, escaped)
   108  			})
   109  		})
   110  	}
   111  }
   112  
   113  // TestEscapingReader_EmitsPartialReads should emit partial results
   114  // if next character is not read
   115  func TestEscapingReader_FlushesPartialReads(t *testing.T) {
   116  	pr, pw := io.Pipe()
   117  
   118  	h := &testHandler{}
   119  	filter := NewReader(pr, '~', h.handler)
   120  
   121  	var lock sync.Mutex
   122  	var read bytes.Buffer
   123  
   124  	// helper for asserting reads
   125  	requireRead := func(expected *bytes.Buffer) {
   126  		readSoFar := ""
   127  
   128  		start := time.Now()
   129  		for time.Since(start) < 2*time.Second {
   130  			lock.Lock()
   131  			readSoFar = read.String()
   132  			lock.Unlock()
   133  
   134  			if readSoFar == expected.String() {
   135  				break
   136  			}
   137  
   138  			time.Sleep(50 * time.Millisecond)
   139  		}
   140  
   141  		require.Equal(t, expected.String(), readSoFar, "timed out without output")
   142  	}
   143  
   144  	var rerr error
   145  	var wg sync.WaitGroup
   146  	wg.Add(1)
   147  
   148  	// goroutine for reading partial data
   149  	go func() {
   150  		defer wg.Done()
   151  
   152  		buf := make([]byte, 1024)
   153  		for {
   154  			n, err := filter.Read(buf)
   155  			lock.Lock()
   156  			read.Write(buf[:n])
   157  			lock.Unlock()
   158  
   159  			if err != nil {
   160  				rerr = err
   161  				break
   162  			}
   163  		}
   164  	}()
   165  
   166  	expected := &bytes.Buffer{}
   167  
   168  	// test basic start and no new lines
   169  	pw.Write([]byte("first data"))
   170  	expected.WriteString("first data")
   171  	requireRead(expected)
   172  	require.Equal(t, "", h.escaped())
   173  
   174  	// test ~. appearing in middle of line but stop at new line
   175  	pw.Write([]byte("~.inmiddleappears\n"))
   176  	expected.WriteString("~.inmiddleappears\n")
   177  	requireRead(expected)
   178  	require.Equal(t, "", h.escaped())
   179  
   180  	// from here on we test \n~ at boundary
   181  
   182  	// ~~ after new line; and stop at \n~
   183  	pw.Write([]byte("~~second line\n~"))
   184  	expected.WriteString("~second line\n")
   185  	requireRead(expected)
   186  	require.Equal(t, "", h.escaped())
   187  
   188  	// . to be skipped; stop at \n~ again
   189  	pw.Write([]byte(".third line\n~"))
   190  	expected.WriteString("third line\n")
   191  	requireRead(expected)
   192  	require.Equal(t, ".", h.escaped())
   193  
   194  	// q to be emitted; stop at \n
   195  	pw.Write([]byte("qfourth line\n"))
   196  	expected.WriteString("~qfourth line\n")
   197  	requireRead(expected)
   198  	require.Equal(t, ".q", h.escaped())
   199  
   200  	// ~. to be skipped; stop at \n~
   201  	pw.Write([]byte("~.fifth line\n~"))
   202  	expected.WriteString("fifth line\n")
   203  	requireRead(expected)
   204  	require.Equal(t, ".q.", h.escaped())
   205  
   206  	// ~ alone after \n~ - should be emitted
   207  	pw.Write([]byte("~"))
   208  	expected.WriteString("~")
   209  	requireRead(expected)
   210  	require.Equal(t, ".q.", h.escaped())
   211  
   212  	// rest of line ending with \n~
   213  	pw.Write([]byte("rest of line\n~"))
   214  	expected.WriteString("rest of line\n")
   215  	requireRead(expected)
   216  	require.Equal(t, ".q.", h.escaped())
   217  
   218  	// m alone after \n~ - should be emitted with ~
   219  	pw.Write([]byte("m"))
   220  	expected.WriteString("~m")
   221  	requireRead(expected)
   222  	require.Equal(t, ".q.m", h.escaped())
   223  
   224  	// rest of line and end with \n
   225  	pw.Write([]byte("onemore line\n"))
   226  	expected.WriteString("onemore line\n")
   227  	requireRead(expected)
   228  	require.Equal(t, ".q.m", h.escaped())
   229  
   230  	// ~q to be emitted stop at \n~; last charcater
   231  	pw.Write([]byte("~qlast line\n~"))
   232  	expected.WriteString("~qlast line\n")
   233  	requireRead(expected)
   234  	require.Equal(t, ".q.mq", h.escaped())
   235  
   236  	// last ~ gets emitted and we preserve error
   237  	eerr := errors.New("my custom error")
   238  	pw.CloseWithError(eerr)
   239  	expected.WriteString("~")
   240  	requireRead(expected)
   241  	require.Equal(t, ".q.mq", h.escaped())
   242  
   243  	wg.Wait()
   244  	require.Error(t, rerr)
   245  	require.Equal(t, eerr, rerr)
   246  }
   247  
   248  func TestEscapingReader_Generated_EquivalentToNaive(t *testing.T) {
   249  	f := func(v readingInput) bool {
   250  		return checkEquivalenceToNaive(t, string(v))
   251  	}
   252  
   253  	require.NoError(t, quick.Check(f, &quick.Config{
   254  		MaxCountScale: 200,
   255  	}))
   256  }
   257  
   258  // testHandler is a conveneient struct for finding "escaped" ascii letters
   259  // in escaping reader.
   260  // We avoid complicated unicode characters that may cross byte boundary
   261  type testHandler struct {
   262  	l      sync.Mutex
   263  	result string
   264  }
   265  
   266  // handler is method to be passed to escaping io reader
   267  func (t *testHandler) handler(c byte) bool {
   268  	rc := rune(c)
   269  	simple := unicode.IsLetter(rc) ||
   270  		unicode.IsDigit(rc) ||
   271  		unicode.IsPunct(rc) ||
   272  		unicode.IsSymbol(rc)
   273  
   274  	if simple {
   275  		t.l.Lock()
   276  		t.result += string([]byte{c})
   277  		t.l.Unlock()
   278  	}
   279  	return c == '.'
   280  }
   281  
   282  // escaped returns all seen escaped characters so far
   283  func (t *testHandler) escaped() string {
   284  	t.l.Lock()
   285  	defer t.l.Unlock()
   286  
   287  	return t.result
   288  }
   289  
   290  // checkEquivalence returns true if parsing input with naive implementation
   291  // is equivalent to our reader
   292  func checkEquivalenceToNaive(t *testing.T, input string) bool {
   293  	nh := &testHandler{}
   294  	expected := naiveEscapeCharacters(input, '~', nh.handler)
   295  
   296  	foundH := &testHandler{}
   297  
   298  	var inputReader io.Reader = bytes.NewBufferString(input)
   299  	inputReader = &arbtiraryReader{
   300  		buf:         inputReader.(*bytes.Buffer),
   301  		maxReadOnce: 10,
   302  	}
   303  	filter := NewReader(inputReader, '~', foundH.handler)
   304  	var found bytes.Buffer
   305  	_, err := io.Copy(&found, filter)
   306  	if err != nil {
   307  		t.Logf("unexpected error while reading: %v", err)
   308  		return false
   309  	}
   310  
   311  	if nh.escaped() == foundH.escaped() && expected == found.String() {
   312  		return true
   313  	}
   314  
   315  	t.Logf("escaped differed=%v expected=%v found=%v", nh.escaped() != foundH.escaped(), nh.escaped(), foundH.escaped())
   316  	t.Logf("read  differed=%v expected=%s found=%v", expected != found.String(), expected, found.String())
   317  	return false
   318  
   319  }
   320  
   321  func TestEscapingReader_Generated_EquivalentToReadOnce(t *testing.T) {
   322  	f := func(v readingInput) bool {
   323  		return checkEquivalenceToNaive(t, string(v))
   324  	}
   325  
   326  	require.NoError(t, quick.Check(f, &quick.Config{
   327  		MaxCountScale: 200,
   328  	}))
   329  }
   330  
   331  // checkEquivalenceToReadOnce returns true if parsing input in a single
   332  // read matches multiple reads
   333  func checkEquivalenceToReadOnce(t *testing.T, input string) bool {
   334  	nh := &testHandler{}
   335  	var expected bytes.Buffer
   336  
   337  	// getting expected value from read all at once
   338  	{
   339  		buf := make([]byte, len(input)+5)
   340  		inputReader := NewReader(bytes.NewBufferString(input), '~', nh.handler)
   341  		_, err := io.CopyBuffer(&expected, inputReader, buf)
   342  		if err != nil {
   343  			t.Logf("unexpected error while reading: %v", err)
   344  			return false
   345  		}
   346  	}
   347  
   348  	foundH := &testHandler{}
   349  	var found bytes.Buffer
   350  
   351  	// getting found by using arbitrary reader
   352  	{
   353  		inputReader := &arbtiraryReader{
   354  			buf:         bytes.NewBufferString(input),
   355  			maxReadOnce: 10,
   356  		}
   357  		filter := NewReader(inputReader, '~', foundH.handler)
   358  		_, err := io.Copy(&found, filter)
   359  		if err != nil {
   360  			t.Logf("unexpected error while reading: %v", err)
   361  			return false
   362  		}
   363  	}
   364  
   365  	if nh.escaped() == foundH.escaped() && expected.String() == found.String() {
   366  		return true
   367  	}
   368  
   369  	t.Logf("escaped differed=%v expected=%v found=%v", nh.escaped() != foundH.escaped(), nh.escaped(), foundH.escaped())
   370  	t.Logf("read  differed=%v expected=%s found=%v", expected.String() != found.String(), expected.String(), found.String())
   371  	return false
   372  
   373  }
   374  
   375  // readingInput is a string with some quick generation capability to
   376  // inject some \n, \n~., \n~q in text
   377  type readingInput string
   378  
   379  func (i readingInput) Generate(rand *rand.Rand, size int) reflect.Value {
   380  	v, ok := quick.Value(reflect.TypeOf(""), rand)
   381  	if !ok {
   382  		panic("couldn't generate a string")
   383  	}
   384  
   385  	// inject some terminals
   386  	var b bytes.Buffer
   387  	injectProbabilistically := func() {
   388  		p := rand.Float32()
   389  		if p < 0.05 {
   390  			b.WriteString("\n~.")
   391  		} else if p < 0.10 {
   392  			b.WriteString("\n~q")
   393  		} else if p < 0.15 {
   394  			b.WriteString("\n")
   395  		} else if p < 0.2 {
   396  			b.WriteString("~")
   397  		} else if p < 0.25 {
   398  			b.WriteString("~~")
   399  		}
   400  	}
   401  
   402  	for _, c := range v.String() {
   403  		injectProbabilistically()
   404  		b.WriteRune(c)
   405  	}
   406  
   407  	injectProbabilistically()
   408  
   409  	return reflect.ValueOf(readingInput(b.String()))
   410  }
   411  
   412  // naiveEscapeCharacters is a simplified implementation that operates
   413  // on entire unchunked string.  Uses regexp implementation.
   414  //
   415  // It differs from the other implementation in handling unicode characters
   416  // proceeding `\n~`
   417  func naiveEscapeCharacters(input string, escapeChar byte, h Handler) string {
   418  	reg := regexp.MustCompile(fmt.Sprintf("(\n|\r)%c.", escapeChar))
   419  
   420  	// check first appearances
   421  	if len(input) > 1 && input[0] == escapeChar {
   422  		if input[1] == escapeChar {
   423  			input = input[1:]
   424  		} else if h(input[1]) {
   425  			input = input[2:]
   426  		} else {
   427  			// we are good
   428  		}
   429  
   430  	}
   431  
   432  	return reg.ReplaceAllStringFunc(input, func(match string) string {
   433  		// match can be more than three bytes because of unicode
   434  		if len(match) < 3 {
   435  			panic(fmt.Errorf("match is less than characters: %d %s", len(match), match))
   436  		}
   437  
   438  		c := match[2]
   439  
   440  		// ignore some unicode partial codes
   441  		ltr := len(match) > 3 ||
   442  			('a' <= c && c <= 'z') ||
   443  			('A' <= c && c <= 'Z') ||
   444  			('0' <= c && c <= '9') ||
   445  			(c == '~' || c == '.' || c == escapeChar)
   446  
   447  		if c == escapeChar {
   448  			return match[:2]
   449  		} else if ltr && h(c) {
   450  			return match[:1]
   451  		} else {
   452  			return match
   453  		}
   454  	})
   455  }
   456  
   457  // arbitraryReader is a reader that reads arbitrary length at a time
   458  // to simulate input being read in chunks.
   459  type arbtiraryReader struct {
   460  	buf         *bytes.Buffer
   461  	maxReadOnce int
   462  }
   463  
   464  func (r *arbtiraryReader) Read(buf []byte) (int, error) {
   465  	l := r.buf.Len()
   466  	if l == 0 || l == 1 {
   467  		return r.buf.Read(buf)
   468  	}
   469  
   470  	if l > r.maxReadOnce {
   471  		l = r.maxReadOnce
   472  	}
   473  	if l != 1 {
   474  		l = rand.Intn(l-1) + 1
   475  	}
   476  	if l > len(buf) {
   477  		l = len(buf)
   478  	}
   479  
   480  	return r.buf.Read(buf[:l])
   481  }
   482  
   483  func assertEventually(t *testing.T, testFn func() (bool, error)) {
   484  	start := time.Now()
   485  	var err error
   486  	var b bool
   487  	for {
   488  		if time.Since(start) > 2*time.Second {
   489  			assert.Fail(t, "timed out", "error: %v", err)
   490  		}
   491  
   492  		b, err = testFn()
   493  		if b {
   494  			return
   495  		}
   496  
   497  		time.Sleep(50 * time.Millisecond)
   498  	}
   499  }