github.com/segmentio/kafka-go@v0.4.48-0.20240318174348-3f6244eb34fd/protocol/prototest/reflect.go (about)

     1  package prototest
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"reflect"
     8  	"time"
     9  
    10  	"github.com/segmentio/kafka-go/protocol"
    11  )
    12  
    13  var (
    14  	recordReader = reflect.TypeOf((*protocol.RecordReader)(nil)).Elem()
    15  )
    16  
    17  func closeMessage(m protocol.Message) {
    18  	forEachField(reflect.ValueOf(m), func(v reflect.Value) {
    19  		if v.Type().Implements(recordReader) {
    20  			rr := v.Interface().(protocol.RecordReader)
    21  			for {
    22  				r, err := rr.ReadRecord()
    23  				if err != nil {
    24  					break
    25  				}
    26  				if r.Key != nil {
    27  					r.Key.Close()
    28  				}
    29  				if r.Value != nil {
    30  					r.Value.Close()
    31  				}
    32  			}
    33  		}
    34  	})
    35  }
    36  
    37  func load(v interface{}) (reset func()) {
    38  	return loadValue(reflect.ValueOf(v))
    39  }
    40  
    41  func loadValue(v reflect.Value) (reset func()) {
    42  	resets := []func(){}
    43  
    44  	forEachField(v, func(f reflect.Value) {
    45  		switch x := f.Interface().(type) {
    46  		case protocol.RecordReader:
    47  			records := loadRecords(x)
    48  			resetFunc := func() {
    49  				f.Set(reflect.ValueOf(protocol.NewRecordReader(makeRecords(records)...)))
    50  			}
    51  			resetFunc()
    52  			resets = append(resets, resetFunc)
    53  		case io.Reader:
    54  			buf, _ := io.ReadAll(x)
    55  			resetFunc := func() {
    56  				f.Set(reflect.ValueOf(bytes.NewBuffer(buf)))
    57  			}
    58  			resetFunc()
    59  			resets = append(resets, resetFunc)
    60  		}
    61  	})
    62  
    63  	return func() {
    64  		for _, f := range resets {
    65  			f()
    66  		}
    67  	}
    68  }
    69  
    70  func forEachField(v reflect.Value, do func(reflect.Value)) {
    71  	for v.Kind() == reflect.Ptr {
    72  		if v.IsNil() {
    73  			return
    74  		}
    75  		v = v.Elem()
    76  	}
    77  
    78  	switch v.Kind() {
    79  	case reflect.Slice:
    80  		for i, n := 0, v.Len(); i < n; i++ {
    81  			forEachField(v.Index(i), do)
    82  		}
    83  
    84  	case reflect.Struct:
    85  		for i, n := 0, v.NumField(); i < n; i++ {
    86  			forEachField(v.Field(i), do)
    87  		}
    88  
    89  	default:
    90  		do(v)
    91  	}
    92  }
    93  
    94  type memoryRecord struct {
    95  	offset  int64
    96  	time    time.Time
    97  	key     []byte
    98  	value   []byte
    99  	headers []protocol.Header
   100  }
   101  
   102  func (m *memoryRecord) Record() protocol.Record {
   103  	return protocol.Record{
   104  		Offset:  m.offset,
   105  		Time:    m.time,
   106  		Key:     protocol.NewBytes(m.key),
   107  		Value:   protocol.NewBytes(m.value),
   108  		Headers: m.headers,
   109  	}
   110  }
   111  
   112  func makeRecords(memoryRecords []memoryRecord) []protocol.Record {
   113  	records := make([]protocol.Record, len(memoryRecords))
   114  	for i, m := range memoryRecords {
   115  		records[i] = m.Record()
   116  	}
   117  	return records
   118  }
   119  
   120  func loadRecords(r protocol.RecordReader) []memoryRecord {
   121  	records := []memoryRecord{}
   122  
   123  	for {
   124  		rec, err := r.ReadRecord()
   125  		if err != nil {
   126  			if errors.Is(err, io.EOF) {
   127  				return records
   128  			}
   129  			panic(err)
   130  		}
   131  		records = append(records, memoryRecord{
   132  			offset:  rec.Offset,
   133  			time:    rec.Time,
   134  			key:     readAll(rec.Key),
   135  			value:   readAll(rec.Value),
   136  			headers: rec.Headers,
   137  		})
   138  	}
   139  }
   140  
   141  func readAll(bytes protocol.Bytes) []byte {
   142  	if bytes != nil {
   143  		defer bytes.Close()
   144  	}
   145  	b, err := protocol.ReadAll(bytes)
   146  	if err != nil {
   147  		panic(err)
   148  	}
   149  	return b
   150  }