bitbucket.org/Aishee/synsec@v0.0.0-20210414005726-236fc01a153d/pkg/database/decisions.go (about)

     1  package database
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  	"time"
     7  
     8  	"strconv"
     9  
    10  	"bitbucket.org/Aishee/synsec/pkg/database/ent"
    11  	"bitbucket.org/Aishee/synsec/pkg/database/ent/decision"
    12  	"bitbucket.org/Aishee/synsec/pkg/types"
    13  	"github.com/pkg/errors"
    14  )
    15  
    16  func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string][]string) (*ent.DecisionQuery, error) {
    17  
    18  	//func BuildDecisionRequestWithFilter(query *ent.Query, filter map[string][]string) (*ent.DecisionQuery, error) {
    19  	var err error
    20  	var start_ip, start_sfx, end_ip, end_sfx int64
    21  	var ip_sz int
    22  	var contains bool = true
    23  	/*if contains is true, return bans that *contains* the given value (value is the inner)
    24  	  else, return bans that are *contained* by the given value (value is the outer)*/
    25  
    26  	/*the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */
    27  	if v, ok := filter["simulated"]; ok {
    28  		if v[0] == "false" {
    29  			query = query.Where(decision.SimulatedEQ(false))
    30  		}
    31  		delete(filter, "simulated")
    32  	} else {
    33  		query = query.Where(decision.SimulatedEQ(false))
    34  	}
    35  
    36  	for param, value := range filter {
    37  		switch param {
    38  		case "contains":
    39  			contains, err = strconv.ParseBool(value[0])
    40  			if err != nil {
    41  				return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err)
    42  			}
    43  		case "scope":
    44  			var scope string = value[0]
    45  			if strings.ToLower(scope) == "ip" {
    46  				scope = types.Ip
    47  			} else if strings.ToLower(scope) == "range" {
    48  				scope = types.Range
    49  			}
    50  			query = query.Where(decision.ScopeEQ(scope))
    51  		case "value":
    52  			query = query.Where(decision.ValueEQ(value[0]))
    53  		case "type":
    54  			query = query.Where(decision.TypeEQ(value[0]))
    55  		case "ip", "range":
    56  			ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(value[0])
    57  			if err != nil {
    58  				return nil, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err)
    59  			}
    60  		default:
    61  			return query, errors.Wrapf(InvalidFilter, "'%s' doesn't exist", param)
    62  		}
    63  	}
    64  
    65  	if ip_sz == 4 {
    66  
    67  		if contains { /*decision contains {start_ip,end_ip}*/
    68  			query = query.Where(decision.And(
    69  				decision.StartIPLTE(start_ip),
    70  				decision.EndIPGTE(end_ip),
    71  				decision.IPSizeEQ(int64(ip_sz)),
    72  			))
    73  		} else { /*decision is contained within {start_ip,end_ip}*/
    74  			query = query.Where(decision.And(
    75  				decision.StartIPGTE(start_ip),
    76  				decision.EndIPLTE(end_ip),
    77  				decision.IPSizeEQ(int64(ip_sz)),
    78  			))
    79  		}
    80  	} else if ip_sz == 16 {
    81  
    82  		if contains { /*decision contains {start_ip,end_ip}*/
    83  			query = query.Where(decision.And(
    84  				//matching addr size
    85  				decision.IPSizeEQ(int64(ip_sz)),
    86  				decision.Or(
    87  					//decision.start_ip < query.start_ip
    88  					decision.StartIPLT(start_ip),
    89  					decision.And(
    90  						//decision.start_ip == query.start_ip
    91  						decision.StartIPEQ(start_ip),
    92  						//decision.start_suffix <= query.start_suffix
    93  						decision.StartSuffixLTE(start_sfx),
    94  					)),
    95  				decision.Or(
    96  					//decision.end_ip > query.end_ip
    97  					decision.EndIPGT(end_ip),
    98  					decision.And(
    99  						//decision.end_ip == query.end_ip
   100  						decision.EndIPEQ(end_ip),
   101  						//decision.end_suffix >= query.end_suffix
   102  						decision.EndSuffixGTE(end_sfx),
   103  					),
   104  				),
   105  			))
   106  		} else { /*decision is contained {start_ip,end_ip}*/
   107  			query = query.Where(decision.And(
   108  				//matching addr size
   109  				decision.IPSizeEQ(int64(ip_sz)),
   110  				decision.Or(
   111  					//decision.start_ip > query.start_ip
   112  					decision.StartIPGT(start_ip),
   113  					decision.And(
   114  						//decision.start_ip == query.start_ip
   115  						decision.StartIPEQ(start_ip),
   116  						//decision.start_suffix >= query.start_suffix
   117  						decision.StartSuffixGTE(start_sfx),
   118  					)),
   119  				decision.Or(
   120  					//decision.end_ip < query.end_ip
   121  					decision.EndIPLT(end_ip),
   122  					decision.And(
   123  						//decision.end_ip == query.end_ip
   124  						decision.EndIPEQ(end_ip),
   125  						//decision.end_suffix <= query.end_suffix
   126  						decision.EndSuffixLTE(end_sfx),
   127  					),
   128  				),
   129  			))
   130  		}
   131  	} else if ip_sz != 0 {
   132  		return nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
   133  	}
   134  	return query, nil
   135  }
   136  
   137  func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Decision, error) {
   138  	var data []*ent.Decision
   139  	var err error
   140  
   141  	decisions := c.Ent.Decision.Query().
   142  		Where(decision.UntilGTE(time.Now()))
   143  
   144  	decisions, err = BuildDecisionRequestWithFilter(decisions, filter)
   145  	if err != nil {
   146  		return []*ent.Decision{}, err
   147  	}
   148  
   149  	err = decisions.Select(
   150  		decision.FieldID,
   151  		decision.FieldUntil,
   152  		decision.FieldScenario,
   153  		decision.FieldType,
   154  		decision.FieldStartIP,
   155  		decision.FieldEndIP,
   156  		decision.FieldValue,
   157  		decision.FieldScope,
   158  		decision.FieldOrigin,
   159  	).Scan(c.CTX, &data)
   160  	if err != nil {
   161  		c.Log.Warningf("QueryDecisionWithFilter : %s", err)
   162  		return []*ent.Decision{}, errors.Wrap(QueryFail, "query decision failed")
   163  	}
   164  
   165  	return data, nil
   166  }
   167  
   168  func (c *Client) QueryAllDecisions() ([]*ent.Decision, error) {
   169  	data, err := c.Ent.Decision.Query().Where(decision.UntilGT(time.Now())).All(c.CTX)
   170  	if err != nil {
   171  		c.Log.Warningf("QueryAllDecisions : %s", err)
   172  		return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions")
   173  	}
   174  	return data, nil
   175  }
   176  
   177  func (c *Client) QueryExpiredDecisions() ([]*ent.Decision, error) {
   178  	data, err := c.Ent.Decision.Query().Where(decision.UntilLT(time.Now())).All(c.CTX)
   179  	if err != nil {
   180  		c.Log.Warningf("QueryExpiredDecisions : %s", err)
   181  		return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions")
   182  	}
   183  	return data, nil
   184  }
   185  
   186  func (c *Client) QueryExpiredDecisionsSince(since time.Time) ([]*ent.Decision, error) {
   187  	data, err := c.Ent.Decision.Query().Where(decision.UntilLT(time.Now())).Where(decision.UntilGT(since)).All(c.CTX)
   188  	if err != nil {
   189  		c.Log.Warningf("QueryExpiredDecisionsSince : %s", err)
   190  		return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions")
   191  	}
   192  	return data, nil
   193  }
   194  
   195  func (c *Client) QueryNewDecisionsSince(since time.Time) ([]*ent.Decision, error) {
   196  	data, err := c.Ent.Decision.Query().Where(decision.CreatedAtGT(since)).All(c.CTX)
   197  	if err != nil {
   198  		c.Log.Warningf("QueryNewDecisionsSince : %s", err)
   199  		return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String())
   200  	}
   201  	return data, nil
   202  }
   203  
   204  func (c *Client) DeleteDecisionById(decisionId int) error {
   205  	err := c.Ent.Decision.DeleteOneID(decisionId).Exec(c.CTX)
   206  	if err != nil {
   207  		c.Log.Warningf("DeleteDecisionById : %s", err)
   208  		return errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionId)
   209  	}
   210  	return nil
   211  }
   212  
   213  func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, error) {
   214  	var err error
   215  	var start_ip, start_sfx, end_ip, end_sfx int64
   216  	var ip_sz int
   217  	var contains bool = true
   218  	/*if contains is true, return bans that *contains* the given value (value is the inner)
   219  	  else, return bans that are *contained* by the given value (value is the outer) */
   220  
   221  	decisions := c.Ent.Decision.Delete()
   222  	for param, value := range filter {
   223  		switch param {
   224  		case "contains":
   225  			contains, err = strconv.ParseBool(value[0])
   226  			if err != nil {
   227  				return "0", errors.Wrapf(InvalidFilter, "invalid contains value : %s", err)
   228  			}
   229  		case "scope":
   230  			decisions = decisions.Where(decision.ScopeEQ(value[0]))
   231  		case "value":
   232  			decisions = decisions.Where(decision.ValueEQ(value[0]))
   233  		case "type":
   234  			decisions = decisions.Where(decision.TypeEQ(value[0]))
   235  		case "ip", "range":
   236  			ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(value[0])
   237  			if err != nil {
   238  				return "0", errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err)
   239  			}
   240  		default:
   241  			return "0", errors.Wrap(InvalidFilter, fmt.Sprintf("'%s' doesn't exist", param))
   242  		}
   243  	}
   244  	if ip_sz == 4 {
   245  		if contains { /*decision contains {start_ip,end_ip}*/
   246  			decisions = decisions.Where(decision.And(
   247  				decision.StartIPLTE(start_ip),
   248  				decision.EndIPGTE(end_ip),
   249  				decision.IPSizeEQ(int64(ip_sz)),
   250  			))
   251  		} else { /*decision is contained within {start_ip,end_ip}*/
   252  			decisions = decisions.Where(decision.And(
   253  				decision.StartIPGTE(start_ip),
   254  				decision.EndIPLTE(end_ip),
   255  				decision.IPSizeEQ(int64(ip_sz)),
   256  			))
   257  		}
   258  	} else if ip_sz == 16 {
   259  		if contains { /*decision contains {start_ip,end_ip}*/
   260  			decisions = decisions.Where(decision.And(
   261  				//matching addr size
   262  				decision.IPSizeEQ(int64(ip_sz)),
   263  				decision.Or(
   264  					//decision.start_ip < query.start_ip
   265  					decision.StartIPLT(start_ip),
   266  					decision.And(
   267  						//decision.start_ip == query.start_ip
   268  						decision.StartIPEQ(start_ip),
   269  						//decision.start_suffix <= query.start_suffix
   270  						decision.StartSuffixLTE(start_sfx),
   271  					)),
   272  				decision.Or(
   273  					//decision.end_ip > query.end_ip
   274  					decision.EndIPGT(end_ip),
   275  					decision.And(
   276  						//decision.end_ip == query.end_ip
   277  						decision.EndIPEQ(end_ip),
   278  						//decision.end_suffix >= query.end_suffix
   279  						decision.EndSuffixGTE(end_sfx),
   280  					),
   281  				),
   282  			))
   283  		} else {
   284  			decisions = decisions.Where(decision.And(
   285  				//matching addr size
   286  				decision.IPSizeEQ(int64(ip_sz)),
   287  				decision.Or(
   288  					//decision.start_ip > query.start_ip
   289  					decision.StartIPGT(start_ip),
   290  					decision.And(
   291  						//decision.start_ip == query.start_ip
   292  						decision.StartIPEQ(start_ip),
   293  						//decision.start_suffix >= query.start_suffix
   294  						decision.StartSuffixGTE(start_sfx),
   295  					)),
   296  				decision.Or(
   297  					//decision.end_ip < query.end_ip
   298  					decision.EndIPLT(end_ip),
   299  					decision.And(
   300  						//decision.end_ip == query.end_ip
   301  						decision.EndIPEQ(end_ip),
   302  						//decision.end_suffix <= query.end_suffix
   303  						decision.EndSuffixLTE(end_sfx),
   304  					),
   305  				),
   306  			))
   307  		}
   308  	} else if ip_sz != 0 {
   309  		return "0", errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
   310  	}
   311  
   312  	nbDeleted, err := decisions.Exec(c.CTX)
   313  	if err != nil {
   314  		c.Log.Warningf("DeleteDecisionsWithFilter : %s", err)
   315  		return "0", errors.Wrap(DeleteFail, "decisions with provided filter")
   316  	}
   317  	return strconv.Itoa(nbDeleted), nil
   318  }
   319  
   320  // SoftDeleteDecisionsWithFilter udpate the expiration time to now() for the decisions matching the filter
   321  func (c *Client) SoftDeleteDecisionsWithFilter(filter map[string][]string) (string, error) {
   322  	var err error
   323  	var start_ip, start_sfx, end_ip, end_sfx int64
   324  	var ip_sz int
   325  	var contains bool = true
   326  	/*if contains is true, return bans that *contains* the given value (value is the inner)
   327  	  else, return bans that are *contained* by the given value (value is the outer)*/
   328  	decisions := c.Ent.Decision.Update().Where(decision.UntilGT(time.Now()))
   329  	for param, value := range filter {
   330  		switch param {
   331  		case "contains":
   332  			contains, err = strconv.ParseBool(value[0])
   333  			if err != nil {
   334  				return "0", errors.Wrapf(InvalidFilter, "invalid contains value : %s", err)
   335  			}
   336  		case "scope":
   337  			decisions = decisions.Where(decision.ScopeEQ(value[0]))
   338  		case "value":
   339  			decisions = decisions.Where(decision.ValueEQ(value[0]))
   340  		case "type":
   341  			decisions = decisions.Where(decision.TypeEQ(value[0]))
   342  		case "ip", "range":
   343  			ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(value[0])
   344  			if err != nil {
   345  				return "0", errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err)
   346  			}
   347  		default:
   348  			return "0", errors.Wrapf(InvalidFilter, "'%s' doesn't exist", param)
   349  		}
   350  	}
   351  	if ip_sz == 4 {
   352  		if contains {
   353  			/*Decision contains {start_ip,end_ip}*/
   354  			decisions = decisions.Where(decision.And(
   355  				decision.StartIPLTE(start_ip),
   356  				decision.EndIPGTE(end_ip),
   357  				decision.IPSizeEQ(int64(ip_sz)),
   358  			))
   359  		} else {
   360  			/*Decision is contained within {start_ip,end_ip}*/
   361  			decisions = decisions.Where(decision.And(
   362  				decision.StartIPGTE(start_ip),
   363  				decision.EndIPLTE(end_ip),
   364  				decision.IPSizeEQ(int64(ip_sz)),
   365  			))
   366  		}
   367  	} else if ip_sz == 16 {
   368  		/*decision contains {start_ip,end_ip}*/
   369  		if contains {
   370  			decisions = decisions.Where(decision.And(
   371  				//matching addr size
   372  				decision.IPSizeEQ(int64(ip_sz)),
   373  				decision.Or(
   374  					//decision.start_ip < query.start_ip
   375  					decision.StartIPLT(start_ip),
   376  					decision.And(
   377  						//decision.start_ip == query.start_ip
   378  						decision.StartIPEQ(start_ip),
   379  						//decision.start_suffix <= query.start_suffix
   380  						decision.StartSuffixLTE(start_sfx),
   381  					)),
   382  				decision.Or(
   383  					//decision.end_ip > query.end_ip
   384  					decision.EndIPGT(end_ip),
   385  					decision.And(
   386  						//decision.end_ip == query.end_ip
   387  						decision.EndIPEQ(end_ip),
   388  						//decision.end_suffix >= query.end_suffix
   389  						decision.EndSuffixGTE(end_sfx),
   390  					),
   391  				),
   392  			))
   393  		} else {
   394  			/*decision is contained within {start_ip,end_ip}*/
   395  			decisions = decisions.Where(decision.And(
   396  				//matching addr size
   397  				decision.IPSizeEQ(int64(ip_sz)),
   398  				decision.Or(
   399  					//decision.start_ip > query.start_ip
   400  					decision.StartIPGT(start_ip),
   401  					decision.And(
   402  						//decision.start_ip == query.start_ip
   403  						decision.StartIPEQ(start_ip),
   404  						//decision.start_suffix >= query.start_suffix
   405  						decision.StartSuffixGTE(start_sfx),
   406  					)),
   407  				decision.Or(
   408  					//decision.end_ip < query.end_ip
   409  					decision.EndIPLT(end_ip),
   410  					decision.And(
   411  						//decision.end_ip == query.end_ip
   412  						decision.EndIPEQ(end_ip),
   413  						//decision.end_suffix <= query.end_suffix
   414  						decision.EndSuffixLTE(end_sfx),
   415  					),
   416  				),
   417  			))
   418  		}
   419  	} else if ip_sz != 0 {
   420  		return "0", errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
   421  	}
   422  	nbDeleted, err := decisions.SetUntil(time.Now()).Save(c.CTX)
   423  	if err != nil {
   424  		c.Log.Warningf("SoftDeleteDecisionsWithFilter : %s", err)
   425  		return "0", errors.Wrap(DeleteFail, "soft delete decisions with provided filter")
   426  	}
   427  	return strconv.Itoa(nbDeleted), nil
   428  }
   429  
   430  //SoftDeleteDecisionByID set the expiration of a decision to now()
   431  func (c *Client) SoftDeleteDecisionByID(decisionID int) error {
   432  	nbUpdated, err := c.Ent.Decision.Update().Where(decision.IDEQ(decisionID)).SetUntil(time.Now()).Save(c.CTX)
   433  	if err != nil || nbUpdated == 0 {
   434  		c.Log.Warningf("SoftDeleteDecisionByID : %v (nb soft deleted: %d)", err, nbUpdated)
   435  		return errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID)
   436  	}
   437  
   438  	if nbUpdated == 0 {
   439  		return ItemNotFound
   440  	}
   441  	return nil
   442  }