github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/helper/escapingio/reader_test.go (about)

     1  package escapingio
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"math/rand"
     9  	"reflect"
    10  	"regexp"
    11  	"strings"
    12  	"sync"
    13  	"testing"
    14  	"testing/iotest"
    15  	"testing/quick"
    16  	"time"
    17  	"unicode"
    18  
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  func TestEscapingReader_Static(t *testing.T) {
    24  	cases := []struct {
    25  		input    string
    26  		expected string
    27  		escaped  string
    28  	}{
    29  		{"hello", "hello", ""},
    30  		{"he\nllo", "he\nllo", ""},
    31  		{"he~.lo", "he~.lo", ""},
    32  		{"he\n~.rest", "he\nrest", "."},
    33  		{"he\n~.r\n~.est", "he\nr\nest", ".."},
    34  		{"he\n~~r\n~~est", "he\n~r\n~est", ""},
    35  		{"he\n~~r\n~.est", "he\n~r\nest", "."},
    36  		{"he\nr~~est", "he\nr~~est", ""},
    37  		{"he\nr\n~qest", "he\nr\n~qest", "q"},
    38  		{"he\nr\r~qe\r~.st", "he\nr\r~qe\rst", "q."},
    39  		{"~q", "~q", "q"},
    40  		{"~.", "", "."},
    41  		{"m~.", "m~.", ""},
    42  		{"\n~.", "\n", "."},
    43  		{"~", "~", ""},
    44  		{"\r~.", "\r", "."},
    45  		{"b\n~\n~.q", "b\n~\nq", "."},
    46  	}
    47  
    48  	for _, c := range cases {
    49  		t.Run("validate naive implementation", func(t *testing.T) {
    50  			h := &testHandler{}
    51  
    52  			processed := naiveEscapeCharacters(c.input, '~', h.handler)
    53  			require.Equal(t, c.expected, processed)
    54  			require.Equal(t, c.escaped, h.escaped())
    55  		})
    56  
    57  		t.Run("chunks at a time: "+c.input, func(t *testing.T) {
    58  			var found bytes.Buffer
    59  
    60  			input := strings.NewReader(c.input)
    61  
    62  			h := &testHandler{}
    63  
    64  			filter := NewReader(input, '~', h.handler)
    65  
    66  			_, err := io.Copy(&found, filter)
    67  			require.NoError(t, err)
    68  
    69  			require.Equal(t, c.expected, found.String())
    70  			require.Equal(t, c.escaped, h.escaped())
    71  		})
    72  
    73  		t.Run("1 byte at a time: "+c.input, func(t *testing.T) {
    74  			var found bytes.Buffer
    75  
    76  			input := iotest.OneByteReader(strings.NewReader(c.input))
    77  
    78  			h := &testHandler{}
    79  
    80  			filter := NewReader(input, '~', h.handler)
    81  			_, err := io.Copy(&found, filter)
    82  			require.NoError(t, err)
    83  
    84  			require.Equal(t, c.expected, found.String())
    85  			require.Equal(t, c.escaped, h.escaped())
    86  		})
    87  
    88  		t.Run("without reading: "+c.input, func(t *testing.T) {
    89  			input := strings.NewReader(c.input)
    90  
    91  			h := &testHandler{}
    92  
    93  			filter := NewReader(input, '~', h.handler)
    94  
    95  			// don't read to mimic a stalled reader
    96  			_ = filter
    97  
    98  			assertEventually(t, func() (bool, error) {
    99  				escaped := h.escaped()
   100  				if c.escaped == escaped {
   101  					return true, nil
   102  				}
   103  
   104  				return false, fmt.Errorf("expected %v but found %v", c.escaped, escaped)
   105  			})
   106  		})
   107  	}
   108  }
   109  
   110  // TestEscapingReader_EmitsPartialReads should emit partial results
   111  // if next character is not read
   112  func TestEscapingReader_FlushesPartialReads(t *testing.T) {
   113  	pr, pw := io.Pipe()
   114  
   115  	h := &testHandler{}
   116  	filter := NewReader(pr, '~', h.handler)
   117  
   118  	var lock sync.Mutex
   119  	var read bytes.Buffer
   120  
   121  	// helper for asserting reads
   122  	requireRead := func(expected *bytes.Buffer) {
   123  		readSoFar := ""
   124  
   125  		start := time.Now()
   126  		for time.Since(start) < 2*time.Second {
   127  			lock.Lock()
   128  			readSoFar = read.String()
   129  			lock.Unlock()
   130  
   131  			if readSoFar == expected.String() {
   132  				break
   133  			}
   134  
   135  			time.Sleep(50 * time.Millisecond)
   136  		}
   137  
   138  		require.Equal(t, expected.String(), readSoFar, "timed out without output")
   139  	}
   140  
   141  	var rerr error
   142  	var wg sync.WaitGroup
   143  	wg.Add(1)
   144  
   145  	// goroutine for reading partial data
   146  	go func() {
   147  		defer wg.Done()
   148  
   149  		buf := make([]byte, 1024)
   150  		for {
   151  			n, err := filter.Read(buf)
   152  			lock.Lock()
   153  			read.Write(buf[:n])
   154  			lock.Unlock()
   155  
   156  			if err != nil {
   157  				rerr = err
   158  				break
   159  			}
   160  		}
   161  	}()
   162  
   163  	expected := &bytes.Buffer{}
   164  
   165  	// test basic start and no new lines
   166  	pw.Write([]byte("first data"))
   167  	expected.WriteString("first data")
   168  	requireRead(expected)
   169  	require.Equal(t, "", h.escaped())
   170  
   171  	// test ~. appearing in middle of line but stop at new line
   172  	pw.Write([]byte("~.inmiddleappears\n"))
   173  	expected.WriteString("~.inmiddleappears\n")
   174  	requireRead(expected)
   175  	require.Equal(t, "", h.escaped())
   176  
   177  	// from here on we test \n~ at boundary
   178  
   179  	// ~~ after new line; and stop at \n~
   180  	pw.Write([]byte("~~second line\n~"))
   181  	expected.WriteString("~second line\n")
   182  	requireRead(expected)
   183  	require.Equal(t, "", h.escaped())
   184  
   185  	// . to be skipped; stop at \n~ again
   186  	pw.Write([]byte(".third line\n~"))
   187  	expected.WriteString("third line\n")
   188  	requireRead(expected)
   189  	require.Equal(t, ".", h.escaped())
   190  
   191  	// q to be emitted; stop at \n
   192  	pw.Write([]byte("qfourth line\n"))
   193  	expected.WriteString("~qfourth line\n")
   194  	requireRead(expected)
   195  	require.Equal(t, ".q", h.escaped())
   196  
   197  	// ~. to be skipped; stop at \n~
   198  	pw.Write([]byte("~.fifth line\n~"))
   199  	expected.WriteString("fifth line\n")
   200  	requireRead(expected)
   201  	require.Equal(t, ".q.", h.escaped())
   202  
   203  	// ~ alone after \n~ - should be emitted
   204  	pw.Write([]byte("~"))
   205  	expected.WriteString("~")
   206  	requireRead(expected)
   207  	require.Equal(t, ".q.", h.escaped())
   208  
   209  	// rest of line ending with \n~
   210  	pw.Write([]byte("rest of line\n~"))
   211  	expected.WriteString("rest of line\n")
   212  	requireRead(expected)
   213  	require.Equal(t, ".q.", h.escaped())
   214  
   215  	// m alone after \n~ - should be emitted with ~
   216  	pw.Write([]byte("m"))
   217  	expected.WriteString("~m")
   218  	requireRead(expected)
   219  	require.Equal(t, ".q.m", h.escaped())
   220  
   221  	// rest of line and end with \n
   222  	pw.Write([]byte("onemore line\n"))
   223  	expected.WriteString("onemore line\n")
   224  	requireRead(expected)
   225  	require.Equal(t, ".q.m", h.escaped())
   226  
   227  	// ~q to be emitted stop at \n~; last charcater
   228  	pw.Write([]byte("~qlast line\n~"))
   229  	expected.WriteString("~qlast line\n")
   230  	requireRead(expected)
   231  	require.Equal(t, ".q.mq", h.escaped())
   232  
   233  	// last ~ gets emitted and we preserve error
   234  	eerr := errors.New("my custom error")
   235  	pw.CloseWithError(eerr)
   236  	expected.WriteString("~")
   237  	requireRead(expected)
   238  	require.Equal(t, ".q.mq", h.escaped())
   239  
   240  	wg.Wait()
   241  	require.Error(t, rerr)
   242  	require.Equal(t, eerr, rerr)
   243  }
   244  
   245  func TestEscapingReader_Generated_EquivalentToNaive(t *testing.T) {
   246  	f := func(v readingInput) bool {
   247  		return checkEquivalenceToNaive(t, string(v))
   248  	}
   249  
   250  	require.NoError(t, quick.Check(f, &quick.Config{
   251  		MaxCountScale: 200,
   252  	}))
   253  }
   254  
   255  // testHandler is a conveneient struct for finding "escaped" ascii letters
   256  // in escaping reader.
   257  // We avoid complicated unicode characters that may cross byte boundary
   258  type testHandler struct {
   259  	l      sync.Mutex
   260  	result string
   261  }
   262  
   263  // handler is method to be passed to escaping io reader
   264  func (t *testHandler) handler(c byte) bool {
   265  	rc := rune(c)
   266  	simple := unicode.IsLetter(rc) ||
   267  		unicode.IsDigit(rc) ||
   268  		unicode.IsPunct(rc) ||
   269  		unicode.IsSymbol(rc)
   270  
   271  	if simple {
   272  		t.l.Lock()
   273  		t.result += string([]byte{c})
   274  		t.l.Unlock()
   275  	}
   276  	return c == '.'
   277  }
   278  
   279  // escaped returns all seen escaped characters so far
   280  func (t *testHandler) escaped() string {
   281  	t.l.Lock()
   282  	defer t.l.Unlock()
   283  
   284  	return t.result
   285  }
   286  
   287  // checkEquivalence returns true if parsing input with naive implementation
   288  // is equivalent to our reader
   289  func checkEquivalenceToNaive(t *testing.T, input string) bool {
   290  	nh := &testHandler{}
   291  	expected := naiveEscapeCharacters(input, '~', nh.handler)
   292  
   293  	foundH := &testHandler{}
   294  
   295  	var inputReader io.Reader = bytes.NewBufferString(input)
   296  	inputReader = &arbtiraryReader{
   297  		buf:         inputReader.(*bytes.Buffer),
   298  		maxReadOnce: 10,
   299  	}
   300  	filter := NewReader(inputReader, '~', foundH.handler)
   301  	var found bytes.Buffer
   302  	_, err := io.Copy(&found, filter)
   303  	if err != nil {
   304  		t.Logf("unexpected error while reading: %v", err)
   305  		return false
   306  	}
   307  
   308  	if nh.escaped() == foundH.escaped() && expected == found.String() {
   309  		return true
   310  	}
   311  
   312  	t.Logf("escaped differed=%v expected=%v found=%v", nh.escaped() != foundH.escaped(), nh.escaped(), foundH.escaped())
   313  	t.Logf("read  differed=%v expected=%s found=%v", expected != found.String(), expected, found.String())
   314  	return false
   315  
   316  }
   317  
   318  func TestEscapingReader_Generated_EquivalentToReadOnce(t *testing.T) {
   319  	f := func(v readingInput) bool {
   320  		return checkEquivalenceToNaive(t, string(v))
   321  	}
   322  
   323  	require.NoError(t, quick.Check(f, &quick.Config{
   324  		MaxCountScale: 200,
   325  	}))
   326  }
   327  
   328  // checkEquivalenceToReadOnce returns true if parsing input in a single
   329  // read matches multiple reads
   330  func checkEquivalenceToReadOnce(t *testing.T, input string) bool {
   331  	nh := &testHandler{}
   332  	var expected bytes.Buffer
   333  
   334  	// getting expected value from read all at once
   335  	{
   336  		buf := make([]byte, len(input)+5)
   337  		inputReader := NewReader(bytes.NewBufferString(input), '~', nh.handler)
   338  		_, err := io.CopyBuffer(&expected, inputReader, buf)
   339  		if err != nil {
   340  			t.Logf("unexpected error while reading: %v", err)
   341  			return false
   342  		}
   343  	}
   344  
   345  	foundH := &testHandler{}
   346  	var found bytes.Buffer
   347  
   348  	// getting found by using arbitrary reader
   349  	{
   350  		inputReader := &arbtiraryReader{
   351  			buf:         bytes.NewBufferString(input),
   352  			maxReadOnce: 10,
   353  		}
   354  		filter := NewReader(inputReader, '~', foundH.handler)
   355  		_, err := io.Copy(&found, filter)
   356  		if err != nil {
   357  			t.Logf("unexpected error while reading: %v", err)
   358  			return false
   359  		}
   360  	}
   361  
   362  	if nh.escaped() == foundH.escaped() && expected.String() == found.String() {
   363  		return true
   364  	}
   365  
   366  	t.Logf("escaped differed=%v expected=%v found=%v", nh.escaped() != foundH.escaped(), nh.escaped(), foundH.escaped())
   367  	t.Logf("read  differed=%v expected=%s found=%v", expected.String() != found.String(), expected.String(), found.String())
   368  	return false
   369  
   370  }
   371  
   372  // readingInput is a string with some quick generation capability to
   373  // inject some \n, \n~., \n~q in text
   374  type readingInput string
   375  
   376  func (i readingInput) Generate(rand *rand.Rand, size int) reflect.Value {
   377  	v, ok := quick.Value(reflect.TypeOf(""), rand)
   378  	if !ok {
   379  		panic("couldn't generate a string")
   380  	}
   381  
   382  	// inject some terminals
   383  	var b bytes.Buffer
   384  	injectProbabilistically := func() {
   385  		p := rand.Float32()
   386  		if p < 0.05 {
   387  			b.WriteString("\n~.")
   388  		} else if p < 0.10 {
   389  			b.WriteString("\n~q")
   390  		} else if p < 0.15 {
   391  			b.WriteString("\n")
   392  		} else if p < 0.2 {
   393  			b.WriteString("~")
   394  		} else if p < 0.25 {
   395  			b.WriteString("~~")
   396  		}
   397  	}
   398  
   399  	for _, c := range v.String() {
   400  		injectProbabilistically()
   401  		b.WriteRune(c)
   402  	}
   403  
   404  	injectProbabilistically()
   405  
   406  	return reflect.ValueOf(readingInput(b.String()))
   407  }
   408  
   409  // naiveEscapeCharacters is a simplified implementation that operates
   410  // on entire unchunked string.  Uses regexp implementation.
   411  //
   412  // It differs from the other implementation in handling unicode characters
   413  // proceeding `\n~`
   414  func naiveEscapeCharacters(input string, escapeChar byte, h Handler) string {
   415  	reg := regexp.MustCompile(fmt.Sprintf("(\n|\r)%c.", escapeChar))
   416  
   417  	// check first appearances
   418  	if len(input) > 1 && input[0] == escapeChar {
   419  		if input[1] == escapeChar {
   420  			input = input[1:]
   421  		} else if h(input[1]) {
   422  			input = input[2:]
   423  		} else {
   424  			// we are good
   425  		}
   426  
   427  	}
   428  
   429  	return reg.ReplaceAllStringFunc(input, func(match string) string {
   430  		// match can be more than three bytes because of unicode
   431  		if len(match) < 3 {
   432  			panic(fmt.Errorf("match is less than characters: %d %s", len(match), match))
   433  		}
   434  
   435  		c := match[2]
   436  
   437  		// ignore some unicode partial codes
   438  		ltr := len(match) > 3 ||
   439  			('a' <= c && c <= 'z') ||
   440  			('A' <= c && c <= 'Z') ||
   441  			('0' <= c && c <= '9') ||
   442  			(c == '~' || c == '.' || c == escapeChar)
   443  
   444  		if c == escapeChar {
   445  			return match[:2]
   446  		} else if ltr && h(c) {
   447  			return match[:1]
   448  		} else {
   449  			return match
   450  		}
   451  	})
   452  }
   453  
   454  // arbitraryReader is a reader that reads arbitrary length at a time
   455  // to simulate input being read in chunks.
   456  type arbtiraryReader struct {
   457  	buf         *bytes.Buffer
   458  	maxReadOnce int
   459  }
   460  
   461  func (r *arbtiraryReader) Read(buf []byte) (int, error) {
   462  	l := r.buf.Len()
   463  	if l == 0 || l == 1 {
   464  		return r.buf.Read(buf)
   465  	}
   466  
   467  	if l > r.maxReadOnce {
   468  		l = r.maxReadOnce
   469  	}
   470  	if l != 1 {
   471  		l = rand.Intn(l-1) + 1
   472  	}
   473  	if l > len(buf) {
   474  		l = len(buf)
   475  	}
   476  
   477  	return r.buf.Read(buf[:l])
   478  }
   479  
   480  func assertEventually(t *testing.T, testFn func() (bool, error)) {
   481  	start := time.Now()
   482  	var err error
   483  	var b bool
   484  	for {
   485  		if time.Since(start) > 2*time.Second {
   486  			assert.Fail(t, "timed out", "error: %v", err)
   487  		}
   488  
   489  		b, err = testFn()
   490  		if b {
   491  			return
   492  		}
   493  
   494  		time.Sleep(50 * time.Millisecond)
   495  	}
   496  }