From 65ce6f20c13aa8886cb597cf4191622d9b90515d Mon Sep 17 00:00:00 2001 From: Yann Soubeyrand Date: Sun, 9 Feb 2025 11:34:53 +0100 Subject: [PATCH] fix: make WriteConfig and SafeWriteConfig work as expected MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a config file is set using SetConfigFile: – WriteConfig failed if the file did not already exist, – SafeWriteConfig did not use it. Fixes #430 Fixes #433 Signed-off-by: Yann Soubeyrand --- viper.go | 32 +++++++++++++++++---- viper_test.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 97 insertions(+), 12 deletions(-) diff --git a/viper.go b/viper.go index f900e58b1..efd7c242e 100644 --- a/viper.go +++ b/viper.go @@ -69,6 +69,11 @@ func (str UnsupportedConfigError) Error() string { return fmt.Sprintf("Unsupported Config Type %q", string(str)) } +// ConfigFileError denotes failing to get configuration file. +type ConfigFileError struct { + error +} + // ConfigFileNotFoundError denotes failing to find configuration file. type ConfigFileNotFoundError struct { name, locations string @@ -1575,21 +1580,24 @@ func (v *Viper) MergeConfigMap(cfg map[string]any) error { func WriteConfig() error { return v.WriteConfig() } func (v *Viper) WriteConfig() error { - filename, err := v.getConfigFile() + _, err := v.getConfigFile() if err != nil { return err } - return v.writeConfig(filename, true) + + return v.WriteConfigAs(v.configFile) } // SafeWriteConfig writes current configuration to file only if the file does not exist. func SafeWriteConfig() error { return v.SafeWriteConfig() } func (v *Viper) SafeWriteConfig() error { - if len(v.configPaths) < 1 { - return errors.New("missing configuration for 'configPath'") + _, err := v.getConfigFile() + if err != nil { + return err } - return v.SafeWriteConfigAs(filepath.Join(v.configPaths[0], v.configName+"."+v.configType)) + + return v.SafeWriteConfigAs(v.configFile) } // WriteConfigAs writes current configuration to a given filename. @@ -2006,7 +2014,19 @@ func (v *Viper) getConfigFile() (string, error) { if v.configFile == "" { cf, err := v.findConfigFile() if err != nil { - return "", err + if _, ok := err.(ConfigFileNotFoundError); !ok { + return "", err + } + if len(v.configPaths) < 1 { + return "", ConfigFileError{errors.New("missing configuration for 'configPath'")} + } + if v.configName == "" { + return "", ConfigFileError{errors.New("missing configuration for 'configName'")} + } + if v.configType == "" { + return "", ConfigFileError{errors.New("missing configuration for 'configType'")} + } + cf = filepath.Join(v.configPaths[0], v.configName+"."+v.configType) } v.configFile = cf } diff --git a/viper_test.go b/viper_test.go index c0c0e7c06..6766d0864 100644 --- a/viper_test.go +++ b/viper_test.go @@ -1621,7 +1621,7 @@ func TestWrongDirsSearchNotFound(t *testing.T) { v.AddConfigPath(`thispathaintthere`) err := v.ReadInConfig() - assert.IsType(t, ConfigFileNotFoundError{"", ""}, err) + assert.ErrorAs(t, err, &ConfigFileError{}) // Even though config did not load and the error might have // been ignored by the client, the default still loads @@ -1639,7 +1639,7 @@ func TestWrongDirsSearchNotFoundForMerge(t *testing.T) { v.AddConfigPath(`thispathaintthere`) err := v.MergeInConfig() - assert.Equal(t, reflect.TypeOf(ConfigFileNotFoundError{"", ""}), reflect.TypeOf(err)) + assert.ErrorAs(t, err, &ConfigFileError{}) // Even though config did not load and the error might have // been ignored by the client, the default still loads @@ -1731,7 +1731,7 @@ var jsonWriteExpected = []byte(`{ // name: steve // `) -func TestWriteConfig(t *testing.T) { +func TestWriteConfigAs(t *testing.T) { fs := afero.NewMemMapFs() testCases := map[string]struct { configName string @@ -1809,7 +1809,7 @@ func TestWriteConfig(t *testing.T) { } } -func TestWriteConfigTOML(t *testing.T) { +func TestWriteConfigAsTOML(t *testing.T) { fs := afero.NewMemMapFs() testCases := map[string]struct { @@ -1860,7 +1860,7 @@ func TestWriteConfigTOML(t *testing.T) { } } -func TestWriteConfigDotEnv(t *testing.T) { +func TestWriteConfigAsDotEnv(t *testing.T) { fs := afero.NewMemMapFs() testCases := map[string]struct { configName string @@ -1909,6 +1909,56 @@ func TestWriteConfigDotEnv(t *testing.T) { } } +func TestWriteConfig(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + v.SetFs(fs) + v.AddConfigPath("/test") + v.SetConfigName("c") + v.SetConfigType("yaml") + require.NoError(t, v.ReadConfig(bytes.NewBuffer(yamlExample))) + require.NoError(t, v.WriteConfig()) + read, err := afero.ReadFile(fs, "/test/c.yaml") + require.NoError(t, err) + assert.Equal(t, yamlWriteExpected, read) +} + +func TestWriteConfigWithExplicitlySetFile(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + v.SetFs(fs) + v.AddConfigPath("/test1") + v.SetConfigName("c1") + v.SetConfigType("yaml") + v.SetConfigFile("/test2/c2.yaml") + require.NoError(t, v.ReadConfig(bytes.NewBuffer(yamlExample))) + require.NoError(t, v.WriteConfig()) + read, err := afero.ReadFile(fs, "/test2/c2.yaml") + require.NoError(t, err) + assert.Equal(t, yamlWriteExpected, read) +} + +func TestWriteConfigWithMissingConfigPath(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + v.SetFs(fs) + v.SetConfigName("c") + v.SetConfigType("yaml") + require.EqualError(t, v.WriteConfig(), "missing configuration for 'configPath'") +} + +func TestWriteConfigWithExistingFile(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + fs.Create("/test/c.yaml") + v.SetFs(fs) + v.AddConfigPath("/test") + v.SetConfigName("c") + v.SetConfigType("yaml") + err := v.WriteConfig() + require.NoError(t, err) +} + func TestSafeWriteConfig(t *testing.T) { v := New() fs := afero.NewMemMapFs() @@ -1923,6 +1973,21 @@ func TestSafeWriteConfig(t *testing.T) { assert.YAMLEq(t, string(yamlWriteExpected), string(read)) } +func TestSafeWriteConfigWithExplicitlySetFile(t *testing.T) { + v := New() + fs := afero.NewMemMapFs() + v.SetFs(fs) + v.AddConfigPath("/test1") + v.SetConfigName("c1") + v.SetConfigType("yaml") + v.SetConfigFile("/test2/c2.yaml") + require.NoError(t, v.ReadConfig(bytes.NewBuffer(yamlExample))) + require.NoError(t, v.SafeWriteConfig()) + read, err := afero.ReadFile(fs, "/test2/c2.yaml") + require.NoError(t, err) + assert.Equal(t, yamlWriteExpected, read) +} + func TestSafeWriteConfigWithMissingConfigPath(t *testing.T) { v := New() fs := afero.NewMemMapFs() @@ -1946,7 +2011,7 @@ func TestSafeWriteConfigWithExistingFile(t *testing.T) { assert.True(t, ok, "Expected ConfigFileAlreadyExistsError") } -func TestSafeWriteAsConfig(t *testing.T) { +func TestSafeWriteConfigAs(t *testing.T) { v := New() fs := afero.NewMemMapFs() v.SetFs(fs)