aboutsummaryrefslogtreecommitdiff
path: root/backend/src
diff options
context:
space:
mode:
Diffstat (limited to 'backend/src')
-rw-r--r--backend/src/db/users.rs32
-rw-r--r--backend/src/main.rs2
-rw-r--r--backend/src/routes/users.rs40
3 files changed, 47 insertions, 27 deletions
diff --git a/backend/src/db/users.rs b/backend/src/db/users.rs
index c06e5b3..29cee88 100644
--- a/backend/src/db/users.rs
+++ b/backend/src/db/users.rs
@@ -58,24 +58,26 @@ pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResul
}
pub fn authenticate_user(credentials: &Credentials, db_conn: &PgConnection) -> Option<User> {
- let user = users::table
+ users::table
.filter(users::username.eq(&credentials.username))
.first::<User>(db_conn)
- .unwrap();
+ .optional()
+ .unwrap()
+ .and_then(|user| {
+ let password_matches = argon2::verify_raw(
+ credentials.password.as_bytes(),
+ &user.password_salt,
+ &user.password_hash,
+ &argon2_config(),
+ )
+ .unwrap();
- let password_matches = argon2::verify_raw(
- credentials.password.as_bytes(),
- &user.password_salt,
- &user.password_hash,
- &argon2_config(),
- )
- .unwrap();
-
- if password_matches {
- return Some(user);
- } else {
- return None;
- }
+ if password_matches {
+ return Some(user);
+ } else {
+ return None;
+ }
+ })
}
#[test]
diff --git a/backend/src/main.rs b/backend/src/main.rs
index 65be48d..3c0efa8 100644
--- a/backend/src/main.rs
+++ b/backend/src/main.rs
@@ -3,6 +3,6 @@ extern crate rocket;
extern crate mozaic4_backend;
#[launch]
-fn launch() -> _ {
+fn launch() -> rocket::Rocket<rocket::Build> {
mozaic4_backend::rocket()
}
diff --git a/backend/src/routes/users.rs b/backend/src/routes/users.rs
index 274b712..72a857f 100644
--- a/backend/src/routes/users.rs
+++ b/backend/src/routes/users.rs
@@ -7,7 +7,8 @@ use rocket::serde::json::Json;
use serde::{Deserialize, Serialize};
use rocket::http::Status;
-use rocket::request::{self, FromRequest, Outcome, Request};
+use rocket::request::{FromRequest, Outcome, Request};
+use rocket::response::status;
#[derive(Debug)]
pub enum AuthTokenError {
@@ -23,17 +24,25 @@ impl<'r> FromRequest<'r> for User {
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let keys: Vec<_> = request.headers().get("Authorization").collect();
- let token = match keys.len() {
+ let auth_header = match keys.len() {
0 => return Outcome::Failure((Status::BadRequest, AuthTokenError::Missing)),
- 1 => keys[0].to_string(),
+ 1 => keys[0],
_ => return Outcome::Failure((Status::BadRequest, AuthTokenError::BadCount)),
};
+
+ let token = match auth_header.strip_prefix("Bearer ") {
+ Some(token) => token.to_string(),
+ None => return Outcome::Failure((Status::BadRequest, AuthTokenError::Invalid)),
+ };
+
let db = request.guard::<DbConn>().await.unwrap();
- let (_session, user) = db
+ let res = db
.run(move |conn| sessions::find_user_by_session(&token, conn))
- .await
- .unwrap();
- Outcome::Success(user)
+ .await;
+ match res {
+ Ok((_session, user)) => Outcome::Success(user),
+ Err(_) => Outcome::Failure((Status::Unauthorized, AuthTokenError::Invalid)),
+ }
}
}
@@ -79,7 +88,10 @@ pub struct LoginParams {
}
#[post("/login", data = "<params>")]
-pub async fn login(db_conn: DbConn, params: Json<LoginParams>) -> String {
+pub async fn login(
+ db_conn: DbConn,
+ params: Json<LoginParams>,
+) -> Result<String, status::Forbidden<&'static str>> {
db_conn
.run(move |conn| {
let credentials = Credentials {
@@ -87,9 +99,15 @@ pub async fn login(db_conn: DbConn, params: Json<LoginParams>) -> String {
password: &params.password,
};
// TODO: handle failures
- let user = users::authenticate_user(&credentials, conn).unwrap();
- let session = sessions::create_session(&user, conn);
- return session.token;
+ let authenticated = users::authenticate_user(&credentials, conn);
+
+ match authenticated {
+ None => Err(status::Forbidden(Some("invalid auth"))),
+ Some(user) => {
+ let session = sessions::create_session(&user, conn);
+ Ok(session.token)
+ }
+ }
})
.await
}