github.com/xhebox/bstruct@v0.0.0-20221115052913-86d4d6d98866/builder.go (about)

     1  package bstruct
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/printer"
     7  	"go/token"
     8  	"io"
     9  	"math"
    10  	"unicode"
    11  )
    12  
    13  const defWrap = 77
    14  
    15  type builtField struct {
    16  	field *Field
    17  	typ   ast.GenDecl
    18  	enc   ast.FuncDecl
    19  	dec   ast.FuncDecl
    20  	extra []ast.Decl
    21  }
    22  
    23  type Builder struct {
    24  	cnt      uint
    25  	getter   bool
    26  	setter   bool
    27  	lineWrap int
    28  	imports  *ast.GenDecl
    29  	types    map[string]builtField
    30  }
    31  
    32  func NewBuilder() *Builder {
    33  	return &Builder{
    34  		types:    make(map[string]builtField),
    35  		lineWrap: defWrap,
    36  	}
    37  }
    38  
    39  func (e *Builder) Getter(f bool) *Builder {
    40  	e.getter = f
    41  	return e
    42  }
    43  
    44  func (e *Builder) Setter(f bool) *Builder {
    45  	e.setter = f
    46  	return e
    47  }
    48  
    49  func (e *Builder) SetLineWrap(wrap int) *Builder {
    50  	if wrap <= 0 {
    51  		e.lineWrap = defWrap
    52  	}
    53  	return e
    54  }
    55  
    56  func (e *Builder) encPrim(writer ast.Expr, ptr ast.Expr, s *Field) (stmts []ast.Stmt) {
    57  	switch {
    58  	case s.typ.IsPrimitive():
    59  		stmts = append(stmts, newCallST(
    60  			newSel(writer, "Copy"),
    61  			unsafePtr(newPtr(ptr)),
    62  			intLit(s.typ.Size()),
    63  		))
    64  	case s.typ.IsType(FieldSlice) || s.typ.IsType(FieldString):
    65  		stmts = append(stmts,
    66  			newCallST(
    67  				newSel(writer, "WriteLen"),
    68  				newLen(ptr),
    69  			),
    70  		)
    71  		if s.sliceType.typ.IsPrimitive() {
    72  			hdr := e.newIdent()
    73  			var bstmts []ast.Stmt
    74  			if s.typ.IsType(FieldString) {
    75  				bstmts = append(bstmts,
    76  					newDef(hdr, newCall(&ast.ParenExpr{X: &ast.UnaryExpr{X: newSel("reflect", "StringHeader"), Op: token.MUL}}, unsafePtr(newPtr(ptr)))),
    77  				)
    78  			} else {
    79  				bstmts = append(bstmts,
    80  					newDef(hdr, newCall(&ast.ParenExpr{X: &ast.UnaryExpr{X: newSel("reflect", "SliceHeader"), Op: token.MUL}}, unsafePtr(newPtr(ptr)))),
    81  				)
    82  			}
    83  			bstmts = append(bstmts,
    84  				newCallST(
    85  					newSel(writer, "Copy"),
    86  					unsafePtr(newSel(hdr, "Data")),
    87  					newMul(intLit(s.sliceType.typ.Size()), newLen(ptr)),
    88  				),
    89  			)
    90  			stmts = append(stmts, &ast.IfStmt{
    91  				Cond: &ast.BinaryExpr{X: newLen(ptr), Op: token.GTR, Y: intLit(0)},
    92  				Body: &ast.BlockStmt{List: bstmts},
    93  			})
    94  		} else {
    95  			i := e.newIdent()
    96  			stmts = append(stmts, &ast.RangeStmt{
    97  				Key:  i,
    98  				Tok:  token.DEFINE,
    99  				X:    ptr,
   100  				Body: &ast.BlockStmt{List: e.encField(writer, newIdx(ptr, i), s.sliceType)},
   101  			})
   102  		}
   103  	case s.typ.IsType(FieldStruct):
   104  		for i := range s.strucFields {
   105  			bstmts := e.encField(writer, newSel(ptr, s.strucFields[i].strucName), s.strucFields[i].Field)
   106  			if s.strucFields[i].optional {
   107  				hasField := newSel(ptr, newOpt(s.strucFields[i].strucName))
   108  				stmts = append(stmts, e.encPrim(writer, hasField, New(FieldBool))...)
   109  				stmts = append(stmts, &ast.IfStmt{
   110  					Cond: hasField,
   111  					Body: &ast.BlockStmt{List: bstmts},
   112  				})
   113  			} else {
   114  				stmts = append(stmts, bstmts...)
   115  			}
   116  		}
   117  	case s.typ.IsType(FieldCustom):
   118  		if s.cusenc != nil {
   119  			stmts = s.cusenc(writer, ptr, s)
   120  		}
   121  	default:
   122  		panic("wth")
   123  	}
   124  	return
   125  }
   126  
   127  func (e *Builder) encField(writer ast.Expr, ptr ast.Expr, s *Field) (stmts []ast.Stmt) {
   128  	if _, ok := e.types[s.typename]; ok {
   129  		stmts = append(stmts, newCallST(
   130  			newSel(ptr, "Encode"),
   131  			writer,
   132  		))
   133  		return
   134  	}
   135  
   136  	stmts = e.encPrim(writer, ptr, s)
   137  	return
   138  }
   139  
   140  func (e *Builder) decPrim(reader ast.Expr, ptr ast.Expr, s *Field) (stmts []ast.Stmt) {
   141  	switch {
   142  	case s.typ.IsPrimitive():
   143  		stmts = append(stmts, newCallST(
   144  			newSel(reader, "Copy"),
   145  			unsafePtr(newPtr(ptr)),
   146  			intLit(s.typ.Size()),
   147  		))
   148  	case s.typ.IsType(FieldSlice) || s.typ.IsType(FieldString):
   149  		length := e.newIdent()
   150  		stmts = append(stmts, newDef(length, newCall(newSel(reader, "ReadLen"))))
   151  		var bstmts []ast.Stmt
   152  		if s.sliceType.typ.IsPrimitive() {
   153  			hdr := e.newIdent()
   154  			if s.typ.IsType(FieldSlice) {
   155  				bstmts = append(bstmts,
   156  					newDef(hdr, newCall(&ast.ParenExpr{X: &ast.UnaryExpr{X: newSel("reflect", "SliceHeader"), Op: token.MUL}}, unsafePtr(newPtr(ptr)))),
   157  				)
   158  			} else {
   159  				bstmts = append(bstmts,
   160  					newDef(hdr, newCall(&ast.ParenExpr{X: &ast.UnaryExpr{X: newSel("reflect", "StringHeader"), Op: token.MUL}}, unsafePtr(newPtr(ptr)))),
   161  				)
   162  			}
   163  			bstmts = append(bstmts,
   164  				newAssign(newSel(hdr, "Len"), length),
   165  				newAssign(newSel(hdr, "Data"), newCall(
   166  					newSel(reader, "Read"),
   167  					newMul(intLit(s.sliceType.typ.Size()), length),
   168  				)),
   169  			)
   170  			if s.typ.IsType(FieldSlice) {
   171  				bstmts = append(bstmts,
   172  					newAssign(newSel(hdr, "Cap"), length),
   173  				)
   174  			}
   175  		} else {
   176  			i := e.newIdent()
   177  			bstmts = append(bstmts,
   178  				newAssign(ptr, newCall("make", &ast.ArrayType{Elt: e.typWrap(s.sliceType)}, length)),
   179  				&ast.RangeStmt{
   180  					Key:  newIdent(i),
   181  					Tok:  token.DEFINE,
   182  					X:    ptr,
   183  					Body: &ast.BlockStmt{List: e.decField(reader, newIdx(ptr, i), s.sliceType)},
   184  				},
   185  			)
   186  		}
   187  		stmts = append(stmts, &ast.IfStmt{
   188  			Cond: &ast.BinaryExpr{X: length, Op: token.GTR, Y: intLit(0)},
   189  			Body: &ast.BlockStmt{List: bstmts},
   190  		})
   191  	case s.typ.IsType(FieldStruct):
   192  		for i := range s.strucFields {
   193  			bstmts := e.decField(reader, newSel(ptr, s.strucFields[i].strucName), s.strucFields[i].Field)
   194  			if s.strucFields[i].optional {
   195  				hasField := newSel(ptr, newOpt(s.strucFields[i].strucName))
   196  				stmts = append(stmts, e.decPrim(reader, hasField, New(FieldBool))...)
   197  				stmts = append(stmts, &ast.IfStmt{
   198  					Cond: hasField,
   199  					Body: &ast.BlockStmt{List: bstmts},
   200  				})
   201  			} else {
   202  				stmts = append(stmts, bstmts...)
   203  			}
   204  		}
   205  	case s.typ.IsType(FieldCustom):
   206  		if s.cusdec != nil {
   207  			stmts = s.cusdec(reader, ptr, s)
   208  		}
   209  	default:
   210  		panic("wth")
   211  	}
   212  	return
   213  }
   214  
   215  func (e *Builder) decField(reader, ptr ast.Expr, s *Field) (stmts []ast.Stmt) {
   216  	if _, ok := e.types[s.typename]; ok {
   217  		stmts = append(stmts, newCallST(
   218  			newSel(ptr, "Decode"),
   219  			reader,
   220  		))
   221  		return
   222  	}
   223  
   224  	stmts = e.decPrim(reader, ptr, s)
   225  	return
   226  }
   227  
   228  func (e *Builder) newIdent() *ast.Ident {
   229  	r := ast.NewIdent(fmt.Sprintf("v%d", e.cnt))
   230  	e.cnt++
   231  	return r
   232  }
   233  
   234  func (e *Builder) typPrim(s *Field) ast.Expr {
   235  	switch {
   236  	case s.typ.IsPrimitive() || s.typ.IsType(FieldString):
   237  		return newIdent(s.typ.String())
   238  	case s.typ.IsType(FieldSlice):
   239  		return &ast.ArrayType{
   240  			Elt: e.typWrap(s.sliceType),
   241  		}
   242  	case s.typ.IsType(FieldStruct):
   243  		var fields []*ast.Field
   244  		for _, field := range s.strucFields {
   245  			if field.optional {
   246  				fields = append(fields, &ast.Field{
   247  					Names: []*ast.Ident{ast.NewIdent(newOpt(field.strucName))},
   248  					Type:  newIdent(FieldBool.String()),
   249  				})
   250  			}
   251  			fields = append(fields, &ast.Field{
   252  				Names:   []*ast.Ident{ast.NewIdent(field.strucName)},
   253  				Comment: e.commentGroup(field.comment),
   254  				Type:    e.typWrap(field.Field),
   255  			})
   256  		}
   257  		return &ast.StructType{Fields: &ast.FieldList{List: fields}}
   258  	case s.typ.IsType(FieldCustom):
   259  		return s.custyp
   260  	default:
   261  		return newIdent("invalid")
   262  	}
   263  }
   264  
   265  func (e *Builder) typWrap(s *Field) ast.Expr {
   266  	if _, ok := e.types[s.typename]; ok {
   267  		return newIdent(s.typename)
   268  	}
   269  
   270  	return e.typPrim(s)
   271  }
   272  
   273  func (e *Builder) getFunc(el *Field, name string) (*ast.FuncDecl, ast.Expr) {
   274  	typ := &ast.UnaryExpr{X: newIdent(el.typename), Op: token.MUL}
   275  	idt := ast.NewIdent("v")
   276  	var val ast.Expr = idt
   277  	if t := el.typ; !t.IsType(FieldStruct) {
   278  		val = &ast.ParenExpr{X: &ast.UnaryExpr{X: val, Op: token.MUL}}
   279  	}
   280  	return &ast.FuncDecl{
   281  		Name: ast.NewIdent(name),
   282  		Recv: &ast.FieldList{List: []*ast.Field{
   283  			{Names: []*ast.Ident{idt}, Type: typ},
   284  		}},
   285  		Type: &ast.FuncType{
   286  			Params:  &ast.FieldList{List: []*ast.Field{}},
   287  			Results: &ast.FieldList{List: []*ast.Field{}},
   288  		},
   289  		Body: &ast.BlockStmt{List: []ast.Stmt{}},
   290  	}, val
   291  }
   292  
   293  func (e *Builder) getFieldGetter(p *Field, el StructField) ast.Decl {
   294  	typ := e.typPrim(el.Field)
   295  
   296  	var getterName string
   297  	if unicode.IsUpper(rune(el.strucName[0])) {
   298  		getterName = fmt.Sprintf("Get%s", el.strucName)
   299  	} else {
   300  		getterName = capitalize(el.strucName)
   301  	}
   302  	getter, val := e.getFunc(p, getterName)
   303  	getter.Type.Results.List = append(getter.Type.Results.List, &ast.Field{Type: typ})
   304  	getter.Body.List = append(getter.Body.List, &ast.ReturnStmt{Results: []ast.Expr{newSel(val, el.strucName)}})
   305  
   306  	return getter
   307  }
   308  
   309  func (e *Builder) getFieldSetter(p *Field, el StructField) ast.Decl {
   310  	typ := e.typPrim(el.Field)
   311  
   312  	setterName := fmt.Sprintf("Set%s", capitalize(el.strucName))
   313  	setter, val := e.getFunc(p, setterName)
   314  	sval := ast.NewIdent("i")
   315  	setter.Type.Params.List = append(setter.Type.Params.List, &ast.Field{Names: []*ast.Ident{sval}, Type: typ})
   316  	setter.Body.List = append(setter.Body.List, newAssign(newSel(val, el.strucName), sval))
   317  
   318  	return setter
   319  }
   320  
   321  func (e *Builder) Process() {
   322  	e.cnt = 0
   323  
   324  	e.imports = &ast.GenDecl{
   325  		Tok: token.IMPORT,
   326  	}
   327  	e.imports.Specs = append(e.imports.Specs,
   328  		&ast.ImportSpec{
   329  			Path: &ast.BasicLit{
   330  				Kind:  token.STRING,
   331  				Value: "\"unsafe\"",
   332  			},
   333  		},
   334  		&ast.ImportSpec{
   335  			Path: &ast.BasicLit{
   336  				Kind:  token.STRING,
   337  				Value: "\"reflect\"",
   338  			},
   339  		},
   340  		&ast.ImportSpec{
   341  			Path: &ast.BasicLit{
   342  				Kind:  token.STRING,
   343  				Value: "\"github.com/xhebox/bstruct\"",
   344  			},
   345  		},
   346  	)
   347  
   348  	for name, el := range e.types {
   349  		el.typ = ast.GenDecl{
   350  			Tok: token.TYPE,
   351  			Doc: e.commentGroup(el.field.comment),
   352  			Specs: []ast.Spec{
   353  				&ast.TypeSpec{
   354  					Name: ast.NewIdent(name),
   355  					Type: e.typPrim(el.field),
   356  				},
   357  			},
   358  		}
   359  
   360  		writer := ast.NewIdent("wt")
   361  		enc, val := e.getFunc(el.field, "Encode")
   362  		enc.Type.Params.List = append(enc.Type.Params.List, &ast.Field{Names: []*ast.Ident{writer}, Type: writerType})
   363  		enc.Body.List = append(enc.Body.List, e.encPrim(writer, val, el.field)...)
   364  		el.enc = *enc
   365  
   366  		reader := ast.NewIdent("rd")
   367  		dec, val := e.getFunc(el.field, "Decode")
   368  		dec.Type.Params.List = append(dec.Type.Params.List, &ast.Field{Names: []*ast.Ident{reader}, Type: readerType})
   369  		dec.Body.List = append(dec.Body.List, e.decPrim(reader, val, el.field)...)
   370  		el.dec = *dec
   371  
   372  		el.extra = e.extraDecl(el.field)
   373  
   374  		e.types[name] = el
   375  	}
   376  }
   377  
   378  func (e *Builder) extraDecl(el *Field) (decls []ast.Decl) {
   379  	if el.typ.IsType(FieldStruct) && e.getter {
   380  		for _, field := range el.strucFields {
   381  			decls = append(decls, e.getFieldGetter(el, field))
   382  		}
   383  	}
   384  
   385  	if el.typ.IsType(FieldStruct) && e.setter {
   386  		for _, field := range el.strucFields {
   387  			decls = append(decls, e.getFieldSetter(el, field))
   388  		}
   389  	}
   390  
   391  	return
   392  }
   393  
   394  func (e *Builder) Print(buf io.Writer, pak string) error {
   395  	ts := token.NewFileSet()
   396  	cfg := printer.Config{
   397  		Mode:     printer.TabIndent,
   398  		Tabwidth: 2,
   399  	}
   400  	file := &ast.File{
   401  		Name:  ast.NewIdent(pak),
   402  		Decls: []ast.Decl{e.imports},
   403  	}
   404  	for _, e := range e.types {
   405  		el := e
   406  		file.Decls = append(file.Decls, &el.typ, &el.enc, &el.dec)
   407  		file.Decls = append(file.Decls, el.extra...)
   408  	}
   409  	if err := cfg.Fprint(buf, ts, file); err != nil {
   410  		return err
   411  	}
   412  	if _, err := fmt.Fprintf(buf, "\n"); err != nil {
   413  		return err
   414  	}
   415  	return nil
   416  }
   417  
   418  func (e *Builder) commentGroup(comment string) *ast.CommentGroup {
   419  	if len(comment) == 0 {
   420  		return nil
   421  	}
   422  
   423  	strLength := len(comment)
   424  	splitedLength := int(math.Ceil(float64(strLength) / float64(e.lineWrap)))
   425  	splited := make([]*ast.Comment, splitedLength)
   426  	var start, stop int
   427  	for i := 0; i < splitedLength; i += 1 {
   428  		start = i * e.lineWrap
   429  		stop = start + e.lineWrap
   430  		if stop > strLength {
   431  			stop = strLength
   432  		}
   433  		splited[i] = &ast.Comment{Text: fmt.Sprintf("// %s", comment[start:stop])}
   434  	}
   435  	return &ast.CommentGroup{List: splited}
   436  }