aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/guff/database.cr178
1 files changed, 178 insertions, 0 deletions
diff --git a/src/guff/database.cr b/src/guff/database.cr
new file mode 100644
index 0000000..6e5906e
--- /dev/null
+++ b/src/guff/database.cr
@@ -0,0 +1,178 @@
+require "sqlite3"
+
+module Guff
+ class Database < ::SQLite3::Database
+ SQL = {
+ table_exists: "
+ SELECT name
+
+ FROM sqlite_master
+
+ WHERE type = 'table'
+ AND name = ?
+ ",
+
+ pragma_foreign_keys: "
+ PRAGMA foreign_keys = true
+ ",
+ }
+
+ def initialize(path)
+ super(path)
+ @savepoint_id = 0_i64
+ query(SQL[:pragma_foreign_keys])
+ end
+
+ def initialize(path, &block : Database ->)
+ super(path)
+ @savepoint_id = 0_i64
+ query(SQL[:pragma_foreign_keys])
+
+ begin
+ block.call(self)
+ ensure
+ close unless closed?
+ end
+ end
+
+ def table_exists?(table : String) : Bool
+ one(SQL[:table_exists], [table]) == table
+ end
+
+ def one(
+ sql : String,
+ args : Array(String) | Hash(String, String) | Nil = nil
+ )
+ r = nil
+
+ run(sql, args) do |rs|
+ if rs.next
+ r = rs[0].to_s
+ end
+ end
+
+ # return result
+ r
+ end
+
+ def row(
+ sql : String,
+ args : Array(String) | Hash(String, String) | Nil = nil
+ )
+ r = nil
+
+ # exec query
+ run(sql, args) do |rs|
+ r = to_row(rs) if rs.next
+ end
+
+ # return result
+ r
+ end
+
+ def all(
+ sql : String,
+ args : Array(String) | Hash(String, String) | Nil = nil,
+ &block : Hash(String, ::SQLite3::Value) -> \
+ )
+ # build statement
+ run(sql, args) do |rs|
+ # walk results
+ while rs.next
+ # build row and pass it to callback
+ block.call(to_row(rs))
+ end
+ end
+
+ nil
+ end
+
+ def query(
+ sql : String
+ )
+ run(sql, nil) do |rs|
+ # make sure query executes
+ rs.next
+ nil
+ end
+ end
+
+ def query(
+ sql : String,
+ args : Array(String) | Hash(String, String) | Nil = nil,
+ )
+ run(sql, args) do |rs|
+ # make sure query executes
+ rs.next
+ nil
+ end
+ end
+
+ #
+ # NOTE: if you pass a block, be sure to call rs.next at least once,
+ # or the query will _not_ execute!!!
+ #
+ def query(
+ sql : String,
+ args : Array(String) | Hash(String, String) | Nil = nil,
+ &block : ::SQLite3::ResultSet -> \
+ )
+ run(sql, args, &block)
+ end
+
+ def transaction(&block)
+ # get next savepoint id
+ id = next_savepoint_id
+
+ begin
+ query("SAVEPOINT %s" % [id])
+ block.call
+ rescue e
+ query("ROLLBACK TO %s" % [id])
+ raise e
+ ensure
+ query("RELEASE %s" % [id])
+ end
+ end
+
+ private def next_savepoint_id : String
+ # increment savepoint counter
+ @savepoint_id += 1
+
+ # return savepoint id
+ "guff_savepoint_%s" % [@savepoint_id]
+ end
+
+ private def run(
+ sql : String,
+ args : Hash(String, String),
+ &block : ::SQLite3::ResultSet -> \
+ )
+ run(sql, [args], &block)
+ end
+
+ private def run(
+ sql : String,
+ args : Array(String | Hash(String, String))? = nil,
+ &block : ::SQLite3::ResultSet -> \
+ )
+ # build statement
+ puts "sql = %s" % [sql]
+ st = prepare(sql)
+
+ # exec and close statement
+ if args && args.size > 0
+ st.execute(args, &block)
+ else
+ st.execute(&block)
+ end
+
+ # return result
+ nil
+ end
+
+ private def to_row(rs)
+ Hash(String, ::SQLite3::Value).zip(rs.columns, rs.to_a)
+ end
+ end
+end