github.com/banmanh482/nomad@v0.11.8/helper/escapingio/reader_test.go (about)

     1  package escapingio
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"math/rand"
     9  	"reflect"
    10  	"regexp"
    11  	"strings"
    12  	"sync"
    13  	"testing"
    14  	"testing/iotest"
    15  	"testing/quick"
    16  	"time"
    17  	"unicode"
    18  
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  func TestEscapingReader_Static(t *testing.T) {
    24  	cases := []struct {
    25  		input    string
    26  		expected string
    27  		escaped  string
    28  	}{
    29  		{"hello", "hello", ""},
    30  		{"he\nllo", "he\nllo", ""},
    31  		{"he~.lo", "he~.lo", ""},
    32  		{"he\n~.rest", "he\nrest", "."},
    33  		{"he\n~.r\n~.est", "he\nr\nest", ".."},
    34  		{"he\n~~r\n~~est", "he\n~r\n~est", ""},
    35  		{"he\n~~r\n~.est", "he\n~r\nest", "."},
    36  		{"he\nr~~est", "he\nr~~est", ""},
    37  		{"he\nr\n~qest", "he\nr\n~qest", "q"},
    38  		{"he\nr\r~qe\r~.st", "he\nr\r~qe\rst", "q."},
    39  		{"~q", "~q", "q"},
    40  		{"~.", "", "."},
    41  		{"m~.", "m~.", ""},
    42  		{"\n~.", "\n", "."},
    43  		{"~", "~", ""},
    44  		{"\r~.", "\r", "."},
    45  	}
    46  
    47  	for _, c := range cases {
    48  		t.Run("sanity check naive implementation", func(t *testing.T) {
    49  			h := &testHandler{}
    50  
    51  			processed := naiveEscapeCharacters(c.input, '~', h.handler)
    52  			require.Equal(t, c.expected, processed)
    53  			require.Equal(t, c.escaped, h.escaped())
    54  		})
    55  
    56  		t.Run("chunks at a time: "+c.input, func(t *testing.T) {
    57  			var found bytes.Buffer
    58  
    59  			input := strings.NewReader(c.input)
    60  
    61  			h := &testHandler{}
    62  
    63  			filter := NewReader(input, '~', h.handler)
    64  
    65  			_, err := io.Copy(&found, filter)
    66  			require.NoError(t, err)
    67  
    68  			require.Equal(t, c.expected, found.String())
    69  			require.Equal(t, c.escaped, h.escaped())
    70  		})
    71  
    72  		t.Run("1 byte at a time: "+c.input, func(t *testing.T) {
    73  			var found bytes.Buffer
    74  
    75  			input := iotest.OneByteReader(strings.NewReader(c.input))
    76  
    77  			h := &testHandler{}
    78  
    79  			filter := NewReader(input, '~', h.handler)
    80  			_, err := io.Copy(&found, filter)
    81  			require.NoError(t, err)
    82  
    83  			require.Equal(t, c.expected, found.String())
    84  			require.Equal(t, c.escaped, h.escaped())
    85  		})
    86  
    87  		t.Run("without reading: "+c.input, func(t *testing.T) {
    88  			input := strings.NewReader(c.input)
    89  
    90  			h := &testHandler{}
    91  
    92  			filter := NewReader(input, '~', h.handler)
    93  
    94  			// don't read to mimic a stalled reader
    95  			_ = filter
    96  
    97  			assertEventually(t, func() (bool, error) {
    98  				escaped := h.escaped()
    99  				if c.escaped == escaped {
   100  					return true, nil
   101  				}
   102  
   103  				return false, fmt.Errorf("expected %v but found %v", c.escaped, escaped)
   104  			})
   105  		})
   106  	}
   107  }
   108  
   109  // TestEscapingReader_EmitsPartialReads should emit partial results
   110  // if next character is not read
   111  func TestEscapingReader_FlushesPartialReads(t *testing.T) {
   112  	pr, pw := io.Pipe()
   113  
   114  	h := &testHandler{}
   115  	filter := NewReader(pr, '~', h.handler)
   116  
   117  	var lock sync.Mutex
   118  	var read bytes.Buffer
   119  
   120  	// helper for asserting reads
   121  	requireRead := func(expected *bytes.Buffer) {
   122  		readSoFar := ""
   123  
   124  		start := time.Now()
   125  		for time.Since(start) < 2*time.Second {
   126  			lock.Lock()
   127  			readSoFar = read.String()
   128  			lock.Unlock()
   129  
   130  			if readSoFar == expected.String() {
   131  				break
   132  			}
   133  
   134  			time.Sleep(50 * time.Millisecond)
   135  		}
   136  
   137  		require.Equal(t, expected.String(), readSoFar, "timed out without output")
   138  	}
   139  
   140  	var rerr error
   141  	var wg sync.WaitGroup
   142  	wg.Add(1)
   143  
   144  	// goroutine for reading partial data
   145  	go func() {
   146  		defer wg.Done()
   147  
   148  		buf := make([]byte, 1024)
   149  		for {
   150  			n, err := filter.Read(buf)
   151  			lock.Lock()
   152  			read.Write(buf[:n])
   153  			lock.Unlock()
   154  
   155  			if err != nil {
   156  				rerr = err
   157  				break
   158  			}
   159  		}
   160  	}()
   161  
   162  	expected := &bytes.Buffer{}
   163  
   164  	// test basic start and no new lines
   165  	pw.Write([]byte("first data"))
   166  	expected.WriteString("first data")
   167  	requireRead(expected)
   168  	require.Equal(t, "", h.escaped())
   169  
   170  	// test ~. appearing in middle of line but stop at new line
   171  	pw.Write([]byte("~.inmiddleappears\n"))
   172  	expected.WriteString("~.inmiddleappears\n")
   173  	requireRead(expected)
   174  	require.Equal(t, "", h.escaped())
   175  
   176  	// from here on we test \n~ at boundary
   177  
   178  	// ~~ after new line; and stop at \n~
   179  	pw.Write([]byte("~~second line\n~"))
   180  	expected.WriteString("~second line\n")
   181  	requireRead(expected)
   182  	require.Equal(t, "", h.escaped())
   183  
   184  	// . to be skipped; stop at \n~ again
   185  	pw.Write([]byte(".third line\n~"))
   186  	expected.WriteString("third line\n")
   187  	requireRead(expected)
   188  	require.Equal(t, ".", h.escaped())
   189  
   190  	// q to be emitted; stop at \n
   191  	pw.Write([]byte("qfourth line\n"))
   192  	expected.WriteString("~qfourth line\n")
   193  	requireRead(expected)
   194  	require.Equal(t, ".q", h.escaped())
   195  
   196  	// ~. to be skipped; stop at \n~
   197  	pw.Write([]byte("~.fifth line\n~"))
   198  	expected.WriteString("fifth line\n")
   199  	requireRead(expected)
   200  	require.Equal(t, ".q.", h.escaped())
   201  
   202  	// ~ alone after \n~ - should be emitted
   203  	pw.Write([]byte("~"))
   204  	expected.WriteString("~")
   205  	requireRead(expected)
   206  	require.Equal(t, ".q.", h.escaped())
   207  
   208  	// rest of line ending with \n~
   209  	pw.Write([]byte("rest of line\n~"))
   210  	expected.WriteString("rest of line\n")
   211  	requireRead(expected)
   212  	require.Equal(t, ".q.", h.escaped())
   213  
   214  	// m alone after \n~ - should be emitted with ~
   215  	pw.Write([]byte("m"))
   216  	expected.WriteString("~m")
   217  	requireRead(expected)
   218  	require.Equal(t, ".q.m", h.escaped())
   219  
   220  	// rest of line and end with \n
   221  	pw.Write([]byte("onemore line\n"))
   222  	expected.WriteString("onemore line\n")
   223  	requireRead(expected)
   224  	require.Equal(t, ".q.m", h.escaped())
   225  
   226  	// ~q to be emitted stop at \n~; last charcater
   227  	pw.Write([]byte("~qlast line\n~"))
   228  	expected.WriteString("~qlast line\n")
   229  	requireRead(expected)
   230  	require.Equal(t, ".q.mq", h.escaped())
   231  
   232  	// last ~ gets emitted and we preserve error
   233  	eerr := errors.New("my custom error")
   234  	pw.CloseWithError(eerr)
   235  	expected.WriteString("~")
   236  	requireRead(expected)
   237  	require.Equal(t, ".q.mq", h.escaped())
   238  
   239  	wg.Wait()
   240  	require.Error(t, rerr)
   241  	require.Equal(t, eerr, rerr)
   242  }
   243  
   244  func TestEscapingReader_Generated_EquivalentToNaive(t *testing.T) {
   245  	f := func(v readingInput) bool {
   246  		return checkEquivalenceToNaive(t, string(v))
   247  	}
   248  
   249  	require.NoError(t, quick.Check(f, &quick.Config{
   250  		MaxCountScale: 200,
   251  	}))
   252  }
   253  
   254  // testHandler is a conveneient struct for finding "escaped" ascii letters
   255  // in escaping reader.
   256  // We avoid complicated unicode characters that may cross byte boundary
   257  type testHandler struct {
   258  	l      sync.Mutex
   259  	result string
   260  }
   261  
   262  // handler is method to be passed to escaping io reader
   263  func (t *testHandler) handler(c byte) bool {
   264  	rc := rune(c)
   265  	simple := unicode.IsLetter(rc) ||
   266  		unicode.IsDigit(rc) ||
   267  		unicode.IsPunct(rc) ||
   268  		unicode.IsSymbol(rc)
   269  
   270  	if simple {
   271  		t.l.Lock()
   272  		t.result += string([]byte{c})
   273  		t.l.Unlock()
   274  	}
   275  	return c == '.'
   276  }
   277  
   278  // escaped returns all seen escaped characters so far
   279  func (t *testHandler) escaped() string {
   280  	t.l.Lock()
   281  	defer t.l.Unlock()
   282  
   283  	return t.result
   284  }
   285  
   286  // checkEquivalence returns true if parsing input with naive implementation
   287  // is equivalent to our reader
   288  func checkEquivalenceToNaive(t *testing.T, input string) bool {
   289  	nh := &testHandler{}
   290  	expected := naiveEscapeCharacters(input, '~', nh.handler)
   291  
   292  	foundH := &testHandler{}
   293  
   294  	var inputReader io.Reader = bytes.NewBufferString(input)
   295  	inputReader = &arbtiraryReader{
   296  		buf:         inputReader.(*bytes.Buffer),
   297  		maxReadOnce: 10,
   298  	}
   299  	filter := NewReader(inputReader, '~', foundH.handler)
   300  	var found bytes.Buffer
   301  	_, err := io.Copy(&found, filter)
   302  	if err != nil {
   303  		t.Logf("unexpected error while reading: %v", err)
   304  		return false
   305  	}
   306  
   307  	if nh.escaped() == foundH.escaped() && expected == found.String() {
   308  		return true
   309  	}
   310  
   311  	t.Logf("escaped differed=%v expected=%v found=%v", nh.escaped() != foundH.escaped(), nh.escaped(), foundH.escaped())
   312  	t.Logf("read  differed=%v expected=%s found=%v", expected != found.String(), expected, found.String())
   313  	return false
   314  
   315  }
   316  
   317  func TestEscapingReader_Generated_EquivalentToReadOnce(t *testing.T) {
   318  	f := func(v readingInput) bool {
   319  		return checkEquivalenceToNaive(t, string(v))
   320  	}
   321  
   322  	require.NoError(t, quick.Check(f, &quick.Config{
   323  		MaxCountScale: 200,
   324  	}))
   325  }
   326  
   327  // checkEquivalenceToReadOnce returns true if parsing input in a single
   328  // read matches multiple reads
   329  func checkEquivalenceToReadOnce(t *testing.T, input string) bool {
   330  	nh := &testHandler{}
   331  	var expected bytes.Buffer
   332  
   333  	// getting expected value from read all at once
   334  	{
   335  		buf := make([]byte, len(input)+5)
   336  		inputReader := NewReader(bytes.NewBufferString(input), '~', nh.handler)
   337  		_, err := io.CopyBuffer(&expected, inputReader, buf)
   338  		if err != nil {
   339  			t.Logf("unexpected error while reading: %v", err)
   340  			return false
   341  		}
   342  	}
   343  
   344  	foundH := &testHandler{}
   345  	var found bytes.Buffer
   346  
   347  	// getting found by using arbitrary reader
   348  	{
   349  		inputReader := &arbtiraryReader{
   350  			buf:         bytes.NewBufferString(input),
   351  			maxReadOnce: 10,
   352  		}
   353  		filter := NewReader(inputReader, '~', foundH.handler)
   354  		_, err := io.Copy(&found, filter)
   355  		if err != nil {
   356  			t.Logf("unexpected error while reading: %v", err)
   357  			return false
   358  		}
   359  	}
   360  
   361  	if nh.escaped() == foundH.escaped() && expected.String() == found.String() {
   362  		return true
   363  	}
   364  
   365  	t.Logf("escaped differed=%v expected=%v found=%v", nh.escaped() != foundH.escaped(), nh.escaped(), foundH.escaped())
   366  	t.Logf("read  differed=%v expected=%s found=%v", expected.String() != found.String(), expected.String(), found.String())
   367  	return false
   368  
   369  }
   370  
   371  // readingInput is a string with some quick generation capability to
   372  // inject some \n, \n~., \n~q in text
   373  type readingInput string
   374  
   375  func (i readingInput) Generate(rand *rand.Rand, size int) reflect.Value {
   376  	v, ok := quick.Value(reflect.TypeOf(""), rand)
   377  	if !ok {
   378  		panic("couldn't generate a string")
   379  	}
   380  
   381  	// inject some terminals
   382  	var b bytes.Buffer
   383  	injectProbabilistically := func() {
   384  		p := rand.Float32()
   385  		if p < 0.05 {
   386  			b.WriteString("\n~.")
   387  		} else if p < 0.10 {
   388  			b.WriteString("\n~q")
   389  		} else if p < 0.15 {
   390  			b.WriteString("\n")
   391  		} else if p < 0.2 {
   392  			b.WriteString("~")
   393  		} else if p < 0.25 {
   394  			b.WriteString("~~")
   395  		}
   396  	}
   397  
   398  	for _, c := range v.String() {
   399  		injectProbabilistically()
   400  		b.WriteRune(c)
   401  	}
   402  
   403  	injectProbabilistically()
   404  
   405  	return reflect.ValueOf(readingInput(b.String()))
   406  }
   407  
   408  // naiveEscapeCharacters is a simplified implementation that operates
   409  // on entire unchunked string.  Uses regexp implementation.
   410  //
   411  // It differs from the other implementation in handling unicode characters
   412  // proceeding `\n~`
   413  func naiveEscapeCharacters(input string, escapeChar byte, h Handler) string {
   414  	reg := regexp.MustCompile(fmt.Sprintf("(\n|\r)%c.", escapeChar))
   415  
   416  	// check first appearances
   417  	if len(input) > 1 && input[0] == escapeChar {
   418  		if input[1] == escapeChar {
   419  			input = input[1:]
   420  		} else if h(input[1]) {
   421  			input = input[2:]
   422  		} else {
   423  			// we are good
   424  		}
   425  
   426  	}
   427  
   428  	return reg.ReplaceAllStringFunc(input, func(match string) string {
   429  		// match can be more than three bytes because of unicode
   430  		if len(match) < 3 {
   431  			panic(fmt.Errorf("match is less than characters: %d %s", len(match), match))
   432  		}
   433  
   434  		c := match[2]
   435  
   436  		// ignore some unicode partial codes
   437  		ltr := len(match) > 3 ||
   438  			('a' <= c && c <= 'z') ||
   439  			('A' <= c && c <= 'Z') ||
   440  			('0' <= c && c <= '9') ||
   441  			(c == '~' || c == '.' || c == escapeChar)
   442  
   443  		if c == escapeChar {
   444  			return match[:2]
   445  		} else if ltr && h(c) {
   446  			return match[:1]
   447  		} else {
   448  			return match
   449  		}
   450  	})
   451  }
   452  
   453  // arbitraryReader is a reader that reads arbitrary length at a time
   454  // to simulate input being read in chunks.
   455  type arbtiraryReader struct {
   456  	buf         *bytes.Buffer
   457  	maxReadOnce int
   458  }
   459  
   460  func (r *arbtiraryReader) Read(buf []byte) (int, error) {
   461  	l := r.buf.Len()
   462  	if l == 0 || l == 1 {
   463  		return r.buf.Read(buf)
   464  	}
   465  
   466  	if l > r.maxReadOnce {
   467  		l = r.maxReadOnce
   468  	}
   469  	if l != 1 {
   470  		l = rand.Intn(l-1) + 1
   471  	}
   472  	if l > len(buf) {
   473  		l = len(buf)
   474  	}
   475  
   476  	return r.buf.Read(buf[:l])
   477  }
   478  
   479  func assertEventually(t *testing.T, testFn func() (bool, error)) {
   480  	start := time.Now()
   481  	var err error
   482  	var b bool
   483  	for {
   484  		if time.Since(start) > 2*time.Second {
   485  			assert.Fail(t, "timed out", "error: %v", err)
   486  		}
   487  
   488  		b, err = testFn()
   489  		if b {
   490  			return
   491  		}
   492  
   493  		time.Sleep(50 * time.Millisecond)
   494  	}
   495  }