github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/importers/imports.go (about) 1 // Package importers helps with dynamic imports for templating 2 package importers 3 4 import ( 5 "bytes" 6 "fmt" 7 "sort" 8 "strings" 9 10 "github.com/spf13/cast" 11 12 "github.com/friendsofgo/errors" 13 "github.com/volatiletech/strmangle" 14 ) 15 16 // Collection of imports for various templating purposes 17 // Drivers add to any and all of these, and is completely responsible 18 // for populating BasedOnType. 19 type Collection struct { 20 All Set `toml:"all" json:"all,omitempty"` 21 Test Set `toml:"test" json:"test,omitempty"` 22 23 Singleton Map `toml:"singleton" json:"singleton,omitempty"` 24 TestSingleton Map `toml:"test_singleton" json:"test_singleton,omitempty"` 25 26 BasedOnType Map `toml:"based_on_type" json:"based_on_type,omitempty"` 27 } 28 29 // Set defines the optional standard imports and 30 // thirdParty imports (from github for example) 31 type Set struct { 32 Standard List `toml:"standard"` 33 ThirdParty List `toml:"third_party"` 34 } 35 36 // Format the set into Go syntax (compatible with go imports) 37 func (s Set) Format() []byte { 38 stdlen, thirdlen := len(s.Standard), len(s.ThirdParty) 39 if stdlen+thirdlen < 1 { 40 return []byte{} 41 } 42 43 if stdlen+thirdlen == 1 { 44 var imp string 45 if stdlen == 1 { 46 imp = s.Standard[0] 47 } else { 48 imp = s.ThirdParty[0] 49 } 50 return []byte(fmt.Sprintf("import %s", imp)) 51 } 52 53 buf := &bytes.Buffer{} 54 buf.WriteString("import (") 55 for _, std := range s.Standard { 56 fmt.Fprintf(buf, "\n\t%s", std) 57 } 58 if stdlen != 0 && thirdlen != 0 { 59 buf.WriteString("\n") 60 } 61 for _, third := range s.ThirdParty { 62 fmt.Fprintf(buf, "\n\t%s", third) 63 } 64 buf.WriteString("\n)\n") 65 66 return buf.Bytes() 67 } 68 69 // SetFromInterface creates a set from a theoretical map[string]interface{}. 70 // This is to load from a loosely defined configuration file. 71 func SetFromInterface(intf interface{}) (Set, error) { 72 s := Set{} 73 74 setIntf, ok := intf.(map[string]interface{}) 75 if !ok { 76 return s, errors.New("import set should be map[string]interface{}") 77 } 78 79 standardIntf, ok := setIntf["standard"] 80 if ok { 81 standardsIntf, ok := standardIntf.([]interface{}) 82 if !ok { 83 return s, errors.New("import set standards must be an slice") 84 } 85 86 s.Standard = List{} 87 for i, intf := range standardsIntf { 88 str, ok := intf.(string) 89 if !ok { 90 return s, errors.Errorf("import set standard slice element %d (%+v) must be string", i, s) 91 } 92 s.Standard = append(s.Standard, str) 93 } 94 } 95 96 thirdPartyIntf, ok := setIntf["third_party"] 97 if ok { 98 thirdPartysIntf, ok := thirdPartyIntf.([]interface{}) 99 if !ok { 100 return s, errors.New("import set third_party must be an slice") 101 } 102 103 s.ThirdParty = List{} 104 for i, intf := range thirdPartysIntf { 105 str, ok := intf.(string) 106 if !ok { 107 return s, errors.Errorf("import set third party slice element %d (%+v) must be string", i, intf) 108 } 109 s.ThirdParty = append(s.ThirdParty, str) 110 } 111 } 112 113 return s, nil 114 } 115 116 // Map of file/type -> imports 117 // Map's consumers do not understand windows paths. Always specify paths 118 // using forward slash (/). 119 type Map map[string]Set 120 121 // MapFromInterface creates a Map from a theoretical map[string]interface{} 122 // or []map[string]interface{} 123 // This is to load from a loosely defined configuration file. 124 func MapFromInterface(intf interface{}) (Map, error) { 125 m := Map{} 126 127 iter := func(i interface{}, fn func(string, interface{}) error) error { 128 switch toIter := intf.(type) { 129 case []interface{}: 130 for _, intf := range toIter { 131 obj := cast.ToStringMap(intf) 132 name := obj["name"].(string) 133 if err := fn(name, intf); err != nil { 134 return err 135 } 136 } 137 case map[string]interface{}: 138 for k, v := range toIter { 139 if err := fn(k, v); err != nil { 140 return err 141 } 142 } 143 default: 144 panic("import map should be map[string]interface or []map[string]interface{}") 145 } 146 147 return nil 148 } 149 150 err := iter(intf, func(name string, value interface{}) error { 151 s, err := SetFromInterface(value) 152 if err != nil { 153 return err 154 } 155 156 m[name] = s 157 return nil 158 }) 159 160 if err != nil { 161 return nil, err 162 } 163 164 return m, nil 165 } 166 167 // List of imports 168 type List []string 169 170 // Len implements sort.Interface.Len 171 func (l List) Len() int { 172 return len(l) 173 } 174 175 // Swap implements sort.Interface.Swap 176 func (l List) Swap(i, j int) { 177 l[i], l[j] = l[j], l[i] 178 } 179 180 // Less implements sort.Interface.Less 181 func (l List) Less(i, j int) bool { 182 res := strings.Compare(strings.TrimLeft(l[i], "_ "), strings.TrimLeft(l[j], "_ ")) 183 if res <= 0 { 184 return true 185 } 186 187 return false 188 } 189 190 // NewDefaultImports returns a default Imports struct. 191 func NewDefaultImports() Collection { 192 var col Collection 193 194 col.All = Set{ 195 Standard: List{ 196 `"database/sql"`, 197 `"fmt"`, 198 `"reflect"`, 199 `"strings"`, 200 `"sync"`, 201 `"time"`, 202 }, 203 ThirdParty: List{ 204 `"github.com/friendsofgo/errors"`, 205 `"github.com/volatiletech/sqlboiler/v4/boil"`, 206 `"github.com/volatiletech/sqlboiler/v4/queries"`, 207 `"github.com/volatiletech/sqlboiler/v4/queries/qm"`, 208 `"github.com/volatiletech/sqlboiler/v4/queries/qmhelper"`, 209 `"github.com/volatiletech/strmangle"`, 210 }, 211 } 212 213 col.Singleton = Map{ 214 "boil_queries": { 215 ThirdParty: List{ 216 `"github.com/volatiletech/sqlboiler/v4/drivers"`, 217 `"github.com/volatiletech/sqlboiler/v4/queries"`, 218 `"github.com/volatiletech/sqlboiler/v4/queries/qm"`, 219 }, 220 }, 221 "boil_types": { 222 Standard: List{ 223 `"strconv"`, 224 }, 225 ThirdParty: List{ 226 `"github.com/friendsofgo/errors"`, 227 `"github.com/volatiletech/sqlboiler/v4/boil"`, 228 `"github.com/volatiletech/strmangle"`, 229 }, 230 }, 231 } 232 233 col.Test = Set{ 234 Standard: List{ 235 `"bytes"`, 236 `"reflect"`, 237 `"testing"`, 238 }, 239 ThirdParty: List{ 240 `"github.com/volatiletech/sqlboiler/v4/boil"`, 241 `"github.com/volatiletech/sqlboiler/v4/queries"`, 242 `"github.com/volatiletech/randomize"`, 243 `"github.com/volatiletech/strmangle"`, 244 }, 245 } 246 247 col.TestSingleton = Map{ 248 "boil_main_test": { 249 Standard: List{ 250 `"database/sql"`, 251 `"flag"`, 252 `"fmt"`, 253 `"math/rand"`, 254 `"os"`, 255 `"path/filepath"`, 256 `"strings"`, 257 `"testing"`, 258 `"time"`, 259 }, 260 ThirdParty: List{ 261 `"github.com/spf13/viper"`, 262 `"github.com/volatiletech/sqlboiler/v4/boil"`, 263 }, 264 }, 265 "boil_queries_test": { 266 Standard: List{ 267 `"bytes"`, 268 `"fmt"`, 269 `"io"`, 270 `"math/rand"`, 271 `"regexp"`, 272 }, 273 ThirdParty: List{ 274 `"github.com/volatiletech/sqlboiler/v4/boil"`, 275 }, 276 }, 277 "boil_suites_test": { 278 Standard: List{ 279 `"testing"`, 280 }, 281 }, 282 } 283 284 return col 285 } 286 287 // NullableEnumImports returns imports collection for nullable enum types. 288 func NullableEnumImports() Collection { 289 var col Collection 290 291 col.Singleton = Map{ 292 "boil_types": { 293 Standard: List{ 294 `"bytes"`, 295 `"database/sql/driver"`, 296 `"encoding/json"`, 297 }, 298 ThirdParty: List{ 299 `"github.com/volatiletech/null/v8"`, 300 `"github.com/volatiletech/null/v8/convert"`, 301 }, 302 }, 303 } 304 305 return col 306 } 307 308 // AddTypeImports takes a set of imports 'a', a type -> import mapping 'typeMap' 309 // and a set of column types that are currently in use and produces a new set 310 // including both the old standard/third party, as well as the imports required 311 // for the types in use. 312 func AddTypeImports(a Set, typeMap map[string]Set, columnTypes []string) Set { 313 tmpImp := Set{ 314 Standard: make(List, len(a.Standard)), 315 ThirdParty: make(List, len(a.ThirdParty)), 316 } 317 318 copy(tmpImp.Standard, a.Standard) 319 copy(tmpImp.ThirdParty, a.ThirdParty) 320 321 for _, typ := range columnTypes { 322 for key, imp := range typeMap { 323 if typ == key { 324 tmpImp.Standard = append(tmpImp.Standard, imp.Standard...) 325 tmpImp.ThirdParty = append(tmpImp.ThirdParty, imp.ThirdParty...) 326 } 327 } 328 } 329 330 tmpImp.Standard = strmangle.RemoveDuplicates(tmpImp.Standard) 331 tmpImp.ThirdParty = strmangle.RemoveDuplicates(tmpImp.ThirdParty) 332 333 sort.Sort(tmpImp.Standard) 334 sort.Sort(tmpImp.ThirdParty) 335 336 return tmpImp 337 } 338 339 // Merge takes two collections and creates a new one 340 // with the de-duplication contents of both. 341 func Merge(a, b Collection) Collection { 342 var c Collection 343 344 c.All = mergeSet(a.All, b.All) 345 c.Test = mergeSet(a.Test, b.Test) 346 347 c.Singleton = mergeMap(a.Singleton, b.Singleton) 348 c.TestSingleton = mergeMap(a.TestSingleton, b.TestSingleton) 349 350 c.BasedOnType = mergeMap(a.BasedOnType, b.BasedOnType) 351 352 return c 353 } 354 355 func mergeSet(a, b Set) Set { 356 var c Set 357 358 c.Standard = strmangle.RemoveDuplicates(combineStringSlices(a.Standard, b.Standard)) 359 c.ThirdParty = strmangle.RemoveDuplicates(combineStringSlices(a.ThirdParty, b.ThirdParty)) 360 361 sort.Sort(c.Standard) 362 sort.Sort(c.ThirdParty) 363 364 return c 365 } 366 367 func mergeMap(a, b Map) Map { 368 m := make(Map) 369 370 for k, v := range a { 371 m[k] = v 372 } 373 374 for k, toMerge := range b { 375 exist, ok := m[k] 376 if !ok { 377 m[k] = toMerge 378 } 379 380 m[k] = mergeSet(exist, toMerge) 381 } 382 383 return m 384 } 385 386 func combineStringSlices(a, b []string) []string { 387 c := make([]string, len(a)+len(b)) 388 if len(a) > 0 { 389 copy(c, a) 390 } 391 if len(b) > 0 { 392 copy(c[len(a):], b) 393 } 394 395 return c 396 }