aboutsummaryrefslogtreecommitdiff
path: root/planetwars-server
diff options
context:
space:
mode:
Diffstat (limited to 'planetwars-server')
-rw-r--r--planetwars-server/Cargo.toml10
-rw-r--r--planetwars-server/src/cli.rs4
-rw-r--r--planetwars-server/src/db/bots.rs30
-rw-r--r--planetwars-server/src/db/maps.rs10
-rw-r--r--planetwars-server/src/db/matches.rs96
-rw-r--r--planetwars-server/src/db/ratings.rs6
-rw-r--r--planetwars-server/src/db/sessions.rs6
-rw-r--r--planetwars-server/src/db/users.rs12
-rw-r--r--planetwars-server/src/db_types.rs2
-rw-r--r--planetwars-server/src/lib.rs16
-rw-r--r--planetwars-server/src/modules/bots.rs2
-rw-r--r--planetwars-server/src/modules/client_api.rs6
-rw-r--r--planetwars-server/src/modules/matches.rs11
-rw-r--r--planetwars-server/src/modules/ranking.rs13
-rw-r--r--planetwars-server/src/modules/registry.rs40
-rw-r--r--planetwars-server/src/routes/bots.rs56
-rw-r--r--planetwars-server/src/routes/demo.rs9
-rw-r--r--planetwars-server/src/routes/maps.rs4
-rw-r--r--planetwars-server/src/routes/matches.rs16
-rw-r--r--planetwars-server/src/routes/users.rs20
-rw-r--r--planetwars-server/src/schema.rs41
-rw-r--r--planetwars-server/tests/integration.rs16
22 files changed, 226 insertions, 200 deletions
diff --git a/planetwars-server/Cargo.toml b/planetwars-server/Cargo.toml
index f5641d1..183bb90 100644
--- a/planetwars-server/Cargo.toml
+++ b/planetwars-server/Cargo.toml
@@ -15,15 +15,15 @@ path = "src/cli.rs"
[dependencies]
futures = "0.3"
-tokio = { version = "1.15", features = ["full"] }
+tokio = { version = "1.21", features = ["full"] }
tokio-stream = "0.1.9"
hyper = "0.14"
tower-http = { version = "0.3.4", features = ["full"] }
axum = { version = "0.5", features = ["json", "headers", "multipart"] }
-diesel = { version = "1.4.4", features = ["postgres", "chrono"] }
-diesel-derive-enum = { version = "1.1", features = ["postgres"] }
-bb8 = "0.7"
-bb8-diesel = "0.2"
+diesel = { version = "2.0", features = ["postgres", "chrono"] }
+diesel-derive-enum = { version = "2.0.0-rc.0", features = ["postgres"] }
+bb8 = "0.8"
+bb8-diesel = { git = "https://github.com/overdrivenpotato/bb8-diesel.git" }
dotenv = "0.15.0"
rust-argon2 = "0.8"
rand = "0.8.4"
diff --git a/planetwars-server/src/cli.rs b/planetwars-server/src/cli.rs
index f33506e..e1eeac3 100644
--- a/planetwars-server/src/cli.rs
+++ b/planetwars-server/src/cli.rs
@@ -38,12 +38,12 @@ impl SetPassword {
let global_config = get_config().unwrap();
let pool = create_db_pool(&global_config).await;
- let conn = pool.get().await.expect("could not get database connection");
+ let mut conn = pool.get().await.expect("could not get database connection");
let credentials = db::users::Credentials {
username: &self.username,
password: &self.new_password,
};
- db::users::set_user_password(credentials, &conn).expect("could not set password");
+ db::users::set_user_password(credentials, &mut conn).expect("could not set password");
}
}
diff --git a/planetwars-server/src/db/bots.rs b/planetwars-server/src/db/bots.rs
index a0a31b0..cf8bbb5 100644
--- a/planetwars-server/src/db/bots.rs
+++ b/planetwars-server/src/db/bots.rs
@@ -5,7 +5,7 @@ use crate::schema::{bot_versions, bots};
use chrono;
#[derive(Insertable)]
-#[table_name = "bots"]
+#[diesel(table_name = bots)]
pub struct NewBot<'a> {
pub owner_id: Option<i32>,
pub name: &'a str,
@@ -19,29 +19,29 @@ pub struct Bot {
pub active_version: Option<i32>,
}
-pub fn create_bot(new_bot: &NewBot, conn: &PgConnection) -> QueryResult<Bot> {
+pub fn create_bot(new_bot: &NewBot, conn: &mut PgConnection) -> QueryResult<Bot> {
diesel::insert_into(bots::table)
.values(new_bot)
.get_result(conn)
}
-pub fn find_bot(id: i32, conn: &PgConnection) -> QueryResult<Bot> {
+pub fn find_bot(id: i32, conn: &mut PgConnection) -> QueryResult<Bot> {
bots::table.find(id).first(conn)
}
-pub fn find_bots_by_owner(owner_id: i32, conn: &PgConnection) -> QueryResult<Vec<Bot>> {
+pub fn find_bots_by_owner(owner_id: i32, conn: &mut PgConnection) -> QueryResult<Vec<Bot>> {
bots::table
.filter(bots::owner_id.eq(owner_id))
.get_results(conn)
}
-pub fn find_bot_by_name(name: &str, conn: &PgConnection) -> QueryResult<Bot> {
+pub fn find_bot_by_name(name: &str, conn: &mut PgConnection) -> QueryResult<Bot> {
bots::table.filter(bots::name.eq(name)).first(conn)
}
pub fn find_bot_with_version_by_name(
bot_name: &str,
- conn: &PgConnection,
+ conn: &mut PgConnection,
) -> QueryResult<(Bot, BotVersion)> {
bots::table
.inner_join(bot_versions::table.on(bots::active_version.eq(bot_versions::id.nullable())))
@@ -49,26 +49,28 @@ pub fn find_bot_with_version_by_name(
.first(conn)
}
-pub fn all_active_bots_with_version(conn: &PgConnection) -> QueryResult<Vec<(Bot, BotVersion)>> {
+pub fn all_active_bots_with_version(
+ conn: &mut PgConnection,
+) -> QueryResult<Vec<(Bot, BotVersion)>> {
bots::table
.inner_join(bot_versions::table.on(bots::active_version.eq(bot_versions::id.nullable())))
.get_results(conn)
}
-pub fn find_all_bots(conn: &PgConnection) -> QueryResult<Vec<Bot>> {
+pub fn find_all_bots(conn: &mut PgConnection) -> QueryResult<Vec<Bot>> {
bots::table.get_results(conn)
}
/// Find all bots that have an associated active version.
/// These are the bots that can be run.
-pub fn find_active_bots(conn: &PgConnection) -> QueryResult<Vec<Bot>> {
+pub fn find_active_bots(conn: &mut PgConnection) -> QueryResult<Vec<Bot>> {
bots::table
.filter(bots::active_version.is_not_null())
.get_results(conn)
}
#[derive(Insertable)]
-#[table_name = "bot_versions"]
+#[diesel(table_name = bot_versions)]
pub struct NewBotVersion<'a> {
pub bot_id: Option<i32>,
pub code_bundle_path: Option<&'a str>,
@@ -86,7 +88,7 @@ pub struct BotVersion {
pub fn create_bot_version(
new_bot_version: &NewBotVersion,
- conn: &PgConnection,
+ conn: &mut PgConnection,
) -> QueryResult<BotVersion> {
diesel::insert_into(bot_versions::table)
.values(new_bot_version)
@@ -96,7 +98,7 @@ pub fn create_bot_version(
pub fn set_active_version(
bot_id: i32,
version_id: Option<i32>,
- conn: &PgConnection,
+ conn: &mut PgConnection,
) -> QueryResult<()> {
diesel::update(bots::table.filter(bots::id.eq(bot_id)))
.set(bots::active_version.eq(version_id))
@@ -104,13 +106,13 @@ pub fn set_active_version(
Ok(())
}
-pub fn find_bot_version(version_id: i32, conn: &PgConnection) -> QueryResult<BotVersion> {
+pub fn find_bot_version(version_id: i32, conn: &mut PgConnection) -> QueryResult<BotVersion> {
bot_versions::table
.filter(bot_versions::id.eq(version_id))
.first(conn)
}
-pub fn find_bot_versions(bot_id: i32, conn: &PgConnection) -> QueryResult<Vec<BotVersion>> {
+pub fn find_bot_versions(bot_id: i32, conn: &mut PgConnection) -> QueryResult<Vec<BotVersion>> {
bot_versions::table
.filter(bot_versions::bot_id.eq(bot_id))
.get_results(conn)
diff --git a/planetwars-server/src/db/maps.rs b/planetwars-server/src/db/maps.rs
index dffe4fd..8972461 100644
--- a/planetwars-server/src/db/maps.rs
+++ b/planetwars-server/src/db/maps.rs
@@ -3,7 +3,7 @@ use diesel::prelude::*;
use crate::schema::maps;
#[derive(Insertable)]
-#[table_name = "maps"]
+#[diesel(table_name = maps)]
pub struct NewMap<'a> {
pub name: &'a str,
pub file_path: &'a str,
@@ -16,20 +16,20 @@ pub struct Map {
pub file_path: String,
}
-pub fn create_map(new_map: NewMap, conn: &PgConnection) -> QueryResult<Map> {
+pub fn create_map(new_map: NewMap, conn: &mut PgConnection) -> QueryResult<Map> {
diesel::insert_into(maps::table)
.values(new_map)
.get_result(conn)
}
-pub fn find_map(id: i32, conn: &PgConnection) -> QueryResult<Map> {
+pub fn find_map(id: i32, conn: &mut PgConnection) -> QueryResult<Map> {
maps::table.find(id).get_result(conn)
}
-pub fn find_map_by_name(name: &str, conn: &PgConnection) -> QueryResult<Map> {
+pub fn find_map_by_name(name: &str, conn: &mut PgConnection) -> QueryResult<Map> {
maps::table.filter(maps::name.eq(name)).first(conn)
}
-pub fn list_maps(conn: &PgConnection) -> QueryResult<Vec<Map>> {
+pub fn list_maps(conn: &mut PgConnection) -> QueryResult<Vec<Map>> {
maps::table.get_results(conn)
}
diff --git a/planetwars-server/src/db/matches.rs b/planetwars-server/src/db/matches.rs
index 1dded43..bfec892 100644
--- a/planetwars-server/src/db/matches.rs
+++ b/planetwars-server/src/db/matches.rs
@@ -1,9 +1,6 @@
pub use crate::db_types::MatchState;
use chrono::NaiveDateTime;
use diesel::associations::BelongsTo;
-use diesel::pg::Pg;
-use diesel::query_builder::BoxedSelectStatement;
-use diesel::query_source::{AppearsInFromClause, Once};
use diesel::sql_types::*;
use diesel::{
BelongingToDsl, ExpressionMethods, JoinOnDsl, NullableExpressionMethods, QueryDsl, RunQueryDsl,
@@ -18,7 +15,7 @@ use super::bots::{Bot, BotVersion};
use super::maps::Map;
#[derive(Insertable)]
-#[table_name = "matches"]
+#[diesel(table_name = matches)]
pub struct NewMatch<'a> {
pub state: MatchState,
pub log_path: &'a str,
@@ -27,7 +24,7 @@ pub struct NewMatch<'a> {
}
#[derive(Insertable)]
-#[table_name = "match_players"]
+#[diesel(table_name = match_players)]
pub struct NewMatchPlayer {
/// id of the match this player is in
pub match_id: i32,
@@ -38,7 +35,7 @@ pub struct NewMatchPlayer {
}
#[derive(Queryable, Identifiable)]
-#[table_name = "matches"]
+#[diesel(table_name = matches)]
pub struct MatchBase {
pub id: i32,
pub state: MatchState,
@@ -50,8 +47,8 @@ pub struct MatchBase {
}
#[derive(Queryable, Identifiable, Associations, Clone)]
-#[primary_key(match_id, player_id)]
-#[belongs_to(MatchBase, foreign_key = "match_id")]
+#[diesel(primary_key(match_id, player_id))]
+#[diesel(belongs_to(MatchBase, foreign_key = match_id))]
pub struct MatchPlayer {
pub match_id: i32,
pub player_id: i32,
@@ -65,9 +62,9 @@ pub struct MatchPlayerData {
pub fn create_match(
new_match_base: &NewMatch,
new_match_players: &[MatchPlayerData],
- conn: &PgConnection,
+ conn: &mut PgConnection,
) -> QueryResult<MatchData> {
- conn.transaction(|| {
+ conn.transaction(|conn| {
let match_base = diesel::insert_into(matches::table)
.values(new_match_base)
.get_result::<MatchBase>(conn)?;
@@ -101,7 +98,7 @@ pub struct MatchData {
/// Add player information to MatchBase instances
fn fetch_full_match_data(
matches: Vec<MatchBase>,
- conn: &PgConnection,
+ conn: &mut PgConnection,
) -> QueryResult<Vec<FullMatchData>> {
let map_ids: HashSet<i32> = matches.iter().filter_map(|m| m.map_id).collect();
@@ -140,8 +137,8 @@ fn fetch_full_match_data(
}
// TODO: this method should disappear
-pub fn list_matches(amount: i64, conn: &PgConnection) -> QueryResult<Vec<FullMatchData>> {
- conn.transaction(|| {
+pub fn list_matches(amount: i64, conn: &mut PgConnection) -> QueryResult<Vec<FullMatchData>> {
+ conn.transaction(|conn| {
let matches = matches::table
.filter(matches::state.eq(MatchState::Finished))
.order_by(matches::created_at.desc())
@@ -164,17 +161,32 @@ pub fn list_public_matches(
amount: i64,
before: Option<NaiveDateTime>,
after: Option<NaiveDateTime>,
- conn: &PgConnection,
+ conn: &mut PgConnection,
) -> QueryResult<Vec<FullMatchData>> {
- conn.transaction(|| {
+ conn.transaction(|conn| {
// TODO: how can this common logic be abstracted?
- let query = matches::table
+ let mut query = matches::table
.filter(matches::state.eq(MatchState::Finished))
.filter(matches::is_public.eq(true))
.into_boxed();
- let matches =
- select_matches_page(query, amount, before, after).get_results::<MatchBase>(conn)?;
+ // TODO: how to remove this duplication?
+ query = match (before, after) {
+ (None, None) => query.order_by(matches::created_at.desc()),
+ (Some(before), None) => query
+ .filter(matches::created_at.lt(before))
+ .order_by(matches::created_at.desc()),
+ (None, Some(after)) => query
+ .filter(matches::created_at.gt(after))
+ .order_by(matches::created_at.asc()),
+ (Some(before), Some(after)) => query
+ .filter(matches::created_at.lt(before))
+ .filter(matches::created_at.gt(after))
+ .order_by(matches::created_at.desc()),
+ };
+ query = query.limit(amount);
+
+ let matches = query.get_results::<MatchBase>(conn)?;
fetch_full_match_data(matches, conn)
})
}
@@ -185,7 +197,7 @@ pub fn list_bot_matches(
amount: i64,
before: Option<NaiveDateTime>,
after: Option<NaiveDateTime>,
- conn: &PgConnection,
+ conn: &mut PgConnection,
) -> QueryResult<Vec<FullMatchData>> {
let mut query = matches::table
.filter(matches::state.eq(MatchState::Finished))
@@ -211,22 +223,8 @@ pub fn list_bot_matches(
};
}
- let matches =
- select_matches_page(query, amount, before, after).get_results::<MatchBase>(conn)?;
- fetch_full_match_data(matches, conn)
-}
-
-fn select_matches_page<QS>(
- query: BoxedSelectStatement<'static, matches::SqlType, QS, Pg>,
- amount: i64,
- before: Option<NaiveDateTime>,
- after: Option<NaiveDateTime>,
-) -> BoxedSelectStatement<'static, matches::SqlType, QS, Pg>
-where
- QS: AppearsInFromClause<matches::table, Count = Once>,
-{
- // TODO: this is not nice. Replace this with proper cursor logic.
- match (before, after) {
+ // TODO: how to remove this duplication?
+ query = match (before, after) {
(None, None) => query.order_by(matches::created_at.desc()),
(Some(before), None) => query
.filter(matches::created_at.lt(before))
@@ -238,8 +236,11 @@ where
.filter(matches::created_at.lt(before))
.filter(matches::created_at.gt(after))
.order_by(matches::created_at.desc()),
- }
- .limit(amount)
+ };
+ query = query.limit(amount);
+
+ let matches = query.get_results::<MatchBase>(conn)?;
+ fetch_full_match_data(matches, conn)
}
// TODO: maybe unify this with matchdata?
@@ -270,8 +271,8 @@ impl BelongsTo<MatchBase> for FullMatchPlayerData {
}
}
-pub fn find_match(id: i32, conn: &PgConnection) -> QueryResult<FullMatchData> {
- conn.transaction(|| {
+pub fn find_match(id: i32, conn: &mut PgConnection) -> QueryResult<FullMatchData> {
+ conn.transaction(|conn| {
let match_base = matches::table.find(id).get_result::<MatchBase>(conn)?;
let map = match match_base.map_id {
@@ -298,7 +299,7 @@ pub fn find_match(id: i32, conn: &PgConnection) -> QueryResult<FullMatchData> {
})
}
-pub fn find_match_base(id: i32, conn: &PgConnection) -> QueryResult<MatchBase> {
+pub fn find_match_base(id: i32, conn: &mut PgConnection) -> QueryResult<MatchBase> {
matches::table.find(id).get_result::<MatchBase>(conn)
}
@@ -306,7 +307,7 @@ pub enum MatchResult {
Finished { winner: Option<i32> },
}
-pub fn save_match_result(id: i32, result: MatchResult, conn: &PgConnection) -> QueryResult<()> {
+pub fn save_match_result(id: i32, result: MatchResult, conn: &mut PgConnection) -> QueryResult<()> {
let MatchResult::Finished { winner } = result;
diesel::update(matches::table.find(id))
@@ -320,17 +321,20 @@ pub fn save_match_result(id: i32, result: MatchResult, conn: &PgConnection) -> Q
#[derive(QueryableByName)]
pub struct BotStatsRecord {
- #[sql_type = "Text"]
+ #[diesel(sql_type = Text)]
pub opponent: String,
- #[sql_type = "Text"]
+ #[diesel(sql_type = Text)]
pub map: String,
- #[sql_type = "Nullable<Bool>"]
+ #[diesel(sql_type = Nullable<Bool>)]
pub win: Option<bool>,
- #[sql_type = "Int8"]
+ #[diesel(sql_type = Int8)]
pub count: i64,
}
-pub fn fetch_bot_stats(bot_name: &str, db_conn: &PgConnection) -> QueryResult<Vec<BotStatsRecord>> {
+pub fn fetch_bot_stats(
+ bot_name: &str,
+ db_conn: &mut PgConnection,
+) -> QueryResult<Vec<BotStatsRecord>> {
diesel::sql_query(
"
SELECT opponent, map, win, COUNT(*) as count
diff --git a/planetwars-server/src/db/ratings.rs b/planetwars-server/src/db/ratings.rs
index 8262fed..0a510d4 100644
--- a/planetwars-server/src/db/ratings.rs
+++ b/planetwars-server/src/db/ratings.rs
@@ -10,7 +10,7 @@ pub struct Rating {
pub rating: f64,
}
-pub fn get_rating(bot_id: i32, db_conn: &PgConnection) -> QueryResult<Option<f64>> {
+pub fn get_rating(bot_id: i32, db_conn: &mut PgConnection) -> QueryResult<Option<f64>> {
ratings::table
.filter(ratings::bot_id.eq(bot_id))
.select(ratings::rating)
@@ -18,7 +18,7 @@ pub fn get_rating(bot_id: i32, db_conn: &PgConnection) -> QueryResult<Option<f64
.optional()
}
-pub fn set_rating(bot_id: i32, rating: f64, db_conn: &PgConnection) -> QueryResult<usize> {
+pub fn set_rating(bot_id: i32, rating: f64, db_conn: &mut PgConnection) -> QueryResult<usize> {
diesel::insert_into(ratings::table)
.values(Rating { bot_id, rating })
.on_conflict(ratings::bot_id)
@@ -40,7 +40,7 @@ pub struct RankedBot {
pub rating: f64,
}
-pub fn get_bot_ranking(db_conn: &PgConnection) -> QueryResult<Vec<RankedBot>> {
+pub fn get_bot_ranking(db_conn: &mut PgConnection) -> QueryResult<Vec<RankedBot>> {
bots::table
.left_join(users::table)
.inner_join(ratings::table)
diff --git a/planetwars-server/src/db/sessions.rs b/planetwars-server/src/db/sessions.rs
index a91d954..f8108cc 100644
--- a/planetwars-server/src/db/sessions.rs
+++ b/planetwars-server/src/db/sessions.rs
@@ -6,7 +6,7 @@ use diesel::{insert_into, prelude::*, Insertable, RunQueryDsl};
use rand::{self, Rng};
#[derive(Insertable)]
-#[table_name = "sessions"]
+#[diesel(table_name = sessions)]
struct NewSession {
token: String,
user_id: i32,
@@ -19,7 +19,7 @@ pub struct Session {
pub token: String,
}
-pub fn create_session(user: &User, conn: &PgConnection) -> Session {
+pub fn create_session(user: &User, conn: &mut PgConnection) -> Session {
let new_session = NewSession {
token: gen_session_token(),
user_id: user.id,
@@ -31,7 +31,7 @@ pub fn create_session(user: &User, conn: &PgConnection) -> Session {
.unwrap()
}
-pub fn find_user_by_session(token: &str, conn: &PgConnection) -> QueryResult<(Session, User)> {
+pub fn find_user_by_session(token: &str, conn: &mut PgConnection) -> QueryResult<(Session, User)> {
sessions::table
.inner_join(users::table)
.filter(sessions::token.eq(&token))
diff --git a/planetwars-server/src/db/users.rs b/planetwars-server/src/db/users.rs
index 9676dae..60cc20a 100644
--- a/planetwars-server/src/db/users.rs
+++ b/planetwars-server/src/db/users.rs
@@ -11,7 +11,7 @@ pub struct Credentials<'a> {
}
#[derive(Insertable)]
-#[table_name = "users"]
+#[diesel(table_name = users)]
pub struct NewUser<'a> {
pub username: &'a str,
pub password_hash: &'a [u8],
@@ -50,7 +50,7 @@ pub fn hash_password(password: &str) -> (Vec<u8>, [u8; 32]) {
(hash, salt)
}
-pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResult<User> {
+pub fn create_user(credentials: &Credentials, conn: &mut PgConnection) -> QueryResult<User> {
let (hash, salt) = hash_password(&credentials.password);
let new_user = NewUser {
@@ -63,19 +63,19 @@ pub fn create_user(credentials: &Credentials, conn: &PgConnection) -> QueryResul
.get_result::<User>(conn)
}
-pub fn find_user(user_id: i32, db_conn: &PgConnection) -> QueryResult<User> {
+pub fn find_user(user_id: i32, db_conn: &mut PgConnection) -> QueryResult<User> {
users::table
.filter(users::id.eq(user_id))
.first::<User>(db_conn)
}
-pub fn find_user_by_name(username: &str, db_conn: &PgConnection) -> QueryResult<User> {
+pub fn find_user_by_name(username: &str, db_conn: &mut PgConnection) -> QueryResult<User> {
users::table
.filter(users::username.eq(username))
.first::<User>(db_conn)
}
-pub fn set_user_password(credentials: Credentials, db_conn: &PgConnection) -> QueryResult<()> {
+pub fn set_user_password(credentials: Credentials, db_conn: &mut PgConnection) -> QueryResult<()> {
let (hash, salt) = hash_password(&credentials.password);
let n_changes = diesel::update(users::table.filter(users::username.eq(&credentials.username)))
@@ -91,7 +91,7 @@ pub fn set_user_password(credentials: Credentials, db_conn: &PgConnection) -> Qu
}
}
-pub fn authenticate_user(credentials: &Credentials, db_conn: &PgConnection) -> Option<User> {
+pub fn authenticate_user(credentials: &Credentials, db_conn: &mut PgConnection) -> Option<User> {
find_user_by_name(credentials.username, db_conn)
.optional()
.unwrap()
diff --git a/planetwars-server/src/db_types.rs b/planetwars-server/src/db_types.rs
index 1b99e49..29b1e9b 100644
--- a/planetwars-server/src/db_types.rs
+++ b/planetwars-server/src/db_types.rs
@@ -2,7 +2,7 @@ use diesel_derive_enum::DbEnum;
use serde::{Deserialize, Serialize};
#[derive(DbEnum, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
-#[DieselType = "Match_state"]
+#[DieselTypePath = "crate::schema::sql_types::MatchState"]
pub enum MatchState {
Playing,
diff --git a/planetwars-server/src/lib.rs b/planetwars-server/src/lib.rs
index 1696f1a..316458c 100644
--- a/planetwars-server/src/lib.rs
+++ b/planetwars-server/src/lib.rs
@@ -8,7 +8,7 @@ pub mod routes;
pub mod schema;
pub mod util;
-use std::ops::Deref;
+use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
use std::sync::Arc;
use std::{fs, net::SocketAddr};
@@ -70,9 +70,9 @@ pub struct GlobalConfig {
const SIMPLEBOT_PATH: &str = "../simplebot/simplebot.py";
pub async fn seed_simplebot(config: &GlobalConfig, pool: &ConnectionPool) {
- let conn = pool.get().await.expect("could not get database connection");
+ let mut conn = pool.get().await.expect("could not get database connection");
// This transaction is expected to fail when simplebot already exists.
- let _res = conn.transaction::<(), diesel::result::Error, _>(|| {
+ let _res = conn.transaction::<(), diesel::result::Error, _>(|conn| {
use db::bots::NewBot;
let new_bot = NewBot {
@@ -80,12 +80,12 @@ pub async fn seed_simplebot(config: &GlobalConfig, pool: &ConnectionPool) {
owner_id: None,
};
- let simplebot = db::bots::create_bot(&new_bot, &conn)?;
+ let simplebot = db::bots::create_bot(&new_bot, conn)?;
let simplebot_code =
std::fs::read_to_string(SIMPLEBOT_PATH).expect("could not read simplebot code");
- modules::bots::save_code_string(&simplebot_code, Some(simplebot.id), &conn, config)?;
+ modules::bots::save_code_string(&simplebot_code, Some(simplebot.id), conn, config)?;
println!("initialized simplebot");
@@ -209,6 +209,12 @@ impl Deref for DatabaseConnection {
}
}
+impl DerefMut for DatabaseConnection {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
#[async_trait]
impl<B> FromRequest<B> for DatabaseConnection
where
diff --git a/planetwars-server/src/modules/bots.rs b/planetwars-server/src/modules/bots.rs
index 6a2883c..6893581 100644
--- a/planetwars-server/src/modules/bots.rs
+++ b/planetwars-server/src/modules/bots.rs
@@ -9,7 +9,7 @@ use crate::{db, util::gen_alphanumeric, GlobalConfig};
pub fn save_code_string(
bot_code: &str,
bot_id: Option<i32>,
- conn: &PgConnection,
+ conn: &mut PgConnection,
config: &GlobalConfig,
) -> QueryResult<db::bots::BotVersion> {
let bundle_name = gen_alphanumeric(16);
diff --git a/planetwars-server/src/modules/client_api.rs b/planetwars-server/src/modules/client_api.rs
index 6e5d05a..9c0bbe7 100644
--- a/planetwars-server/src/modules/client_api.rs
+++ b/planetwars-server/src/modules/client_api.rs
@@ -149,19 +149,19 @@ impl pb::client_api_service_server::ClientApiService for ClientApiServer {
req: Request<pb::CreateMatchRequest>,
) -> Result<Response<pb::CreateMatchResponse>, Status> {
// TODO: unify with matchrunner module
- let conn = self.conn_pool.get().await.unwrap();
+ let mut conn = self.conn_pool.get().await.unwrap();
let match_request = req.get_ref();
let (opponent_bot, opponent_bot_version) =
- db::bots::find_bot_with_version_by_name(&match_request.opponent_name, &conn)
+ db::bots::find_bot_with_version_by_name(&match_request.opponent_name, &mut conn)
.map_err(|_| Status::not_found("opponent not found"))?;
let map_name = match match_request.map_name.as_str() {
"" => "hex",
name => name,
};
- let map = db::maps::find_map_by_name(map_name, &conn)
+ let map = db::maps::find_map_by_name(map_name, &mut conn)
.map_err(|_| Status::not_found("map not found"))?;
let player_key = gen_alphanumeric(32);
diff --git a/planetwars-server/src/modules/matches.rs b/planetwars-server/src/modules/matches.rs
index ecc7976..71e8a98 100644
--- a/planetwars-server/src/modules/matches.rs
+++ b/planetwars-server/src/modules/matches.rs
@@ -80,8 +80,8 @@ impl RunMatch {
let match_data = {
// TODO: it would be nice to get an already-open connection here when possible.
// Maybe we need an additional abstraction, bundling a connection and connection pool?
- let db_conn = conn_pool.get().await.expect("could not get a connection");
- self.store_in_database(&db_conn)?
+ let mut db_conn = conn_pool.get().await.expect("could not get a connection");
+ self.store_in_database(&mut db_conn)?
};
let runner_config = self.into_runner_config();
@@ -90,7 +90,7 @@ impl RunMatch {
Ok((match_data, handle))
}
- fn store_in_database(&self, db_conn: &PgConnection) -> QueryResult<MatchData> {
+ fn store_in_database(&self, db_conn: &mut PgConnection) -> QueryResult<MatchData> {
let new_match_data = db::matches::NewMatch {
state: db::matches::MatchState::Playing,
log_path: &self.log_file_name,
@@ -167,7 +167,7 @@ async fn run_match_task(
let outcome = runner::run_match(match_config).await;
// update match state in database
- let conn = connection_pool
+ let mut conn = connection_pool
.get()
.await
.expect("could not get database connection");
@@ -176,7 +176,8 @@ async fn run_match_task(
winner: outcome.winner.map(|w| (w - 1) as i32), // player numbers in matchrunner start at 1
};
- db::matches::save_match_result(match_id, result, &conn).expect("could not save match result");
+ db::matches::save_match_result(match_id, result, &mut conn)
+ .expect("could not save match result");
outcome
}
diff --git a/planetwars-server/src/modules/ranking.rs b/planetwars-server/src/modules/ranking.rs
index 90c4a56..92f0f8a 100644
--- a/planetwars-server/src/modules/ranking.rs
+++ b/planetwars-server/src/modules/ranking.rs
@@ -20,13 +20,14 @@ pub async fn run_ranker(config: Arc<GlobalConfig>, db_pool: DbPool) {
// TODO: make this configurable
// play at most one match every n seconds
let mut interval = tokio::time::interval(Duration::from_secs(RANKER_INTERVAL));
- let db_conn = db_pool
+ let mut db_conn = db_pool
.get()
.await
.expect("could not get database connection");
loop {
interval.tick().await;
- let bots = db::bots::all_active_bots_with_version(&db_conn).expect("could not load bots");
+ let bots =
+ db::bots::all_active_bots_with_version(&mut db_conn).expect("could not load bots");
if bots.len() < 2 {
// not enough bots to play a match
continue;
@@ -37,14 +38,14 @@ pub async fn run_ranker(config: Arc<GlobalConfig>, db_pool: DbPool) {
.cloned()
.collect();
- let maps = db::maps::list_maps(&db_conn).expect("could not load map");
+ let maps = db::maps::list_maps(&mut db_conn).expect("could not load map");
let map = match maps.choose(&mut rand::thread_rng()).cloned() {
None => continue, // no maps available
Some(map) => map,
};
play_ranking_match(config.clone(), map, selected_bots, db_pool.clone()).await;
- recalculate_ratings(&db_conn).expect("could not recalculate ratings");
+ recalculate_ratings(&mut db_conn).expect("could not recalculate ratings");
}
}
@@ -71,7 +72,7 @@ async fn play_ranking_match(
let _outcome = handle.await;
}
-fn recalculate_ratings(db_conn: &PgConnection) -> QueryResult<()> {
+fn recalculate_ratings(db_conn: &mut PgConnection) -> QueryResult<()> {
let start = Instant::now();
let match_stats = fetch_match_stats(db_conn)?;
let ratings = estimate_ratings_from_stats(match_stats);
@@ -91,7 +92,7 @@ struct MatchStats {
num_matches: usize,
}
-fn fetch_match_stats(db_conn: &PgConnection) -> QueryResult<HashMap<(i32, i32), MatchStats>> {
+fn fetch_match_stats(db_conn: &mut PgConnection) -> QueryResult<HashMap<(i32, i32), MatchStats>> {
let matches = db::matches::list_matches(RANKER_NUM_MATCHES, db_conn)?;
let mut match_stats = HashMap::<(i32, i32), MatchStats>::new();
diff --git a/planetwars-server/src/modules/registry.rs b/planetwars-server/src/modules/registry.rs
index 4a79d59..5e1e05b 100644
--- a/planetwars-server/src/modules/registry.rs
+++ b/planetwars-server/src/modules/registry.rs
@@ -112,8 +112,8 @@ where
Err(RegistryAuthError::InvalidCredentials)
}
} else {
- let db_conn = DatabaseConnection::from_request(req).await.unwrap();
- let user = authenticate_user(&credentials, &db_conn)
+ let mut db_conn = DatabaseConnection::from_request(req).await.unwrap();
+ let user = authenticate_user(&credentials, &mut db_conn)
.ok_or(RegistryAuthError::InvalidCredentials)?;
Ok(RegistryAuth::User(user))
@@ -159,12 +159,12 @@ pub struct RegistryError {
}
async fn check_blob_exists(
- db_conn: DatabaseConnection,
+ mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path((repository_name, raw_digest)): Path<(String, String)>,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
- check_access(&repository_name, &auth, &db_conn)?;
+ check_access(&repository_name, &auth, &mut db_conn)?;
let digest = raw_digest.strip_prefix("sha256:").unwrap();
let blob_path = PathBuf::from(&config.registry_directory)
@@ -179,12 +179,12 @@ async fn check_blob_exists(
}
async fn get_blob(
- db_conn: DatabaseConnection,
+ mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path((repository_name, raw_digest)): Path<(String, String)>,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
- check_access(&repository_name, &auth, &db_conn)?;
+ check_access(&repository_name, &auth, &mut db_conn)?;
let digest = raw_digest.strip_prefix("sha256:").unwrap();
let blob_path = PathBuf::from(&config.registry_directory)
@@ -200,12 +200,12 @@ async fn get_blob(
}
async fn create_upload(
- db_conn: DatabaseConnection,
+ mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path(repository_name): Path<String>,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
- check_access(&repository_name, &auth, &db_conn)?;
+ check_access(&repository_name, &auth, &mut db_conn)?;
let uuid = gen_alphanumeric(16);
tokio::fs::File::create(
@@ -229,13 +229,13 @@ async fn create_upload(
}
async fn patch_upload(
- db_conn: DatabaseConnection,
+ mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path((repository_name, uuid)): Path<(String, String)>,
mut stream: BodyStream,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
- check_access(&repository_name, &auth, &db_conn)?;
+ check_access(&repository_name, &auth, &mut db_conn)?;
// TODO: support content range header in request
let upload_path = PathBuf::from(&config.registry_directory)
@@ -275,14 +275,14 @@ struct UploadParams {
}
async fn put_upload(
- db_conn: DatabaseConnection,
+ mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path((repository_name, uuid)): Path<(String, String)>,
Query(params): Query<UploadParams>,
mut stream: BodyStream,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
- check_access(&repository_name, &auth, &db_conn)?;
+ check_access(&repository_name, &auth, &mut db_conn)?;
let upload_path = PathBuf::from(&config.registry_directory)
.join("uploads")
@@ -332,12 +332,12 @@ async fn put_upload(
}
async fn get_manifest(
- db_conn: DatabaseConnection,
+ mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path((repository_name, reference)): Path<(String, String)>,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
- check_access(&repository_name, &auth, &db_conn)?;
+ check_access(&repository_name, &auth, &mut db_conn)?;
let manifest_path = PathBuf::from(&config.registry_directory)
.join("manifests")
@@ -357,13 +357,13 @@ async fn get_manifest(
}
async fn put_manifest(
- db_conn: DatabaseConnection,
+ mut db_conn: DatabaseConnection,
auth: RegistryAuth,
Path((repository_name, reference)): Path<(String, String)>,
mut stream: BodyStream,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<impl IntoResponse, StatusCode> {
- let bot = check_access(&repository_name, &auth, &db_conn)?;
+ let bot = check_access(&repository_name, &auth, &mut db_conn)?;
let repository_dir = PathBuf::from(&config.registry_directory)
.join("manifests")
@@ -399,9 +399,9 @@ async fn put_manifest(
code_bundle_path: None,
container_digest: Some(&content_digest),
};
- let version =
- db::bots::create_bot_version(&new_version, &db_conn).expect("could not save bot version");
- db::bots::set_active_version(bot.id, Some(version.id), &db_conn)
+ let version = db::bots::create_bot_version(&new_version, &mut db_conn)
+ .expect("could not save bot version");
+ db::bots::set_active_version(bot.id, Some(version.id), &mut db_conn)
.expect("could not update bot version");
Ok(Response::builder()
@@ -421,7 +421,7 @@ async fn put_manifest(
fn check_access(
repository_name: &str,
auth: &RegistryAuth,
- db_conn: &DatabaseConnection,
+ db_conn: &mut DatabaseConnection,
) -> Result<db::bots::Bot, StatusCode> {
use diesel::OptionalExtension;
diff --git a/planetwars-server/src/routes/bots.rs b/planetwars-server/src/routes/bots.rs
index f8087fd..f0ff9bf 100644
--- a/planetwars-server/src/routes/bots.rs
+++ b/planetwars-server/src/routes/bots.rs
@@ -100,10 +100,10 @@ pub fn validate_bot_name(bot_name: &str) -> Result<(), SaveBotError> {
pub async fn save_bot(
Json(params): Json<SaveBotParams>,
user: User,
- conn: DatabaseConnection,
+ mut conn: DatabaseConnection,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<Json<Bot>, SaveBotError> {
- let res = bots::find_bot_by_name(&params.bot_name, &conn)
+ let res = bots::find_bot_by_name(&params.bot_name, &mut conn)
.optional()
.expect("could not run query");
@@ -122,10 +122,10 @@ pub async fn save_bot(
name: &params.bot_name,
};
- bots::create_bot(&new_bot, &conn).expect("could not create bot")
+ bots::create_bot(&new_bot, &mut conn).expect("could not create bot")
}
};
- let _code_bundle = save_code_string(&params.code, Some(bot.id), &conn, &config)
+ let _code_bundle = save_code_string(&params.code, Some(bot.id), &mut conn, &config)
.expect("failed to save code bundle");
Ok(Json(bot))
}
@@ -137,12 +137,12 @@ pub struct BotParams {
// TODO: can we unify this with save_bot?
pub async fn create_bot(
- conn: DatabaseConnection,
+ mut conn: DatabaseConnection,
user: User,
params: Json<BotParams>,
) -> Result<(StatusCode, Json<Bot>), SaveBotError> {
validate_bot_name(&params.name)?;
- let existing_bot = bots::find_bot_by_name(&params.name, &conn)
+ let existing_bot = bots::find_bot_by_name(&params.name, &mut conn)
.optional()
.expect("could not run query");
if existing_bot.is_some() {
@@ -152,26 +152,27 @@ pub async fn create_bot(
owner_id: Some(user.id),
name: &params.name,
};
- let bot = bots::create_bot(&bot_params, &conn).unwrap();
+ let bot = bots::create_bot(&bot_params, &mut conn).unwrap();
Ok((StatusCode::CREATED, Json(bot)))
}
// TODO: handle errors
pub async fn get_bot(
- conn: DatabaseConnection,
+ mut conn: DatabaseConnection,
Path(bot_name): Path<String>,
) -> Result<Json<JsonValue>, StatusCode> {
- let bot = db::bots::find_bot_by_name(&bot_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?;
+ let bot =
+ db::bots::find_bot_by_name(&bot_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
let owner: Option<UserData> = match bot.owner_id {
Some(user_id) => {
- let user = db::users::find_user(user_id, &conn)
+ let user = db::users::find_user(user_id, &mut conn)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Some(user.into())
}
None => None,
};
- let versions =
- bots::find_bot_versions(bot.id, &conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
+ let versions = bots::find_bot_versions(bot.id, &mut conn)
+ .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(json!({
"bot": bot,
"owner": owner,
@@ -180,32 +181,32 @@ pub async fn get_bot(
}
pub async fn get_user_bots(
- conn: DatabaseConnection,
+ mut conn: DatabaseConnection,
Path(user_name): Path<String>,
) -> Result<Json<Vec<Bot>>, StatusCode> {
let user =
- db::users::find_user_by_name(&user_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?;
- db::bots::find_bots_by_owner(user.id, &conn)
+ db::users::find_user_by_name(&user_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
+ db::bots::find_bots_by_owner(user.id, &mut conn)
.map(Json)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
/// List all active bots
-pub async fn list_bots(conn: DatabaseConnection) -> Result<Json<Vec<Bot>>, StatusCode> {
- bots::find_active_bots(&conn)
+pub async fn list_bots(mut conn: DatabaseConnection) -> Result<Json<Vec<Bot>>, StatusCode> {
+ bots::find_active_bots(&mut conn)
.map(Json)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
-pub async fn get_ranking(conn: DatabaseConnection) -> Result<Json<Vec<RankedBot>>, StatusCode> {
- ratings::get_bot_ranking(&conn)
+pub async fn get_ranking(mut conn: DatabaseConnection) -> Result<Json<Vec<RankedBot>>, StatusCode> {
+ ratings::get_bot_ranking(&mut conn)
.map(Json)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
// TODO: currently this only implements the happy flow
pub async fn upload_code_multipart(
- conn: DatabaseConnection,
+ mut conn: DatabaseConnection,
user: User,
Path(bot_name): Path<String>,
mut multipart: Multipart,
@@ -213,7 +214,7 @@ pub async fn upload_code_multipart(
) -> Result<Json<BotVersion>, StatusCode> {
let bots_dir = PathBuf::from(&config.bots_directory);
- let bot = bots::find_bot_by_name(&bot_name, &conn).map_err(|_| StatusCode::NOT_FOUND)?;
+ let bot = bots::find_bot_by_name(&bot_name, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
if Some(user.id) != bot.owner_id {
return Err(StatusCode::FORBIDDEN);
@@ -246,21 +247,22 @@ pub async fn upload_code_multipart(
container_digest: None,
};
let code_bundle =
- bots::create_bot_version(&bot_version, &conn).expect("Failed to create code bundle");
+ bots::create_bot_version(&bot_version, &mut conn).expect("Failed to create code bundle");
Ok(Json(code_bundle))
}
pub async fn get_code(
- conn: DatabaseConnection,
+ mut conn: DatabaseConnection,
user: User,
Path(bundle_id): Path<i32>,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<Vec<u8>, StatusCode> {
let version =
- db::bots::find_bot_version(bundle_id, &conn).map_err(|_| StatusCode::NOT_FOUND)?;
+ db::bots::find_bot_version(bundle_id, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
let bot_id = version.bot_id.ok_or(StatusCode::FORBIDDEN)?;
- let bot = db::bots::find_bot(bot_id, &conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
+ let bot =
+ db::bots::find_bot(bot_id, &mut conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if bot.owner_id != Some(user.id) {
return Err(StatusCode::FORBIDDEN);
@@ -297,10 +299,10 @@ impl MatchupStats {
type BotStats = HashMap<String, HashMap<String, MatchupStats>>;
pub async fn get_bot_stats(
- conn: DatabaseConnection,
+ mut conn: DatabaseConnection,
Path(bot_name): Path<String>,
) -> Result<Json<BotStats>, StatusCode> {
- let stats_records = db::matches::fetch_bot_stats(&bot_name, &conn)
+ let stats_records = db::matches::fetch_bot_stats(&bot_name, &mut conn)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let mut bot_stats: BotStats = HashMap::new();
for record in stats_records {
diff --git a/planetwars-server/src/routes/demo.rs b/planetwars-server/src/routes/demo.rs
index 1ec8825..cd490ef 100644
--- a/planetwars-server/src/routes/demo.rs
+++ b/planetwars-server/src/routes/demo.rs
@@ -35,7 +35,7 @@ pub async fn submit_bot(
Extension(pool): Extension<ConnectionPool>,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<Json<SubmitBotResponse>, StatusCode> {
- let conn = pool.get().await.expect("could not get database connection");
+ let mut conn = pool.get().await.expect("could not get database connection");
let opponent_name = params
.opponent_name
@@ -46,12 +46,13 @@ pub async fn submit_bot(
.unwrap_or_else(|| DEFAULT_MAP_NAME.to_string());
let (opponent_bot, opponent_bot_version) =
- db::bots::find_bot_with_version_by_name(&opponent_name, &conn)
+ db::bots::find_bot_with_version_by_name(&opponent_name, &mut conn)
.map_err(|_| StatusCode::BAD_REQUEST)?;
- let map = db::maps::find_map_by_name(&map_name, &conn).map_err(|_| StatusCode::BAD_REQUEST)?;
+ let map =
+ db::maps::find_map_by_name(&map_name, &mut conn).map_err(|_| StatusCode::BAD_REQUEST)?;
- let player_bot_version = save_code_string(&params.code, None, &conn, &config)
+ let player_bot_version = save_code_string(&params.code, None, &mut conn, &config)
// TODO: can we recover from this?
.expect("could not save bot code");
diff --git a/planetwars-server/src/routes/maps.rs b/planetwars-server/src/routes/maps.rs
index 689b11e..188089f 100644
--- a/planetwars-server/src/routes/maps.rs
+++ b/planetwars-server/src/routes/maps.rs
@@ -8,8 +8,8 @@ pub struct ApiMap {
pub name: String,
}
-pub async fn list_maps(conn: DatabaseConnection) -> Result<Json<Vec<ApiMap>>, StatusCode> {
- let maps = db::maps::list_maps(&conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
+pub async fn list_maps(mut conn: DatabaseConnection) -> Result<Json<Vec<ApiMap>>, StatusCode> {
+ let maps = db::maps::list_maps(&mut conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let api_maps = maps
.into_iter()
diff --git a/planetwars-server/src/routes/matches.rs b/planetwars-server/src/routes/matches.rs
index 1d7403c..3ad10cf 100644
--- a/planetwars-server/src/routes/matches.rs
+++ b/planetwars-server/src/routes/matches.rs
@@ -56,7 +56,7 @@ pub struct ListMatchesResponse {
pub async fn list_recent_matches(
Query(params): Query<ListRecentMatchesParams>,
- conn: DatabaseConnection,
+ mut conn: DatabaseConnection,
) -> Result<Json<ListMatchesResponse>, StatusCode> {
let requested_count = std::cmp::min(
params.count.unwrap_or(DEFAULT_NUM_RETURNED_MATCHES),
@@ -68,7 +68,7 @@ pub async fn list_recent_matches(
let matches_result = match params.bot {
Some(bot_name) => {
- let bot = db::bots::find_bot_by_name(&bot_name, &conn)
+ let bot = db::bots::find_bot_by_name(&bot_name, &mut conn)
.map_err(|_| StatusCode::BAD_REQUEST)?;
matches::list_bot_matches(
bot.id,
@@ -76,10 +76,10 @@ pub async fn list_recent_matches(
count,
params.before,
params.after,
- &conn,
+ &mut conn,
)
}
- None => matches::list_public_matches(count, params.before, params.after, &conn),
+ None => matches::list_public_matches(count, params.before, params.after, &mut conn),
};
let mut matches = matches_result.map_err(|_| StatusCode::BAD_REQUEST)?;
@@ -119,9 +119,9 @@ pub fn match_data_to_api(data: matches::FullMatchData) -> ApiMatch {
pub async fn get_match_data(
Path(match_id): Path<i32>,
- conn: DatabaseConnection,
+ mut conn: DatabaseConnection,
) -> Result<Json<ApiMatch>, StatusCode> {
- let match_data = matches::find_match(match_id, &conn)
+ let match_data = matches::find_match(match_id, &mut conn)
.map_err(|_| StatusCode::NOT_FOUND)
.map(match_data_to_api)?;
Ok(Json(match_data))
@@ -129,11 +129,11 @@ pub async fn get_match_data(
pub async fn get_match_log(
Path(match_id): Path<i32>,
- conn: DatabaseConnection,
+ mut conn: DatabaseConnection,
Extension(config): Extension<Arc<GlobalConfig>>,
) -> Result<Vec<u8>, StatusCode> {
let match_base =
- matches::find_match_base(match_id, &conn).map_err(|_| StatusCode::NOT_FOUND)?;
+ matches::find_match_base(match_id, &mut conn).map_err(|_| StatusCode::NOT_FOUND)?;
let log_path = PathBuf::from(&config.match_logs_directory).join(&match_base.log_path);
let log_contents = std::fs::read(log_path).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(log_contents)
diff --git a/planetwars-server/src/routes/users.rs b/planetwars-server/src/routes/users.rs
index 264e5b9..d072d0a 100644
--- a/planetwars-server/src/routes/users.rs
+++ b/planetwars-server/src/routes/users.rs
@@ -23,13 +23,13 @@ where
type Rejection = (StatusCode, String);
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
- let conn = DatabaseConnection::from_request(req).await?;
+ let mut conn = DatabaseConnection::from_request(req).await?;
let TypedHeader(Authorization(bearer)) = AuthorizationHeader::from_request(req)
.await
.map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?;
- let (_session, user) = sessions::find_user_by_session(bearer.token(), &conn)
+ let (_session, user) = sessions::find_user_by_session(bearer.token(), &mut conn)
.map_err(|_| (StatusCode::UNAUTHORIZED, "".to_string()))?;
Ok(user)
@@ -66,7 +66,7 @@ pub enum RegistrationError {
}
impl RegistrationParams {
- fn validate(&self, conn: &DatabaseConnection) -> Result<(), RegistrationError> {
+ fn validate(&self, conn: &mut DatabaseConnection) -> Result<(), RegistrationError> {
let mut errors = Vec::new();
// TODO: do we want to support cased usernames?
@@ -95,7 +95,7 @@ impl RegistrationParams {
errors.push("that username is not allowed".to_string());
}
- if users::find_user_by_name(&self.username, &conn).is_ok() {
+ if users::find_user_by_name(&self.username, conn).is_ok() {
errors.push("username is already taken".to_string());
}
@@ -137,16 +137,16 @@ impl IntoResponse for RegistrationError {
}
pub async fn register(
- conn: DatabaseConnection,
+ mut conn: DatabaseConnection,
params: Json<RegistrationParams>,
) -> Result<Json<UserData>, RegistrationError> {
- params.validate(&conn)?;
+ params.validate(&mut conn)?;
let credentials = Credentials {
username: &params.username,
password: &params.password,
};
- let user = users::create_user(&credentials, &conn)?;
+ let user = users::create_user(&credentials, &mut conn)?;
Ok(Json(user.into()))
}
@@ -156,18 +156,18 @@ pub struct LoginParams {
pub password: String,
}
-pub async fn login(conn: DatabaseConnection, params: Json<LoginParams>) -> Response {
+pub async fn login(mut conn: DatabaseConnection, params: Json<LoginParams>) -> Response {
let credentials = Credentials {
username: &params.username,
password: &params.password,
};
// TODO: handle failures
- let authenticated = users::authenticate_user(&credentials, &conn);
+ let authenticated = users::authenticate_user(&credentials, &mut conn);
match authenticated {
None => StatusCode::FORBIDDEN.into_response(),
Some(user) => {
- let session = sessions::create_session(&user, &conn);
+ let session = sessions::create_session(&user, &mut conn);
let user_data: UserData = user.into();
let headers = [("Token", &session.token)];
diff --git a/planetwars-server/src/schema.rs b/planetwars-server/src/schema.rs
index adc6555..27ebebe 100644
--- a/planetwars-server/src/schema.rs
+++ b/planetwars-server/src/schema.rs
@@ -1,7 +1,15 @@
// This file is autogenerated by diesel
#![allow(unused_imports)]
-table! {
+// @generated automatically by Diesel CLI.
+
+pub mod sql_types {
+ #[derive(diesel::sql_types::SqlType)]
+ #[diesel(postgres_type(name = "match_state"))]
+ pub struct MatchState;
+}
+
+diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@@ -14,7 +22,7 @@ table! {
}
}
-table! {
+diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@@ -26,7 +34,7 @@ table! {
}
}
-table! {
+diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@@ -37,7 +45,7 @@ table! {
}
}
-table! {
+diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@@ -48,13 +56,14 @@ table! {
}
}
-table! {
+diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
+ use super::sql_types::MatchState;
matches (id) {
id -> Int4,
- state -> Match_state,
+ state -> MatchState,
log_path -> Text,
created_at -> Timestamp,
winner -> Nullable<Int4>,
@@ -63,7 +72,7 @@ table! {
}
}
-table! {
+diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@@ -73,7 +82,7 @@ table! {
}
}
-table! {
+diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@@ -84,7 +93,7 @@ table! {
}
}
-table! {
+diesel::table! {
use diesel::sql_types::*;
use crate::db_types::*;
@@ -96,14 +105,14 @@ table! {
}
}
-joinable!(bots -> users (owner_id));
-joinable!(match_players -> bot_versions (bot_version_id));
-joinable!(match_players -> matches (match_id));
-joinable!(matches -> maps (map_id));
-joinable!(ratings -> bots (bot_id));
-joinable!(sessions -> users (user_id));
+diesel::joinable!(bots -> users (owner_id));
+diesel::joinable!(match_players -> bot_versions (bot_version_id));
+diesel::joinable!(match_players -> matches (match_id));
+diesel::joinable!(matches -> maps (map_id));
+diesel::joinable!(ratings -> bots (bot_id));
+diesel::joinable!(sessions -> users (user_id));
-allow_tables_to_appear_in_same_query!(
+diesel::allow_tables_to_appear_in_same_query!(
bot_versions,
bots,
maps,
diff --git a/planetwars-server/tests/integration.rs b/planetwars-server/tests/integration.rs
index ad63e0e..83de912 100644
--- a/planetwars-server/tests/integration.rs
+++ b/planetwars-server/tests/integration.rs
@@ -27,7 +27,7 @@ fn create_subdir<P: AsRef<Path>>(base_path: &Path, p: P) -> io::Result<String> {
Ok(dir_path_string)
}
-fn clear_database(conn: &PgConnection) {
+fn clear_database(conn: &mut PgConnection) {
diesel::sql_query(
"TRUNCATE TABLE
bots,
@@ -45,20 +45,20 @@ fn clear_database(conn: &PgConnection) {
/// Setup a simple text fixture, having simplebot and the hex map.
/// This is enough to run a simple match.
-fn setup_simple_fixture(db_conn: &PgConnection, config: &GlobalConfig) {
+fn setup_simple_fixture(db_conn: &mut PgConnection, config: &GlobalConfig) {
let bot = db::bots::create_bot(
&db::bots::NewBot {
owner_id: None,
name: "simplebot",
},
- &db_conn,
+ db_conn,
)
.expect("could not create simplebot");
let simplebot_code = std::fs::read_to_string("../simplebot/simplebot.py")
.expect("could not read simplebot code");
let _bot_version =
- modules::bots::save_code_string(&simplebot_code, Some(bot.id), &db_conn, &config)
+ modules::bots::save_code_string(&simplebot_code, Some(bot.id), db_conn, &config)
.expect("could not save bot version");
std::fs::copy(
@@ -71,7 +71,7 @@ fn setup_simple_fixture(db_conn: &PgConnection, config: &GlobalConfig) {
name: "hex",
file_path: "hex.json",
},
- &db_conn,
+ db_conn,
)
.expect("could not save map");
}
@@ -119,14 +119,14 @@ impl<'a> TestApp<'a> {
async fn with_db_conn<F, R>(&self, function: F) -> R
where
- F: FnOnce(&PgConnection) -> R,
+ F: FnOnce(&mut PgConnection) -> R,
{
- let db_conn = self
+ let mut db_conn = self
.db_pool
.get()
.await
.expect("could not get db connection");
- function(&db_conn)
+ function(&mut db_conn)
}
}