github.com/segmentio/kafka-go@v0.4.48-0.20240318174348-3f6244eb34fd/protocol/prototest/reflect.go (about) 1 package prototest 2 3 import ( 4 "bytes" 5 "errors" 6 "io" 7 "reflect" 8 "time" 9 10 "github.com/segmentio/kafka-go/protocol" 11 ) 12 13 var ( 14 recordReader = reflect.TypeOf((*protocol.RecordReader)(nil)).Elem() 15 ) 16 17 func closeMessage(m protocol.Message) { 18 forEachField(reflect.ValueOf(m), func(v reflect.Value) { 19 if v.Type().Implements(recordReader) { 20 rr := v.Interface().(protocol.RecordReader) 21 for { 22 r, err := rr.ReadRecord() 23 if err != nil { 24 break 25 } 26 if r.Key != nil { 27 r.Key.Close() 28 } 29 if r.Value != nil { 30 r.Value.Close() 31 } 32 } 33 } 34 }) 35 } 36 37 func load(v interface{}) (reset func()) { 38 return loadValue(reflect.ValueOf(v)) 39 } 40 41 func loadValue(v reflect.Value) (reset func()) { 42 resets := []func(){} 43 44 forEachField(v, func(f reflect.Value) { 45 switch x := f.Interface().(type) { 46 case protocol.RecordReader: 47 records := loadRecords(x) 48 resetFunc := func() { 49 f.Set(reflect.ValueOf(protocol.NewRecordReader(makeRecords(records)...))) 50 } 51 resetFunc() 52 resets = append(resets, resetFunc) 53 case io.Reader: 54 buf, _ := io.ReadAll(x) 55 resetFunc := func() { 56 f.Set(reflect.ValueOf(bytes.NewBuffer(buf))) 57 } 58 resetFunc() 59 resets = append(resets, resetFunc) 60 } 61 }) 62 63 return func() { 64 for _, f := range resets { 65 f() 66 } 67 } 68 } 69 70 func forEachField(v reflect.Value, do func(reflect.Value)) { 71 for v.Kind() == reflect.Ptr { 72 if v.IsNil() { 73 return 74 } 75 v = v.Elem() 76 } 77 78 switch v.Kind() { 79 case reflect.Slice: 80 for i, n := 0, v.Len(); i < n; i++ { 81 forEachField(v.Index(i), do) 82 } 83 84 case reflect.Struct: 85 for i, n := 0, v.NumField(); i < n; i++ { 86 forEachField(v.Field(i), do) 87 } 88 89 default: 90 do(v) 91 } 92 } 93 94 type memoryRecord struct { 95 offset int64 96 time time.Time 97 key []byte 98 value []byte 99 headers []protocol.Header 100 } 101 102 func (m *memoryRecord) Record() protocol.Record { 103 return protocol.Record{ 104 Offset: m.offset, 105 Time: m.time, 106 Key: protocol.NewBytes(m.key), 107 Value: protocol.NewBytes(m.value), 108 Headers: m.headers, 109 } 110 } 111 112 func makeRecords(memoryRecords []memoryRecord) []protocol.Record { 113 records := make([]protocol.Record, len(memoryRecords)) 114 for i, m := range memoryRecords { 115 records[i] = m.Record() 116 } 117 return records 118 } 119 120 func loadRecords(r protocol.RecordReader) []memoryRecord { 121 records := []memoryRecord{} 122 123 for { 124 rec, err := r.ReadRecord() 125 if err != nil { 126 if errors.Is(err, io.EOF) { 127 return records 128 } 129 panic(err) 130 } 131 records = append(records, memoryRecord{ 132 offset: rec.Offset, 133 time: rec.Time, 134 key: readAll(rec.Key), 135 value: readAll(rec.Value), 136 headers: rec.Headers, 137 }) 138 } 139 } 140 141 func readAll(bytes protocol.Bytes) []byte { 142 if bytes != nil { 143 defer bytes.Close() 144 } 145 b, err := protocol.ReadAll(bytes) 146 if err != nil { 147 panic(err) 148 } 149 return b 150 }