diff options
-rw-r--r-- | src/guff/database.cr | 178 |
1 files changed, 178 insertions, 0 deletions
diff --git a/src/guff/database.cr b/src/guff/database.cr new file mode 100644 index 0000000..6e5906e --- /dev/null +++ b/src/guff/database.cr @@ -0,0 +1,178 @@ +require "sqlite3" + +module Guff + class Database < ::SQLite3::Database + SQL = { + table_exists: " + SELECT name + + FROM sqlite_master + + WHERE type = 'table' + AND name = ? + ", + + pragma_foreign_keys: " + PRAGMA foreign_keys = true + ", + } + + def initialize(path) + super(path) + @savepoint_id = 0_i64 + query(SQL[:pragma_foreign_keys]) + end + + def initialize(path, &block : Database ->) + super(path) + @savepoint_id = 0_i64 + query(SQL[:pragma_foreign_keys]) + + begin + block.call(self) + ensure + close unless closed? + end + end + + def table_exists?(table : String) : Bool + one(SQL[:table_exists], [table]) == table + end + + def one( + sql : String, + args : Array(String) | Hash(String, String) | Nil = nil + ) + r = nil + + run(sql, args) do |rs| + if rs.next + r = rs[0].to_s + end + end + + # return result + r + end + + def row( + sql : String, + args : Array(String) | Hash(String, String) | Nil = nil + ) + r = nil + + # exec query + run(sql, args) do |rs| + r = to_row(rs) if rs.next + end + + # return result + r + end + + def all( + sql : String, + args : Array(String) | Hash(String, String) | Nil = nil, + &block : Hash(String, ::SQLite3::Value) -> \ + ) + # build statement + run(sql, args) do |rs| + # walk results + while rs.next + # build row and pass it to callback + block.call(to_row(rs)) + end + end + + nil + end + + def query( + sql : String + ) + run(sql, nil) do |rs| + # make sure query executes + rs.next + nil + end + end + + def query( + sql : String, + args : Array(String) | Hash(String, String) | Nil = nil, + ) + run(sql, args) do |rs| + # make sure query executes + rs.next + nil + end + end + + # + # NOTE: if you pass a block, be sure to call rs.next at least once, + # or the query will _not_ execute!!! + # + def query( + sql : String, + args : Array(String) | Hash(String, String) | Nil = nil, + &block : ::SQLite3::ResultSet -> \ + ) + run(sql, args, &block) + end + + def transaction(&block) + # get next savepoint id + id = next_savepoint_id + + begin + query("SAVEPOINT %s" % [id]) + block.call + rescue e + query("ROLLBACK TO %s" % [id]) + raise e + ensure + query("RELEASE %s" % [id]) + end + end + + private def next_savepoint_id : String + # increment savepoint counter + @savepoint_id += 1 + + # return savepoint id + "guff_savepoint_%s" % [@savepoint_id] + end + + private def run( + sql : String, + args : Hash(String, String), + &block : ::SQLite3::ResultSet -> \ + ) + run(sql, [args], &block) + end + + private def run( + sql : String, + args : Array(String | Hash(String, String))? = nil, + &block : ::SQLite3::ResultSet -> \ + ) + # build statement + puts "sql = %s" % [sql] + st = prepare(sql) + + # exec and close statement + if args && args.size > 0 + st.execute(args, &block) + else + st.execute(&block) + end + + # return result + nil + end + + private def to_row(rs) + Hash(String, ::SQLite3::Value).zip(rs.columns, rs.to_a) + end + end +end |