diff options
Diffstat (limited to 'planetwars-server')
-rw-r--r-- | planetwars-server/Cargo.toml | 5 | ||||
-rw-r--r-- | planetwars-server/src/lib.rs | 20 | ||||
-rw-r--r-- | planetwars-server/src/modules/mod.rs | 1 | ||||
-rw-r--r-- | planetwars-server/src/modules/ranking.rs | 295 | ||||
-rw-r--r-- | planetwars-server/src/modules/registry.rs | 408 | ||||
-rw-r--r-- | planetwars-server/src/routes/users.rs | 6 |
6 files changed, 701 insertions, 34 deletions
diff --git a/planetwars-server/Cargo.toml b/planetwars-server/Cargo.toml index 0ceabbc..f2444c1 100644 --- a/planetwars-server/Cargo.toml +++ b/planetwars-server/Cargo.toml @@ -6,10 +6,11 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +futures = "0.3" tokio = { version = "1.15", features = ["full"] } tokio-stream = "0.1.9" hyper = "0.14" -axum = { version = "0.4", features = ["json", "headers", "multipart"] } +axum = { version = "0.5", features = ["json", "headers", "multipart"] } diesel = { version = "1.4.4", features = ["postgres", "chrono"] } diesel-derive-enum = { version = "1.1", features = ["postgres"] } bb8 = "0.7" @@ -27,6 +28,8 @@ toml = "0.5" planetwars-matchrunner = { path = "../planetwars-matchrunner" } config = { version = "0.12", features = ["toml"] } thiserror = "1.0.31" +sha2 = "0.10" +tokio-util = { version="0.7.3", features=["io"] } prost = "0.10" tonic = "0.7.2" diff --git a/planetwars-server/src/lib.rs b/planetwars-server/src/lib.rs index 28d7a76..9c9a03c 100644 --- a/planetwars-server/src/lib.rs +++ b/planetwars-server/src/lib.rs @@ -16,6 +16,7 @@ use bb8_diesel::{self, DieselConnectionManager}; use config::ConfigError; use diesel::{Connection, PgConnection}; use modules::ranking::run_ranker; +use modules::registry::registry_service; use serde::Deserialize; use axum::{ @@ -23,7 +24,7 @@ use axum::{ extract::{Extension, FromRequest, RequestParts}, http::StatusCode, routing::{get, post}, - AddExtensionLayer, Router, + Router, }; // TODO: make these configurable @@ -104,15 +105,30 @@ pub fn get_config() -> Result<Configuration, ConfigError> { .try_deserialize() } +async fn run_registry(db_pool: DbPool) { + // TODO: put in config + let addr = SocketAddr::from(([127, 0, 0, 1], 9001)); + + axum::Server::bind(&addr) + .serve( + registry_service() + .layer(Extension(db_pool)) + .into_make_service(), + ) + .await + .unwrap(); +} + pub async fn run_app() { let configuration = get_config().unwrap(); let db_pool = prepare_db(&configuration.database_url).await; tokio::spawn(run_ranker(db_pool.clone())); + tokio::spawn(run_registry(db_pool.clone())); let api_service = Router::new() .nest("/api", api()) - .layer(AddExtensionLayer::new(db_pool)) + .layer(Extension(db_pool)) .into_make_service(); // TODO: put in config diff --git a/planetwars-server/src/modules/mod.rs b/planetwars-server/src/modules/mod.rs index 43c2507..1200f9d 100644 --- a/planetwars-server/src/modules/mod.rs +++ b/planetwars-server/src/modules/mod.rs @@ -4,3 +4,4 @@ pub mod bot_api; pub mod bots; pub mod matches; pub mod ranking; +pub mod registry; diff --git a/planetwars-server/src/modules/ranking.rs b/planetwars-server/src/modules/ranking.rs index d83debb..72156ee 100644 --- a/planetwars-server/src/modules/ranking.rs +++ b/planetwars-server/src/modules/ranking.rs @@ -1,15 +1,15 @@ use crate::{db::bots::Bot, DbPool}; use crate::db; +use diesel::{PgConnection, QueryResult}; use crate::modules::matches::{MatchPlayer, RunMatch}; use rand::seq::SliceRandom; -use std::time::Duration; +use std::collections::HashMap; +use std::mem; +use std::time::{Duration, Instant}; use tokio; const RANKER_INTERVAL: u64 = 60; -const START_RATING: f64 = 0.0; -const SCALE: f64 = 100.0; -const MAX_UPDATE: f64 = 0.1; pub async fn run_ranker(db_pool: DbPool) { // TODO: make this configurable @@ -31,6 +31,7 @@ pub async fn run_ranker(db_pool: DbPool) { bots.choose_multiple(&mut rng, 2).cloned().collect() }; play_ranking_match(selected_bots, db_pool.clone()).await; + recalculate_ratings(&db_conn).expect("could not recalculate ratings"); } } @@ -52,39 +53,277 @@ async fn play_ranking_match(selected_bots: Vec<Bot>, db_pool: DbPool) { run_match .store_in_database(&db_conn) .expect("could not store match in db"); - let outcome = run_match + run_match .spawn(db_pool.clone()) .await .expect("running match failed"); +} - let mut ratings = Vec::new(); - for bot in &selected_bots { - let rating = db::ratings::get_rating(bot.id, &db_conn) - .expect("could not get bot rating") - .unwrap_or(START_RATING); - ratings.push(rating); +fn recalculate_ratings(db_conn: &PgConnection) -> QueryResult<()> { + let start = Instant::now(); + let match_stats = fetch_match_stats(db_conn)?; + let ratings = estimate_ratings_from_stats(match_stats); + + for (bot_id, rating) in ratings { + db::ratings::set_rating(bot_id, rating, db_conn).expect("could not update bot rating"); + } + let elapsed = Instant::now() - start; + // TODO: set up proper logging infrastructure + println!("computed ratings in {} ms", elapsed.subsec_millis()); + Ok(()) +} + +#[derive(Default)] +struct MatchStats { + total_score: f64, + num_matches: usize, +} + +fn fetch_match_stats(db_conn: &PgConnection) -> QueryResult<HashMap<(i32, i32), MatchStats>> { + let matches = db::matches::list_matches(db_conn)?; + + let mut match_stats = HashMap::<(i32, i32), MatchStats>::new(); + for m in matches { + if m.match_players.len() != 2 { + continue; + } + let (mut a_id, mut b_id) = match (&m.match_players[0].bot, &m.match_players[1].bot) { + (Some(ref a), Some(ref b)) => (a.id, b.id), + _ => continue, + }; + // score of player a + let mut score = match m.base.winner { + None => 0.5, + Some(0) => 1.0, + Some(1) => 0.0, + _ => panic!("invalid winner"), + }; + + // put players in canonical order: smallest id first + if b_id < a_id { + mem::swap(&mut a_id, &mut b_id); + score = 1.0 - score; + } + + let entry = match_stats.entry((a_id, b_id)).or_default(); + entry.num_matches += 1; + entry.total_score += score; } + Ok(match_stats) +} + +/// Tokenizes player ids to a set of consecutive numbers +struct PlayerTokenizer { + id_to_ix: HashMap<i32, usize>, + ids: Vec<i32>, +} - // simple elo rating +impl PlayerTokenizer { + fn new() -> Self { + PlayerTokenizer { + id_to_ix: HashMap::new(), + ids: Vec::new(), + } + } - let scores = match outcome.winner { - None => vec![0.5; 2], - Some(player_num) => { - // TODO: please get rid of this offset - let player_ix = player_num - 1; - let mut scores = vec![0.0; 2]; - scores[player_ix] = 1.0; - scores + fn tokenize(&mut self, id: i32) -> usize { + match self.id_to_ix.get(&id) { + Some(&ix) => ix, + None => { + let ix = self.ids.len(); + self.ids.push(id); + self.id_to_ix.insert(id, ix); + ix + } } - }; + } + + fn detokenize(&self, ix: usize) -> i32 { + self.ids[ix] + } + + fn player_count(&self) -> usize { + self.ids.len() + } +} + +fn sigmoid(logit: f64) -> f64 { + 1.0 / (1.0 + (-logit).exp()) +} + +fn estimate_ratings_from_stats(match_stats: HashMap<(i32, i32), MatchStats>) -> Vec<(i32, f64)> { + // map player ids to player indexes in the ratings array + let mut input_records = Vec::<RatingInputRecord>::with_capacity(match_stats.len()); + let mut player_tokenizer = PlayerTokenizer::new(); + + for ((a_id, b_id), stats) in match_stats.into_iter() { + input_records.push(RatingInputRecord { + p1_ix: player_tokenizer.tokenize(a_id), + p2_ix: player_tokenizer.tokenize(b_id), + score: stats.total_score / stats.num_matches as f64, + weight: stats.num_matches as f64, + }) + } + + let mut ratings = vec![0f64; player_tokenizer.player_count()]; + // TODO: fetch these from config + let params = OptimizeRatingsParams::default(); + optimize_ratings(&mut ratings, &input_records, ¶ms); + + ratings + .into_iter() + .enumerate() + .map(|(ix, rating)| { + ( + player_tokenizer.detokenize(ix), + rating * 100f64 / 10f64.ln(), + ) + }) + .collect() +} + +struct RatingInputRecord { + /// index of first player + p1_ix: usize, + /// index of secord player + p2_ix: usize, + /// score of player 1 (= 1 - score of player 2) + score: f64, + /// weight of this record + weight: f64, +} + +struct OptimizeRatingsParams { + tolerance: f64, + learning_rate: f64, + max_iterations: usize, + regularization_weight: f64, +} + +impl Default for OptimizeRatingsParams { + fn default() -> Self { + OptimizeRatingsParams { + tolerance: 10f64.powi(-8), + learning_rate: 0.1, + max_iterations: 10_000, + regularization_weight: 10.0, + } + } +} + +fn optimize_ratings( + ratings: &mut [f64], + input_records: &[RatingInputRecord], + params: &OptimizeRatingsParams, +) { + let total_weight = + params.regularization_weight + input_records.iter().map(|r| r.weight).sum::<f64>(); + + for _iteration in 0..params.max_iterations { + let mut gradients = vec![0f64; ratings.len()]; + + // calculate gradients + for record in input_records.iter() { + let predicted = sigmoid(ratings[record.p1_ix] - ratings[record.p2_ix]); + let gradient = record.weight * (predicted - record.score); + gradients[record.p1_ix] += gradient; + gradients[record.p2_ix] -= gradient; + } + + // apply update step + let mut converged = true; + for (rating, gradient) in ratings.iter_mut().zip(&gradients) { + let update = params.learning_rate * (gradient + params.regularization_weight * *rating) + / total_weight; + if update > params.tolerance { + converged = false; + } + *rating -= update; + } + + if converged { + break; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn is_close(a: f64, b: f64) -> bool { + (a - b).abs() < 10f64.powi(-6) + } + + #[test] + fn test_optimize_ratings() { + let input_records = vec![RatingInputRecord { + p1_ix: 0, + p2_ix: 1, + score: 0.8, + weight: 1.0, + }]; + + let mut ratings = vec![0.0; 2]; + optimize_ratings( + &mut ratings, + &input_records, + &OptimizeRatingsParams { + regularization_weight: 0.0, + ..Default::default() + }, + ); + assert!(is_close(sigmoid(ratings[0] - ratings[1]), 0.8)); + } + + #[test] + fn test_optimize_ratings_weight() { + let input_records = vec![ + RatingInputRecord { + p1_ix: 0, + p2_ix: 1, + score: 1.0, + weight: 1.0, + }, + RatingInputRecord { + p1_ix: 1, + p2_ix: 0, + score: 1.0, + weight: 3.0, + }, + ]; + + let mut ratings = vec![0.0; 2]; + optimize_ratings( + &mut ratings, + &input_records, + &OptimizeRatingsParams { + regularization_weight: 0.0, + ..Default::default() + }, + ); + assert!(is_close(sigmoid(ratings[0] - ratings[1]), 0.25)); + } - for i in 0..2 { - let j = 1 - i; + #[test] + fn test_optimize_ratings_regularization() { + let input_records = vec![RatingInputRecord { + p1_ix: 0, + p2_ix: 1, + score: 0.8, + weight: 100.0, + }]; - let scaled_difference = (ratings[j] - ratings[i]) / SCALE; - let expected = 1.0 / (1.0 + 10f64.powf(scaled_difference)); - let new_rating = ratings[i] + MAX_UPDATE * (scores[i] - expected); - db::ratings::set_rating(selected_bots[i].id, new_rating, &db_conn) - .expect("could not update bot rating"); + let mut ratings = vec![0.0; 2]; + optimize_ratings( + &mut ratings, + &input_records, + &OptimizeRatingsParams { + regularization_weight: 1.0, + ..Default::default() + }, + ); + let predicted = sigmoid(ratings[0] - ratings[1]); + assert!(0.5 < predicted && predicted < 0.8); } } diff --git a/planetwars-server/src/modules/registry.rs b/planetwars-server/src/modules/registry.rs new file mode 100644 index 0000000..c8ec4fa --- /dev/null +++ b/planetwars-server/src/modules/registry.rs @@ -0,0 +1,408 @@ +// TODO: this module is functional, but it needs a good refactor for proper error handling. + +use axum::body::{Body, StreamBody}; +use axum::extract::{BodyStream, FromRequest, Path, Query, RequestParts, TypedHeader}; +use axum::headers::authorization::Basic; +use axum::headers::Authorization; +use axum::response::{IntoResponse, Response}; +use axum::routing::{get, head, post, put}; +use axum::{async_trait, Router}; +use futures::StreamExt; +use hyper::StatusCode; +use serde::Serialize; +use sha2::{Digest, Sha256}; +use std::path::PathBuf; +use tokio::io::AsyncWriteExt; +use tokio_util::io::ReaderStream; + +use crate::util::gen_alphanumeric; +use crate::{db, DatabaseConnection}; + +use crate::db::users::{authenticate_user, Credentials, User}; + +// TODO: put this in a config file +const REGISTRY_PATH: &str = "./data/registry"; + +pub fn registry_service() -> Router { + Router::new() + // The docker API requires this trailing slash + .nest("/v2/", registry_api_v2()) +} + +fn registry_api_v2() -> Router { + Router::new() + .route("/", get(get_root)) + .route( + "/:name/manifests/:reference", + get(get_manifest).put(put_manifest), + ) + .route( + "/:name/blobs/:digest", + head(check_blob_exists).get(get_blob), + ) + .route("/:name/blobs/uploads/", post(create_upload)) + .route( + "/:name/blobs/uploads/:uuid", + put(put_upload).patch(patch_upload), + ) +} + +const ADMIN_USERNAME: &str = "admin"; +// TODO: put this in some configuration +const ADMIN_PASSWORD: &str = "supersecretpassword"; + +type AuthorizationHeader = TypedHeader<Authorization<Basic>>; + +enum RegistryAuth { + User(User), + Admin, +} + +enum RegistryAuthError { + NoAuthHeader, + InvalidCredentials, +} + +impl IntoResponse for RegistryAuthError { + fn into_response(self) -> Response { + // TODO: create enum for registry errors + let err = RegistryErrors { + errors: vec![RegistryError { + code: "UNAUTHORIZED".to_string(), + message: "please log in".to_string(), + detail: serde_json::Value::Null, + }], + }; + + ( + StatusCode::UNAUTHORIZED, + [ + ("Docker-Distribution-API-Version", "registry/2.0"), + ("WWW-Authenticate", "Basic"), + ], + serde_json::to_string(&err).unwrap(), + ) + .into_response() + } +} + +#[async_trait] +impl<B> FromRequest<B> for RegistryAuth +where + B: Send, +{ + type Rejection = RegistryAuthError; + + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + let TypedHeader(Authorization(basic)) = AuthorizationHeader::from_request(req) + .await + .map_err(|_| RegistryAuthError::NoAuthHeader)?; + + // TODO: Into<Credentials> would be nice + let credentials = Credentials { + username: basic.username(), + password: basic.password(), + }; + + if credentials.username == ADMIN_USERNAME { + if credentials.password == ADMIN_PASSWORD { + Ok(RegistryAuth::Admin) + } else { + Err(RegistryAuthError::InvalidCredentials) + } + } else { + let db_conn = DatabaseConnection::from_request(req).await.unwrap(); + let user = authenticate_user(&credentials, &db_conn) + .ok_or(RegistryAuthError::InvalidCredentials)?; + + Ok(RegistryAuth::User(user)) + } + } +} + +// Since async file io just calls spawn_blocking internally, it does not really make sense +// to make this an async function +fn file_sha256_digest(path: &std::path::Path) -> std::io::Result<String> { + let mut file = std::fs::File::open(path)?; + let mut hasher = Sha256::new(); + let _n = std::io::copy(&mut file, &mut hasher)?; + Ok(format!("{:x}", hasher.finalize())) +} + +/// Get the index of the last byte in a file +async fn last_byte_pos(file: &tokio::fs::File) -> std::io::Result<u64> { + let n_bytes = file.metadata().await?.len(); + let pos = if n_bytes == 0 { 0 } else { n_bytes - 1 }; + Ok(pos) +} + +async fn get_root(_auth: RegistryAuth) -> impl IntoResponse { + // root should return 200 OK to confirm api compliance + Response::builder() + .status(StatusCode::OK) + .header("Docker-Distribution-API-Version", "registry/2.0") + .body(Body::empty()) + .unwrap() +} + +#[derive(Serialize)] +pub struct RegistryErrors { + errors: Vec<RegistryError>, +} + +#[derive(Serialize)] +pub struct RegistryError { + code: String, + message: String, + detail: serde_json::Value, +} + +async fn check_blob_exists( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, raw_digest)): Path<(String, String)>, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + let digest = raw_digest.strip_prefix("sha256:").unwrap(); + let blob_path = PathBuf::from(REGISTRY_PATH).join("sha256").join(&digest); + if blob_path.exists() { + let metadata = std::fs::metadata(&blob_path).unwrap(); + Ok((StatusCode::OK, [("Content-Length", metadata.len())])) + } else { + Err(StatusCode::NOT_FOUND) + } +} + +async fn get_blob( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, raw_digest)): Path<(String, String)>, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + let digest = raw_digest.strip_prefix("sha256:").unwrap(); + let blob_path = PathBuf::from(REGISTRY_PATH).join("sha256").join(&digest); + if !blob_path.exists() { + return Err(StatusCode::NOT_FOUND); + } + let file = tokio::fs::File::open(&blob_path).await.unwrap(); + let reader_stream = ReaderStream::new(file); + let stream_body = StreamBody::new(reader_stream); + Ok(stream_body) +} + +async fn create_upload( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path(repository_name): Path<String>, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + let uuid = gen_alphanumeric(16); + tokio::fs::File::create(PathBuf::from(REGISTRY_PATH).join("uploads").join(&uuid)) + .await + .unwrap(); + + Ok(Response::builder() + .status(StatusCode::ACCEPTED) + .header( + "Location", + format!("/v2/{}/blobs/uploads/{}", repository_name, uuid), + ) + .header("Docker-Upload-UUID", uuid) + .header("Range", "bytes=0-0") + .body(Body::empty()) + .unwrap()) +} + +async fn patch_upload( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, uuid)): Path<(String, String)>, + mut stream: BodyStream, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + // TODO: support content range header in request + let upload_path = PathBuf::from(REGISTRY_PATH).join("uploads").join(&uuid); + let mut file = tokio::fs::OpenOptions::new() + .read(false) + .write(true) + .append(true) + .create(false) + .open(upload_path) + .await + .unwrap(); + while let Some(Ok(chunk)) = stream.next().await { + file.write_all(&chunk).await.unwrap(); + } + + let last_byte = last_byte_pos(&file).await.unwrap(); + + Ok(Response::builder() + .status(StatusCode::ACCEPTED) + .header( + "Location", + format!("/v2/{}/blobs/uploads/{}", repository_name, uuid), + ) + .header("Docker-Upload-UUID", uuid) + // range indicating current progress of the upload + .header("Range", format!("0-{}", last_byte)) + .body(Body::empty()) + .unwrap()) +} + +use serde::Deserialize; +#[derive(Deserialize)] +struct UploadParams { + digest: String, +} + +async fn put_upload( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, uuid)): Path<(String, String)>, + Query(params): Query<UploadParams>, + mut stream: BodyStream, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + let upload_path = PathBuf::from(REGISTRY_PATH).join("uploads").join(&uuid); + let mut file = tokio::fs::OpenOptions::new() + .read(false) + .write(true) + .append(true) + .create(false) + .open(&upload_path) + .await + .unwrap(); + + let range_begin = last_byte_pos(&file).await.unwrap(); + while let Some(Ok(chunk)) = stream.next().await { + file.write_all(&chunk).await.unwrap(); + } + file.flush().await.unwrap(); + let range_end = last_byte_pos(&file).await.unwrap(); + + let expected_digest = params.digest.strip_prefix("sha256:").unwrap(); + let digest = file_sha256_digest(&upload_path).unwrap(); + if digest != expected_digest { + // TODO: return a docker error body + return Err(StatusCode::BAD_REQUEST); + } + + let target_path = PathBuf::from(REGISTRY_PATH).join("sha256").join(&digest); + tokio::fs::rename(&upload_path, &target_path).await.unwrap(); + + Ok(Response::builder() + .status(StatusCode::CREATED) + .header( + "Location", + format!("/v2/{}/blobs/{}", repository_name, digest), + ) + .header("Docker-Upload-UUID", uuid) + // content range for bytes that were in the body of this request + .header("Content-Range", format!("{}-{}", range_begin, range_end)) + .header("Docker-Content-Digest", params.digest) + .body(Body::empty()) + .unwrap()) +} + +async fn get_manifest( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, reference)): Path<(String, String)>, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + let manifest_path = PathBuf::from(REGISTRY_PATH) + .join("manifests") + .join(&repository_name) + .join(&reference) + .with_extension("json"); + let data = tokio::fs::read(&manifest_path).await.unwrap(); + + let manifest: serde_json::Map<String, serde_json::Value> = + serde_json::from_slice(&data).unwrap(); + let media_type = manifest.get("mediaType").unwrap().as_str().unwrap(); + Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", media_type) + .body(axum::body::Full::from(data)) + .unwrap()) +} + +async fn put_manifest( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, reference)): Path<(String, String)>, + mut stream: BodyStream, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + let repository_dir = PathBuf::from(REGISTRY_PATH) + .join("manifests") + .join(&repository_name); + + tokio::fs::create_dir_all(&repository_dir).await.unwrap(); + + let mut hasher = Sha256::new(); + let manifest_path = repository_dir.join(&reference).with_extension("json"); + { + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&manifest_path) + .await + .unwrap(); + while let Some(Ok(chunk)) = stream.next().await { + hasher.update(&chunk); + file.write_all(&chunk).await.unwrap(); + } + } + let digest = hasher.finalize(); + // TODO: store content-adressable manifests separately + let content_digest = format!("sha256:{:x}", digest); + let digest_path = repository_dir.join(&content_digest).with_extension("json"); + tokio::fs::copy(manifest_path, digest_path).await.unwrap(); + + Ok(Response::builder() + .status(StatusCode::CREATED) + .header( + "Location", + format!("/v2/{}/manifests/{}", repository_name, reference), + ) + .header("Docker-Content-Digest", content_digest) + .body(Body::empty()) + .unwrap()) +} + +/// Ensure that the accessed repository exists +/// and the user is allowed to access ti +fn check_access( + repository_name: &str, + auth: &RegistryAuth, + db_conn: &DatabaseConnection, +) -> Result<(), StatusCode> { + use diesel::OptionalExtension; + + // TODO: it would be nice to provide the found repository + // to the route handlers + let bot = db::bots::find_bot_by_name(repository_name, db_conn) + .optional() + .expect("could not run query") + .ok_or(StatusCode::NOT_FOUND)?; + + match &auth { + RegistryAuth::Admin => Ok(()), + RegistryAuth::User(user) => { + if bot.owner_id == Some(user.id) { + Ok(()) + } else { + Err(StatusCode::FORBIDDEN) + } + } + } +} diff --git a/planetwars-server/src/routes/users.rs b/planetwars-server/src/routes/users.rs index 54ddd09..1989904 100644 --- a/planetwars-server/src/routes/users.rs +++ b/planetwars-server/src/routes/users.rs @@ -5,7 +5,7 @@ use axum::extract::{FromRequest, RequestParts, TypedHeader}; use axum::headers::authorization::Bearer; use axum::headers::Authorization; use axum::http::StatusCode; -use axum::response::{Headers, IntoResponse, Response}; +use axum::response::{IntoResponse, Response}; use axum::{async_trait, Json}; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -163,9 +163,9 @@ pub async fn login(conn: DatabaseConnection, params: Json<LoginParams>) -> Respo Some(user) => { let session = sessions::create_session(&user, &conn); let user_data: UserData = user.into(); - let headers = Headers(vec![("Token", &session.token)]); + let headers = [("Token", &session.token)]; - (headers, Json(user_data)).into_response() + (StatusCode::OK, headers, Json(user_data)).into_response() } } } |