diff --git a/cmd/auth.go b/cmd/auth.go index 2b618bf..595262d 100644 --- a/cmd/auth.go +++ b/cmd/auth.go @@ -26,6 +26,7 @@ import ( "go.uploadedlobster.com/scotty/internal/auth" "go.uploadedlobster.com/scotty/internal/backends" "go.uploadedlobster.com/scotty/internal/cli" + "go.uploadedlobster.com/scotty/internal/config" "go.uploadedlobster.com/scotty/internal/models" "go.uploadedlobster.com/scotty/internal/storage" "golang.org/x/oauth2" @@ -76,7 +77,7 @@ var authCmd = &cobra.Command{ cobra.CheckErr(err) // Store the retrieved token in the database - db, err := storage.New(viper.GetString("database")) + db, err := storage.New(config.DatabasePath()) cobra.CheckErr(err) err = db.SetOAuth2Token(serviceName, tok) diff --git a/cmd/listens.go b/cmd/listens.go index 8bf3c3b..558259e 100644 --- a/cmd/listens.go +++ b/cmd/listens.go @@ -18,9 +18,9 @@ package cmd import ( "github.com/spf13/cobra" - "github.com/spf13/viper" "go.uploadedlobster.com/scotty/internal/backends" "go.uploadedlobster.com/scotty/internal/cli" + "go.uploadedlobster.com/scotty/internal/config" "go.uploadedlobster.com/scotty/internal/models" "go.uploadedlobster.com/scotty/internal/storage" ) @@ -31,7 +31,7 @@ var listensCmd = &cobra.Command{ Short: "Transfer listens between two services", Long: `Transfers listens between two configured services.`, Run: func(cmd *cobra.Command, args []string) { - db, err := storage.New(viper.GetString("database")) + db, err := storage.New(config.DatabasePath()) cobra.CheckErr(err) c, err := cli.NewTransferCmd[ models.ListensExport, diff --git a/cmd/loves.go b/cmd/loves.go index f44cb5b..a802c42 100644 --- a/cmd/loves.go +++ b/cmd/loves.go @@ -18,9 +18,9 @@ package cmd import ( "github.com/spf13/cobra" - "github.com/spf13/viper" "go.uploadedlobster.com/scotty/internal/backends" "go.uploadedlobster.com/scotty/internal/cli" + "go.uploadedlobster.com/scotty/internal/config" "go.uploadedlobster.com/scotty/internal/models" "go.uploadedlobster.com/scotty/internal/storage" ) @@ -31,7 +31,7 @@ var lovesCmd = &cobra.Command{ Short: "Transfer loves between two services", Long: `Transfers loves between two configured services.`, Run: func(cmd *cobra.Command, args []string) { - db, err := storage.New(viper.GetString("database")) + db, err := storage.New(config.DatabasePath()) cobra.CheckErr(err) c, err := cli.NewTransferCmd[ models.LovesExport, diff --git a/cmd/root.go b/cmd/root.go index 10dc528..8b4fdf8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -19,10 +19,10 @@ package cmd import ( "fmt" "os" - "path" "github.com/spf13/cobra" "github.com/spf13/viper" + "go.uploadedlobster.com/scotty/internal/config" "go.uploadedlobster.com/scotty/internal/version" ) @@ -56,7 +56,7 @@ func init() { // Cobra supports persistent flags, which, if defined here, // will be global for your application. - configDir := defaultConfigDir() + configDir := config.DefaultConfigDir() rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", fmt.Sprintf("config file (default is %s/scotty.yaml)", configDir)) @@ -65,27 +65,12 @@ func init() { // rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") } -func defaultConfigDir() string { - configDir, err := os.UserConfigDir() - cobra.CheckErr(err) - return path.Join(configDir, version.AppName) -} - // initConfig reads in config file and ENV variables if set. func initConfig() { - if cfgFile != "" { - // Use config file from the flag. - viper.SetConfigFile(cfgFile) - } else { - viper.AddConfigPath(defaultConfigDir()) - viper.SetConfigType("toml") - viper.SetConfigName(version.AppName) - } - - viper.AutomaticEnv() // read in environment variables that match - // If a config file is found, read it in. - if err := viper.ReadInConfig(); err == nil { - fmt.Fprintln(os.Stderr, "Using config file:", viper.ConfigFileUsed()) + if err := config.InitConfig(cfgFile); err != nil { + fmt.Fprintln(os.Stderr, "Failed reading config:", err) + } else { + fmt.Println("Using config file:", viper.ConfigFileUsed()) } } diff --git a/internal/cli/common.go b/internal/cli/common.go index a26253d..e7fe19d 100644 --- a/internal/cli/common.go +++ b/internal/cli/common.go @@ -30,7 +30,7 @@ func GetConfigFromFlag(cmd *cobra.Command, flagName string) (string, *viper.Vipe config = servicesConfig.Sub(configName) } if config == nil { - cobra.CheckErr(fmt.Sprintf("Invalid source configuration \"%s\"", configName)) + cobra.CheckErr(fmt.Sprintf("invalid configuration \"%s\"", configName)) } return configName, config } diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..ece3544 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,87 @@ +/* +Copyright © 2023 Philipp Wolfer + +Scotty is free software: you can redistribute it and/or modify it under the +terms of the GNU General Public License as published by the Free Software +Foundation, either version 3 of the License, or (at your option) any later version. + +Scotty is distributed in the hope that it will be useful, but WITHOUT ANY +WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. See the GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along with +Scotty. If not, see . +*/ + +package config + +import ( + "os" + "path" + "path/filepath" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + "go.uploadedlobster.com/scotty/internal/version" +) + +const ( + defaultDatabase = "scotty.sqlite3" + defaultOAuthHost = "127.0.0.1:2369" +) + +func DefaultConfigDir() string { + configDir, err := os.UserConfigDir() + cobra.CheckErr(err) + return path.Join(configDir, version.AppName) +} + +// initConfig reads in config file and ENV variables if set. +func InitConfig(cfgFile string) error { + configDir := DefaultConfigDir() + if cfgFile != "" { + // Use given config file + viper.SetConfigFile(cfgFile) + } else { + viper.AddConfigPath(configDir) + viper.SetConfigType("toml") + viper.SetConfigName(version.AppName) + viper.SetConfigPermissions(0640) + } + + setDefaults() + + // Create global config if it does not exist + if viper.ConfigFileUsed() == "" && cfgFile == "" { + if err := os.MkdirAll(configDir, 0750); err == nil { + viper.SafeWriteConfig() + } + } + + // read in environment variables that match + viper.AutomaticEnv() + + // If a config file is found, read it in. + return viper.ReadInConfig() +} + +func DatabasePath() string { + path := viper.GetString("database") + if filepath.IsAbs(path) { + return path + } + + return filepath.Join(getConfigDir(), path) +} + +func setDefaults() { + viper.SetDefault("database", defaultDatabase) + viper.SetDefault("oauth-host", defaultOAuthHost) + + // Always configure the dump backend as a default service + viper.SetDefault("service.dump.backend", "dump") +} + +func getConfigDir() string { + return filepath.Dir(viper.ConfigFileUsed()) +}