github.com/ManabuSeki/goa-v1@v1.4.3/goagen/codegen/finalizer.go (about) 1 package codegen 2 3 import ( 4 "bytes" 5 "fmt" 6 "text/template" 7 8 "github.com/goadesign/goa/design" 9 ) 10 11 // Finalizer is the code generator for the 'Finalize' type methods. 12 type Finalizer struct { 13 assignmentT *template.Template 14 arrayAssignmentT *template.Template 15 seen map[*design.AttributeDefinition]map[*design.AttributeDefinition]*bytes.Buffer 16 } 17 18 // NewFinalizer instantiates a finalize code generator. 19 func NewFinalizer() *Finalizer { 20 var ( 21 f = &Finalizer{seen: make(map[*design.AttributeDefinition]map[*design.AttributeDefinition]*bytes.Buffer)} 22 err error 23 ) 24 fm := template.FuncMap{ 25 "tabs": Tabs, 26 "goify": Goify, 27 "gotyperef": GoTypeRef, 28 "add": Add, 29 "finalizeCode": f.Code, 30 } 31 f.assignmentT, err = template.New("assignment").Funcs(fm).Parse(assignmentTmpl) 32 if err != nil { 33 panic(err) 34 } 35 f.arrayAssignmentT, err = template.New("arrAssignment").Funcs(fm).Parse(arrayAssignmentTmpl) 36 if err != nil { 37 panic(err) 38 } 39 return f 40 } 41 42 // Code produces Go code that sets the default values for fields recursively for the given 43 // attribute. 44 func (f *Finalizer) Code(att *design.AttributeDefinition, target string, depth int) string { 45 buf := f.recurse(att, att, target, depth) 46 return buf.String() 47 } 48 49 func (f *Finalizer) recurse(root, att *design.AttributeDefinition, target string, depth int) *bytes.Buffer { 50 var ( 51 buf = new(bytes.Buffer) 52 first = true 53 ) 54 55 if s, ok := f.seen[root]; ok { 56 if buf, ok := s[att]; ok { 57 return buf 58 } 59 s[att] = buf 60 } else { 61 f.seen[root] = map[*design.AttributeDefinition]*bytes.Buffer{att: buf} 62 } 63 64 if o := att.Type.ToObject(); o != nil { 65 o.IterateAttributes(func(n string, catt *design.AttributeDefinition) error { 66 if att.HasDefaultValue(n) { 67 data := map[string]interface{}{ 68 "target": target, 69 "field": n, 70 "catt": catt, 71 "depth": depth, 72 "isDatetime": catt.Type == design.DateTime, 73 "defaultVal": PrintVal(catt.Type, catt.DefaultValue), 74 } 75 if !first { 76 buf.WriteByte('\n') 77 } else { 78 first = false 79 } 80 buf.WriteString(RunTemplate(f.assignmentT, data)) 81 } 82 a := f.recurse(root, catt, fmt.Sprintf("%s.%s", target, Goify(n, true)), depth+1).String() 83 if a != "" { 84 if catt.Type.IsObject() { 85 a = fmt.Sprintf("%sif %s.%s != nil {\n%s\n%s}", 86 Tabs(depth), target, Goify(n, true), a, Tabs(depth)) 87 } 88 if !first { 89 buf.WriteByte('\n') 90 } else { 91 first = false 92 } 93 buf.WriteString(a) 94 } 95 return nil 96 }) 97 } else if a := att.Type.ToArray(); a != nil { 98 data := map[string]interface{}{ 99 "elemType": a.ElemType, 100 "target": target, 101 "depth": 1, 102 } 103 if as := RunTemplate(f.arrayAssignmentT, data); as != "" { 104 buf.WriteString(as) 105 } 106 } 107 return buf 108 } 109 110 // PrintVal prints the given value corresponding to the given data type. 111 // The value is already checked for the compatibility with the data type. 112 func PrintVal(t design.DataType, val interface{}) string { 113 switch { 114 case t.IsPrimitive(): 115 // For primitive types, simply print the value 116 s := fmt.Sprintf("%#v", val) 117 switch t { 118 case design.Number: 119 v := val 120 if i, ok := val.(int); ok { 121 v = float64(i) 122 } 123 s = fmt.Sprintf("%f", v) 124 case design.DateTime: 125 s = fmt.Sprintf("time.Parse(time.RFC3339, %s)", s) 126 } 127 return s 128 case t.IsHash(): 129 // The input is a hash 130 h := t.ToHash() 131 hval := val.(map[interface{}]interface{}) 132 if len(hval) == 0 { 133 return fmt.Sprintf("%s{}", GoTypeName(t, nil, 0, false)) 134 } 135 var buffer bytes.Buffer 136 buffer.WriteString(fmt.Sprintf("%s{", GoTypeName(t, nil, 0, false))) 137 for k, v := range hval { 138 buffer.WriteString(fmt.Sprintf("%s: %s, ", PrintVal(h.KeyType.Type, k), PrintVal(h.ElemType.Type, v))) 139 } 140 buffer.Truncate(buffer.Len() - 2) // remove ", " 141 buffer.WriteString("}") 142 return buffer.String() 143 case t.IsArray(): 144 // Input is an array 145 a := t.ToArray() 146 aval := val.([]interface{}) 147 if len(aval) == 0 { 148 return fmt.Sprintf("%s{}", GoTypeName(t, nil, 0, false)) 149 } 150 var buffer bytes.Buffer 151 buffer.WriteString(fmt.Sprintf("%s{", GoTypeName(t, nil, 0, false))) 152 for _, e := range aval { 153 buffer.WriteString(fmt.Sprintf("%s, ", PrintVal(a.ElemType.Type, e))) 154 } 155 buffer.Truncate(buffer.Len() - 2) // remove ", " 156 buffer.WriteString("}") 157 return buffer.String() 158 default: 159 // shouldn't happen as the value's compatibility is already checked. 160 panic("unknown type") 161 } 162 } 163 164 const ( 165 assignmentTmpl = `{{ if .catt.Type.IsPrimitive }}{{ $defaultName := (print "default" (goify .field true)) }}{{/* 166 */}}{{ tabs .depth }}var {{ $defaultName }}{{if .isDatetime}}, _{{end}} = {{ .defaultVal }} 167 {{ tabs .depth }}if {{ .target }}.{{ goify .field true }} == nil { 168 {{ tabs .depth }} {{ .target }}.{{ goify .field true }} = &{{ $defaultName }} 169 }{{ else }}{{ tabs .depth }}if {{ .target }}.{{ goify .field true }} == nil { 170 {{ tabs .depth }} {{ .target }}.{{ goify .field true }} = {{ .defaultVal }} 171 }{{ end }}` 172 173 arrayAssignmentTmpl = `{{ $a := finalizeCode .elemType "e" (add .depth 1) }}{{/* 174 */}}{{ if $a }}{{ tabs .depth }}for _, e := range {{ .target }} { 175 {{ $a }} 176 {{ tabs .depth }}}{{ end }}` 177 )