github.com/readium/readium-lcp-server@v0.0.0-20240101192032-6e95190e99f1/lcpserver/lcpserver.go (about)

     1  // Copyright (c) 2022 Readium Foundation
     2  // Use of this source code is governed by a BSD-style license
     3  // that can be found in the LICENSE file exposed on Github (readium) in the project repository.
     4  
     5  package main
     6  
     7  import (
     8  	"crypto/tls"
     9  	"database/sql"
    10  	"fmt"
    11  	"log"
    12  	"os"
    13  	"os/signal"
    14  	"runtime"
    15  	"strconv"
    16  	"strings"
    17  	"syscall"
    18  
    19  	auth "github.com/abbot/go-http-auth"
    20  	_ "github.com/go-sql-driver/mysql"
    21  	_ "github.com/lib/pq"
    22  	_ "github.com/mattn/go-sqlite3"
    23  	_ "github.com/microsoft/go-mssqldb"
    24  
    25  	"github.com/readium/readium-lcp-server/config"
    26  	"github.com/readium/readium-lcp-server/index"
    27  	lcpserver "github.com/readium/readium-lcp-server/lcpserver/server"
    28  	"github.com/readium/readium-lcp-server/license"
    29  	"github.com/readium/readium-lcp-server/logging"
    30  	"github.com/readium/readium-lcp-server/pack"
    31  	"github.com/readium/readium-lcp-server/storage"
    32  )
    33  
    34  func main() {
    35  	var config_file, storagePath, certFile, privKeyFile string
    36  	var readonly bool = false
    37  	var err error
    38  
    39  	if config_file = os.Getenv("READIUM_LCPSERVER_CONFIG"); config_file == "" {
    40  		config_file = "config.yaml"
    41  	}
    42  	config.ReadConfig(config_file)
    43  	log.Println("Config from " + config_file)
    44  
    45  	// read only flag
    46  	readonly = config.Config.LcpServer.ReadOnly
    47  
    48  	// if the logging key is set, logs will be sent to a file and/or Slack channel for test purposes
    49  	err = logging.Init(config.Config.Logging)
    50  	if err != nil {
    51  		panic(err)
    52  	}
    53  
    54  	err = config.SetPublicUrls()
    55  	if err != nil {
    56  		log.Println("Error setting public urls: " + err.Error())
    57  		os.Exit(1)
    58  	}
    59  	if certFile = config.Config.Certificate.Cert; certFile == "" {
    60  		log.Println("Missing certificate in the configuration")
    61  		os.Exit(1)
    62  	}
    63  	if privKeyFile = config.Config.Certificate.PrivateKey; privKeyFile == "" {
    64  		log.Println("Missing private key in the configuration")
    65  		os.Exit(1)
    66  	}
    67  	cert, err := tls.LoadX509KeyPair(certFile, privKeyFile)
    68  	if err != nil {
    69  		log.Println("Error loading X509 cert: " + err.Error())
    70  		os.Exit(1)
    71  	}
    72  	/* this check is temporarily deactivated. It will be reactivated after a new LCP production lib has been distributed.
    73  	if config.Config.Profile != "basic" && !license.LCP_PRODUCTION_LIB {
    74  		log.Println("Can't run in production mode, not built with the proper lib")
    75  		os.Exit(1)
    76  	}
    77  	*/
    78  	if config.Config.Profile == "basic" {
    79  		log.Println("Server running in test mode")
    80  	} else {
    81  		log.Println("Server running in production mode, profile " + config.Config.Profile)
    82  	}
    83  
    84  	driver, cnxn := config.GetDatabase(config.Config.LcpServer.Database)
    85  	log.Println("Database driver " + driver)
    86  
    87  	db, err := sql.Open(driver, cnxn)
    88  	if err != nil {
    89  		log.Println("Error opening the sql db: " + err.Error())
    90  		os.Exit(1)
    91  	}
    92  
    93  	if driver == "sqlite3" && !strings.Contains(cnxn, "_journal") {
    94  		_, err = db.Exec("PRAGMA journal_mode = WAL")
    95  		if err != nil {
    96  			log.Println("Error journaling sqlite3: " + err.Error())
    97  			os.Exit(1)
    98  		}
    99  	}
   100  
   101  	idx, err := index.Open(db)
   102  	if err != nil {
   103  		log.Println("Error opening the index db: " + err.Error())
   104  		os.Exit(1)
   105  	}
   106  
   107  	lst, err := license.Open(db)
   108  	if err != nil {
   109  		log.Println("Error opening the license db: " + err.Error())
   110  		os.Exit(1)
   111  	}
   112  
   113  	err = license.CreateDefaultLinks()
   114  	if err != nil {
   115  		log.Println("Error setting default links: " + err.Error())
   116  		os.Exit(1)
   117  	}
   118  
   119  	var store storage.Store
   120  	if mode := config.Config.Storage.Mode; mode == "s3" {
   121  		s3Conf := s3ConfigFromYAML()
   122  		store, _ = storage.S3(s3Conf)
   123  	} else if config.Config.Storage.FileSystem.Directory != "" {
   124  		storagePath = config.Config.Storage.FileSystem.Directory
   125  		os.MkdirAll(storagePath, os.ModePerm) //ignore the error, the folder can already exist
   126  		store = storage.NewFileSystem(storagePath, config.Config.Storage.FileSystem.URL)
   127  		log.Println("Storage in", storagePath, " at URL", config.Config.Storage.FileSystem.URL)
   128  	} else {
   129  		store = storage.NoStorage()
   130  		log.Println("No storage created")
   131  	}
   132  
   133  	packager := pack.NewPackager(store, idx, 4)
   134  
   135  	authFile := config.Config.LcpServer.AuthFile
   136  	if authFile == "" {
   137  		log.Println("Missing passwords file")
   138  		os.Exit(1)
   139  
   140  	}
   141  	_, err = os.Stat(authFile)
   142  	if err != nil {
   143  		log.Println("Error reaching passwords file: " + err.Error())
   144  		os.Exit(1)
   145  	}
   146  	htpasswd := auth.HtpasswdFileProvider(authFile)
   147  	authenticator := auth.NewBasicAuthenticator("Readium License Content Protection Server", htpasswd)
   148  
   149  	HandleSignals()
   150  
   151  	parsedPort := strconv.Itoa(config.Config.LcpServer.Port)
   152  	s := lcpserver.New(":"+parsedPort, readonly, &idx, &store, &lst, &cert, packager, authenticator)
   153  	if readonly {
   154  		log.Println("License server running in readonly mode on port " + parsedPort)
   155  	} else {
   156  		log.Println("License server running on port " + parsedPort)
   157  	}
   158  	log.Println("Public base URL=" + config.Config.LcpServer.PublicBaseUrl)
   159  
   160  	if err := s.ListenAndServe(); err != nil {
   161  		log.Println("Error " + err.Error())
   162  	}
   163  
   164  }
   165  
   166  func HandleSignals() {
   167  	sigChan := make(chan os.Signal)
   168  	go func() {
   169  		stacktrace := make([]byte, 1<<20)
   170  		for sig := range sigChan {
   171  			switch sig {
   172  			case syscall.SIGQUIT:
   173  				length := runtime.Stack(stacktrace, true)
   174  				fmt.Println(string(stacktrace[:length]))
   175  			case syscall.SIGINT:
   176  				fallthrough
   177  			case syscall.SIGTERM:
   178  				fmt.Println("Shutting down...")
   179  				os.Exit(0)
   180  			}
   181  		}
   182  	}()
   183  	signal.Notify(sigChan, syscall.SIGQUIT, syscall.SIGINT, syscall.SIGTERM)
   184  }
   185  
   186  func s3ConfigFromYAML() storage.S3Config {
   187  	s3config := storage.S3Config{}
   188  
   189  	s3config.ID = config.Config.Storage.AccessId
   190  	s3config.Secret = config.Config.Storage.Secret
   191  	s3config.Token = config.Config.Storage.Token
   192  
   193  	s3config.Endpoint = config.Config.Storage.Endpoint
   194  	s3config.Bucket = config.Config.Storage.Bucket
   195  	s3config.Region = config.Config.Storage.Region
   196  
   197  	s3config.DisableSSL = config.Config.Storage.DisableSSL
   198  	s3config.ForcePathStyle = config.Config.Storage.PathStyle
   199  
   200  	return s3config
   201  }