summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/database/database.rs7
-rw-r--r--src/database/sql/initialize.sql3
-rw-r--r--src/database/sql/insert_user.sql2
-rw-r--r--src/database/user.rs6
-rw-r--r--src/main.rs27
-rw-r--r--src/routes/user.rs4
-rw-r--r--src/utils.rs23
7 files changed, 60 insertions, 12 deletions
diff --git a/src/database/database.rs b/src/database/database.rs
index ca0e018..7444779 100644
--- a/src/database/database.rs
+++ b/src/database/database.rs
@@ -33,7 +33,7 @@ impl Database {
Ok(Database { pool })
}
- pub fn insert_user(&self, email: &str, username: &str, password_hash: &str) -> Result<User, DatabaseError> {
+ pub fn insert_user(&self, email: &str, username: &str, password_hash: &str, is_admin: bool) -> Result<User, DatabaseError> {
static QUERY: &str = include_str!("sql/insert_user.sql");
let conn = self.pool
.get()
@@ -42,12 +42,13 @@ impl Database {
.map_err(|e| DatabaseError::Query(e.to_string()))?;
Ok(statement
- .query_one((email, username, password_hash), |row| {
+ .query_one((email, username, password_hash, is_admin), |row| {
Ok(User::new(
row.get("id")?,
email.to_owned(),
username.to_owned(),
password_hash.to_owned(),
+ is_admin
))
})
.map_err(|e| DatabaseError::Query(e.to_string()))?
@@ -69,6 +70,7 @@ impl Database {
row.get("email")?,
row.get("username")?,
row.get("password_hash")?,
+ row.get("is_admin")?,
))
})
.optional()
@@ -91,6 +93,7 @@ impl Database {
row.get("email")?,
row.get("username")?,
row.get("password_hash")?,
+ row.get("is_admin")?
))
})
.optional()
diff --git a/src/database/sql/initialize.sql b/src/database/sql/initialize.sql
index 1c2a40f..72b7468 100644
--- a/src/database/sql/initialize.sql
+++ b/src/database/sql/initialize.sql
@@ -8,7 +8,8 @@ CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
username TEXT UNIQUE NOT NULL,
- password_hash TEXT NOT NULL
+ password_hash TEXT NOT NULL,
+ is_admin INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS submission (
diff --git a/src/database/sql/insert_user.sql b/src/database/sql/insert_user.sql
index 6b67678..6017cf5 100644
--- a/src/database/sql/insert_user.sql
+++ b/src/database/sql/insert_user.sql
@@ -1 +1 @@
-INSERT INTO user (email, username, password_hash) VALUES (?1, ?2, ?3) RETURNING id;
+INSERT INTO user (email, username, password_hash, is_admin) VALUES (?1, ?2, ?3, ?4) RETURNING id;
diff --git a/src/database/user.rs b/src/database/user.rs
index c580226..c6674c7 100644
--- a/src/database/user.rs
+++ b/src/database/user.rs
@@ -6,15 +6,17 @@ pub struct User {
email: String,
username: String,
password_hash: String,
+ is_admin: bool
}
impl User {
- pub(super) fn new(id: i64, email: String, username: String, password_hash: String) -> Self {
- Self { id, email, username, password_hash }
+ pub(super) fn new(id: i64, email: String, username: String, password_hash: String, is_admin: bool) -> Self {
+ Self { id, email, username, password_hash, is_admin }
}
pub fn id(&self) -> i64 { self.id }
pub fn email(&self) -> &str { &self.email }
pub fn username(&self) -> &str { &self.username }
+ pub fn is_admin(&self) -> bool { self.is_admin }
pub fn password_hash(&self) -> &str { &self.password_hash }
}
diff --git a/src/main.rs b/src/main.rs
index 36ef319..b1979b1 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -2,6 +2,7 @@ use std::env;
mod database;
mod routes;
+mod utils;
use axum::{
routing::{get, post},
@@ -13,7 +14,7 @@ use routes::user::{create_user, me};
use routes::auth::{login, logout};
use tower_http::services::ServeDir;
-use crate::database::Database;
+use crate::{database::Database, utils::register_admin};
#[derive(Clone)]
struct AppState {
@@ -24,13 +25,31 @@ struct AppState {
#[tokio::main]
async fn main() {
+ let database = Database::new_in_memory().unwrap();
+ database.initialize().unwrap();
+
+ // TODO: generically load environment variables, this is tedious
let Ok(secret) = env::var("JWT_SECRET") else {
eprintln!("missing environment variable JWT_SECRET");
return;
};
-
- let database = Database::new_in_memory().unwrap();
- database.initialize().unwrap();
+ let Ok(admin_email) = env::var("ADMIN_EMAIL") else {
+ eprintln!("missing environment variable ADMIN_EMAIL");
+ return;
+ };
+ let Ok(admin_username) = env::var("ADMIN_USERNAME") else {
+ eprintln!("missing environment variable ADMIN_USERNAME");
+ return;
+ };
+ let Ok(admin_password) = env::var("ADMIN_PASSWORD") else {
+ eprintln!("missing environment variable ADMIN_PASSWORD");
+ return;
+ };
+
+ let Ok(()) = register_admin(&database, &admin_email, &admin_username, &admin_password) else {
+ eprintln!("failed to register admin user");
+ return;
+ };
let state = AppState {
secret: secret,
diff --git a/src/routes/user.rs b/src/routes/user.rs
index 157cf09..178a272 100644
--- a/src/routes/user.rs
+++ b/src/routes/user.rs
@@ -26,7 +26,7 @@ pub async fn create_user(
};
match state.database.fetch_user_by_username(&request.username) {
- Err(_) => return Err(RouteError::Internal("database action failed B".into())),
+ Err(_) => return Err(RouteError::Internal("database action failed".into())),
Ok(Some(_)) => return Err(RouteError::UserCreateUsernameExists(request.username)),
Ok(None) => {},
};
@@ -35,7 +35,7 @@ pub async fn create_user(
return Err(RouteError::Internal("failed to hash password".into()))
};
- let Ok(user) = state.database.insert_user(&request.email, &request.username, &password_hash) else {
+ let Ok(user) = state.database.insert_user(&request.email, &request.username, &password_hash, false) else {
return Err(RouteError::Internal("failed to create user".into()));
};
diff --git a/src/utils.rs b/src/utils.rs
new file mode 100644
index 0000000..6a854d5
--- /dev/null
+++ b/src/utils.rs
@@ -0,0 +1,23 @@
+use crate::{database::Database, routes::auth::hash_password};
+
+pub fn register_admin(database: &Database, email: &str, username: &str, password: &str) -> Result<(), ()> {
+ match database.fetch_user_by_email(email) {
+ Err(_) | Ok(Some(_)) => return Err(()),
+ Ok(None) => {},
+ };
+
+ match database.fetch_user_by_username(username) {
+ Err(_) | Ok(Some(_)) => return Err(()),
+ Ok(None) => {},
+ };
+
+ let Ok(password_hash) = hash_password(password) else { return Err(()) };
+
+ match database.insert_user(email, username, &password_hash, true) {
+ Ok(_) => Ok(()),
+ Err(e) => {
+ eprintln!("{:?}", e);
+ return Err(());
+ }
+ }
+}