github.com/go-kivik/kivik/v4@v4.3.2/mockdb/gen/render.go (about) 1 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 2 // use this file except in compliance with the License. You may obtain a copy of 3 // the License at 4 // 5 // http://www.apache.org/licenses/LICENSE-2.0 6 // 7 // Unless required by applicable law or agreed to in writing, software 8 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 9 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 10 // License for the specific language governing permissions and limitations under 11 // the License. 12 13 package main 14 15 import ( 16 "bytes" 17 "fmt" 18 "os" 19 "reflect" 20 "strings" 21 "text/template" 22 ) 23 24 var tmpl *template.Template 25 26 func initTemplates(root string) { 27 var err error 28 tmpl, err = template.ParseGlob(root + "/*") 29 if err != nil { 30 panic(err) 31 } 32 } 33 34 func renderExpectationsGo(filename string, methods []*method) error { 35 file, err := os.Create(filename) 36 if err != nil { 37 return err 38 } 39 return tmpl.ExecuteTemplate(file, "expectations.go.tmpl", methods) 40 } 41 42 func renderClientGo(filename string, methods []*method) error { 43 file, err := os.Create(filename) 44 if err != nil { 45 return err 46 } 47 return tmpl.ExecuteTemplate(file, "client.go.tmpl", methods) 48 } 49 50 func renderMockGo(filename string, methods []*method) error { 51 file, err := os.Create(filename) 52 if err != nil { 53 return err 54 } 55 return tmpl.ExecuteTemplate(file, "mock.go.tmpl", methods) 56 } 57 58 func renderDriverMethod(m *method) (string, error) { 59 buf := &bytes.Buffer{} 60 err := tmpl.ExecuteTemplate(buf, "drivermethod.tmpl", m) 61 return buf.String(), err 62 } 63 64 func renderExpectedType(m *method) (string, error) { 65 buf := &bytes.Buffer{} 66 err := tmpl.ExecuteTemplate(buf, "expectedtype.tmpl", m) 67 return buf.String(), err 68 } 69 70 func (m *method) DriverArgs() string { 71 const extraCount = 2 72 args := make([]string, 0, len(m.Accepts)+extraCount) 73 if m.AcceptsContext { 74 args = append(args, "ctx context.Context") 75 } 76 for i, arg := range m.Accepts { 77 args = append(args, fmt.Sprintf("arg%d %s", i, typeName(arg))) 78 } 79 if m.AcceptsOptions { 80 args = append(args, "options driver.Options") 81 } 82 return strings.Join(args, ", ") 83 } 84 85 func (m *method) ReturnArgs() string { 86 args := make([]string, 0, len(m.Returns)+1) 87 for _, arg := range m.Returns { 88 args = append(args, arg.String()) 89 } 90 if m.ReturnsError { 91 args = append(args, "error") 92 } 93 if len(args) > 1 { 94 return `(` + strings.Join(args, ", ") + `)` 95 } 96 return args[0] 97 } 98 99 func (m *method) VariableDefinitions() string { 100 result := make([]string, 0, len(m.Accepts)+len(m.Returns)) 101 for i, arg := range m.Accepts { 102 result = append(result, fmt.Sprintf("\targ%d %s\n", i, typeName(arg))) 103 } 104 for i, ret := range m.Returns { 105 name := typeName(ret) 106 switch name { 107 case "driver.DB": // nolint: goconst 108 name = "*DB" 109 case "driver.Replication": // nolint: goconst 110 name = "*Replication" 111 case "[]driver.Replication": // nolint: goconst 112 name = "[]*Replication" 113 } 114 result = append(result, fmt.Sprintf("\tret%d %s\n", i, name)) 115 } 116 return strings.Join(result, "") 117 } 118 119 func (m *method) inputVars() []string { 120 args := make([]string, 0, len(m.Accepts)+1) 121 for i := range m.Accepts { 122 args = append(args, fmt.Sprintf("arg%d", i)) 123 } 124 if m.AcceptsOptions { 125 args = append(args, "options") 126 } 127 return args 128 } 129 130 func (m *method) ExpectedVariables() string { 131 args := []string{} 132 if m.DBMethod { 133 args = append(args, "db") 134 } 135 args = append(args, m.inputVars()...) 136 return alignVars(0, args) 137 } 138 139 func (m *method) InputVariables() string { 140 result := make([]string, len(m.Accepts)+1) 141 var common []string 142 if m.DBMethod { 143 common = append(common, "\t\t\tdb: db.DB,\n") 144 } 145 for i := range m.Accepts { 146 result = append(result, fmt.Sprintf("\t\targ%d: arg%d,\n", i, i)) 147 } 148 if m.AcceptsOptions { 149 common = append(common, "\t\t\toptions: options,\n") 150 } 151 if len(common) > 0 { 152 result = append(result, fmt.Sprintf("\t\tcommonExpectation: commonExpectation{\n%s\t\t},\n", 153 strings.Join(common, ""))) 154 } 155 return strings.Join(result, "") 156 } 157 158 func (m *method) Variables(indent int) string { 159 args := m.inputVars() 160 for i := range m.Returns { 161 args = append(args, fmt.Sprintf("ret%d", i)) 162 } 163 return alignVars(indent, args) 164 } 165 166 func alignVars(indent int, args []string) string { 167 var maxLen int 168 for _, arg := range args { 169 if l := len(arg); l > maxLen { 170 maxLen = l 171 } 172 } 173 final := make([]string, len(args)) 174 for i, arg := range args { 175 final[i] = fmt.Sprintf("%s%*s %s,", strings.Repeat("\t", indent), -(maxLen + 1), arg+":", arg) 176 } 177 return strings.Join(final, "\n") 178 } 179 180 func (m *method) ZeroReturns() string { 181 args := make([]string, 0, len(m.Returns)) 182 for _, arg := range m.Returns { 183 args = append(args, zeroValue(arg)) 184 } 185 args = append(args, "err") 186 return strings.Join(args, ", ") 187 } 188 189 func zeroValue(t reflect.Type) string { 190 z := fmt.Sprintf("%#v", reflect.Zero(t).Interface()) 191 if strings.HasSuffix(z, "(nil)") { 192 return "nil" 193 } 194 if z == "<nil>" { 195 return "nil" 196 } 197 return z 198 } 199 200 func (m *method) ExpectedReturns() string { 201 args := make([]string, 0, len(m.Returns)) 202 for i, arg := range m.Returns { 203 switch arg.String() { 204 case "driver.Rows": 205 args = append(args, fmt.Sprintf("&driverRows{Context: ctx, Rows: coalesceRows(expected.ret%d)}", i)) 206 case "driver.Changes": 207 args = append(args, fmt.Sprintf("&driverChanges{Context: ctx, Changes: coalesceChanges(expected.ret%d)}", i)) 208 case "driver.DB": 209 args = append(args, fmt.Sprintf("&driverDB{DB: expected.ret%d}", i)) 210 case "driver.DBUpdates": 211 args = append(args, fmt.Sprintf("&driverDBUpdates{Context:ctx, Updates: coalesceDBUpdates(expected.ret%d)}", i)) 212 case "driver.Replication": 213 args = append(args, fmt.Sprintf("&driverReplication{Replication: expected.ret%d}", i)) 214 case "[]driver.Replication": 215 args = append(args, fmt.Sprintf("driverReplications(expected.ret%d)", i)) 216 default: 217 args = append(args, fmt.Sprintf("expected.ret%d", i)) 218 } 219 } 220 if m.AcceptsContext { 221 args = append(args, "expected.wait(ctx)") 222 } else { 223 args = append(args, "expected.err") 224 } 225 return strings.Join(args, ", ") 226 } 227 228 func (m *method) ReturnTypes() string { 229 args := make([]string, len(m.Returns)) 230 for i, ret := range m.Returns { 231 name := typeName(ret) 232 switch name { 233 case "driver.DB": 234 name = "*DB" 235 case "driver.Replication": 236 name = "*Replication" 237 case "[]driver.Replication": 238 name = "[]*Replication" 239 } 240 args[i] = fmt.Sprintf("ret%d %s", i, name) 241 } 242 return strings.Join(args, ", ") 243 } 244 245 func typeName(t reflect.Type) string { 246 name := t.String() 247 switch name { 248 case "interface {}": 249 return "interface{}" 250 case "driver.Rows": 251 return "*Rows" 252 case "driver.Changes": 253 return "*Changes" 254 case "driver.DBUpdates": 255 return "*Updates" 256 } 257 return name 258 } 259 260 func (m *method) SetExpectations() string { 261 var args []string 262 if m.DBMethod { 263 args = append(args, "commonExpectation: commonExpectation{db: db},\n") 264 } 265 if m.Name == "DB" { 266 args = append(args, "ret0: &DB{},\n") 267 } 268 for i, ret := range m.Returns { 269 var zero string 270 switch ret.String() { 271 case "*kivik.Rows": 272 zero = "&Rows{}" 273 case "*kivik.QueryPlan": 274 zero = "&driver.QueryPlan{}" 275 case "*kivik.PurgeResult": 276 zero = "&driver.PurgeResult{}" 277 case "*kivik.DBUpdates": 278 zero = "&Updates{}" 279 } 280 if zero != "" { 281 args = append(args, fmt.Sprintf("ret%d: %s,\n", i, zero)) 282 } 283 } 284 return strings.Join(args, "") 285 } 286 287 func (m *method) MetExpectations() string { 288 if len(m.Accepts) == 0 { 289 return "" 290 } 291 args := make([]string, 0, len(m.Accepts)+1) 292 args = append(args, fmt.Sprintf("\texp := ex.(*Expected%s)", m.Name)) 293 var check string 294 for i, arg := range m.Accepts { 295 switch arg.String() { 296 case "string": 297 check = `exp.arg%[1]d != "" && exp.arg%[1]d != e.arg%[1]d` 298 case "int": 299 check = "exp.arg%[1]d != 0 && exp.arg%[1]d != e.arg%[1]d" 300 case "interface {}": 301 check = "exp.arg%[1]d != nil && !jsonMeets(exp.arg%[1]d, e.arg%[1]d)" 302 default: 303 check = "exp.arg%[1]d != nil && !reflect.DeepEqual(exp.arg%[1]d, e.arg%[1]d)" 304 } 305 args = append(args, fmt.Sprintf("if "+check+" {\n\t\treturn false\n\t}", i)) 306 } 307 return strings.Join(args, "\n") 308 } 309 310 func (m *method) MethodArgs() string { 311 str := make([]string, 0, len(m.Accepts)+1) 312 def := make([]string, 0, len(m.Accepts)+1) 313 const maxVarLen = 3 314 vars := make([]string, 0, maxVarLen) 315 var args, mid []string 316 prefix := "" 317 if m.DBMethod { 318 prefix = "DB(%s)." 319 args = append(args, "e.dbo().name") 320 } 321 if m.AcceptsContext { 322 vars = append(vars, "ctx") 323 } 324 var lines []string 325 for i, acc := range m.Accepts { 326 str = append(str, fmt.Sprintf("arg%d", i)) 327 def = append(def, `"?"`) 328 vars = append(vars, "%s") 329 switch acc.String() { 330 case "string": 331 mid = append(mid, fmt.Sprintf(` if e.arg%[1]d != "" { arg%[1]d = fmt.Sprintf("%%q", e.arg%[1]d)}`, i)) 332 case "int": 333 mid = append(mid, fmt.Sprintf(` if e.arg%[1]d != 0 { arg%[1]d = fmt.Sprintf("%%q", e.arg%[1]d)}`, i)) 334 default: 335 mid = append(mid, fmt.Sprintf(` if e.arg%[1]d != nil { arg%[1]d = fmt.Sprintf("%%v", e.arg%[1]d) }`, i)) 336 } 337 } 338 if m.AcceptsOptions { 339 str = append(str, "options") 340 def = append(def, `formatOptions(e.options)`) 341 vars = append(vars, "%s") 342 } 343 if len(str) > 0 { 344 lines = append(lines, fmt.Sprintf("\t%s := %s", strings.Join(str, ", "), strings.Join(def, ", "))) 345 } 346 lines = append(lines, mid...) 347 lines = append(lines, fmt.Sprintf("\treturn fmt.Sprintf(\"%s%s(%s)\", %s)", prefix, m.Name, strings.Join(vars, ", "), strings.Join(append(args, str...), ", "))) 348 return strings.Join(lines, "\n") 349 } 350 351 // CallbackType returns the type definition for a callback for this method. 352 func (m *method) CallbackTypes() string { 353 const extraCount = 2 354 inputs := make([]string, 0, len(m.Accepts)+extraCount) 355 if m.AcceptsContext { 356 inputs = append(inputs, "context.Context") 357 } 358 for _, arg := range m.Accepts { 359 inputs = append(inputs, typeName(arg)) 360 } 361 if m.AcceptsOptions { 362 inputs = append(inputs, "driver.Options") 363 } 364 return strings.Join(inputs, ", ") 365 } 366 367 // CallbackArgs returns the list of arguments to be passed to the callback 368 func (m *method) CallbackArgs() string { 369 const extraCount = 2 370 args := make([]string, 0, len(m.Accepts)+extraCount) 371 if m.AcceptsContext { 372 args = append(args, "ctx") 373 } 374 for i := range m.Accepts { 375 args = append(args, fmt.Sprintf("arg%d", i)) 376 } 377 if m.AcceptsOptions { 378 args = append(args, "options") 379 } 380 return strings.Join(args, ", ") 381 } 382 383 func (m *method) CallbackReturns() string { 384 args := make([]string, 0, len(m.Returns)+1) 385 for _, ret := range m.Returns { 386 args = append(args, ret.String()) 387 } 388 if m.ReturnsError { 389 args = append(args, "error") 390 } 391 if len(args) > 1 { 392 return "(" + strings.Join(args, ", ") + ")" 393 } 394 return strings.Join(args, ", ") 395 }