aboutsummaryrefslogtreecommitdiff
path: root/planetwars-server/src
diff options
context:
space:
mode:
Diffstat (limited to 'planetwars-server/src')
-rw-r--r--planetwars-server/src/lib.rs20
-rw-r--r--planetwars-server/src/modules/mod.rs1
-rw-r--r--planetwars-server/src/modules/ranking.rs295
-rw-r--r--planetwars-server/src/modules/registry.rs408
-rw-r--r--planetwars-server/src/routes/users.rs6
5 files changed, 697 insertions, 33 deletions
diff --git a/planetwars-server/src/lib.rs b/planetwars-server/src/lib.rs
index 28d7a76..9c9a03c 100644
--- a/planetwars-server/src/lib.rs
+++ b/planetwars-server/src/lib.rs
@@ -16,6 +16,7 @@ use bb8_diesel::{self, DieselConnectionManager};
use config::ConfigError;
use diesel::{Connection, PgConnection};
use modules::ranking::run_ranker;
+use modules::registry::registry_service;
use serde::Deserialize;
use axum::{
@@ -23,7 +24,7 @@ use axum::{
extract::{Extension, FromRequest, RequestParts},
http::StatusCode,
routing::{get, post},
- AddExtensionLayer, Router,
+ Router,
};
// TODO: make these configurable
@@ -104,15 +105,30 @@ pub fn get_config() -> Result<Configuration, ConfigError> {
.try_deserialize()
}
+async fn run_registry(db_pool: DbPool) {
+ // TODO: put in config
+ let addr = SocketAddr::from(([127, 0, 0, 1], 9001));
+
+ axum::Server::bind(&addr)
+ .serve(
+ registry_service()
+ .layer(Extension(db_pool))
+ .into_make_service(),
+ )
+ .await
+ .unwrap();
+}
+
pub async fn run_app() {
let configuration = get_config().unwrap();
let db_pool = prepare_db(&configuration.database_url).await;
tokio::spawn(run_ranker(db_pool.clone()));
+ tokio::spawn(run_registry(db_pool.clone()));
let api_service = Router::new()
.nest("/api", api())
- .layer(AddExtensionLayer::new(db_pool))
+ .layer(Extension(db_pool))
.into_make_service();
// TODO: put in config
diff --git a/planetwars-server/src/modules/mod.rs b/planetwars-server/src/modules/mod.rs
index 43c2507..1200f9d 100644
--- a/planetwars-server/src/modules/mod.rs
+++ b/planetwars-server/src/modules/mod.rs
@@ -4,3 +4,4 @@ pub mod bot_api;
pub mod bots;
pub mod matches;
pub mod ranking;
+pub mod registry;
diff --git a/planetwars-server/src/modules/ranking.rs b/planetwars-server/src/modules/ranking.rs
index d83debb..72156ee 100644
--- a/planetwars-server/src/modules/ranking.rs
+++ b/planetwars-server/src/modules/ranking.rs
@@ -1,15 +1,15 @@
use crate::{db::bots::Bot, DbPool};
use crate::db;
+use diesel::{PgConnection, QueryResult};
use crate::modules::matches::{MatchPlayer, RunMatch};
use rand::seq::SliceRandom;
-use std::time::Duration;
+use std::collections::HashMap;
+use std::mem;
+use std::time::{Duration, Instant};
use tokio;
const RANKER_INTERVAL: u64 = 60;
-const START_RATING: f64 = 0.0;
-const SCALE: f64 = 100.0;
-const MAX_UPDATE: f64 = 0.1;
pub async fn run_ranker(db_pool: DbPool) {
// TODO: make this configurable
@@ -31,6 +31,7 @@ pub async fn run_ranker(db_pool: DbPool) {
bots.choose_multiple(&mut rng, 2).cloned().collect()
};
play_ranking_match(selected_bots, db_pool.clone()).await;
+ recalculate_ratings(&db_conn).expect("could not recalculate ratings");
}
}
@@ -52,39 +53,277 @@ async fn play_ranking_match(selected_bots: Vec<Bot>, db_pool: DbPool) {
run_match
.store_in_database(&db_conn)
.expect("could not store match in db");
- let outcome = run_match
+ run_match
.spawn(db_pool.clone())
.await
.expect("running match failed");
+}
- let mut ratings = Vec::new();
- for bot in &selected_bots {
- let rating = db::ratings::get_rating(bot.id, &db_conn)
- .expect("could not get bot rating")
- .unwrap_or(START_RATING);
- ratings.push(rating);
+fn recalculate_ratings(db_conn: &PgConnection) -> QueryResult<()> {
+ let start = Instant::now();
+ let match_stats = fetch_match_stats(db_conn)?;
+ let ratings = estimate_ratings_from_stats(match_stats);
+
+ for (bot_id, rating) in ratings {
+ db::ratings::set_rating(bot_id, rating, db_conn).expect("could not update bot rating");
+ }
+ let elapsed = Instant::now() - start;
+ // TODO: set up proper logging infrastructure
+ println!("computed ratings in {} ms", elapsed.subsec_millis());
+ Ok(())
+}
+
+#[derive(Default)]
+struct MatchStats {
+ total_score: f64,
+ num_matches: usize,
+}
+
+fn fetch_match_stats(db_conn: &PgConnection) -> QueryResult<HashMap<(i32, i32), MatchStats>> {
+ let matches = db::matches::list_matches(db_conn)?;
+
+ let mut match_stats = HashMap::<(i32, i32), MatchStats>::new();
+ for m in matches {
+ if m.match_players.len() != 2 {
+ continue;
+ }
+ let (mut a_id, mut b_id) = match (&m.match_players[0].bot, &m.match_players[1].bot) {
+ (Some(ref a), Some(ref b)) => (a.id, b.id),
+ _ => continue,
+ };
+ // score of player a
+ let mut score = match m.base.winner {
+ None => 0.5,
+ Some(0) => 1.0,
+ Some(1) => 0.0,
+ _ => panic!("invalid winner"),
+ };
+
+ // put players in canonical order: smallest id first
+ if b_id < a_id {
+ mem::swap(&mut a_id, &mut b_id);
+ score = 1.0 - score;
+ }
+
+ let entry = match_stats.entry((a_id, b_id)).or_default();
+ entry.num_matches += 1;
+ entry.total_score += score;
}
+ Ok(match_stats)
+}
+
+/// Tokenizes player ids to a set of consecutive numbers
+struct PlayerTokenizer {
+ id_to_ix: HashMap<i32, usize>,
+ ids: Vec<i32>,
+}
- // simple elo rating
+impl PlayerTokenizer {
+ fn new() -> Self {
+ PlayerTokenizer {
+ id_to_ix: HashMap::new(),
+ ids: Vec::new(),
+ }
+ }
- let scores = match outcome.winner {
- None => vec![0.5; 2],
- Some(player_num) => {
- // TODO: please get rid of this offset
- let player_ix = player_num - 1;
- let mut scores = vec![0.0; 2];
- scores[player_ix] = 1.0;
- scores
+ fn tokenize(&mut self, id: i32) -> usize {
+ match self.id_to_ix.get(&id) {
+ Some(&ix) => ix,
+ None => {
+ let ix = self.ids.len();
+ self.ids.push(id);
+ self.id_to_ix.insert(id, ix);
+ ix
+ }
}
- };
+ }
+
+ fn detokenize(&self, ix: usize) -> i32 {
+ self.ids[ix]
+ }
+
+ fn player_count(&self) -> usize {
+ self.ids.len()
+ }
+}
+
+fn sigmoid(logit: f64) -> f64 {
+ 1.0 / (1.0 + (-logit).exp())
+}
+
+fn estimate_ratings_from_stats(match_stats: HashMap<(i32, i32), MatchStats>) -> Vec<(i32, f64)> {
+ // map player ids to player indexes in the ratings array
+ let mut input_records = Vec::<RatingInputRecord>::with_capacity(match_stats.len());
+ let mut player_tokenizer = PlayerTokenizer::new();
+
+ for ((a_id, b_id), stats) in match_stats.into_iter() {
+ input_records.push(RatingInputRecord {
+ p1_ix: player_tokenizer.tokenize(a_id),
+ p2_ix: player_tokenizer.tokenize(b_id),
+ score: stats.total_score / stats.num_matches as f64,
+ weight: stats.num_matches as f64,
+ })
+ }
+
+ let mut ratings = vec![0f64; player_tokenizer.player_count()];
+ // TODO: fetch these from config
+ let params = OptimizeRatingsParams::default();
+ optimize_ratings(&mut ratings, &input_records, &params);
+
+ ratings
+ .into_iter()
+ .enumerate()
+ .map(|(ix, rating)| {
+ (
+ player_tokenizer.detokenize(ix),
+ rating * 100f64 / 10f64.ln(),
+ )
+ })
+ .collect()
+}
+
+struct RatingInputRecord {
+ /// index of first player
+ p1_ix: usize,
+ /// index of secord player
+ p2_ix: usize,
+ /// score of player 1 (= 1 - score of player 2)
+ score: f64,
+ /// weight of this record
+ weight: f64,
+}
+
+struct OptimizeRatingsParams {
+ tolerance: f64,
+ learning_rate: f64,
+ max_iterations: usize,
+ regularization_weight: f64,
+}
+
+impl Default for OptimizeRatingsParams {
+ fn default() -> Self {
+ OptimizeRatingsParams {
+ tolerance: 10f64.powi(-8),
+ learning_rate: 0.1,
+ max_iterations: 10_000,
+ regularization_weight: 10.0,
+ }
+ }
+}
+
+fn optimize_ratings(
+ ratings: &mut [f64],
+ input_records: &[RatingInputRecord],
+ params: &OptimizeRatingsParams,
+) {
+ let total_weight =
+ params.regularization_weight + input_records.iter().map(|r| r.weight).sum::<f64>();
+
+ for _iteration in 0..params.max_iterations {
+ let mut gradients = vec![0f64; ratings.len()];
+
+ // calculate gradients
+ for record in input_records.iter() {
+ let predicted = sigmoid(ratings[record.p1_ix] - ratings[record.p2_ix]);
+ let gradient = record.weight * (predicted - record.score);
+ gradients[record.p1_ix] += gradient;
+ gradients[record.p2_ix] -= gradient;
+ }
+
+ // apply update step
+ let mut converged = true;
+ for (rating, gradient) in ratings.iter_mut().zip(&gradients) {
+ let update = params.learning_rate * (gradient + params.regularization_weight * *rating)
+ / total_weight;
+ if update > params.tolerance {
+ converged = false;
+ }
+ *rating -= update;
+ }
+
+ if converged {
+ break;
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ fn is_close(a: f64, b: f64) -> bool {
+ (a - b).abs() < 10f64.powi(-6)
+ }
+
+ #[test]
+ fn test_optimize_ratings() {
+ let input_records = vec![RatingInputRecord {
+ p1_ix: 0,
+ p2_ix: 1,
+ score: 0.8,
+ weight: 1.0,
+ }];
+
+ let mut ratings = vec![0.0; 2];
+ optimize_ratings(
+ &mut ratings,
+ &input_records,
+ &OptimizeRatingsParams {
+ regularization_weight: 0.0,
+ ..Default::default()
+ },
+ );
+ assert!(is_close(sigmoid(ratings[0] - ratings[1]), 0.8));
+ }
+
+ #[test]
+ fn test_optimize_ratings_weight() {
+ let input_records = vec![
+ RatingInputRecord {
+ p1_ix: 0,
+ p2_ix: 1,
+ score: 1.0,
+ weight: 1.0,
+ },
+ RatingInputRecord {
+ p1_ix: 1,
+ p2_ix: 0,
+ score: 1.0,
+ weight: 3.0,
+ },
+ ];
+
+ let mut ratings = vec![0.0; 2];
+ optimize_ratings(
+ &mut ratings,
+ &input_records,
+ &OptimizeRatingsParams {
+ regularization_weight: 0.0,
+ ..Default::default()
+ },
+ );
+ assert!(is_close(sigmoid(ratings[0] - ratings[1]), 0.25));
+ }
- for i in 0..2 {
- let j = 1 - i;
+ #[test]
+ fn test_optimize_ratings_regularization() {
+ let input_records = vec![RatingInputRecord {
+ p1_ix: 0,
+ p2_ix: 1,
+ score: 0.8,
+ weight: 100.0,
+ }];
- let scaled_difference = (ratings[j] - ratings[i]) / SCALE;
- let expected = 1.0 / (1.0 + 10f64.powf(scaled_difference));
- let new_rating = ratings[i] + MAX_UPDATE * (scores[i] - expected);
- db::ratings::set_rating(selected_bots[i].id, new_rating, &db_conn)
- .expect("could not update bot rating");
+ let mut ratings = vec![0.0; 2];
+ optimize_ratings(
+ &mut ratings,
+ &input_records,
+ &OptimizeRatingsParams {
+ regularization_weight: 1.0,
+ ..Default::default()
+ },
+ );
+ let predicted = sigmoid(ratings[0] - ratings[1]);
+ assert!(0.5 < predicted && predicted < 0.8);
}
}
diff --git a/planetwars-server/src/modules/registry.rs b/planetwars-server/src/modules/registry.rs
new file mode 100644
index 0000000..c8ec4fa
--- /dev/null
+++ b/planetwars-server/src/modules/registry.rs
@@ -0,0 +1,408 @@
+// TODO: this module is functional, but it needs a good refactor for proper error handling.
+
+use axum::body::{Body, StreamBody};
+use axum::extract::{BodyStream, FromRequest, Path, Query, RequestParts, TypedHeader};
+use axum::headers::authorization::Basic;
+use axum::headers::Authorization;
+use axum::response::{IntoResponse, Response};
+use axum::routing::{get, head, post, put};
+use axum::{async_trait, Router};
+use futures::StreamExt;
+use hyper::StatusCode;
+use serde::Serialize;
+use sha2::{Digest, Sha256};
+use std::path::PathBuf;
+use tokio::io::AsyncWriteExt;
+use tokio_util::io::ReaderStream;
+
+use crate::util::gen_alphanumeric;
+use crate::{db, DatabaseConnection};
+
+use crate::db::users::{authenticate_user, Credentials, User};
+
+// TODO: put this in a config file
+const REGISTRY_PATH: &str = "./data/registry";
+
+pub fn registry_service() -> Router {
+ Router::new()
+ // The docker API requires this trailing slash
+ .nest("/v2/", registry_api_v2())
+}
+
+fn registry_api_v2() -> Router {
+ Router::new()
+ .route("/", get(get_root))
+ .route(
+ "/:name/manifests/:reference",
+ get(get_manifest).put(put_manifest),
+ )
+ .route(
+ "/:name/blobs/:digest",
+ head(check_blob_exists).get(get_blob),
+ )
+ .route("/:name/blobs/uploads/", post(create_upload))
+ .route(
+ "/:name/blobs/uploads/:uuid",
+ put(put_upload).patch(patch_upload),
+ )
+}
+
+const ADMIN_USERNAME: &str = "admin";
+// TODO: put this in some configuration
+const ADMIN_PASSWORD: &str = "supersecretpassword";
+
+type AuthorizationHeader = TypedHeader<Authorization<Basic>>;
+
+enum RegistryAuth {
+ User(User),
+ Admin,
+}
+
+enum RegistryAuthError {
+ NoAuthHeader,
+ InvalidCredentials,
+}
+
+impl IntoResponse for RegistryAuthError {
+ fn into_response(self) -> Response {
+ // TODO: create enum for registry errors
+ let err = RegistryErrors {
+ errors: vec![RegistryError {
+ code: "UNAUTHORIZED".to_string(),
+ message: "please log in".to_string(),
+ detail: serde_json::Value::Null,
+ }],
+ };
+
+ (
+ StatusCode::UNAUTHORIZED,
+ [
+ ("Docker-Distribution-API-Version", "registry/2.0"),
+ ("WWW-Authenticate", "Basic"),
+ ],
+ serde_json::to_string(&err).unwrap(),
+ )
+ .into_response()
+ }
+}
+
+#[async_trait]
+impl<B> FromRequest<B> for RegistryAuth
+where
+ B: Send,
+{
+ type Rejection = RegistryAuthError;
+
+ async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+ let TypedHeader(Authorization(basic)) = AuthorizationHeader::from_request(req)
+ .await
+ .map_err(|_| RegistryAuthError::NoAuthHeader)?;
+
+ // TODO: Into<Credentials> would be nice
+ let credentials = Credentials {
+ username: basic.username(),
+ password: basic.password(),
+ };
+
+ if credentials.username == ADMIN_USERNAME {
+ if credentials.password == ADMIN_PASSWORD {
+ Ok(RegistryAuth::Admin)
+ } else {
+ Err(RegistryAuthError::InvalidCredentials)
+ }
+ } else {
+ let db_conn = DatabaseConnection::from_request(req).await.unwrap();
+ let user = authenticate_user(&credentials, &db_conn)
+ .ok_or(RegistryAuthError::InvalidCredentials)?;
+
+ Ok(RegistryAuth::User(user))
+ }
+ }
+}
+
+// Since async file io just calls spawn_blocking internally, it does not really make sense
+// to make this an async function
+fn file_sha256_digest(path: &std::path::Path) -> std::io::Result<String> {
+ let mut file = std::fs::File::open(path)?;
+ let mut hasher = Sha256::new();
+ let _n = std::io::copy(&mut file, &mut hasher)?;
+ Ok(format!("{:x}", hasher.finalize()))
+}
+
+/// Get the index of the last byte in a file
+async fn last_byte_pos(file: &tokio::fs::File) -> std::io::Result<u64> {
+ let n_bytes = file.metadata().await?.len();
+ let pos = if n_bytes == 0 { 0 } else { n_bytes - 1 };
+ Ok(pos)
+}
+
+async fn get_root(_auth: RegistryAuth) -> impl IntoResponse {
+ // root should return 200 OK to confirm api compliance
+ Response::builder()
+ .status(StatusCode::OK)
+ .header("Docker-Distribution-API-Version", "registry/2.0")
+ .body(Body::empty())
+ .unwrap()
+}
+
+#[derive(Serialize)]
+pub struct RegistryErrors {
+ errors: Vec<RegistryError>,
+}
+
+#[derive(Serialize)]
+pub struct RegistryError {
+ code: String,
+ message: String,
+ detail: serde_json::Value,
+}
+
+async fn check_blob_exists(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, raw_digest)): Path<(String, String)>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let digest = raw_digest.strip_prefix("sha256:").unwrap();
+ let blob_path = PathBuf::from(REGISTRY_PATH).join("sha256").join(&digest);
+ if blob_path.exists() {
+ let metadata = std::fs::metadata(&blob_path).unwrap();
+ Ok((StatusCode::OK, [("Content-Length", metadata.len())]))
+ } else {
+ Err(StatusCode::NOT_FOUND)
+ }
+}
+
+async fn get_blob(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, raw_digest)): Path<(String, String)>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let digest = raw_digest.strip_prefix("sha256:").unwrap();
+ let blob_path = PathBuf::from(REGISTRY_PATH).join("sha256").join(&digest);
+ if !blob_path.exists() {
+ return Err(StatusCode::NOT_FOUND);
+ }
+ let file = tokio::fs::File::open(&blob_path).await.unwrap();
+ let reader_stream = ReaderStream::new(file);
+ let stream_body = StreamBody::new(reader_stream);
+ Ok(stream_body)
+}
+
+async fn create_upload(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path(repository_name): Path<String>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let uuid = gen_alphanumeric(16);
+ tokio::fs::File::create(PathBuf::from(REGISTRY_PATH).join("uploads").join(&uuid))
+ .await
+ .unwrap();
+
+ Ok(Response::builder()
+ .status(StatusCode::ACCEPTED)
+ .header(
+ "Location",
+ format!("/v2/{}/blobs/uploads/{}", repository_name, uuid),
+ )
+ .header("Docker-Upload-UUID", uuid)
+ .header("Range", "bytes=0-0")
+ .body(Body::empty())
+ .unwrap())
+}
+
+async fn patch_upload(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, uuid)): Path<(String, String)>,
+ mut stream: BodyStream,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ // TODO: support content range header in request
+ let upload_path = PathBuf::from(REGISTRY_PATH).join("uploads").join(&uuid);
+ let mut file = tokio::fs::OpenOptions::new()
+ .read(false)
+ .write(true)
+ .append(true)
+ .create(false)
+ .open(upload_path)
+ .await
+ .unwrap();
+ while let Some(Ok(chunk)) = stream.next().await {
+ file.write_all(&chunk).await.unwrap();
+ }
+
+ let last_byte = last_byte_pos(&file).await.unwrap();
+
+ Ok(Response::builder()
+ .status(StatusCode::ACCEPTED)
+ .header(
+ "Location",
+ format!("/v2/{}/blobs/uploads/{}", repository_name, uuid),
+ )
+ .header("Docker-Upload-UUID", uuid)
+ // range indicating current progress of the upload
+ .header("Range", format!("0-{}", last_byte))
+ .body(Body::empty())
+ .unwrap())
+}
+
+use serde::Deserialize;
+#[derive(Deserialize)]
+struct UploadParams {
+ digest: String,
+}
+
+async fn put_upload(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, uuid)): Path<(String, String)>,
+ Query(params): Query<UploadParams>,
+ mut stream: BodyStream,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let upload_path = PathBuf::from(REGISTRY_PATH).join("uploads").join(&uuid);
+ let mut file = tokio::fs::OpenOptions::new()
+ .read(false)
+ .write(true)
+ .append(true)
+ .create(false)
+ .open(&upload_path)
+ .await
+ .unwrap();
+
+ let range_begin = last_byte_pos(&file).await.unwrap();
+ while let Some(Ok(chunk)) = stream.next().await {
+ file.write_all(&chunk).await.unwrap();
+ }
+ file.flush().await.unwrap();
+ let range_end = last_byte_pos(&file).await.unwrap();
+
+ let expected_digest = params.digest.strip_prefix("sha256:").unwrap();
+ let digest = file_sha256_digest(&upload_path).unwrap();
+ if digest != expected_digest {
+ // TODO: return a docker error body
+ return Err(StatusCode::BAD_REQUEST);
+ }
+
+ let target_path = PathBuf::from(REGISTRY_PATH).join("sha256").join(&digest);
+ tokio::fs::rename(&upload_path, &target_path).await.unwrap();
+
+ Ok(Response::builder()
+ .status(StatusCode::CREATED)
+ .header(
+ "Location",
+ format!("/v2/{}/blobs/{}", repository_name, digest),
+ )
+ .header("Docker-Upload-UUID", uuid)
+ // content range for bytes that were in the body of this request
+ .header("Content-Range", format!("{}-{}", range_begin, range_end))
+ .header("Docker-Content-Digest", params.digest)
+ .body(Body::empty())
+ .unwrap())
+}
+
+async fn get_manifest(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, reference)): Path<(String, String)>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let manifest_path = PathBuf::from(REGISTRY_PATH)
+ .join("manifests")
+ .join(&repository_name)
+ .join(&reference)
+ .with_extension("json");
+ let data = tokio::fs::read(&manifest_path).await.unwrap();
+
+ let manifest: serde_json::Map<String, serde_json::Value> =
+ serde_json::from_slice(&data).unwrap();
+ let media_type = manifest.get("mediaType").unwrap().as_str().unwrap();
+ Ok(Response::builder()
+ .status(StatusCode::OK)
+ .header("Content-Type", media_type)
+ .body(axum::body::Full::from(data))
+ .unwrap())
+}
+
+async fn put_manifest(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, reference)): Path<(String, String)>,
+ mut stream: BodyStream,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let repository_dir = PathBuf::from(REGISTRY_PATH)
+ .join("manifests")
+ .join(&repository_name);
+
+ tokio::fs::create_dir_all(&repository_dir).await.unwrap();
+
+ let mut hasher = Sha256::new();
+ let manifest_path = repository_dir.join(&reference).with_extension("json");
+ {
+ let mut file = tokio::fs::OpenOptions::new()
+ .write(true)
+ .create(true)
+ .truncate(true)
+ .open(&manifest_path)
+ .await
+ .unwrap();
+ while let Some(Ok(chunk)) = stream.next().await {
+ hasher.update(&chunk);
+ file.write_all(&chunk).await.unwrap();
+ }
+ }
+ let digest = hasher.finalize();
+ // TODO: store content-adressable manifests separately
+ let content_digest = format!("sha256:{:x}", digest);
+ let digest_path = repository_dir.join(&content_digest).with_extension("json");
+ tokio::fs::copy(manifest_path, digest_path).await.unwrap();
+
+ Ok(Response::builder()
+ .status(StatusCode::CREATED)
+ .header(
+ "Location",
+ format!("/v2/{}/manifests/{}", repository_name, reference),
+ )
+ .header("Docker-Content-Digest", content_digest)
+ .body(Body::empty())
+ .unwrap())
+}
+
+/// Ensure that the accessed repository exists
+/// and the user is allowed to access ti
+fn check_access(
+ repository_name: &str,
+ auth: &RegistryAuth,
+ db_conn: &DatabaseConnection,
+) -> Result<(), StatusCode> {
+ use diesel::OptionalExtension;
+
+ // TODO: it would be nice to provide the found repository
+ // to the route handlers
+ let bot = db::bots::find_bot_by_name(repository_name, db_conn)
+ .optional()
+ .expect("could not run query")
+ .ok_or(StatusCode::NOT_FOUND)?;
+
+ match &auth {
+ RegistryAuth::Admin => Ok(()),
+ RegistryAuth::User(user) => {
+ if bot.owner_id == Some(user.id) {
+ Ok(())
+ } else {
+ Err(StatusCode::FORBIDDEN)
+ }
+ }
+ }
+}
diff --git a/planetwars-server/src/routes/users.rs b/planetwars-server/src/routes/users.rs
index 54ddd09..1989904 100644
--- a/planetwars-server/src/routes/users.rs
+++ b/planetwars-server/src/routes/users.rs
@@ -5,7 +5,7 @@ use axum::extract::{FromRequest, RequestParts, TypedHeader};
use axum::headers::authorization::Bearer;
use axum::headers::Authorization;
use axum::http::StatusCode;
-use axum::response::{Headers, IntoResponse, Response};
+use axum::response::{IntoResponse, Response};
use axum::{async_trait, Json};
use serde::{Deserialize, Serialize};
use serde_json::json;
@@ -163,9 +163,9 @@ pub async fn login(conn: DatabaseConnection, params: Json<LoginParams>) -> Respo
Some(user) => {
let session = sessions::create_session(&user, &conn);
let user_data: UserData = user.into();
- let headers = Headers(vec![("Token", &session.token)]);
+ let headers = [("Token", &session.token)];
- (headers, Json(user_data)).into_response()
+ (StatusCode::OK, headers, Json(user_data)).into_response()
}
}
}