diff options
author | Ilion Beyst <ilion.beyst@gmail.com> | 2022-10-12 22:52:15 +0200 |
---|---|---|
committer | Ilion Beyst <ilion.beyst@gmail.com> | 2022-10-12 22:52:15 +0200 |
commit | ae57359353cf31ff374a8932999742920878bf00 (patch) | |
tree | 0db27d394a2a61a5cc94e73014c82954829c1338 /planetwars-server/src | |
parent | ed016773b112460ebbf0ff023b0915545229ed41 (diff) | |
download | planetwars.dev-ae57359353cf31ff374a8932999742920878bf00.tar.xz planetwars.dev-ae57359353cf31ff374a8932999742920878bf00.zip |
upgrade to diesel 2.0
Diffstat (limited to 'planetwars-server/src')
-rw-r--r-- | planetwars-server/src/cli.rs | 4 | ||||
-rw-r--r-- | planetwars-server/src/db/bots.rs | 30 | ||||
-rw-r--r-- | planetwars-server/src/db/maps.rs | 10 | ||||
-rw-r--r-- | planetwars-server/src/db/matches.rs | 96 | ||||
-rw-r--r-- | planetwars-server/src/db/ratings.rs | 6 | ||||
-rw-r--r-- | planetwars-server/src/db/sessions.rs | 6 | ||||
-rw-r--r-- | planetwars-server/src/db/users.rs | 12 | ||||
-rw-r--r-- | planetwars-server/src/db_types.rs | 2 | ||||
-rw-r--r-- | planetwars-server/src/lib.rs | 16 | ||||
-rw-r--r-- | planetwars-server/src/modules/bots.rs | 2 | ||||
-rw-r--r-- | planetwars-server/src/modules/client_api.rs | 6 | ||||
-rw-r--r-- | planetwars-server/src/modules/matches.rs | 11 | ||||
-rw-r--r-- | planetwars-server/src/modules/ranking.rs | 13 | ||||
-rw-r--r-- | planetwars-server/src/modules/registry.rs | 40 | ||||
-rw-r--r-- | planetwars-server/src/routes/bots.rs | 56 | ||||
-rw-r--r-- | planetwars-server/src/routes/demo.rs | 9 | ||||
-rw-r--r-- | planetwars-server/src/routes/maps.rs | 4 | ||||
-rw-r--r-- | planetwars-server/src/routes/matches.rs | 16 | ||||
-rw-r--r-- | planetwars-server/src/routes/users.rs | 20 | ||||
-rw-r--r-- | planetwars-server/src/schema.rs | 41 |
20 files changed, 213 insertions, 187 deletions
diff --git a/planetwars-server/src/cli.rs b/planetwars-server/src/cli.rs index f33506e..e1eeac3 100644 --- a/planetwars-server/src/cli.rs +++ b/planetwars-server/src/cli.rs @@ -38,12 +38,12 @@ impl SetPassword { let global_config = get_config().unwrap(); let pool = create_db_pool(&global_config).await; - let conn = pool.get().await.expect("could not get database connection"); + let mut conn = pool.get().await.expect("could not get database connection"); let credentials = db::users::Credentials { username: &self.username, password: &self.new_password, }; - db::users::set_user_password(credentials, &conn).expect("could not set password"); + db::users::set_user_password(credentials, &mut conn).expect("could not set password"); } } diff --git a/planetwars-server/src/db/bots.rs b/planetwars-server/src/db/bots.rs index a0a31b0..cf8bbb5 100644 --- a/planetwars-server/src/db/bots.rs +++ b/planetwars-server/src/db/bots.rs @@ -5,7 +5,7 @@ use crate::schema::{bot_versions, bots}; use chrono; #[derive(Insertable)] -#[table_name = "bots"] +#[diesel(table_name = bots)] pub struct NewBot<'a> { pub owner_id: Option<i32>, pub name: &'a str, @@ -19,29 +19,29 @@ pub struct Bot { pub active_version: Option<i32>, } -pub fn create_bot(new_bot: &NewBot, conn: &PgConnection) -> QueryResult<Bot> { +pub fn create_bot(new_bot: &NewBot, conn: &mut PgConnection) -> QueryResult<Bot> { diesel::insert_into(bots::table) .values(new_bot) .get_result(conn) } -pub fn find_bot(id: i32, conn: &PgConnection) -> QueryResult<Bot> { +pub fn find_bot(id: i32, conn: &mut PgConnection) -> QueryResult<Bot> { bots::table.find(id).first(conn) } -pub fn find_bots_by_owner(owner_id: i32, conn: &PgConnection) -> QueryResult<Vec<Bot>> { +pub fn find_bots_by_owner(owner_id: i32, conn: &mut PgConnection) -> QueryResult<Vec<Bot>> { bots::table .filter(bots::owner_id.eq(owner_id)) .get_results(conn) } -pub fn find_bot_by_name(name: &str, conn: &PgConnection) -> QueryResult<Bot> { +pub fn find_bot_by_name(name: &str, conn: &mut PgConnection) -> QueryResult<Bot> { bots::table.filter(bots::name.eq(name)).first(conn) } pub fn find_bot_with_version_by_name( bot_name: &str, - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult<(Bot, BotVersion)> { bots::table .inner_join(bot_versions::table.on(bots::active_version.eq(bot_versions::id.nullable()))) @@ -49,26 +49,28 @@ pub fn find_bot_with_version_by_name( .first(conn) } -pub fn all_active_bots_with_version(conn: &PgConnection) -> QueryResult<Vec<(Bot, BotVersion)>> { +pub fn all_active_bots_with_version( + conn: &mut PgConnection, +) -> QueryResult<Vec<(Bot, BotVersion)>> { bots::table .inner_join(bot_versions::table.on(bots::active_version.eq(bot_versions::id.nullable()))) .get_results(conn) } -pub fn find_all_bots(conn: &PgConnection) -> QueryResult<Vec<Bot>> { +pub fn find_all_bots(conn: &mut PgConnection) -> QueryResult<Vec<Bot>> { bots::table.get_results(conn) } /// Find all bots that have an associated active version. /// These are the bots that can be run. -pub fn find_active_bots(conn: &PgConnection) -> QueryResult<Vec<Bot>> { +pub fn find_active_bots(conn: &mut PgConnection) -> QueryResult<Vec<Bot>> { bots::table .filter(bots::active_version.is_not_null()) .get_results(conn) } #[derive(Insertable)] -#[table_name = "bot_versions"] +#[diesel(table_name = bot_versions)] pub struct NewBotVersion<'a> { pub bot_id: Option<i32>, pub code_bundle_path: Option<&'a str>, @@ -86,7 +88,7 @@ pub struct BotVersion { pub fn create_bot_version( new_bot_version: &NewBotVersion, - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult<BotVersion> { diesel::insert_into(bot_versions::table) .values(new_bot_version) @@ -96,7 +98,7 @@ pub fn create_bot_version( pub fn set_active_version( bot_id: i32, version_id: Option<i32>, - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult<()> { diesel::update(bots::table.filter(bots::id.eq(bot_id))) .set(bots::active_version.eq(version_id)) @@ -104,13 +106,13 @@ pub fn set_active_version( Ok(()) } -pub fn find_bot_version(version_id: i32, conn: &PgConnection) -> QueryResult<BotVersion> { +pub fn find_bot_version(version_id: i32, conn: &mut PgConnection) -> QueryResult<BotVersion> { bot_versions::table .filter(bot_versions::id.eq(version_id)) .first(conn) } -pub fn find_bot_versions(bot_id: i32, conn: &PgConnection) -> QueryResult<Vec<BotVersion>> { +pub fn find_bot_versions(bot_id: i32, conn: &mut PgConnection) -> QueryResult<Vec<BotVersion>> { bot_versions::table .filter(bot_versions::bot_id.eq(bot_id)) .get_results(conn) diff --git a/planetwars-server/src/db/maps.rs b/planetwars-server/src/db/maps.rs index dffe4fd..8972461 100644 --- a/planetwars-server/src/db/maps.rs +++ b/planetwars-server/src/db/maps.rs @@ -3,7 +3,7 @@ use diesel::prelude::*; use crate::schema::maps; #[derive(Insertable)] -#[table_name = "maps"] +#[diesel(table_name = maps)] pub struct NewMap<'a> { pub name: &'a str, pub file_path: &'a str, @@ -16,20 +16,20 @@ pub struct Map { pub file_path: String, } -pub fn create_map(new_map: NewMap, conn: &PgConnection) -> QueryResult<Map> { +pub fn create_map(new_map: NewMap, conn: &mut PgConnection) -> QueryResult<Map> { diesel::insert_into(maps::table) .values(new_map) .get_result(conn) } -pub fn find_map(id: i32, conn: &PgConnection) -> QueryResult<Map> { +pub fn find_map(id: i32, conn: &mut PgConnection) -> QueryResult<Map> { maps::table.find(id).get_result(conn) } -pub fn find_map_by_name(name: &str, conn: &PgConnection) -> QueryResult<Map> { +pub fn find_map_by_name(name: &str, conn: &mut PgConnection) -> QueryResult<Map> { maps::table.filter(maps::name.eq(name)).first(conn) } -pub fn list_maps(conn: &PgConnection) -> QueryResult<Vec<Map>> { +pub fn list_maps(conn: &mut PgConnection) -> QueryResult<Vec<Map>> { maps::table.get_results(conn) } diff --git a/planetwars-server/src/db/matches.rs b/planetwars-server/src/db/matches.rs index 1dded43..bfec892 100644 --- a/planetwars-server/src/db/matches.rs +++ b/planetwars-server/src/db/matches.rs @@ -1,9 +1,6 @@ pub use crate::db_types::MatchState; use chrono::NaiveDateTime; use diesel::associations::BelongsTo; -use diesel::pg::Pg; -use diesel::query_builder::BoxedSelectStatement; -use diesel::query_source::{AppearsInFromClause, Once}; use diesel::sql_types::*; use diesel::{ BelongingToDsl, ExpressionMethods, JoinOnDsl, NullableExpressionMethods, QueryDsl, RunQueryDsl, @@ -18,7 +15,7 @@ use super::bots::{Bot, BotVersion}; use super::maps::Map; #[derive(Insertable)] -#[table_name = "matches"] +#[diesel(table_name = matches)] pub struct NewMatch<'a> { pub state: MatchState, pub log_path: &'a str, @@ -27,7 +24,7 @@ pub struct NewMatch<'a> { } #[derive(Insertable)] -#[table_name = "match_players"] +#[diesel(table_name = match_players)] pub struct NewMatchPlayer { /// id of the match this player is in pub match_id: i32, @@ -38,7 +35,7 @@ pub struct NewMatchPlayer { } #[derive(Queryable, Identifiable)] -#[table_name = "matches"] +#[diesel(table_name = matches)] pub struct MatchBase { pub id: i32, pub state: MatchState, @@ -50,8 +47,8 @@ pub struct MatchBase { } #[derive(Queryable, Identifiable, Associations, Clone)] -#[primary_key(match_id, player_id)] -#[belongs_to(MatchBase, foreign_key = "match_id")] +#[diesel(primary_key(match_id, player_id))] +#[diesel(belongs_to(MatchBase, foreign_key = match_id))] pub struct MatchPlayer { pub match_id: i32, pub player_id: i32, @@ -65,9 +62,9 @@ pub struct MatchPlayerData { pub fn create_match( new_match_base: &NewMatch, new_match_players: &[MatchPlayerData], - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult<MatchData> { - conn.transaction(|| { + conn.transaction(|conn| { let match_base = diesel::insert_into(matches::table) .values(new_match_base) .get_result::<MatchBase>(conn)?; @@ -101,7 +98,7 @@ pub struct MatchData { /// Add player information to MatchBase instances fn fetch_full_match_data( matches: Vec<MatchBase>, - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult<Vec<FullMatchData>> { let map_ids: HashSet<i32> = matches.iter().filter_map(|m| m.map_id).collect(); @@ -140,8 +137,8 @@ fn fetch_full_match_data( } // TODO: this method should disappear -pub fn list_matches(amount: i64, conn: &PgConnection) -> QueryResult<Vec<FullMatchData>> { - conn.transaction(|| { +pub fn list_matches(amount: i64, conn: &mut PgConnection) -> QueryResult<Vec<FullMatchData>> { + conn.transaction(|conn| { let matches = matches::table .filter(matches::state.eq(MatchState::Finished)) .order_by(matches::created_at.desc()) @@ -164,17 +161,32 @@ pub fn list_public_matches( amount: i64, before: Option<NaiveDateTime>, after: Option<NaiveDateTime>, - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult<Vec<FullMatchData>> { - conn.transaction(|| { + conn.transaction(|conn| { // TODO: how can this common logic be abstracted? - let query = matches::table + let mut query = matches::table .filter(matches::state.eq(MatchState::Finished)) .filter(matches::is_public.eq(true)) .into_boxed(); - let matches = - select_matches_page(query, amount, before, after).get_results::<MatchBase>(conn)?; + // TODO: how to remove this duplication? + query = match (before, after) { + (None, None) => query.order_by(matches::created_at.desc()), + (Some(before), None) => query + .filter(matches::created_at.lt(before)) + .order_by(matches::created_at.desc()), + (None, Some(after)) => query + .filter(matches::created_at.gt(after)) + .order_by(matches::created_at.asc()), + (Some(before), Some(after)) => query + .filter(matches::created_at.lt(before)) + .filter(matches::created_at.gt(after)) + .order_by(matches::created_at.desc()), + }; + query = query.limit(amount); + + let matches = query.get_results::<MatchBase>(conn)?; fetch_full_match_data(matches, conn) }) } @@ -185,7 +197,7 @@ pub fn list_bot_matches( amount: i64, before: Option<NaiveDateTime>, after: Option<NaiveDateTime>, - conn: &PgConnection, + conn: &mut PgConnection, ) -> QueryResult<Vec<FullMatchData>> { let mut query = matches::table .filter(matches::state.eq(MatchState::Finished)) @@ -211,22 +223,8 @@ pub fn list_bot_matches( }; } - let matches = - select_matches_page(query, amount, before, after).get_results::<MatchBase>(conn)?; - fetch_full_match_data(matches, conn) -} - -fn select_matches_page<QS>( - query: BoxedSelectStatement<'static, matches::SqlType, QS, Pg>, - amount: i64, - before: Option<NaiveDateTime>, - after: Option<NaiveDateTime>, -) -> BoxedSelectStatement<'static, matches::SqlType, QS, Pg> -where - QS: AppearsInFromClause<matches::table, Count = Once>, -{ - // TODO: this is not nice. Replace this with proper cursor logic. - match (before, after) { + // TODO: how to remove this duplication? + query = match (before, after) { (None, None) => query.order_by(matches::created_at.desc()), (Some(before), None) => query .filter(matches::created_at.lt(before)) @@ -238,8 +236,11 @@ where .filter(matches::created_at.lt(before)) .filter(matches::created_at.gt(after)) .order_by(matches::created_at.desc()), - } - .limit(amount) + }; + query = query.limit(amount); + + let matches = query.get_results::<MatchBase>(conn)?; + fetch_full_match_data(matches, conn) } // TODO: maybe unify this with matchdata? @@ -270,8 +271,8 @@ impl BelongsTo<MatchBase> for FullMatchPlayerData { } } -pub fn find_match(id: i32, conn: &PgConnection) -> QueryResult<FullMatchData> { - conn.transaction(|| { +pub fn find_match(id: i32, conn: &mut PgConnection) -> QueryResult<FullMatchData> { + conn.transaction(|conn| { let match_base = matches::table.find(id).get_result::<MatchBase>(conn)?; let map = match match_base.map_id { @@ -298,7 +299,7 @@ pub fn find_match(id: i32, conn: &PgConnection) -> QueryResult<FullMatchData> { }) } -pub fn find_match_base(id: i32, conn: &PgConnection) -> QueryResult<MatchBase> { +pub fn find_match_base(id: i32, conn: &mut PgConnection) -> QueryResult<MatchBase> { matches::table.find(id).get_result::<MatchBase>(conn) } @@ -306,7 +307,7 @@ pub enum MatchResult { Finished { winner: Option<i32> }, } -pub fn save_match_result(id: i32, result: MatchResult, conn: &PgConnection) -> QueryResult<()> { +pub fn save_match_result(id: i32, result: MatchResult, conn: &mut PgConnection) -> QueryResult<()> { let MatchResult::Finished { winner } = result; diesel::update(matches::table.find(id)) @@ -320,17 +321,20 @@ pub fn save_match_result(id: i32, result: MatchResult, conn: &PgConnection) -> Q #[derive(QueryableByName)] pub struct BotStatsRecord { - #[sql_type = "Text"] + #[diesel(sql_type = Text)] pub opponent: String, - #[sql_type = "Text"] + #[diesel(sql_type = Text)] pub map: String, - #[sql_type = "Nullable<Bool>"] + #[diesel(sql_type = Nullable<Bool>)] pub win: Option<bool>, - #[sql_type = "Int8"] + #[diesel(sql_type = Int8)] pub count: i64, } -pub fn fetch_bot_stats(bot_name: &str, db_conn: &PgConnection) -> QueryResult<Vec<BotStatsRecord>> { +pub fn fetch_bot_stats( + bot_name: &str, + db_conn: &mut PgConnection, +) -> QueryResult<Vec<BotStatsRecord>> { diesel::sql_query( " SELECT opponent, map, win, COUNT(*) as count diff --git a/planetwars-server/src/db/ratings.rs b/planetwars-server/src/db/ratings.rs index 8262fed..0a510d4 100644 --- a/planetwars-server/src/db/ratings.rs +++ b/planetwars-server/src/db/ratings.rs @@ -10,7 +10,7 @@ pub struct Rating { pub rating: f64, } -pub fn get_rating(bot_id: i32, db_conn: &PgConnection) -> QueryResult<Option<f64>> { +pub fn get_rating(bot_id: i32, db_conn: &mut PgConnection) -> QueryResult<Option<f64>> { ratings::table .filter(ratings::bot_id.eq(bot_id)) .select(ratings::rating) @@ -18,7 +18,7 @@ pub fn get_rating(bot_id: i32, db_conn: &PgConnection) -> QueryResult<Option<f64 .optional() } -pub fn set_rating(bot_id: i32, rating: f64, db_conn: &PgConnection) -> QueryResult<usize> { +pub fn set_rating(bot_id: i32, rating: f64, db_conn: &mut PgConnection) -> QueryResult<usize> { diesel::insert_into(ratings::table) .values(Rating { bot_id, rating }) .on_conflict(ratings::bot_id) @@ -40,7 +40,7 @@ pub struct RankedBot { pub rating: f64, } -pub fn get_bot_ranking(db_conn: &PgConnection) -> QueryResult<Vec<RankedBot>> { +pub fn get_bot_ranking(db_conn: &mut PgConnection) -> QueryResult<Vec<RankedBot>> { bots::table .left_join(users::table) .inner_join(ratings::table) diff --git a/planetwars-server/src/db/sessions.rs b/planetwars-server/src/db/sessions.rs index a91d954..f8108cc 100644 --- a/planetwars-server/src/db/sessions.rs +++ b/planetwars-server/src/db/sessions.rs @@ -6,7 +6,7 @@ use diesel::{insert_into, prelude::*, Insertable, RunQueryDsl}; use rand::{self, Rng}; #[derive(Insertable)] -#[table_name = "sessions"] +#[diesel(table_name = sessions)] struct NewSession { token: String, user_id: i32, @@ -19,7 +19,7 @@ pub struct Session { pub token: String, } -pub fn create_session(user: &User, conn: &PgConnection) -> Session { +pub fn create_session(user: &User, conn: &mut PgConnection) -> Session { let new_session = NewSession { token: gen_session_token(), user_id: user.id, @@ -31,7 +31,7 @@ pub fn create_session(user: &User, conn: &PgConnection) -> Session { .unwrap() } -pub fn find_user_by_session(token: &str, conn: &PgConnection) -> QueryResult<(Session, User)> { +pub fn find_user_by_session(token: &str, conn: &mut PgConnection) -> QueryResult<(Session, User)> { sessions::table .inner_join(users::table) .filter(sessions::token.eq(&token)) diff --git a/planetwars-server/src/db/users.rs b/planetwars-server/src/db/users.rs index 9676dae..60cc20a 100644 --- a/planetwars-server/src/db/users.rs +++ b/planetwars-server/src/db/users.rs @@ -11,7 +11,7 @@ pub struct Credentials<'a> { } #[derive(Insertable)] -#[table_name = "users"] +#[diesel(table_name = users)] pub struct NewUser<'a> { pub username: &'a str, pub password_hash: &'a [u8], @@ -50,7 +50,7 @@ pub fn hash_password(password: &str) -> (Vec<u8>, [u8; 32]) { (hash, salt) } -pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResult<User> { +pub fn create_user(credentials: &Credentials, conn: &mut PgConnection) -> QueryResult<User> { let (hash, salt) = hash_password(&credentials.password); let new_user = NewUser { @@ -63,19 +63,19 @@ pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResul .get_result::<User>(conn) } -pub fn find_user(user_id: i32, db_conn: &PgConnection) -> QueryResult<User> { +pub fn find_user(user_id: i32, db_conn: &mut PgConnection) -> QueryResult<User> { users::table .filter(users::id.eq(user_id)) .first::<User>(db_conn) } -pub fn find_user_by_name(username: &str, db_conn: &PgConnection) -> QueryResult<User> { +pub fn find_user_by_name(username: &str, db_conn: &mut PgConnection) -> QueryResult<User> { users::table .filter(users::username.eq(username)) .first::<User>(db_conn) } -pub fn set_user_password(credentials: Credentials, db_conn: &PgConnection) -> QueryResult<()> { +pub fn set_user_password(credentials: Credentials, db_conn: &mut PgConnection) -> QueryResult<()> { let (hash, salt) = hash_password(&credentials.password); let n_changes = diesel::update(users::table.filter(users::username.eq(&credentials.username))) @@ -91,7 +91,7 @@ pub fn set_user_password(credentials: Credentials, db_conn: &PgConnection) -> Qu } } -pub fn authenticate_user(credentials: &Credentials, db_conn: &PgConnection) -> Option<User> { +pub fn authenticate_user(credentials: &Credentials, db_conn: &mut PgConnection) -> Option<User> { find_user_by_name(credentials.username, db_conn) .optional() .unwrap() diff --git a/planetwars-server/src/db_types.rs b/planetwars-server/src/db_types.rs index 1b99e49..29b1e9b 100644 --- a/planetwars-server/src/db_types.rs +++ b/planetwars-server/src/db_types.rs @@ -2,7 +2,7 @@ use diesel_derive_enum::DbEnum; use serde::{Deserialize, Serialize}; #[derive(DbEnum, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] -#[DieselType = "Match_state"] +#[DieselTypePath = "crate::schema::sql_types::MatchState"] pub enum MatchState { Playing, diff --git a/planetwars-server/src/lib.rs b/planetwars-server/src/lib.rs index 1696f1a..316458c 100644 --- a/planetwars-server/src/lib.rs +++ b/planetwars-server/src/lib.rs @@ -8,7 +8,7 @@ pub mod routes; pub mod schema; pub mod util; -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; use std::path::PathBuf; use std::sync::Arc; use std::{fs, net::SocketAddr}; @@ -70,9 +70,9 @@ pub struct GlobalConfig { const SIMPLEBOT_PATH: &str = "../simplebot/simplebot.py"; pub async fn seed_simplebot(config: &GlobalConfig, pool: &ConnectionPool) { - let conn = pool.get().await.expect("could not get database connection"); + let mut conn = pool.get().await.expect("could not get database connection"); // This transaction is expected to fail when simplebot already exists. - let _res = conn.transaction::<(), diesel::result::Error, _>(|| { + let _res = conn.transaction::<(), diesel::result::Error, _>(|conn| { use db::bots::NewBot; let new_bot = NewBot { @@ -80,12 +80,12 @@ pub async fn seed_simplebot(config: &GlobalConfig, pool: &ConnectionPool) { owner_id: None, }; - let simplebot = db::bots::create_bot(&new_bot, &conn)?; + let simplebot = db::bots::create_bot(&new_bot, conn)?; let simplebot_code = std::fs::read_to_string(SIMPLEBOT_PATH).expect("could not read simplebot code"); - modules::bots::save_code_string(&simplebot_code, Some(simplebot.id), &conn, config)?; + modules::bots::save_code_string(&simplebot_code, Some(simplebot.id), conn, config)?; println!("initialized simplebot"); @@ -209,6 +209,12 @@ impl Deref for DatabaseConnection { } } +impl DerefMut for DatabaseConnection { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + #[async_trait] impl<B> FromRequest<B> for DatabaseConnection where diff --git a/planetwars-server/src/modules/bots.rs b/planetwars-server/src/modules/bots.rs index 6a2883c..6893581 100644 --- a/planetwars-server/src/modules/bots.rs +++ b/planetwars-server/src/modules/bots.rs @@ -9,7 +9,7 @@ use crate::{db, util::gen_alphanumeric, GlobalConfig}; pub fn save_code_string( bot_code: &str, bot_id: Option<i32>, - conn: &PgConnection, + conn: &mut PgConnection, config: &GlobalConfig, ) -> QueryResult<db::bots::BotVersion> { let bundle_name = gen_alphanumeric(16); diff --git a/planetwars-server/src/modules/client_api.rs b/planetwars-server/src/modules/client_api.rs index 6e5d05a..9c0bbe7 100644 --- a/planetwars-server/src/modules/client_api.rs +++ b/planetwars-server/src/modules/client_api.rs @@ -149,19 +149,19 @@ impl pb::client_api_service_server::ClientApiService for ClientApiServer { req: Request<pb::CreateMatchRequest>, ) -> Result<Response<pb::CreateMatchResponse>, Status> { // TODO: unify with matchrunner module - let conn = self.conn_pool.get().await.unwrap(); + let mut conn = self.conn_pool.get().await.unwrap(); let match_request = req.get_ref(); let (opponent_bot, opponent_bot_version) = - db::bots::find_bot_with_version_by_name(&match_request.opponent_name, &conn) + db::bots::find_bot_with_version_by_name(&match_request.opponent_name, &mut conn) .map_err(|_| Status::not_found("opponent not found"))?; let map_name = match match_request.map_name.as_str() { "" => "hex", name => name, }; - let map = db::maps::find_map_by_name(map_name, &conn) + let map = db::maps::find_map_by_name(map_name, &mut conn) .map_err(|_| Status::not_found("map not found"))?; let player_key = gen_alphanumeric(32); diff --git a/planetwars-server/src/modules/matches.rs b/planetwars-server/src/modules/matches.rs index ecc7976..71e8a98 100644 --- a/planetwars-server/src/modules/matches.rs +++ b/planetwars-server/src/modules/matches.rs @@ -80,8 +80,8 @@ impl RunMatch { let match_data = { // TODO: it would be nice to get an already-open connection here when possible. // Maybe we need an additional abstraction, bundling a connection and connection pool? - let db_conn = conn_pool.get().await.expect("could not get a connection"); - self.store_in_database(&db_conn)? + let mut db_conn = conn_pool.get().await.expect("could not get a connection"); + self.store_in_database(&mut db_conn)? }; let runner_config = self.into_runner_config(); @@ -90,7 +90,7 @@ impl RunMatch { Ok((match_data, handle)) } - fn store_in_database(&self, db_conn: &PgConnection) -> QueryResult<MatchData> { + fn store_in_database(&self, db_conn: &mut PgConnection) -> QueryResult<MatchData> { let new_match_data = db::matches::NewMatch { state: db::matches::MatchState::Playing, log_path: &self.log_file_name, @@ -167,7 +167,7 @@ async fn run_match_task( let outcome = runner::run_match(match_config).await; // update match state in database - let conn = connection_pool + let mut conn = connection_pool .get() .await .expect("could not get database connection"); @@ -176,7 +176,8 @@ async fn run_match_task( winner: outcome.winner.map(|w| (w - 1) as i32), // player numbers in matchrunner start at 1 }; - db::matches::save_match_result(match_id, result, &conn).expect("could not save match result"); + db::matches::save_match_result(match_id, result, &mut conn) + .expect("could not save match result"); outcome } diff --git a/planetwars-server/src/modules/ranking.rs b/planetwars-server/src/modules/ranking.rs index 90c4a56..92f0f8a 100644 --- a/planetwars-server/src/modules/ranking.rs +++ b/planetwars-server/src/modules/ranking.rs @@ -20,13 +20,14 @@ pub async fn run_ranker(config: Arc<GlobalConfig>, db_pool: DbPool) { // TODO: make this configurable // play at most one match every n seconds let mut interval = tokio::time::interval(Duration::from_secs(RANKER_INTERVAL)); - let db_conn = db_pool + let mut db_conn = db_pool .get() .await .expect("could not get database connection"); loop { interval.tick().await; - let bots = db::bots::all_active_bots_with_version(&db_conn).expect("could not load bots"); + let bots = + db::bots::all_active_bots_with_version(&mut db_conn).expect("could not load bots"); if bots.len() < 2 { // not enough bots to play a match continue; @@ -37,14 +38,14 @@ pub async fn run_ranker(config: Arc<GlobalConfig>, db_pool: DbPool) { .cloned() .collect(); - let maps = db::maps::list_maps(&db_conn).expect("could not load map"); + let maps = db::maps::list_maps(&mut db_conn).expect("could not load map"); let map = match maps.choose(&mut rand::thread_rng()).cloned() { None => continue, // no maps available Some(map) => map, }; play_ranking_match(config.clone(), map, selected_bots, db_pool.clone()).await; - recalculate_ratings(&db_conn).expect("could not recalculate ratings"); + recalculate_ratings(&mut db_conn).expect("could not recalculate ratings"); } } @@ -71,7 +72,7 @@ async fn play_ranking_match( let _outcome = handle.await; } -fn recalculate_ratings(db_conn: &PgConnection) -> QueryResult<()> { +fn recalculate_ratings(db_conn: &mut PgConnection) -> QueryResult<()> { let start = Instant::now(); let match_stats = fetch_match_stats(db_conn)?; let ratings = estimate_ratings_from_stats(match_stats); @@ -91,7 +92,7 @@ struct MatchStats { num_matches: usize, } -fn fetch_match_stats(db_conn: &PgConnection) -> QueryResult<HashMap<(i32, i32), MatchStats>> { +fn fetch_match_stats(db_conn: &mut PgConnection) -> QueryResult<HashMap<(i32, i32), MatchStats>> { let matches = db::matches::list_matches(RANKER_NUM_MATCHES, db_conn)?; let mut match_stats = HashMap::<(i32, i32), MatchStats>::new(); diff --git a/planetwars-server/src/modules/registry.rs b/planetwars-server/src/modules/registry.rs index 4a79d59..5e1e05b 100644 --- a/planetwars-server/src/modules/registry.rs +++ b/planetwars-server/src/modules/registry.rs @@ -112,8 +112,8 @@ where Err(RegistryAuthError::InvalidCredentials) } } else { - let db_conn = DatabaseConnection::from_request(req).await.unwrap(); - let user = authenticate_user(&credentials, &db_conn) + let mut db_conn = DatabaseConnection::from_request(req).await.unwrap(); + let user = authenticate_user(&credentials, &mut db_conn) .ok_or(RegistryAuthError::InvalidCredentials)?; Ok(RegistryAuth::User(user)) @@ -159,12 +159,12 @@ pub struct RegistryError { } async fn check_blob_exists( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path((repository_name, raw_digest)): Path<(String, String)>, Extension(config): Extension<Arc<GlobalConfig>>, ) -> Result<impl IntoResponse, StatusCode> { - check_access(&repository_name, &auth, &db_conn)?; + check_access(&repository_name, &auth, &mut db_conn)?; let digest = raw_digest.strip_prefix("sha256:").unwrap(); let blob_path = PathBuf::from(&config.registry_directory) @@ -179,12 +179,12 @@ async fn check_blob_exists( } async fn get_blob( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path((repository_name, raw_digest)): Path<(String, String)>, Extension(config): Extension<Arc<GlobalConfig>>, ) -> Result<impl IntoResponse, StatusCode> { - check_access(&repository_name, &auth, &db_conn)?; + check_access(&repository_name, &auth, &mut db_conn)?; let digest = raw_digest.strip_prefix("sha256:").unwrap(); let blob_path = PathBuf::from(&config.registry_directory) @@ -200,12 +200,12 @@ async fn get_blob( } async fn create_upload( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path(repository_name): Path<String>, Extension(config): Extension<Arc<GlobalConfig>>, ) -> Result<impl IntoResponse, StatusCode> { - check_access(&repository_name, &auth, &db_conn)?; + check_access(&repository_name, &auth, &mut db_conn)?; let uuid = gen_alphanumeric(16); tokio::fs::File::create( @@ -229,13 +229,13 @@ async fn create_upload( } async fn patch_upload( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path((repository_name, uuid)): Path<(String, String)>, mut stream: BodyStream, Extension(config): Extension<Arc<GlobalConfig>>, ) -> Result<impl IntoResponse, StatusCode> { - check_access(&repository_name, &auth, &db_conn)?; + check_access(&repository_name, &auth, &mut db_conn)?; // TODO: support content range header in request let upload_path = PathBuf::from(&config.registry_directory) @@ -275,14 +275,14 @@ struct UploadParams { } async fn put_upload( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path((repository_name, uuid)): Path<(String, String)>, Query(params): Query<UploadParams>, mut stream: BodyStream, Extension(config): Extension<Arc<GlobalConfig>>, ) -> Result<impl IntoResponse, StatusCode> { - check_access(&repository_name, &auth, &db_conn)?; + check_access(&repository_name, &auth, &mut db_conn)?; let upload_path = PathBuf::from(&config.registry_directory) .join("uploads") @@ -332,12 +332,12 @@ async fn put_upload( } async fn get_manifest( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path((repository_name, reference)): Path<(String, String)>, Extension(config): Extension<Arc<GlobalConfig>>, ) -> Result<impl IntoResponse, StatusCode> { - check_access(&repository_name, &auth, &db_conn)?; + check_access(&repository_name, &auth, &mut db_conn)?; let manifest_path = PathBuf::from(&config.registry_directory) .join("manifests") @@ -357,13 +357,13 @@ async fn get_manifest( } async fn put_manifest( - db_conn: DatabaseConnection, + mut db_conn: DatabaseConnection, auth: RegistryAuth, Path((repository_name, reference)): Path<(String, String)>, mut stream: BodyStream, Extension(config): Extension<Arc<GlobalConfig>>, ) -> Result<impl IntoResponse, StatusCode> { - let bot = check_access(&repository_name, &auth, &db_conn)?; + let bot = check_access(&repository_name, &auth, &mut db_conn)?; let repository_dir = PathBuf::from(&config.registry_directory) .join("manifests") @@ -399,9 +399,9 @@ async fn put_manifest( code_bundle_path: None, container_digest: Some(&content_digest), }; - let version = - db::bots::create_bot_version(&new_version, &db_conn).expect("could not save bot version"); - db::bots::set_active_version(bot.id, Some(version.id), &db_conn) + let version = db::bots::create_bot_version(&new_version, &mut db_conn) + .expect("could not save bot version"); + db::bots::set_active_version(bot.id, Some(version.id), &mut db_conn) .expect("could not update bot version"); Ok(Response::builder() @@ -421,7 +421,7 @@ async fn put_manifest( fn check_access( repository_name: &str, auth: &RegistryAuth, - db_conn: &DatabaseConnection, + db_conn: &mut DatabaseConnection, ) -> Result<db::bots::Bot, StatusCode> { use diesel::OptionalExtension; diff --git a/planetwars-server/src/routes/bots.rs b/planetwars-server/src/routes/bots.rs index f8087fd..f0ff9bf 100644 --- a/planetwars-server/src/routes/bots.rs +++ b/planetwars-server/src/routes/bots.rs @@ -100,10 +100,10 @@ pub fn validate_bot_name(bot_name: &str) -> Result<(), SaveBotError> { pub async fn save_bot( Json(params): Json<SaveBotParams>, user: User, - conn: DatabaseConnection, + mut conn: DatabaseConnection, Extension(config): Extension<Arc<GlobalConfig>>, ) -> Result<Json<Bot>, SaveBotError> { - let res = bots::find_bot_by_name(¶ms.bot_name, &conn) + let res = bots::find_bot_by_name(¶ms.bot_name, &mut conn) .optional() .expect("could not run query"); @@ -122,10 +122,10 @@ pub async fn save_bot( name: ¶ms.bot_name, }; - bots::create_bot(&new_bot, &conn).expect("could not create bot") + bots::create_bot(&new_bot, &mut conn).expect("could not create bot") } }; - let _code_bundle = save_code_string(¶ms.code, Some(bot.id), &conn, &config) + let _code_bundle = save_code_string(¶ms.code, Some(bot.id), &mut conn, &config) .expect("failed to save code bundle"); Ok(Json(bot)) } @@ -137,12 +137,12 @@ pub struct BotParams { // TODO: can we unify this with save_bot? pub async fn create_bot( - conn: DatabaseConnection, + mut conn: DatabaseConnection, user: User, params: Json<BotParams>, ) -> Result<(StatusCode, Json<Bot>), SaveBotError> { validate_bot_name(¶ms.name)?; - let existing_bot = bots::find_bot_by_name(¶ms.name, &conn) + let existing_bot = bots::find_bot_by_name(¶ms.name, &mut conn) .optional() .expect("could not run query"); if existing_bot.is_some() { @@ -152,26 +152,27 @@ pub async fn create_bot( owner_id: Some(user.id), name: ¶ms.name, }; - let bot = bots::create_bot(&bot_params, &conn).unwrap(); + let bot = bots::create_bot(&bot_params, &mut conn).unwrap(); Ok((StatusCode::CREATED, Json(bot))) } // TODO: handle errors pub async fn get_bot( - conn: DatabaseConnection, + mut conn: DatabaseConnection, Path(bot_name): Path<String>, ) -> Result<Json<JsonValue>, StatusCode> { - let bot = db::bots::find_bot_by_name(&bot_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?; + let bot = + db::bots::find_bot_by_name(&bot_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?; let owner: Option<UserData> = match bot.owner_id { Some(user_id) => { - let user = db::users::find_user(user_id, &conn) + let user = db::users::find_user(user_id, &mut conn) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Some(user.into()) } None => None, }; - let versions = - bots::find_bot_versions(bot.id, &conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let versions = bots::find_bot_versions(bot.id, &mut conn) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(Json(json!({ "bot": bot, "owner": owner, @@ -180,32 +181,32 @@ pub async fn get_bot( } pub async fn get_user_bots( - conn: DatabaseConnection, + mut conn: DatabaseConnection, Path(user_name): Path<String>, ) -> Result<Json<Vec<Bot>>, StatusCode> { let user = - db::users::find_user_by_name(&user_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?; - db::bots::find_bots_by_owner(user.id, &conn) + db::users::find_user_by_name(&user_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?; + db::bots::find_bots_by_owner(user.id, &mut conn) .map(Json) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } /// List all active bots -pub async fn list_bots(conn: DatabaseConnection) -> Result<Json<Vec<Bot>>, StatusCode> { - bots::find_active_bots(&conn) +pub async fn list_bots(mut conn: DatabaseConnection) -> Result<Json<Vec<Bot>>, StatusCode> { + bots::find_active_bots(&mut conn) .map(Json) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } -pub async fn get_ranking(conn: DatabaseConnection) -> Result<Json<Vec<RankedBot>>, StatusCode> { - ratings::get_bot_ranking(&conn) +pub async fn get_ranking(mut conn: DatabaseConnection) -> Result<Json<Vec<RankedBot>>, StatusCode> { + ratings::get_bot_ranking(&mut conn) .map(Json) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } // TODO: currently this only implements the happy flow pub async fn upload_code_multipart( - conn: DatabaseConnection, + mut conn: DatabaseConnection, user: User, Path(bot_name): Path<String>, mut multipart: Multipart, @@ -213,7 +214,7 @@ pub async fn upload_code_multipart( ) -> Result<Json<BotVersion>, StatusCode> { let bots_dir = PathBuf::from(&config.bots_directory); - let bot = bots::find_bot_by_name(&bot_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?; + let bot = bots::find_bot_by_name(&bot_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?; if Some(user.id) != bot.owner_id { return Err(StatusCode::FORBIDDEN); @@ -246,21 +247,22 @@ pub async fn upload_code_multipart( container_digest: None, }; let code_bundle = - bots::create_bot_version(&bot_version, &conn).expect("Failed to create code bundle"); + bots::create_bot_version(&bot_version, &mut conn).expect("Failed to create code bundle"); Ok(Json(code_bundle)) } pub async fn get_code( - conn: DatabaseConnection, + mut conn: DatabaseConnection, user: User, Path(bundle_id): Path<i32>, Extension(config): Extension<Arc<GlobalConfig>>, ) -> Result<Vec<u8>, StatusCode> { let version = - db::bots::find_bot_version(bundle_id, &conn).map_err(|_| StatusCode::NOT_FOUND)?; + db::bots::find_bot_version(bundle_id, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?; let bot_id = version.bot_id.ok_or(StatusCode::FORBIDDEN)?; - let bot = db::bots::find_bot(bot_id, &conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let bot = + db::bots::find_bot(bot_id, &mut conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; if bot.owner_id != Some(user.id) { return Err(StatusCode::FORBIDDEN); @@ -297,10 +299,10 @@ impl MatchupStats { type BotStats = HashMap<String, HashMap<String, MatchupStats>>; pub async fn get_bot_stats( - conn: DatabaseConnection, + mut conn: DatabaseConnection, Path(bot_name): Path<String>, ) -> Result<Json<BotStats>, StatusCode> { - let stats_records = db::matches::fetch_bot_stats(&bot_name, &conn) + let stats_records = db::matches::fetch_bot_stats(&bot_name, &mut conn) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let mut bot_stats: BotStats = HashMap::new(); for record in stats_records { diff --git a/planetwars-server/src/routes/demo.rs b/planetwars-server/src/routes/demo.rs index 1ec8825..cd490ef 100644 --- a/planetwars-server/src/routes/demo.rs +++ b/planetwars-server/src/routes/demo.rs @@ -35,7 +35,7 @@ pub async fn submit_bot( Extension(pool): Extension<ConnectionPool>, Extension(config): Extension<Arc<GlobalConfig>>, ) -> Result<Json<SubmitBotResponse>, StatusCode> { - let conn = pool.get().await.expect("could not get database connection"); + let mut conn = pool.get().await.expect("could not get database connection"); let opponent_name = params .opponent_name @@ -46,12 +46,13 @@ pub async fn submit_bot( .unwrap_or_else(|| DEFAULT_MAP_NAME.to_string()); let (opponent_bot, opponent_bot_version) = - db::bots::find_bot_with_version_by_name(&opponent_name, &conn) + db::bots::find_bot_with_version_by_name(&opponent_name, &mut conn) .map_err(|_| StatusCode::BAD_REQUEST)?; - let map = db::maps::find_map_by_name(&map_name, &conn).map_err(|_| StatusCode::BAD_REQUEST)?; + let map = + db::maps::find_map_by_name(&map_name, &mut conn).map_err(|_| StatusCode::BAD_REQUEST)?; - let player_bot_version = save_code_string(¶ms.code, None, &conn, &config) + let player_bot_version = save_code_string(¶ms.code, None, &mut conn, &config) // TODO: can we recover from this? .expect("could not save bot code"); diff --git a/planetwars-server/src/routes/maps.rs b/planetwars-server/src/routes/maps.rs index 689b11e..188089f 100644 --- a/planetwars-server/src/routes/maps.rs +++ b/planetwars-server/src/routes/maps.rs @@ -8,8 +8,8 @@ pub struct ApiMap { pub name: String, } -pub async fn list_maps(conn: DatabaseConnection) -> Result<Json<Vec<ApiMap>>, StatusCode> { - let maps = db::maps::list_maps(&conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; +pub async fn list_maps(mut conn: DatabaseConnection) -> Result<Json<Vec<ApiMap>>, StatusCode> { + let maps = db::maps::list_maps(&mut conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let api_maps = maps .into_iter() diff --git a/planetwars-server/src/routes/matches.rs b/planetwars-server/src/routes/matches.rs index 1d7403c..3ad10cf 100644 --- a/planetwars-server/src/routes/matches.rs +++ b/planetwars-server/src/routes/matches.rs @@ -56,7 +56,7 @@ pub struct ListMatchesResponse { pub async fn list_recent_matches( Query(params): Query<ListRecentMatchesParams>, - conn: DatabaseConnection, + mut conn: DatabaseConnection, ) -> Result<Json<ListMatchesResponse>, StatusCode> { let requested_count = std::cmp::min( params.count.unwrap_or(DEFAULT_NUM_RETURNED_MATCHES), @@ -68,7 +68,7 @@ pub async fn list_recent_matches( let matches_result = match params.bot { Some(bot_name) => { - let bot = db::bots::find_bot_by_name(&bot_name, &conn) + let bot = db::bots::find_bot_by_name(&bot_name, &mut conn) .map_err(|_| StatusCode::BAD_REQUEST)?; matches::list_bot_matches( bot.id, @@ -76,10 +76,10 @@ pub async fn list_recent_matches( count, params.before, params.after, - &conn, + &mut conn, ) } - None => matches::list_public_matches(count, params.before, params.after, &conn), + None => matches::list_public_matches(count, params.before, params.after, &mut conn), }; let mut matches = matches_result.map_err(|_| StatusCode::BAD_REQUEST)?; @@ -119,9 +119,9 @@ pub fn match_data_to_api(data: matches::FullMatchData) -> ApiMatch { pub async fn get_match_data( Path(match_id): Path<i32>, - conn: DatabaseConnection, + mut conn: DatabaseConnection, ) -> Result<Json<ApiMatch>, StatusCode> { - let match_data = matches::find_match(match_id, &conn) + let match_data = matches::find_match(match_id, &mut conn) .map_err(|_| StatusCode::NOT_FOUND) .map(match_data_to_api)?; Ok(Json(match_data)) @@ -129,11 +129,11 @@ pub async fn get_match_data( pub async fn get_match_log( Path(match_id): Path<i32>, - conn: DatabaseConnection, + mut conn: DatabaseConnection, Extension(config): Extension<Arc<GlobalConfig>>, ) -> Result<Vec<u8>, StatusCode> { let match_base = - matches::find_match_base(match_id, &conn).map_err(|_| StatusCode::NOT_FOUND)?; + matches::find_match_base(match_id, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?; let log_path = PathBuf::from(&config.match_logs_directory).join(&match_base.log_path); let log_contents = std::fs::read(log_path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(log_contents) diff --git a/planetwars-server/src/routes/users.rs b/planetwars-server/src/routes/users.rs index 264e5b9..d072d0a 100644 --- a/planetwars-server/src/routes/users.rs +++ b/planetwars-server/src/routes/users.rs @@ -23,13 +23,13 @@ where type Rejection = (StatusCode, String); async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { - let conn = DatabaseConnection::from_request(req).await?; + let mut conn = DatabaseConnection::from_request(req).await?; let TypedHeader(Authorization(bearer)) = AuthorizationHeader::from_request(req) .await .map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?; - let (_session, user) = sessions::find_user_by_session(bearer.token(), &conn) + let (_session, user) = sessions::find_user_by_session(bearer.token(), &mut conn) .map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?; Ok(user) @@ -66,7 +66,7 @@ pub enum RegistrationError { } impl RegistrationParams { - fn validate(&self, conn: &DatabaseConnection) -> Result<(), RegistrationError> { + fn validate(&self, conn: &mut DatabaseConnection) -> Result<(), RegistrationError> { let mut errors = Vec::new(); // TODO: do we want to support cased usernames? @@ -95,7 +95,7 @@ impl RegistrationParams { errors.push("that username is not allowed".to_string()); } - if users::find_user_by_name(&self.username, &conn).is_ok() { + if users::find_user_by_name(&self.username, conn).is_ok() { errors.push("username is already taken".to_string()); } @@ -137,16 +137,16 @@ impl IntoResponse for RegistrationError { } pub async fn register( - conn: DatabaseConnection, + mut conn: DatabaseConnection, params: Json<RegistrationParams>, ) -> Result<Json<UserData>, RegistrationError> { - params.validate(&conn)?; + params.validate(&mut conn)?; let credentials = Credentials { username: ¶ms.username, password: ¶ms.password, }; - let user = users::create_user(&credentials, &conn)?; + let user = users::create_user(&credentials, &mut conn)?; Ok(Json(user.into())) } @@ -156,18 +156,18 @@ pub struct LoginParams { pub password: String, } -pub async fn login(conn: DatabaseConnection, params: Json<LoginParams>) -> Response { +pub async fn login(mut conn: DatabaseConnection, params: Json<LoginParams>) -> Response { let credentials = Credentials { username: ¶ms.username, password: ¶ms.password, }; // TODO: handle failures - let authenticated = users::authenticate_user(&credentials, &conn); + let authenticated = users::authenticate_user(&credentials, &mut conn); match authenticated { None => StatusCode::FORBIDDEN.into_response(), Some(user) => { - let session = sessions::create_session(&user, &conn); + let session = sessions::create_session(&user, &mut conn); let user_data: UserData = user.into(); let headers = [("Token", &session.token)]; diff --git a/planetwars-server/src/schema.rs b/planetwars-server/src/schema.rs index adc6555..27ebebe 100644 --- a/planetwars-server/src/schema.rs +++ b/planetwars-server/src/schema.rs @@ -1,7 +1,15 @@ // This file is autogenerated by diesel #![allow(unused_imports)] -table! { +// @generated automatically by Diesel CLI. + +pub mod sql_types { + #[derive(diesel::sql_types::SqlType)] + #[diesel(postgres_type(name = "match_state"))] + pub struct MatchState; +} + +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -14,7 +22,7 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -26,7 +34,7 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -37,7 +45,7 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -48,13 +56,14 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; + use super::sql_types::MatchState; matches (id) { id -> Int4, - state -> Match_state, + state -> MatchState, log_path -> Text, created_at -> Timestamp, winner -> Nullable<Int4>, @@ -63,7 +72,7 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -73,7 +82,7 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -84,7 +93,7 @@ table! { } } -table! { +diesel::table! { use diesel::sql_types::*; use crate::db_types::*; @@ -96,14 +105,14 @@ table! { } } -joinable!(bots -> users (owner_id)); -joinable!(match_players -> bot_versions (bot_version_id)); -joinable!(match_players -> matches (match_id)); -joinable!(matches -> maps (map_id)); -joinable!(ratings -> bots (bot_id)); -joinable!(sessions -> users (user_id)); +diesel::joinable!(bots -> users (owner_id)); +diesel::joinable!(match_players -> bot_versions (bot_version_id)); +diesel::joinable!(match_players -> matches (match_id)); +diesel::joinable!(matches -> maps (map_id)); +diesel::joinable!(ratings -> bots (bot_id)); +diesel::joinable!(sessions -> users (user_id)); -allow_tables_to_appear_in_same_query!( +diesel::allow_tables_to_appear_in_same_query!( bot_versions, bots, maps, |