github.com/ActiveState/cli@v0.0.0-20240508170324-6801f60cd051/cmd/state-svc/internal/messages/messages.go (about)

     1  package messages
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"os"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/ActiveState/cli/internal/config"
    11  	"github.com/ActiveState/cli/internal/constants"
    12  	"github.com/ActiveState/cli/internal/errs"
    13  	"github.com/ActiveState/cli/internal/fileutils"
    14  	"github.com/ActiveState/cli/internal/graph"
    15  	"github.com/ActiveState/cli/internal/httputil"
    16  	"github.com/ActiveState/cli/internal/logging"
    17  	"github.com/ActiveState/cli/internal/poller"
    18  	"github.com/ActiveState/cli/internal/strutils"
    19  	auth "github.com/ActiveState/cli/pkg/platform/authentication"
    20  	"github.com/ActiveState/cli/pkg/sysinfo"
    21  	"github.com/blang/semver"
    22  )
    23  
    24  const ConfigKeyLastReport = "messages.last_reported"
    25  
    26  type Messages struct {
    27  	cfg        *config.Instance
    28  	auth       *auth.Auth
    29  	baseParams *ConditionParams
    30  	poll       *poller.Poller
    31  	checkMutex sync.Mutex
    32  }
    33  
    34  func New(cfg *config.Instance, auth *auth.Auth) (*Messages, error) {
    35  	osVersion, err := sysinfo.OSVersion()
    36  	if err != nil {
    37  		return nil, errs.Wrap(err, "Could not get OS version")
    38  	}
    39  
    40  	stateVersion, err := semver.Parse(constants.Version)
    41  	if err != nil {
    42  		return nil, errs.Wrap(err, "Could not parse state version")
    43  	}
    44  
    45  	poll := poller.New(1*time.Hour, func() (interface{}, error) {
    46  		resp, err := fetch()
    47  		return resp, err
    48  	})
    49  
    50  	return &Messages{
    51  		baseParams: &ConditionParams{
    52  			OS:           sysinfo.OS().String(),
    53  			OSVersion:    NewVersionFromSysinfo(osVersion),
    54  			StateChannel: constants.ChannelName,
    55  			StateVersion: NewVersionFromSemver(stateVersion),
    56  		},
    57  		cfg:  cfg,
    58  		auth: auth,
    59  		poll: poll,
    60  	}, nil
    61  }
    62  
    63  func (m *Messages) Close() error {
    64  	m.poll.Close()
    65  	return nil
    66  }
    67  
    68  func (m *Messages) Check(command string, flags []string) ([]*graph.MessageInfo, error) {
    69  	// Prevent multiple checks at the same time, which could lead to the same message showing multiple times
    70  	m.checkMutex.Lock()
    71  	defer m.checkMutex.Unlock()
    72  
    73  	cacheValue := m.poll.ValueFromCache()
    74  	if cacheValue == nil {
    75  		return []*graph.MessageInfo{}, nil
    76  	}
    77  	allMessages := cacheValue.([]*graph.MessageInfo)
    78  
    79  	conditionParams := *m.baseParams // copy
    80  	conditionParams.UserEmail = m.auth.Email()
    81  	conditionParams.UserName = m.auth.WhoAmI()
    82  	conditionParams.Command = command
    83  	conditionParams.Flags = flags
    84  
    85  	if id := m.auth.UserID(); id != nil {
    86  		conditionParams.UserID = id.String()
    87  	}
    88  
    89  	logging.Debug("Checking %d messages with params: %#v", len(allMessages), conditionParams)
    90  
    91  	lastReportMap := m.cfg.GetStringMap(ConfigKeyLastReport)
    92  	msgs, err := check(&conditionParams, allMessages, lastReportMap, time.Now())
    93  	if err != nil {
    94  		return nil, errs.Wrap(err, "Could not check messages")
    95  	}
    96  	for _, msg := range msgs {
    97  		lastReportMap[msg.ID] = time.Now().Format(time.RFC3339)
    98  	}
    99  	if err := m.cfg.Set(ConfigKeyLastReport, lastReportMap); err != nil {
   100  		return nil, errs.Wrap(err, "Could not save last reported messages")
   101  	}
   102  
   103  	return msgs, nil
   104  }
   105  
   106  func check(params *ConditionParams, messages []*graph.MessageInfo, lastReportMap map[string]interface{}, baseTime time.Time) ([]*graph.MessageInfo, error) {
   107  	funcMap := conditionFuncMap()
   108  	filteredMessages := []*graph.MessageInfo{}
   109  	for _, message := range messages {
   110  		logging.Debug("Checking message %s", message.ID)
   111  		// Ensure we don't show the same message too often
   112  		if lastReport, ok := lastReportMap[message.ID]; ok {
   113  			lastReportTime, err := time.Parse(time.RFC3339, lastReport.(string))
   114  			if err != nil {
   115  				return nil, errs.New("Could not parse last reported time for message %s as it's not a valid RFC3339 value: %v", message.ID, lastReport)
   116  			}
   117  
   118  			lastReportTimeAgo := baseTime.Sub(lastReportTime)
   119  			showMessage, err := repeatValid(message.Repeat, lastReportTimeAgo)
   120  			if err != nil {
   121  				return nil, errs.Wrap(err, "Could not validate repeat for message %s", message.ID)
   122  			}
   123  
   124  			if !showMessage {
   125  				logging.Debug("Skipping message %s as it was shown %s ago", message.ID, lastReportTimeAgo)
   126  				continue
   127  			}
   128  		}
   129  
   130  		// Validate the conditional
   131  		if message.Condition != "" {
   132  			result, err := strutils.ParseTemplate(fmt.Sprintf(`{{%s}}`, message.Condition), params, funcMap)
   133  			if err != nil {
   134  				return nil, errs.Wrap(err, "Could not parse condition template for message %s", message.ID)
   135  			}
   136  			if result == "true" {
   137  				logging.Debug("Including message %s as condition %s evaluated to %s", message.ID, message.Condition, result)
   138  				filteredMessages = append(filteredMessages, message)
   139  			} else {
   140  				logging.Debug("Skipping message %s as condition %s evaluated to %s", message.ID, message.Condition, result)
   141  			}
   142  		} else {
   143  			logging.Debug("Including message %s as it has no condition", message.ID)
   144  			filteredMessages = append(filteredMessages, message)
   145  		}
   146  	}
   147  
   148  	return filteredMessages, nil
   149  }
   150  
   151  func fetch() ([]*graph.MessageInfo, error) {
   152  	var body []byte
   153  	var err error
   154  
   155  	if v := os.Getenv(constants.MessagesOverrideEnvVarName); v != "" {
   156  		body, err = fileutils.ReadFile(v)
   157  		if err != nil {
   158  			return nil, errs.Wrap(err, "Could not read messages override file")
   159  		}
   160  	} else {
   161  		body, err = httputil.Get(constants.MessagesInfoURL)
   162  		if err != nil {
   163  			return nil, errs.Wrap(err, "Could not fetch messages information")
   164  		}
   165  	}
   166  
   167  	var messages []*graph.MessageInfo
   168  	if err := json.Unmarshal(body, &messages); err != nil {
   169  		return nil, errs.Wrap(err, "Could not unmarshall messages information")
   170  	}
   171  
   172  	// Set defaults
   173  	for _, message := range messages {
   174  		if message.Placement == "" {
   175  			message.Placement = graph.MessagePlacementTypeBeforeCmd
   176  		}
   177  		if message.Interrupt == "" {
   178  			message.Interrupt = graph.MessageInterruptTypeDisabled
   179  		}
   180  		if message.Repeat == "" {
   181  			message.Repeat = graph.MessageRepeatTypeDisabled
   182  		}
   183  	}
   184  
   185  	return messages, nil
   186  }
   187  
   188  func repeatValid(repeatType graph.MessageRepeatType, lastReportTimeAgo time.Duration) (bool, error) {
   189  	switch repeatType {
   190  	case graph.MessageRepeatTypeConstantly:
   191  		return true, nil
   192  	case graph.MessageRepeatTypeDisabled:
   193  		return false, nil
   194  	case graph.MessageRepeatTypeHourly:
   195  		return lastReportTimeAgo >= time.Hour, nil
   196  	case graph.MessageRepeatTypeDaily:
   197  		return lastReportTimeAgo >= 24*time.Hour, nil
   198  	case graph.MessageRepeatTypeWeekly:
   199  		return lastReportTimeAgo >= 7*24*time.Hour, nil
   200  	case graph.MessageRepeatTypeMonthly:
   201  		return lastReportTimeAgo >= 30*24*time.Hour, nil
   202  	default:
   203  		return false, errs.New("Unknown repeat type: %s", repeatType)
   204  	}
   205  }