diff options
| -rw-r--r-- | src/database/database.rs | 7 | ||||
| -rw-r--r-- | src/database/sql/initialize.sql | 3 | ||||
| -rw-r--r-- | src/database/sql/insert_user.sql | 2 | ||||
| -rw-r--r-- | src/database/user.rs | 6 | ||||
| -rw-r--r-- | src/main.rs | 27 | ||||
| -rw-r--r-- | src/routes/user.rs | 4 | ||||
| -rw-r--r-- | src/utils.rs | 23 |
7 files changed, 60 insertions, 12 deletions
diff --git a/src/database/database.rs b/src/database/database.rs index ca0e018..7444779 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -33,7 +33,7 @@ impl Database { Ok(Database { pool }) } - pub fn insert_user(&self, email: &str, username: &str, password_hash: &str) -> Result<User, DatabaseError> { + pub fn insert_user(&self, email: &str, username: &str, password_hash: &str, is_admin: bool) -> Result<User, DatabaseError> { static QUERY: &str = include_str!("sql/insert_user.sql"); let conn = self.pool .get() @@ -42,12 +42,13 @@ impl Database { .map_err(|e| DatabaseError::Query(e.to_string()))?; Ok(statement - .query_one((email, username, password_hash), |row| { + .query_one((email, username, password_hash, is_admin), |row| { Ok(User::new( row.get("id")?, email.to_owned(), username.to_owned(), password_hash.to_owned(), + is_admin )) }) .map_err(|e| DatabaseError::Query(e.to_string()))? @@ -69,6 +70,7 @@ impl Database { row.get("email")?, row.get("username")?, row.get("password_hash")?, + row.get("is_admin")?, )) }) .optional() @@ -91,6 +93,7 @@ impl Database { row.get("email")?, row.get("username")?, row.get("password_hash")?, + row.get("is_admin")? )) }) .optional() diff --git a/src/database/sql/initialize.sql b/src/database/sql/initialize.sql index 1c2a40f..72b7468 100644 --- a/src/database/sql/initialize.sql +++ b/src/database/sql/initialize.sql @@ -8,7 +8,8 @@ CREATE TABLE IF NOT EXISTS user ( id INTEGER PRIMARY KEY, email TEXT UNIQUE NOT NULL, username TEXT UNIQUE NOT NULL, - password_hash TEXT NOT NULL + password_hash TEXT NOT NULL, + is_admin INTEGER NOT NULL ); CREATE TABLE IF NOT EXISTS submission ( diff --git a/src/database/sql/insert_user.sql b/src/database/sql/insert_user.sql index 6b67678..6017cf5 100644 --- a/src/database/sql/insert_user.sql +++ b/src/database/sql/insert_user.sql @@ -1 +1 @@ -INSERT INTO user (email, username, password_hash) VALUES (?1, ?2, ?3) RETURNING id; +INSERT INTO user (email, username, password_hash, is_admin) VALUES (?1, ?2, ?3, ?4) RETURNING id; diff --git a/src/database/user.rs b/src/database/user.rs index c580226..c6674c7 100644 --- a/src/database/user.rs +++ b/src/database/user.rs @@ -6,15 +6,17 @@ pub struct User { email: String, username: String, password_hash: String, + is_admin: bool } impl User { - pub(super) fn new(id: i64, email: String, username: String, password_hash: String) -> Self { - Self { id, email, username, password_hash } + pub(super) fn new(id: i64, email: String, username: String, password_hash: String, is_admin: bool) -> Self { + Self { id, email, username, password_hash, is_admin } } pub fn id(&self) -> i64 { self.id } pub fn email(&self) -> &str { &self.email } pub fn username(&self) -> &str { &self.username } + pub fn is_admin(&self) -> bool { self.is_admin } pub fn password_hash(&self) -> &str { &self.password_hash } } diff --git a/src/main.rs b/src/main.rs index 36ef319..b1979b1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use std::env; mod database; mod routes; +mod utils; use axum::{ routing::{get, post}, @@ -13,7 +14,7 @@ use routes::user::{create_user, me}; use routes::auth::{login, logout}; use tower_http::services::ServeDir; -use crate::database::Database; +use crate::{database::Database, utils::register_admin}; #[derive(Clone)] struct AppState { @@ -24,13 +25,31 @@ struct AppState { #[tokio::main] async fn main() { + let database = Database::new_in_memory().unwrap(); + database.initialize().unwrap(); + + // TODO: generically load environment variables, this is tedious let Ok(secret) = env::var("JWT_SECRET") else { eprintln!("missing environment variable JWT_SECRET"); return; }; - - let database = Database::new_in_memory().unwrap(); - database.initialize().unwrap(); + let Ok(admin_email) = env::var("ADMIN_EMAIL") else { + eprintln!("missing environment variable ADMIN_EMAIL"); + return; + }; + let Ok(admin_username) = env::var("ADMIN_USERNAME") else { + eprintln!("missing environment variable ADMIN_USERNAME"); + return; + }; + let Ok(admin_password) = env::var("ADMIN_PASSWORD") else { + eprintln!("missing environment variable ADMIN_PASSWORD"); + return; + }; + + let Ok(()) = register_admin(&database, &admin_email, &admin_username, &admin_password) else { + eprintln!("failed to register admin user"); + return; + }; let state = AppState { secret: secret, diff --git a/src/routes/user.rs b/src/routes/user.rs index 157cf09..178a272 100644 --- a/src/routes/user.rs +++ b/src/routes/user.rs @@ -26,7 +26,7 @@ pub async fn create_user( }; match state.database.fetch_user_by_username(&request.username) { - Err(_) => return Err(RouteError::Internal("database action failed B".into())), + Err(_) => return Err(RouteError::Internal("database action failed".into())), Ok(Some(_)) => return Err(RouteError::UserCreateUsernameExists(request.username)), Ok(None) => {}, }; @@ -35,7 +35,7 @@ pub async fn create_user( return Err(RouteError::Internal("failed to hash password".into())) }; - let Ok(user) = state.database.insert_user(&request.email, &request.username, &password_hash) else { + let Ok(user) = state.database.insert_user(&request.email, &request.username, &password_hash, false) else { return Err(RouteError::Internal("failed to create user".into())); }; diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..6a854d5 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,23 @@ +use crate::{database::Database, routes::auth::hash_password}; + +pub fn register_admin(database: &Database, email: &str, username: &str, password: &str) -> Result<(), ()> { + match database.fetch_user_by_email(email) { + Err(_) | Ok(Some(_)) => return Err(()), + Ok(None) => {}, + }; + + match database.fetch_user_by_username(username) { + Err(_) | Ok(Some(_)) => return Err(()), + Ok(None) => {}, + }; + + let Ok(password_hash) = hash_password(password) else { return Err(()) }; + + match database.insert_user(email, username, &password_hash, true) { + Ok(_) => Ok(()), + Err(e) => { + eprintln!("{:?}", e); + return Err(()); + } + } +} |
