aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--atomictemp/atomictemp.go48
-rw-r--r--atomictemp/atomictemp_test.go68
2 files changed, 116 insertions, 0 deletions
diff --git a/atomictemp/atomictemp.go b/atomictemp/atomictemp.go
new file mode 100644
index 0000000..b56d095
--- /dev/null
+++ b/atomictemp/atomictemp.go
@@ -0,0 +1,48 @@
+// Create a temporary file in the directory of the given destination
+// file, then rename the temporary file over the destination file if
+// writes to the temporary file complete without error.
+//
+// Used to atomically update files.
+//
+// Note: This is not guaranteed and depends on the sync semantics of the
+// underlying filesystem.
+package atomictemp
+
+import (
+ "io"
+ "os"
+ "path/filepath"
+)
+
+// Create a temporary file in the directory of the given destination
+// file, then rename the temporary file over the destination file if
+// writes to the temporary file complete without error.
+//
+// Used to atomically update files.
+//
+// Note: This is not guaranteed and depends on the sync semantics of the
+// underlying filesystem.
+func Create(dstPath string, fn func(io.Writer) error) error {
+ // open temp output file
+ f, err := os.CreateTemp(filepath.Dir(dstPath), "")
+ if err != nil {
+ return err
+ }
+
+ // build absolute path to temporary file
+ // tmpPath := filepath.Join(filepath.Dir(dstPath), f.Name())
+ defer os.Remove(f.Name())
+
+ // invoke callback
+ if err = fn(f); err != nil {
+ return err
+ }
+
+ // close temp file
+ if err = f.Close(); err != nil {
+ return err
+ }
+
+ // rename to destination file, return result
+ return os.Rename(f.Name(), dstPath)
+}
diff --git a/atomictemp/atomictemp_test.go b/atomictemp/atomictemp_test.go
new file mode 100644
index 0000000..401c730
--- /dev/null
+++ b/atomictemp/atomictemp_test.go
@@ -0,0 +1,68 @@
+package atomictemp
+
+import (
+ "errors"
+ "io"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestCreate(t *testing.T) {
+ dir, err := os.MkdirTemp("", "")
+ if err != nil {
+ t.Error(err)
+ }
+
+ // test good file
+ goodTests := []string { "foo", "bar", "baz" }
+ for _, test := range(goodTests) {
+ t.Run(test, func(t *testing.T) {
+ // build destination path
+ path := filepath.Join(dir, test)
+
+ // create temp file
+ err := Create(path, func(f io.Writer) error {
+ _, err := f.Write([]byte(test))
+ return err
+ })
+
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ if got, err := os.ReadFile(path); err != nil {
+ t.Error(err)
+ } else if string(got) != test {
+ t.Errorf("got \"%s\", exp \"%s\"", string(got), test)
+ }
+ })
+ }
+
+ t.Run("badDir", func(t *testing.T) {
+ // build nonsense path to destination file
+ badPath := filepath.Join(dir, "does/not/exist")
+
+ err := Create(badPath, func(_ io.Writer) error {
+ return nil
+ })
+
+ if err == nil {
+ t.Errorf("got success, exp error")
+ }
+ })
+
+ t.Run("badFunc", func(t *testing.T) {
+ // build path
+ path := filepath.Join(dir, "badFunc")
+
+ err := Create(path, func(_ io.Writer) error {
+ return errors.New("ack!")
+ })
+
+ if err == nil {
+ t.Errorf("got success, exp error")
+ }
+ })
+}