diff options
Diffstat (limited to 'atomictemp')
-rw-r--r-- | atomictemp/atomictemp.go | 48 | ||||
-rw-r--r-- | atomictemp/atomictemp_test.go | 68 |
2 files changed, 116 insertions, 0 deletions
diff --git a/atomictemp/atomictemp.go b/atomictemp/atomictemp.go new file mode 100644 index 0000000..b56d095 --- /dev/null +++ b/atomictemp/atomictemp.go @@ -0,0 +1,48 @@ +// Create a temporary file in the directory of the given destination +// file, then rename the temporary file over the destination file if +// writes to the temporary file complete without error. +// +// Used to atomically update files. +// +// Note: This is not guaranteed and depends on the sync semantics of the +// underlying filesystem. +package atomictemp + +import ( + "io" + "os" + "path/filepath" +) + +// Create a temporary file in the directory of the given destination +// file, then rename the temporary file over the destination file if +// writes to the temporary file complete without error. +// +// Used to atomically update files. +// +// Note: This is not guaranteed and depends on the sync semantics of the +// underlying filesystem. +func Create(dstPath string, fn func(io.Writer) error) error { + // open temp output file + f, err := os.CreateTemp(filepath.Dir(dstPath), "") + if err != nil { + return err + } + + // build absolute path to temporary file + // tmpPath := filepath.Join(filepath.Dir(dstPath), f.Name()) + defer os.Remove(f.Name()) + + // invoke callback + if err = fn(f); err != nil { + return err + } + + // close temp file + if err = f.Close(); err != nil { + return err + } + + // rename to destination file, return result + return os.Rename(f.Name(), dstPath) +} diff --git a/atomictemp/atomictemp_test.go b/atomictemp/atomictemp_test.go new file mode 100644 index 0000000..401c730 --- /dev/null +++ b/atomictemp/atomictemp_test.go @@ -0,0 +1,68 @@ +package atomictemp + +import ( + "errors" + "io" + "os" + "path/filepath" + "testing" +) + +func TestCreate(t *testing.T) { + dir, err := os.MkdirTemp("", "") + if err != nil { + t.Error(err) + } + + // test good file + goodTests := []string { "foo", "bar", "baz" } + for _, test := range(goodTests) { + t.Run(test, func(t *testing.T) { + // build destination path + path := filepath.Join(dir, test) + + // create temp file + err := Create(path, func(f io.Writer) error { + _, err := f.Write([]byte(test)) + return err + }) + + if err != nil { + t.Error(err) + return + } + + if got, err := os.ReadFile(path); err != nil { + t.Error(err) + } else if string(got) != test { + t.Errorf("got \"%s\", exp \"%s\"", string(got), test) + } + }) + } + + t.Run("badDir", func(t *testing.T) { + // build nonsense path to destination file + badPath := filepath.Join(dir, "does/not/exist") + + err := Create(badPath, func(_ io.Writer) error { + return nil + }) + + if err == nil { + t.Errorf("got success, exp error") + } + }) + + t.Run("badFunc", func(t *testing.T) { + // build path + path := filepath.Join(dir, "badFunc") + + err := Create(path, func(_ io.Writer) error { + return errors.New("ack!") + }) + + if err == nil { + t.Errorf("got success, exp error") + } + }) +} |