use std::path::Path; use r2d2::Pool; use r2d2_sqlite::SqliteConnectionManager; use r2d2_sqlite::rusqlite::OptionalExtension; use super::problem::Problem; use super::user::User; #[derive(Clone)] pub struct Database { pool: Pool } #[derive(Debug)] pub enum DatabaseError { Connection(String), Query(String), } impl Database { pub fn new(database_path: impl AsRef) -> Result { let manager = SqliteConnectionManager::file(database_path); let pool = Pool::new(manager) .map_err(|e| DatabaseError::Connection(e.to_string()))?; Ok(Database { pool }) } pub fn new_in_memory() -> Result { let manager = SqliteConnectionManager::memory(); let pool = Pool::new(manager) .map_err(|e| DatabaseError::Connection(e.to_string()))?; Ok(Database { pool }) } pub fn insert_user(&self, email: &str, username: &str, password_hash: &str) -> Result { static QUERY: &str = include_str!("sql/insert_user.sql"); let conn = self.pool .get() .map_err(|e| DatabaseError::Connection(e.to_string()))?; let mut statement = conn.prepare(QUERY) .map_err(|e| DatabaseError::Query(e.to_string()))?; Ok(statement .query_one((email, username, password_hash), |row| { Ok(User::new( row.get("id")?, email.to_owned(), username.to_owned(), password_hash.to_owned(), )) }) .map_err(|e| DatabaseError::Query(e.to_string()))? ) } pub fn fetch_user_by_email(&self, email: &str) -> Result, DatabaseError> { static QUERY: &str = include_str!("sql/fetch_user_by_email.sql"); let conn = self.pool .get() .map_err(|e| DatabaseError::Connection(e.to_string()))?; let mut statement = conn.prepare(QUERY) .map_err(|e| DatabaseError::Query(e.to_string()))?; Ok(statement .query_one([email], |row| { Ok(User::new( row.get("id")?, row.get("email")?, row.get("username")?, row.get("password_hash")?, )) }) .optional() .map_err(|e| DatabaseError::Query(e.to_string()))? ) } pub fn fetch_user_by_username(&self, username: &str) -> Result, DatabaseError> { static QUERY: &str = include_str!("sql/fetch_user_by_username.sql"); let conn = self.pool .get() .map_err(|e| DatabaseError::Connection(e.to_string()))?; let mut statement = conn.prepare(QUERY) .map_err(|e| DatabaseError::Query(e.to_string()))?; Ok(statement .query_one([username], |row| { Ok(User::new( row.get("id")?, row.get("email")?, row.get("username")?, row.get("password_hash")?, )) }) .optional() .map_err(|e| DatabaseError::Query(e.to_string()))? ) } pub fn insert_problem(&self, title: &str, description: &str) -> Result { static QUERY: &str = include_str!("sql/insert_problem.sql"); let conn = self.pool .get() .map_err(|e| DatabaseError::Connection(e.to_string()))?; let mut statement = conn.prepare(QUERY) .map_err(|e| DatabaseError::Query(e.to_string()))?; Ok(statement .query_one((title, description), |row| { Ok(Problem::new( row.get("id")?, title.to_owned(), description.to_owned(), )) }) .map_err(|e| DatabaseError::Query(e.to_string()))? ) } pub fn fetch_problems(&self) -> Result, DatabaseError> { static QUERY: &str = include_str!("sql/fetch_problems.sql"); let conn = self.pool .get() .map_err(|e| DatabaseError::Connection(e.to_string()))?; let mut statement = conn.prepare(QUERY) .map_err(|e| DatabaseError::Query(e.to_string()))?; Ok(statement .query_map([], |row| { Ok(Problem::new( row.get("id")?, row.get("title")?, row.get("description")?, )) }) .map_err(|e| DatabaseError::Query(e.to_string()))? .collect::, _>>() .map_err(|e| DatabaseError::Query(e.to_string()))? ) } pub fn initialize(&self) -> Result<(), DatabaseError> { static QUERY: &str = include_str!("sql/initialize.sql"); let conn = self.pool .get() .map_err(|e| DatabaseError::Connection(e.to_string()))?; conn.execute_batch(QUERY) .map_err(|e| DatabaseError::Query(e.to_string()))?; Ok(()) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_problem_database() { let db = Database::new_in_memory().unwrap(); db.initialize().unwrap(); let title = "test problem 1"; let description = "description of test problem 1"; let problem = db.insert_problem(title, description).unwrap(); assert_eq!(problem.title(), title); assert_eq!(problem.description(), description); let problems = db.fetch_problems().unwrap(); assert_eq!(problems.len(), 1); assert_eq!(problems[0].title(), title); assert_eq!(problems[0].description(), description); assert_eq!(problems[0].id(), problem.id()); } }