github.com/anacrolix/torrent@v1.61.0/bencode/encode.go (about)

     1  package bencode
     2  
     3  import (
     4  	"io"
     5  	"math/big"
     6  	"reflect"
     7  	"runtime"
     8  	"sort"
     9  	"strconv"
    10  	"sync"
    11  
    12  	"github.com/anacrolix/missinggo"
    13  )
    14  
    15  func isEmptyValue(v reflect.Value) bool {
    16  	return missinggo.IsEmptyValue(v)
    17  }
    18  
    19  type Encoder struct {
    20  	w       io.Writer
    21  	scratch [64]byte
    22  }
    23  
    24  func (e *Encoder) Encode(v interface{}) (err error) {
    25  	if v == nil {
    26  		return
    27  	}
    28  	defer func() {
    29  		if e := recover(); e != nil {
    30  			if _, ok := e.(runtime.Error); ok {
    31  				panic(e)
    32  			}
    33  			var ok bool
    34  			err, ok = e.(error)
    35  			if !ok {
    36  				panic(e)
    37  			}
    38  		}
    39  	}()
    40  	e.reflectValue(reflect.ValueOf(v))
    41  	return nil
    42  }
    43  
    44  type stringValues []reflect.Value
    45  
    46  func (sv stringValues) Len() int           { return len(sv) }
    47  func (sv stringValues) Swap(i, j int)      { sv[i], sv[j] = sv[j], sv[i] }
    48  func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
    49  func (sv stringValues) get(i int) string   { return sv[i].String() }
    50  
    51  func (e *Encoder) write(s []byte) {
    52  	_, err := e.w.Write(s)
    53  	if err != nil {
    54  		panic(err)
    55  	}
    56  }
    57  
    58  func (e *Encoder) writeString(s string) {
    59  	for s != "" {
    60  		n := copy(e.scratch[:], s)
    61  		s = s[n:]
    62  		e.write(e.scratch[:n])
    63  	}
    64  }
    65  
    66  func (e *Encoder) reflectString(s string) {
    67  	e.writeStringPrefix(int64(len(s)))
    68  	e.writeString(s)
    69  }
    70  
    71  func (e *Encoder) writeStringPrefix(l int64) {
    72  	b := strconv.AppendInt(e.scratch[:0], l, 10)
    73  	e.write(b)
    74  	e.writeString(":")
    75  }
    76  
    77  func (e *Encoder) reflectByteSlice(s []byte) {
    78  	e.writeStringPrefix(int64(len(s)))
    79  	e.write(s)
    80  }
    81  
    82  // Returns true if the value implements Marshaler interface and marshaling was
    83  // done successfully.
    84  func (e *Encoder) reflectMarshaler(v reflect.Value) bool {
    85  	if !v.Type().Implements(marshalerType) {
    86  		if v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(marshalerType) {
    87  			v = v.Addr()
    88  		} else {
    89  			return false
    90  		}
    91  	}
    92  	m := v.Interface().(Marshaler)
    93  	data, err := m.MarshalBencode()
    94  	if err != nil {
    95  		panic(&MarshalerError{v.Type(), err})
    96  	}
    97  	e.write(data)
    98  	return true
    99  }
   100  
   101  var bigIntType = reflect.TypeOf((*big.Int)(nil)).Elem()
   102  
   103  func (e *Encoder) reflectValue(v reflect.Value) {
   104  	if e.reflectMarshaler(v) {
   105  		return
   106  	}
   107  
   108  	if v.Type() == bigIntType {
   109  		e.writeString("i")
   110  		bi := v.Interface().(big.Int)
   111  		e.writeString(bi.String())
   112  		e.writeString("e")
   113  		return
   114  	}
   115  
   116  	switch v.Kind() {
   117  	case reflect.Bool:
   118  		if v.Bool() {
   119  			e.writeString("i1e")
   120  		} else {
   121  			e.writeString("i0e")
   122  		}
   123  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   124  		e.writeString("i")
   125  		b := strconv.AppendInt(e.scratch[:0], v.Int(), 10)
   126  		e.write(b)
   127  		e.writeString("e")
   128  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   129  		e.writeString("i")
   130  		b := strconv.AppendUint(e.scratch[:0], v.Uint(), 10)
   131  		e.write(b)
   132  		e.writeString("e")
   133  	case reflect.String:
   134  		e.reflectString(v.String())
   135  	case reflect.Struct:
   136  		e.writeString("d")
   137  		for _, ef := range getEncodeFields(v.Type()) {
   138  			fieldValue := ef.i(v)
   139  			if !fieldValue.IsValid() {
   140  				continue
   141  			}
   142  			if ef.omitEmpty && isEmptyValue(fieldValue) {
   143  				continue
   144  			}
   145  			e.reflectString(ef.tag)
   146  			e.reflectValue(fieldValue)
   147  		}
   148  		e.writeString("e")
   149  	case reflect.Map:
   150  		if v.Type().Key().Kind() != reflect.String {
   151  			panic(&MarshalTypeError{v.Type()})
   152  		}
   153  		if v.IsNil() {
   154  			e.writeString("de")
   155  			break
   156  		}
   157  		e.writeString("d")
   158  		sv := stringValues(v.MapKeys())
   159  		sort.Sort(sv)
   160  		for _, key := range sv {
   161  			e.reflectString(key.String())
   162  			e.reflectValue(v.MapIndex(key))
   163  		}
   164  		e.writeString("e")
   165  	case reflect.Slice, reflect.Array:
   166  		e.reflectSequence(v)
   167  	case reflect.Interface:
   168  		e.reflectValue(v.Elem())
   169  	case reflect.Ptr:
   170  		if v.IsNil() {
   171  			v = reflect.Zero(v.Type().Elem())
   172  		} else {
   173  			v = v.Elem()
   174  		}
   175  		e.reflectValue(v)
   176  	default:
   177  		panic(&MarshalTypeError{v.Type()})
   178  	}
   179  }
   180  
   181  func (e *Encoder) reflectSequence(v reflect.Value) {
   182  	// Use bencode string-type
   183  	if v.Type().Elem().Kind() == reflect.Uint8 {
   184  		if v.Kind() != reflect.Slice {
   185  			// Can't use []byte optimization
   186  			if !v.CanAddr() {
   187  				e.writeStringPrefix(int64(v.Len()))
   188  				for i := 0; i < v.Len(); i++ {
   189  					var b [1]byte
   190  					b[0] = byte(v.Index(i).Uint())
   191  					e.write(b[:])
   192  				}
   193  				return
   194  			}
   195  			v = v.Slice(0, v.Len())
   196  		}
   197  		s := v.Bytes()
   198  		e.reflectByteSlice(s)
   199  		return
   200  	}
   201  	if v.IsNil() {
   202  		e.writeString("le")
   203  		return
   204  	}
   205  	e.writeString("l")
   206  	for i, n := 0, v.Len(); i < n; i++ {
   207  		e.reflectValue(v.Index(i))
   208  	}
   209  	e.writeString("e")
   210  }
   211  
   212  type encodeField struct {
   213  	i         func(v reflect.Value) reflect.Value
   214  	tag       string
   215  	omitEmpty bool
   216  }
   217  
   218  type encodeFieldsSortType []encodeField
   219  
   220  func (ef encodeFieldsSortType) Len() int           { return len(ef) }
   221  func (ef encodeFieldsSortType) Swap(i, j int)      { ef[i], ef[j] = ef[j], ef[i] }
   222  func (ef encodeFieldsSortType) Less(i, j int) bool { return ef[i].tag < ef[j].tag }
   223  
   224  var (
   225  	typeCacheLock     sync.RWMutex
   226  	encodeFieldsCache = make(map[reflect.Type][]encodeField)
   227  )
   228  
   229  func getEncodeFields(t reflect.Type) []encodeField {
   230  	typeCacheLock.RLock()
   231  	fs, ok := encodeFieldsCache[t]
   232  	typeCacheLock.RUnlock()
   233  	if ok {
   234  		return fs
   235  	}
   236  	fs = makeEncodeFields(t)
   237  	typeCacheLock.Lock()
   238  	defer typeCacheLock.Unlock()
   239  	encodeFieldsCache[t] = fs
   240  	return fs
   241  }
   242  
   243  func makeEncodeFields(t reflect.Type) (fs []encodeField) {
   244  	for _i, n := 0, t.NumField(); _i < n; _i++ {
   245  		i := _i
   246  		f := t.Field(i)
   247  		if f.PkgPath != "" {
   248  			continue
   249  		}
   250  		if f.Anonymous {
   251  			t := f.Type
   252  			if t.Kind() == reflect.Ptr {
   253  				t = t.Elem()
   254  			}
   255  			anonEFs := makeEncodeFields(t)
   256  			for aefi := range anonEFs {
   257  				anonEF := anonEFs[aefi]
   258  				bottomField := anonEF
   259  				bottomField.i = func(v reflect.Value) reflect.Value {
   260  					v = v.Field(i)
   261  					if v.Kind() == reflect.Ptr {
   262  						if v.IsNil() {
   263  							// This will skip serializing this value.
   264  							return reflect.Value{}
   265  						}
   266  						v = v.Elem()
   267  					}
   268  					return anonEF.i(v)
   269  				}
   270  				fs = append(fs, bottomField)
   271  			}
   272  			continue
   273  		}
   274  		var ef encodeField
   275  		ef.i = func(v reflect.Value) reflect.Value {
   276  			return v.Field(i)
   277  		}
   278  		ef.tag = f.Name
   279  
   280  		tv := getTag(f.Tag)
   281  		if tv.Ignore() {
   282  			continue
   283  		}
   284  		if tv.Key() != "" {
   285  			ef.tag = tv.Key()
   286  		}
   287  		ef.omitEmpty = tv.OmitEmpty()
   288  		fs = append(fs, ef)
   289  	}
   290  	fss := encodeFieldsSortType(fs)
   291  	sort.Sort(fss)
   292  	return fs
   293  }