diff options
Diffstat (limited to 'dbstore/tx.go')
-rw-r--r-- | dbstore/tx.go | 105 |
1 files changed, 105 insertions, 0 deletions
diff --git a/dbstore/tx.go b/dbstore/tx.go new file mode 100644 index 0000000..460bbe1 --- /dev/null +++ b/dbstore/tx.go @@ -0,0 +1,105 @@ +// database storage +package dbstore + +import ( + "context" + db_sql "database/sql" + "fmt" + _ "github.com/mattn/go-sqlite3" +) + +type Tx struct { + tx *db_sql.Tx // underlying transaction + sts map[string]*db_sql.Stmt // prepared statements + closed bool // are the statements closed? + done bool // is the transaction done? + err error // last error +} + +func newTx(ctx context.Context, db *db_sql.DB, queryIds []string) (Tx, error) { + var r Tx + + // begin context + if tx, err := db.BeginTx(ctx, nil); err != nil { + return r, err + } else { + r.tx = tx + } + + // build query map + queries, err := getQueries(queryIds) + if err != nil { + return r, err + } + + // build statements + sts := make(map[string]*db_sql.Stmt) + for id, sql := range(queries) { + if st, err := r.tx.PrepareContext(ctx, sql); err != nil { + return r, err + } else { + sts[id] = st + } + } + r.sts = sts + + // return success + return r, nil +} + +// Finalize statements. +func (tx Tx) Close() { + if !tx.closed { + for id, st := range(tx.sts) { + // close statement + st.Close() + + // delete key + delete(tx.sts, id) + } + + // mark transaction as closed + tx.closed = true + } +} + +// Finalize statements, commit transaction. +func (tx Tx) Commit() error { + // close statements + // FIXME: this isn't really necessary, rollback and commit will take + // care of it according to the database/sql docs + tx.Close() + + if !tx.done { + tx.err = tx.tx.Commit() + tx.done = true + } + + // return last error + return tx.err +} + +// Finalize statements, rollback transaction. +func (tx Tx) Rollback() error { + // close statements + // FIXME: this isn't really necessary, rollback and commit will take + // care of it according to the database/sql docs + tx.Close() + + if !tx.done { + tx.err = tx.tx.Rollback() + tx.done = true + } + + // return last error + return tx.err +} + +// execute given prepared statement +func (tx Tx) Exec(ctx context.Context, id string, args... interface{}) (db_sql.Result, error) { + if st, ok := tx.sts[id]; !ok { + return nil, fmt.Errorf("unknown statement: %s", id) + } else { + return st.ExecContext(ctx, args...) + } +} |