github.com/deanMdreon/kafka-go@v0.4.32/protocol/prototest/reflect.go (about) 1 package prototest 2 3 import ( 4 "errors" 5 "io" 6 "reflect" 7 "time" 8 9 "github.com/deanMdreon/kafka-go/protocol" 10 ) 11 12 var ( 13 recordReader = reflect.TypeOf((*protocol.RecordReader)(nil)).Elem() 14 ) 15 16 func closeMessage(m protocol.Message) { 17 forEachField(reflect.ValueOf(m), func(v reflect.Value) { 18 if v.Type().Implements(recordReader) { 19 rr := v.Interface().(protocol.RecordReader) 20 for { 21 r, err := rr.ReadRecord() 22 if err != nil { 23 break 24 } 25 if r.Key != nil { 26 r.Key.Close() 27 } 28 if r.Value != nil { 29 r.Value.Close() 30 } 31 } 32 } 33 }) 34 } 35 36 func load(v interface{}) (reset func()) { 37 return loadValue(reflect.ValueOf(v)) 38 } 39 40 func loadValue(v reflect.Value) (reset func()) { 41 resets := []func(){} 42 43 forEachField(v, func(f reflect.Value) { 44 switch x := f.Interface().(type) { 45 case protocol.RecordReader: 46 records := loadRecords(x) 47 resetFunc := func() { 48 f.Set(reflect.ValueOf(protocol.NewRecordReader(makeRecords(records)...))) 49 } 50 resetFunc() 51 resets = append(resets, resetFunc) 52 } 53 }) 54 55 return func() { 56 for _, f := range resets { 57 f() 58 } 59 } 60 } 61 62 func forEachField(v reflect.Value, do func(reflect.Value)) { 63 for v.Kind() == reflect.Ptr { 64 if v.IsNil() { 65 return 66 } 67 v = v.Elem() 68 } 69 70 switch v.Kind() { 71 case reflect.Slice: 72 for i, n := 0, v.Len(); i < n; i++ { 73 forEachField(v.Index(i), do) 74 } 75 76 case reflect.Struct: 77 for i, n := 0, v.NumField(); i < n; i++ { 78 forEachField(v.Field(i), do) 79 } 80 81 default: 82 do(v) 83 } 84 } 85 86 type memoryRecord struct { 87 offset int64 88 time time.Time 89 key []byte 90 value []byte 91 headers []protocol.Header 92 } 93 94 func (m *memoryRecord) Record() protocol.Record { 95 return protocol.Record{ 96 Offset: m.offset, 97 Time: m.time, 98 Key: protocol.NewBytes(m.key), 99 Value: protocol.NewBytes(m.value), 100 Headers: m.headers, 101 } 102 } 103 104 func makeRecords(memoryRecords []memoryRecord) []protocol.Record { 105 records := make([]protocol.Record, len(memoryRecords)) 106 for i, m := range memoryRecords { 107 records[i] = m.Record() 108 } 109 return records 110 } 111 112 func loadRecords(r protocol.RecordReader) []memoryRecord { 113 records := []memoryRecord{} 114 115 for { 116 rec, err := r.ReadRecord() 117 if err != nil { 118 if errors.Is(err, io.EOF) { 119 return records 120 } 121 panic(err) 122 } 123 records = append(records, memoryRecord{ 124 offset: rec.Offset, 125 time: rec.Time, 126 key: readAll(rec.Key), 127 value: readAll(rec.Value), 128 headers: rec.Headers, 129 }) 130 } 131 } 132 133 func readAll(bytes protocol.Bytes) []byte { 134 if bytes != nil { 135 defer bytes.Close() 136 } 137 b, err := protocol.ReadAll(bytes) 138 if err != nil { 139 panic(err) 140 } 141 return b 142 }