github.com/deanMdreon/kafka-go@v0.4.32/protocol/prototest/reflect.go (about)

     1  package prototest
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"reflect"
     7  	"time"
     8  
     9  	"github.com/deanMdreon/kafka-go/protocol"
    10  )
    11  
    12  var (
    13  	recordReader = reflect.TypeOf((*protocol.RecordReader)(nil)).Elem()
    14  )
    15  
    16  func closeMessage(m protocol.Message) {
    17  	forEachField(reflect.ValueOf(m), func(v reflect.Value) {
    18  		if v.Type().Implements(recordReader) {
    19  			rr := v.Interface().(protocol.RecordReader)
    20  			for {
    21  				r, err := rr.ReadRecord()
    22  				if err != nil {
    23  					break
    24  				}
    25  				if r.Key != nil {
    26  					r.Key.Close()
    27  				}
    28  				if r.Value != nil {
    29  					r.Value.Close()
    30  				}
    31  			}
    32  		}
    33  	})
    34  }
    35  
    36  func load(v interface{}) (reset func()) {
    37  	return loadValue(reflect.ValueOf(v))
    38  }
    39  
    40  func loadValue(v reflect.Value) (reset func()) {
    41  	resets := []func(){}
    42  
    43  	forEachField(v, func(f reflect.Value) {
    44  		switch x := f.Interface().(type) {
    45  		case protocol.RecordReader:
    46  			records := loadRecords(x)
    47  			resetFunc := func() {
    48  				f.Set(reflect.ValueOf(protocol.NewRecordReader(makeRecords(records)...)))
    49  			}
    50  			resetFunc()
    51  			resets = append(resets, resetFunc)
    52  		}
    53  	})
    54  
    55  	return func() {
    56  		for _, f := range resets {
    57  			f()
    58  		}
    59  	}
    60  }
    61  
    62  func forEachField(v reflect.Value, do func(reflect.Value)) {
    63  	for v.Kind() == reflect.Ptr {
    64  		if v.IsNil() {
    65  			return
    66  		}
    67  		v = v.Elem()
    68  	}
    69  
    70  	switch v.Kind() {
    71  	case reflect.Slice:
    72  		for i, n := 0, v.Len(); i < n; i++ {
    73  			forEachField(v.Index(i), do)
    74  		}
    75  
    76  	case reflect.Struct:
    77  		for i, n := 0, v.NumField(); i < n; i++ {
    78  			forEachField(v.Field(i), do)
    79  		}
    80  
    81  	default:
    82  		do(v)
    83  	}
    84  }
    85  
    86  type memoryRecord struct {
    87  	offset  int64
    88  	time    time.Time
    89  	key     []byte
    90  	value   []byte
    91  	headers []protocol.Header
    92  }
    93  
    94  func (m *memoryRecord) Record() protocol.Record {
    95  	return protocol.Record{
    96  		Offset:  m.offset,
    97  		Time:    m.time,
    98  		Key:     protocol.NewBytes(m.key),
    99  		Value:   protocol.NewBytes(m.value),
   100  		Headers: m.headers,
   101  	}
   102  }
   103  
   104  func makeRecords(memoryRecords []memoryRecord) []protocol.Record {
   105  	records := make([]protocol.Record, len(memoryRecords))
   106  	for i, m := range memoryRecords {
   107  		records[i] = m.Record()
   108  	}
   109  	return records
   110  }
   111  
   112  func loadRecords(r protocol.RecordReader) []memoryRecord {
   113  	records := []memoryRecord{}
   114  
   115  	for {
   116  		rec, err := r.ReadRecord()
   117  		if err != nil {
   118  			if errors.Is(err, io.EOF) {
   119  				return records
   120  			}
   121  			panic(err)
   122  		}
   123  		records = append(records, memoryRecord{
   124  			offset:  rec.Offset,
   125  			time:    rec.Time,
   126  			key:     readAll(rec.Key),
   127  			value:   readAll(rec.Value),
   128  			headers: rec.Headers,
   129  		})
   130  	}
   131  }
   132  
   133  func readAll(bytes protocol.Bytes) []byte {
   134  	if bytes != nil {
   135  		defer bytes.Close()
   136  	}
   137  	b, err := protocol.ReadAll(bytes)
   138  	if err != nil {
   139  		panic(err)
   140  	}
   141  	return b
   142  }