github.com/aacfactory/fns-contrib/databases/sql@v1.2.84/dac/specifications/specification.go (about)

     1  package specifications
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"github.com/aacfactory/errors"
     7  	"github.com/valyala/bytebufferpool"
     8  	"golang.org/x/sync/singleflight"
     9  	"reflect"
    10  	"sync"
    11  )
    12  
    13  type Specification struct {
    14  	Key       string
    15  	Schema    string
    16  	Name      string
    17  	View      bool
    18  	ViewBase  *Specification
    19  	Type      reflect.Type
    20  	Columns   []*Column
    21  	Conflicts []string
    22  }
    23  
    24  func (spec *Specification) Instance() (v any) {
    25  	return reflect.Zero(spec.Type).Interface()
    26  }
    27  
    28  func (spec *Specification) ConflictColumns() (columns []*Column, err error) {
    29  	for _, conflict := range spec.Conflicts {
    30  		column, has := spec.ColumnByField(conflict)
    31  		if !has {
    32  			err = errors.Warning(fmt.Sprintf("sql: %s field was not found", conflict))
    33  			return
    34  		}
    35  		columns = append(columns, column)
    36  	}
    37  	return
    38  }
    39  
    40  func (spec *Specification) ColumnByField(fieldName string) (column *Column, has bool) {
    41  	for _, c := range spec.Columns {
    42  		if c.Field == fieldName {
    43  			column = c
    44  			has = true
    45  			break
    46  		}
    47  	}
    48  	return
    49  }
    50  
    51  func (spec *Specification) Pk() (v *Column, has bool) {
    52  	for _, column := range spec.Columns {
    53  		if column.Kind == Pk {
    54  			v = column
    55  			break
    56  		}
    57  	}
    58  	has = v != nil
    59  	return
    60  }
    61  
    62  func (spec *Specification) AuditCreation() (by *Column, at *Column, has bool) {
    63  	n := 0
    64  	for _, column := range spec.Columns {
    65  		if column.Kind == Acb {
    66  			by = column
    67  			n++
    68  			continue
    69  		}
    70  		if column.Kind == Act {
    71  			at = column
    72  			n++
    73  			continue
    74  		}
    75  		if n == 2 {
    76  			break
    77  		}
    78  	}
    79  	has = n > 0
    80  	return
    81  }
    82  
    83  func (spec *Specification) AuditModification() (by *Column, at *Column, has bool) {
    84  	n := 0
    85  	for _, column := range spec.Columns {
    86  		if column.Kind == Amb {
    87  			by = column
    88  			n++
    89  			continue
    90  		}
    91  		if column.Kind == Amt {
    92  			at = column
    93  			n++
    94  			continue
    95  		}
    96  		if n == 2 {
    97  			break
    98  		}
    99  	}
   100  	has = n > 0
   101  	return
   102  }
   103  
   104  func (spec *Specification) AuditDeletion() (by *Column, at *Column, has bool) {
   105  	n := 0
   106  	for _, column := range spec.Columns {
   107  		if column.Kind == Adb {
   108  			by = column
   109  			n++
   110  			continue
   111  		}
   112  		if column.Kind == Adt {
   113  			at = column
   114  			n++
   115  			continue
   116  		}
   117  		if n == 2 {
   118  			break
   119  		}
   120  	}
   121  	has = n > 0
   122  	return
   123  }
   124  
   125  func (spec *Specification) AuditVersion() (v *Column, has bool) {
   126  	for _, column := range spec.Columns {
   127  		if column.Kind == Aol {
   128  			v = column
   129  			break
   130  		}
   131  	}
   132  	has = v != nil
   133  	return
   134  }
   135  
   136  func (spec *Specification) String() (s string) {
   137  	buf := bytebufferpool.Get()
   138  	defer bytebufferpool.Put(buf)
   139  	_, _ = buf.WriteString(fmt.Sprintf("Specification: %s\n", spec.Key))
   140  	_, _ = buf.WriteString(fmt.Sprintf("  schema: %s\n", spec.Schema))
   141  	_, _ = buf.WriteString(fmt.Sprintf("  name: %s\n", spec.Name))
   142  	_, _ = buf.WriteString(fmt.Sprintf("  view: %v", spec.View))
   143  	if spec.ViewBase == nil {
   144  		_, _ = buf.WriteString(" pure")
   145  	} else {
   146  		_, _ = buf.WriteString(fmt.Sprintf(" base(%s)", spec.ViewBase.Key))
   147  	}
   148  	_, _ = buf.WriteString("\n")
   149  	_, _ = buf.WriteString(fmt.Sprintf("  columns: %v\n", len(spec.Columns)))
   150  	for _, column := range spec.Columns {
   151  		_, _ = buf.WriteString(fmt.Sprintf("    %s\n", column.String()))
   152  	}
   153  	_, _ = buf.WriteString(fmt.Sprintf("  conflicts: %+v\n", spec.Conflicts))
   154  	s = buf.String()
   155  	return
   156  }
   157  
   158  var (
   159  	tables = sync.Map{}
   160  	dict   = NewDict()
   161  	group  = singleflight.Group{}
   162  )
   163  
   164  func GetSpecification(ctx context.Context, e any) (spec *Specification, err error) {
   165  	rt := reflect.TypeOf(e)
   166  	key := fmt.Sprintf("%s.%s", rt.PkgPath(), rt.Name())
   167  
   168  	scanned, has := tables.Load(key)
   169  	if has {
   170  		spec, has = scanned.(*Specification)
   171  		if !has {
   172  			err = errors.Warning("sql: get specification failed").WithCause(fmt.Errorf("stored specification is invalid type"))
   173  			return
   174  		}
   175  		return
   176  	}
   177  
   178  	ctxKey := fmt.Sprintf("@fns:sql:dac:scan:%s", key)
   179  
   180  	processing := ctx.Value(ctxKey)
   181  	if processing != nil {
   182  		spec, has = processing.(*Specification)
   183  		if !has {
   184  			err = errors.Warning("sql: get specification failed").WithCause(fmt.Errorf("processing specification is invalid type"))
   185  			return
   186  		}
   187  		return
   188  	}
   189  
   190  	scanned, err, _ = group.Do(key, func() (v interface{}, err error) {
   191  		current := &Specification{}
   192  		ctx = context.WithValue(ctx, ctxKey, current)
   193  		var result *Specification
   194  		var scanErr error
   195  		if MaybeTable(e) {
   196  			result, scanErr = ScanTable(ctx, e)
   197  		} else if MaybeView(e) {
   198  			result, scanErr = ScanView(ctx, e)
   199  		} else {
   200  			err = errors.Warning("sql: get specification failed").WithCause(fmt.Errorf("invalid type"))
   201  			return
   202  		}
   203  		if scanErr != nil {
   204  			err = scanErr
   205  			return
   206  		}
   207  		reflect.ValueOf(current).Elem().Set(reflect.ValueOf(result).Elem())
   208  		v = current
   209  		tables.Store(key, v)
   210  		return
   211  	})
   212  	group.Forget(key)
   213  	if err != nil {
   214  		err = errors.Warning("sql: get table specification failed").WithCause(err)
   215  		return
   216  	}
   217  
   218  	spec = scanned.(*Specification)
   219  	return
   220  }
   221  
   222  func ScanTable(ctx context.Context, table any) (spec *Specification, err error) {
   223  	rv := reflect.Indirect(reflect.ValueOf(table))
   224  	rt := rv.Type()
   225  	key := fmt.Sprintf("%s.%s", rt.PkgPath(), rt.Name())
   226  	info, infoErr := GetTableInfo(table)
   227  	if infoErr != nil {
   228  		err = errors.Warning("sql: scan table failed").
   229  			WithCause(infoErr).
   230  			WithMeta("struct", key)
   231  		return
   232  	}
   233  	name := info.name
   234  	if name == "" {
   235  		err = errors.Warning("sql: scan table failed").
   236  			WithCause(fmt.Errorf("table name is required")).
   237  			WithMeta("struct", rt.String())
   238  		return
   239  	}
   240  	schema := info.schema
   241  	conflicts := info.conflicts
   242  
   243  	columns, columnsErr := scanTableFields(ctx, fmt.Sprintf("%s.%s", rt.PkgPath(), rt.Name()), rt)
   244  	if columnsErr != nil {
   245  		err = errors.Warning("sql: scan table failed").
   246  			WithCause(columnsErr).
   247  			WithMeta("struct", reflect.TypeOf(table).String())
   248  		return
   249  	}
   250  
   251  	spec = &Specification{
   252  		Key:       key,
   253  		Schema:    schema,
   254  		Name:      name,
   255  		View:      false,
   256  		Type:      rt,
   257  		Columns:   columns,
   258  		Conflicts: conflicts,
   259  	}
   260  
   261  	tableNames := make([]string, 0, 1)
   262  	if schema != "" {
   263  		tableNames = append(tableNames, schema)
   264  	}
   265  	tableNames = append(tableNames, name)
   266  	dict.Set(fmt.Sprintf("%s.%s", rt.PkgPath(), rt.Name()), tableNames...)
   267  
   268  	return
   269  }
   270  
   271  func scanTableFields(ctx context.Context, key string, rt reflect.Type) (columns []*Column, err error) {
   272  	fields := rt.NumField()
   273  	if fields == 0 {
   274  		err = errors.Warning("has not field")
   275  		return
   276  	}
   277  	for i := 0; i < fields; i++ {
   278  		field := rt.Field(i)
   279  		if !field.IsExported() {
   280  			continue
   281  		}
   282  
   283  		if field.Anonymous {
   284  			if field.Type.Kind() == reflect.Ptr {
   285  				err = errors.Warning("type of anonymous field can not be ptr").WithMeta("field", field.Name)
   286  				return
   287  			}
   288  			anonymous, anonymousErr := scanTableFields(ctx, key, field.Type)
   289  			if anonymousErr != nil {
   290  				if err != nil {
   291  					err = errors.Warning("sql: scan table field failed").
   292  						WithCause(anonymousErr).
   293  						WithMeta("field", field.Name)
   294  					return
   295  				}
   296  			}
   297  			for _, column := range anonymous {
   298  				column.FieldIdx = append(column.FieldIdx, i)
   299  				columns = append(columns, column)
   300  			}
   301  			continue
   302  		}
   303  		column, columnErr := newColumn(ctx, field, []int{i})
   304  		if columnErr != nil {
   305  			err = errors.Warning("sql: scan table field failed").
   306  				WithCause(columnErr).
   307  				WithMeta("field", field.Name)
   308  			return
   309  		}
   310  		if column != nil {
   311  			columns = append(columns, column)
   312  			dict.Set(fmt.Sprintf("%s:%s", key, column.Field), column.Name)
   313  		}
   314  	}
   315  
   316  	return
   317  }