diff options
author | Ilion Beyst <ilion.beyst@gmail.com> | 2022-07-18 21:03:34 +0200 |
---|---|---|
committer | Ilion Beyst <ilion.beyst@gmail.com> | 2022-07-18 21:03:34 +0200 |
commit | 7daf8f643798ce76733006f8469890bf1a3fd05e (patch) | |
tree | d755c4f0979f56792684b109d263624aef4baaad /planetwars-server/src/modules | |
parent | 608d05bc167c57d190d3c06f250b5e4a5662e77e (diff) | |
parent | d092f5d89c0fda5cc67349d5489b4ef1b294e053 (diff) | |
download | planetwars.dev-7daf8f643798ce76733006f8469890bf1a3fd05e.tar.xz planetwars.dev-7daf8f643798ce76733006f8469890bf1a3fd05e.zip |
Merge branch 'next'
Diffstat (limited to 'planetwars-server/src/modules')
-rw-r--r-- | planetwars-server/src/modules/bot_api.rs | 283 | ||||
-rw-r--r-- | planetwars-server/src/modules/bots.rs | 17 | ||||
-rw-r--r-- | planetwars-server/src/modules/matches.rs | 123 | ||||
-rw-r--r-- | planetwars-server/src/modules/mod.rs | 2 | ||||
-rw-r--r-- | planetwars-server/src/modules/ranking.rs | 37 | ||||
-rw-r--r-- | planetwars-server/src/modules/registry.rs | 440 |
6 files changed, 840 insertions, 62 deletions
diff --git a/planetwars-server/src/modules/bot_api.rs b/planetwars-server/src/modules/bot_api.rs new file mode 100644 index 0000000..33f5d87 --- /dev/null +++ b/planetwars-server/src/modules/bot_api.rs @@ -0,0 +1,283 @@ +pub mod pb { + tonic::include_proto!("grpc.planetwars.bot_api"); +} + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use runner::match_context::{EventBus, PlayerHandle, RequestError, RequestMessage}; +use runner::match_log::MatchLogger; +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic; +use tonic::transport::Server; +use tonic::{Request, Response, Status, Streaming}; + +use planetwars_matchrunner as runner; + +use crate::db; +use crate::util::gen_alphanumeric; +use crate::ConnectionPool; +use crate::GlobalConfig; + +use super::matches::{MatchPlayer, RunMatch}; + +pub struct BotApiServer { + conn_pool: ConnectionPool, + runner_config: Arc<GlobalConfig>, + router: PlayerRouter, +} + +/// Routes players to their handler +#[derive(Clone)] +struct PlayerRouter { + routing_table: Arc<Mutex<HashMap<String, SyncThingData>>>, +} + +impl PlayerRouter { + pub fn new() -> Self { + PlayerRouter { + routing_table: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +impl Default for PlayerRouter { + fn default() -> Self { + Self::new() + } +} + +// TODO: implement a way to expire entries +impl PlayerRouter { + fn put(&self, player_key: String, entry: SyncThingData) { + let mut routing_table = self.routing_table.lock().unwrap(); + routing_table.insert(player_key, entry); + } + + fn take(&self, player_key: &str) -> Option<SyncThingData> { + // TODO: this design does not allow for reconnects. Is this desired? + let mut routing_table = self.routing_table.lock().unwrap(); + routing_table.remove(player_key) + } +} + +#[tonic::async_trait] +impl pb::bot_api_service_server::BotApiService for BotApiServer { + type ConnectBotStream = UnboundedReceiverStream<Result<pb::PlayerRequest, Status>>; + + async fn connect_bot( + &self, + req: Request<Streaming<pb::PlayerRequestResponse>>, + ) -> Result<Response<Self::ConnectBotStream>, Status> { + // TODO: clean up errors + let player_key = req + .metadata() + .get("player_key") + .ok_or_else(|| Status::unauthenticated("no player_key provided"))?; + + let player_key_str = player_key + .to_str() + .map_err(|_| Status::invalid_argument("unreadable string"))?; + + let sync_data = self + .router + .take(player_key_str) + .ok_or_else(|| Status::not_found("player_key not found"))?; + + let stream = req.into_inner(); + + sync_data.tx.send(stream).unwrap(); + Ok(Response::new(UnboundedReceiverStream::new( + sync_data.server_messages, + ))) + } + + async fn create_match( + &self, + req: Request<pb::MatchRequest>, + ) -> Result<Response<pb::CreatedMatch>, Status> { + // TODO: unify with matchrunner module + let conn = self.conn_pool.get().await.unwrap(); + + let match_request = req.get_ref(); + + let opponent_bot = db::bots::find_bot_by_name(&match_request.opponent_name, &conn) + .map_err(|_| Status::not_found("opponent not found"))?; + let opponent_bot_version = db::bots::active_bot_version(opponent_bot.id, &conn) + .map_err(|_| Status::not_found("no opponent version found"))?; + + let player_key = gen_alphanumeric(32); + + let remote_bot_spec = Box::new(RemoteBotSpec { + player_key: player_key.clone(), + router: self.router.clone(), + }); + let run_match = RunMatch::from_players( + self.runner_config.clone(), + vec![ + MatchPlayer::BotSpec { + spec: remote_bot_spec, + }, + MatchPlayer::BotVersion { + bot: Some(opponent_bot), + version: opponent_bot_version, + }, + ], + ); + let (created_match, _) = run_match + .run(self.conn_pool.clone()) + .await + .expect("failed to create match"); + + Ok(Response::new(pb::CreatedMatch { + match_id: created_match.base.id, + player_key, + })) + } +} + +// TODO: please rename me +struct SyncThingData { + tx: oneshot::Sender<Streaming<pb::PlayerRequestResponse>>, + server_messages: mpsc::UnboundedReceiver<Result<pb::PlayerRequest, Status>>, +} + +struct RemoteBotSpec { + player_key: String, + router: PlayerRouter, +} + +#[tonic::async_trait] +impl runner::BotSpec for RemoteBotSpec { + async fn run_bot( + &self, + player_id: u32, + event_bus: Arc<Mutex<EventBus>>, + _match_logger: MatchLogger, + ) -> Box<dyn PlayerHandle> { + let (tx, rx) = oneshot::channel(); + let (server_msg_snd, server_msg_recv) = mpsc::unbounded_channel(); + self.router.put( + self.player_key.clone(), + SyncThingData { + tx, + server_messages: server_msg_recv, + }, + ); + + let fut = tokio::time::timeout(Duration::from_secs(10), rx); + match fut.await { + Ok(Ok(client_messages)) => { + // let client_messages = rx.await.unwrap(); + tokio::spawn(handle_bot_messages( + player_id, + event_bus.clone(), + client_messages, + )); + } + _ => { + // ensure router cleanup + self.router.take(&self.player_key); + } + }; + + // If the player did not connect, the receiving half of `sender` + // will be dropped here, resulting in a time-out for every turn. + // This is fine for now, but + // TODO: provide a formal mechanism for player startup failure + Box::new(RemoteBotHandle { + sender: server_msg_snd, + player_id, + event_bus, + }) + } +} + +async fn handle_bot_messages( + player_id: u32, + event_bus: Arc<Mutex<EventBus>>, + mut messages: Streaming<pb::PlayerRequestResponse>, +) { + while let Some(message) = messages.message().await.unwrap() { + let request_id = (player_id, message.request_id as u32); + event_bus + .lock() + .unwrap() + .resolve_request(request_id, Ok(message.content)); + } +} + +struct RemoteBotHandle { + sender: mpsc::UnboundedSender<Result<pb::PlayerRequest, Status>>, + player_id: u32, + event_bus: Arc<Mutex<EventBus>>, +} + +impl PlayerHandle for RemoteBotHandle { + fn send_request(&mut self, r: RequestMessage) { + let res = self.sender.send(Ok(pb::PlayerRequest { + request_id: r.request_id as i32, + content: r.content, + })); + match res { + Ok(()) => { + // schedule a timeout. See comments at method implementation + tokio::spawn(schedule_timeout( + (self.player_id, r.request_id), + r.timeout, + self.event_bus.clone(), + )); + } + Err(_send_error) => { + // cannot contact the remote bot anymore; + // directly mark all requests as timed out. + // TODO: create a dedicated error type for this. + // should it be logged? + println!("send error: {:?}", _send_error); + self.event_bus + .lock() + .unwrap() + .resolve_request((self.player_id, r.request_id), Err(RequestError::Timeout)); + } + } + } +} + +// TODO: this will spawn a task for every request, which might not be ideal. +// Some alternatives: +// - create a single task that manages all time-outs. +// - intersperse timeouts with incoming client messages +// - push timeouts upwards, into the matchrunner logic (before we hit the playerhandle). +// This was initially not done to allow timer start to be delayed until the message actually arrived +// with the player. Is this still needed, or is there a different way to do this? +// +async fn schedule_timeout( + request_id: (u32, u32), + duration: Duration, + event_bus: Arc<Mutex<EventBus>>, +) { + tokio::time::sleep(duration).await; + event_bus + .lock() + .unwrap() + .resolve_request(request_id, Err(RequestError::Timeout)); +} + +pub async fn run_bot_api(runner_config: Arc<GlobalConfig>, pool: ConnectionPool) { + let router = PlayerRouter::new(); + let server = BotApiServer { + router, + conn_pool: pool, + runner_config, + }; + + let addr = SocketAddr::from(([127, 0, 0, 1], 50051)); + Server::builder() + .add_service(pb::bot_api_service_server::BotApiServiceServer::new(server)) + .serve(addr) + .await + .unwrap() +} diff --git a/planetwars-server/src/modules/bots.rs b/planetwars-server/src/modules/bots.rs index 843e48d..5513539 100644 --- a/planetwars-server/src/modules/bots.rs +++ b/planetwars-server/src/modules/bots.rs @@ -2,22 +2,25 @@ use std::path::PathBuf; use diesel::{PgConnection, QueryResult}; -use crate::{db, util::gen_alphanumeric, BOTS_DIR}; +use crate::{db, util::gen_alphanumeric, GlobalConfig}; -pub fn save_code_bundle( +/// Save a string containing bot code as a code bundle. +pub fn save_code_string( bot_code: &str, bot_id: Option<i32>, conn: &PgConnection, -) -> QueryResult<db::bots::CodeBundle> { + config: &GlobalConfig, +) -> QueryResult<db::bots::BotVersion> { let bundle_name = gen_alphanumeric(16); - let code_bundle_dir = PathBuf::from(BOTS_DIR).join(&bundle_name); + let code_bundle_dir = PathBuf::from(&config.bots_directory).join(&bundle_name); std::fs::create_dir(&code_bundle_dir).unwrap(); std::fs::write(code_bundle_dir.join("bot.py"), bot_code).unwrap(); - let new_code_bundle = db::bots::NewCodeBundle { + let new_code_bundle = db::bots::NewBotVersion { bot_id, - path: &bundle_name, + code_bundle_path: Some(&bundle_name), + container_digest: None, }; - db::bots::create_code_bundle(&new_code_bundle, conn) + db::bots::create_bot_version(&new_code_bundle, conn) } diff --git a/planetwars-server/src/modules/matches.rs b/planetwars-server/src/modules/matches.rs index a254bac..a1fe63d 100644 --- a/planetwars-server/src/modules/matches.rs +++ b/planetwars-server/src/modules/matches.rs @@ -1,4 +1,4 @@ -use std::path::PathBuf; +use std::{path::PathBuf, sync::Arc}; use diesel::{PgConnection, QueryResult}; use planetwars_matchrunner::{self as runner, docker_runner::DockerBotSpec, BotSpec, MatchConfig}; @@ -11,77 +11,126 @@ use crate::{ matches::{MatchData, MatchResult}, }, util::gen_alphanumeric, - ConnectionPool, BOTS_DIR, MAPS_DIR, MATCHES_DIR, + ConnectionPool, GlobalConfig, }; -const PYTHON_IMAGE: &str = "python:3.10-slim-buster"; - -pub struct RunMatch<'a> { +pub struct RunMatch { log_file_name: String, - player_code_bundles: Vec<&'a db::bots::CodeBundle>, - match_id: Option<i32>, + players: Vec<MatchPlayer>, + config: Arc<GlobalConfig>, +} + +pub enum MatchPlayer { + BotVersion { + bot: Option<db::bots::Bot>, + version: db::bots::BotVersion, + }, + BotSpec { + spec: Box<dyn BotSpec>, + }, } -impl<'a> RunMatch<'a> { - pub fn from_players(player_code_bundles: Vec<&'a db::bots::CodeBundle>) -> Self { +impl RunMatch { + pub fn from_players(config: Arc<GlobalConfig>, players: Vec<MatchPlayer>) -> Self { let log_file_name = format!("{}.log", gen_alphanumeric(16)); RunMatch { + config, log_file_name, - player_code_bundles, - match_id: None, + players, } } - pub fn runner_config(&self) -> runner::MatchConfig { + fn into_runner_config(self) -> runner::MatchConfig { runner::MatchConfig { - map_path: PathBuf::from(MAPS_DIR).join("hex.json"), + map_path: PathBuf::from(&self.config.maps_directory).join("hex.json"), map_name: "hex".to_string(), - log_path: PathBuf::from(MATCHES_DIR).join(&self.log_file_name), + log_path: PathBuf::from(&self.config.match_logs_directory).join(&self.log_file_name), players: self - .player_code_bundles - .iter() - .map(|b| runner::MatchPlayer { - bot_spec: code_bundle_to_botspec(b), + .players + .into_iter() + .map(|player| runner::MatchPlayer { + bot_spec: match player { + MatchPlayer::BotVersion { bot, version } => { + bot_version_to_botspec(&self.config, bot.as_ref(), &version) + } + MatchPlayer::BotSpec { spec } => spec, + }, }) .collect(), } } - pub fn store_in_database(&mut self, db_conn: &PgConnection) -> QueryResult<MatchData> { - // don't store the same match twice - assert!(self.match_id.is_none()); + pub async fn run( + self, + conn_pool: ConnectionPool, + ) -> QueryResult<(MatchData, JoinHandle<MatchOutcome>)> { + let match_data = { + // TODO: it would be nice to get an already-open connection here when possible. + // Maybe we need an additional abstraction, bundling a connection and connection pool? + let db_conn = conn_pool.get().await.expect("could not get a connection"); + self.store_in_database(&db_conn)? + }; + + let runner_config = self.into_runner_config(); + let handle = tokio::spawn(run_match_task(conn_pool, runner_config, match_data.base.id)); + Ok((match_data, handle)) + } + + fn store_in_database(&self, db_conn: &PgConnection) -> QueryResult<MatchData> { let new_match_data = db::matches::NewMatch { state: db::matches::MatchState::Playing, log_path: &self.log_file_name, }; let new_match_players = self - .player_code_bundles + .players .iter() - .map(|b| db::matches::MatchPlayerData { - code_bundle_id: b.id, + .map(|p| db::matches::MatchPlayerData { + code_bundle_id: match p { + MatchPlayer::BotVersion { version, .. } => Some(version.id), + MatchPlayer::BotSpec { .. } => None, + }, }) .collect::<Vec<_>>(); - let match_data = db::matches::create_match(&new_match_data, &new_match_players, &db_conn)?; - self.match_id = Some(match_data.base.id); - Ok(match_data) + db::matches::create_match(&new_match_data, &new_match_players, db_conn) } +} - pub fn spawn(self, pool: ConnectionPool) -> JoinHandle<MatchOutcome> { - let match_id = self.match_id.expect("match must be saved before running"); - let runner_config = self.runner_config(); - tokio::spawn(run_match_task(pool, runner_config, match_id)) +pub fn bot_version_to_botspec( + runner_config: &GlobalConfig, + bot: Option<&db::bots::Bot>, + bot_version: &db::bots::BotVersion, +) -> Box<dyn BotSpec> { + if let Some(code_bundle_path) = &bot_version.code_bundle_path { + python_docker_bot_spec(runner_config, code_bundle_path) + } else if let (Some(container_digest), Some(bot)) = (&bot_version.container_digest, bot) { + Box::new(DockerBotSpec { + image: format!( + "{}/{}@{}", + runner_config.container_registry_url, bot.name, container_digest + ), + binds: None, + argv: None, + working_dir: None, + }) + } else { + // TODO: ideally this would not be possible + panic!("bad bot version") } } -pub fn code_bundle_to_botspec(code_bundle: &db::bots::CodeBundle) -> Box<dyn BotSpec> { - let bundle_path = PathBuf::from(BOTS_DIR).join(&code_bundle.path); +fn python_docker_bot_spec(config: &GlobalConfig, code_bundle_path: &str) -> Box<dyn BotSpec> { + let code_bundle_rel_path = PathBuf::from(&config.bots_directory).join(code_bundle_path); + let code_bundle_abs_path = std::fs::canonicalize(&code_bundle_rel_path).unwrap(); + let code_bundle_path_str = code_bundle_abs_path.as_os_str().to_str().unwrap(); + // TODO: it would be good to simplify this configuration Box::new(DockerBotSpec { - code_path: bundle_path, - image: PYTHON_IMAGE.to_string(), - argv: vec!["python".to_string(), "bot.py".to_string()], + image: config.python_runner_image.clone(), + binds: Some(vec![format!("{}:{}", code_bundle_path_str, "/workdir")]), + argv: Some(vec!["python".to_string(), "bot.py".to_string()]), + working_dir: Some("/workdir".to_string()), }) } @@ -104,5 +153,5 @@ async fn run_match_task( db::matches::save_match_result(match_id, result, &conn).expect("could not save match result"); - return outcome; + outcome } diff --git a/planetwars-server/src/modules/mod.rs b/planetwars-server/src/modules/mod.rs index bea28e0..1200f9d 100644 --- a/planetwars-server/src/modules/mod.rs +++ b/planetwars-server/src/modules/mod.rs @@ -1,5 +1,7 @@ // This module implements general domain logic, not directly // tied to the database or API layers. +pub mod bot_api; pub mod bots; pub mod matches; pub mod ranking; +pub mod registry; diff --git a/planetwars-server/src/modules/ranking.rs b/planetwars-server/src/modules/ranking.rs index 5d496d7..a9f6419 100644 --- a/planetwars-server/src/modules/ranking.rs +++ b/planetwars-server/src/modules/ranking.rs @@ -1,17 +1,18 @@ -use crate::{db::bots::Bot, DbPool}; +use crate::{db::bots::Bot, DbPool, GlobalConfig}; use crate::db; -use crate::modules::matches::RunMatch; +use crate::modules::matches::{MatchPlayer, RunMatch}; use diesel::{PgConnection, QueryResult}; use rand::seq::SliceRandom; use std::collections::HashMap; use std::mem; +use std::sync::Arc; use std::time::{Duration, Instant}; use tokio; const RANKER_INTERVAL: u64 = 60; -pub async fn run_ranker(db_pool: DbPool) { +pub async fn run_ranker(config: Arc<GlobalConfig>, db_pool: DbPool) { // TODO: make this configurable // play at most one match every n seconds let mut interval = tokio::time::interval(Duration::from_secs(RANKER_INTERVAL)); @@ -30,30 +31,30 @@ pub async fn run_ranker(db_pool: DbPool) { let mut rng = &mut rand::thread_rng(); bots.choose_multiple(&mut rng, 2).cloned().collect() }; - play_ranking_match(selected_bots, db_pool.clone()).await; + play_ranking_match(config.clone(), selected_bots, db_pool.clone()).await; recalculate_ratings(&db_conn).expect("could not recalculate ratings"); } } -async fn play_ranking_match(selected_bots: Vec<Bot>, db_pool: DbPool) { +async fn play_ranking_match(config: Arc<GlobalConfig>, selected_bots: Vec<Bot>, db_pool: DbPool) { let db_conn = db_pool.get().await.expect("could not get db pool"); - let mut code_bundles = Vec::new(); + let mut players = Vec::new(); for bot in &selected_bots { - let code_bundle = db::bots::active_code_bundle(bot.id, &db_conn) - .expect("could not get active code bundle"); - code_bundles.push(code_bundle); + let version = db::bots::active_bot_version(bot.id, &db_conn) + .expect("could not get active bot version"); + let player = MatchPlayer::BotVersion { + bot: Some(bot.clone()), + version, + }; + players.push(player); } - let code_bundle_refs = code_bundles.iter().collect::<Vec<_>>(); - - let mut run_match = RunMatch::from_players(code_bundle_refs); - run_match - .store_in_database(&db_conn) - .expect("could not store match in db"); - run_match - .spawn(db_pool.clone()) + let (_, handle) = RunMatch::from_players(config, players) + .run(db_pool.clone()) .await - .expect("running match failed"); + .expect("failed to run match"); + // wait for match to complete, so that only one ranking match can be running + let _outcome = handle.await; } fn recalculate_ratings(db_conn: &PgConnection) -> QueryResult<()> { diff --git a/planetwars-server/src/modules/registry.rs b/planetwars-server/src/modules/registry.rs new file mode 100644 index 0000000..3f6dad2 --- /dev/null +++ b/planetwars-server/src/modules/registry.rs @@ -0,0 +1,440 @@ +// TODO: this module is functional, but it needs a good refactor for proper error handling. + +use axum::body::{Body, StreamBody}; +use axum::extract::{BodyStream, FromRequest, Path, Query, RequestParts, TypedHeader}; +use axum::headers::authorization::Basic; +use axum::headers::Authorization; +use axum::response::{IntoResponse, Response}; +use axum::routing::{get, head, post, put}; +use axum::{async_trait, Extension, Router}; +use futures::StreamExt; +use hyper::StatusCode; +use serde::Serialize; +use sha2::{Digest, Sha256}; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::io::AsyncWriteExt; +use tokio_util::io::ReaderStream; + +use crate::db::bots::NewBotVersion; +use crate::util::gen_alphanumeric; +use crate::{db, DatabaseConnection, GlobalConfig}; + +use crate::db::users::{authenticate_user, Credentials, User}; + +pub fn registry_service() -> Router { + Router::new() + // The docker API requires this trailing slash + .nest("/v2/", registry_api_v2()) +} + +fn registry_api_v2() -> Router { + Router::new() + .route("/", get(get_root)) + .route( + "/:name/manifests/:reference", + get(get_manifest).put(put_manifest), + ) + .route( + "/:name/blobs/:digest", + head(check_blob_exists).get(get_blob), + ) + .route("/:name/blobs/uploads/", post(create_upload)) + .route( + "/:name/blobs/uploads/:uuid", + put(put_upload).patch(patch_upload), + ) +} + +const ADMIN_USERNAME: &str = "admin"; + +type AuthorizationHeader = TypedHeader<Authorization<Basic>>; + +enum RegistryAuth { + User(User), + Admin, +} + +enum RegistryAuthError { + NoAuthHeader, + InvalidCredentials, +} + +impl IntoResponse for RegistryAuthError { + fn into_response(self) -> Response { + // TODO: create enum for registry errors + let err = RegistryErrors { + errors: vec![RegistryError { + code: "UNAUTHORIZED".to_string(), + message: "please log in".to_string(), + detail: serde_json::Value::Null, + }], + }; + + ( + StatusCode::UNAUTHORIZED, + [ + ("Docker-Distribution-API-Version", "registry/2.0"), + ("WWW-Authenticate", "Basic"), + ], + serde_json::to_string(&err).unwrap(), + ) + .into_response() + } +} + +#[async_trait] +impl<B> FromRequest<B> for RegistryAuth +where + B: Send, +{ + type Rejection = RegistryAuthError; + + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + let TypedHeader(Authorization(basic)) = AuthorizationHeader::from_request(req) + .await + .map_err(|_| RegistryAuthError::NoAuthHeader)?; + + // TODO: Into<Credentials> would be nice + let credentials = Credentials { + username: basic.username(), + password: basic.password(), + }; + + let Extension(config) = Extension::<Arc<GlobalConfig>>::from_request(req) + .await + .unwrap(); + + if credentials.username == ADMIN_USERNAME { + if credentials.password == config.registry_admin_password { + Ok(RegistryAuth::Admin) + } else { + Err(RegistryAuthError::InvalidCredentials) + } + } else { + let db_conn = DatabaseConnection::from_request(req).await.unwrap(); + let user = authenticate_user(&credentials, &db_conn) + .ok_or(RegistryAuthError::InvalidCredentials)?; + + Ok(RegistryAuth::User(user)) + } + } +} + +// Since async file io just calls spawn_blocking internally, it does not really make sense +// to make this an async function +fn file_sha256_digest(path: &std::path::Path) -> std::io::Result<String> { + let mut file = std::fs::File::open(path)?; + let mut hasher = Sha256::new(); + let _n = std::io::copy(&mut file, &mut hasher)?; + Ok(format!("{:x}", hasher.finalize())) +} + +/// Get the index of the last byte in a file +async fn last_byte_pos(file: &tokio::fs::File) -> std::io::Result<u64> { + let n_bytes = file.metadata().await?.len(); + let pos = if n_bytes == 0 { 0 } else { n_bytes - 1 }; + Ok(pos) +} + +async fn get_root(_auth: RegistryAuth) -> impl IntoResponse { + // root should return 200 OK to confirm api compliance + Response::builder() + .status(StatusCode::OK) + .header("Docker-Distribution-API-Version", "registry/2.0") + .body(Body::empty()) + .unwrap() +} + +#[derive(Serialize)] +pub struct RegistryErrors { + errors: Vec<RegistryError>, +} + +#[derive(Serialize)] +pub struct RegistryError { + code: String, + message: String, + detail: serde_json::Value, +} + +async fn check_blob_exists( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, raw_digest)): Path<(String, String)>, + Extension(config): Extension<Arc<GlobalConfig>>, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + let digest = raw_digest.strip_prefix("sha256:").unwrap(); + let blob_path = PathBuf::from(&config.registry_directory) + .join("sha256") + .join(&digest); + if blob_path.exists() { + let metadata = std::fs::metadata(&blob_path).unwrap(); + Ok((StatusCode::OK, [("Content-Length", metadata.len())])) + } else { + Err(StatusCode::NOT_FOUND) + } +} + +async fn get_blob( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, raw_digest)): Path<(String, String)>, + Extension(config): Extension<Arc<GlobalConfig>>, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + let digest = raw_digest.strip_prefix("sha256:").unwrap(); + let blob_path = PathBuf::from(&config.registry_directory) + .join("sha256") + .join(&digest); + if !blob_path.exists() { + return Err(StatusCode::NOT_FOUND); + } + let file = tokio::fs::File::open(&blob_path).await.unwrap(); + let reader_stream = ReaderStream::new(file); + let stream_body = StreamBody::new(reader_stream); + Ok(stream_body) +} + +async fn create_upload( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path(repository_name): Path<String>, + Extension(config): Extension<Arc<GlobalConfig>>, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + let uuid = gen_alphanumeric(16); + tokio::fs::File::create( + PathBuf::from(&config.registry_directory) + .join("uploads") + .join(&uuid), + ) + .await + .unwrap(); + + Ok(Response::builder() + .status(StatusCode::ACCEPTED) + .header( + "Location", + format!("/v2/{}/blobs/uploads/{}", repository_name, uuid), + ) + .header("Docker-Upload-UUID", uuid) + .header("Range", "bytes=0-0") + .body(Body::empty()) + .unwrap()) +} + +async fn patch_upload( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, uuid)): Path<(String, String)>, + mut stream: BodyStream, + Extension(config): Extension<Arc<GlobalConfig>>, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + // TODO: support content range header in request + let upload_path = PathBuf::from(&config.registry_directory) + .join("uploads") + .join(&uuid); + let mut file = tokio::fs::OpenOptions::new() + .read(false) + .write(true) + .append(true) + .create(false) + .open(upload_path) + .await + .unwrap(); + while let Some(Ok(chunk)) = stream.next().await { + file.write_all(&chunk).await.unwrap(); + } + + let last_byte = last_byte_pos(&file).await.unwrap(); + + Ok(Response::builder() + .status(StatusCode::ACCEPTED) + .header( + "Location", + format!("/v2/{}/blobs/uploads/{}", repository_name, uuid), + ) + .header("Docker-Upload-UUID", uuid) + // range indicating current progress of the upload + .header("Range", format!("0-{}", last_byte)) + .body(Body::empty()) + .unwrap()) +} + +use serde::Deserialize; +#[derive(Deserialize)] +struct UploadParams { + digest: String, +} + +async fn put_upload( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, uuid)): Path<(String, String)>, + Query(params): Query<UploadParams>, + mut stream: BodyStream, + Extension(config): Extension<Arc<GlobalConfig>>, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + let upload_path = PathBuf::from(&config.registry_directory) + .join("uploads") + .join(&uuid); + let mut file = tokio::fs::OpenOptions::new() + .read(false) + .write(true) + .append(true) + .create(false) + .open(&upload_path) + .await + .unwrap(); + + let range_begin = last_byte_pos(&file).await.unwrap(); + while let Some(Ok(chunk)) = stream.next().await { + file.write_all(&chunk).await.unwrap(); + } + file.flush().await.unwrap(); + let range_end = last_byte_pos(&file).await.unwrap(); + + let expected_digest = params.digest.strip_prefix("sha256:").unwrap(); + let digest = file_sha256_digest(&upload_path).unwrap(); + if digest != expected_digest { + // TODO: return a docker error body + return Err(StatusCode::BAD_REQUEST); + } + + let target_path = PathBuf::from(&config.registry_directory) + .join("sha256") + .join(&digest); + tokio::fs::rename(&upload_path, &target_path).await.unwrap(); + + Ok(Response::builder() + .status(StatusCode::CREATED) + .header( + "Location", + format!("/v2/{}/blobs/{}", repository_name, digest), + ) + .header("Docker-Upload-UUID", uuid) + // content range for bytes that were in the body of this request + .header("Content-Range", format!("{}-{}", range_begin, range_end)) + .header("Docker-Content-Digest", params.digest) + .body(Body::empty()) + .unwrap()) +} + +async fn get_manifest( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, reference)): Path<(String, String)>, + Extension(config): Extension<Arc<GlobalConfig>>, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + + let manifest_path = PathBuf::from(&config.registry_directory) + .join("manifests") + .join(&repository_name) + .join(&reference) + .with_extension("json"); + let data = tokio::fs::read(&manifest_path).await.unwrap(); + + let manifest: serde_json::Map<String, serde_json::Value> = + serde_json::from_slice(&data).unwrap(); + let media_type = manifest.get("mediaType").unwrap().as_str().unwrap(); + Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", media_type) + .body(axum::body::Full::from(data)) + .unwrap()) +} + +async fn put_manifest( + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, reference)): Path<(String, String)>, + mut stream: BodyStream, + Extension(config): Extension<Arc<GlobalConfig>>, +) -> Result<impl IntoResponse, StatusCode> { + let bot = check_access(&repository_name, &auth, &db_conn)?; + + let repository_dir = PathBuf::from(&config.registry_directory) + .join("manifests") + .join(&repository_name); + + tokio::fs::create_dir_all(&repository_dir).await.unwrap(); + + let mut hasher = Sha256::new(); + let manifest_path = repository_dir.join(&reference).with_extension("json"); + { + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&manifest_path) + .await + .unwrap(); + while let Some(Ok(chunk)) = stream.next().await { + hasher.update(&chunk); + file.write_all(&chunk).await.unwrap(); + } + } + let digest = hasher.finalize(); + // TODO: store content-adressable manifests separately + let content_digest = format!("sha256:{:x}", digest); + let digest_path = repository_dir.join(&content_digest).with_extension("json"); + tokio::fs::copy(manifest_path, digest_path).await.unwrap(); + + // Register the new image as a bot version + // TODO: how should tags be handled? + let new_version = NewBotVersion { + bot_id: Some(bot.id), + code_bundle_path: None, + container_digest: Some(&content_digest), + }; + db::bots::create_bot_version(&new_version, &db_conn).expect("could not save bot version"); + + Ok(Response::builder() + .status(StatusCode::CREATED) + .header( + "Location", + format!("/v2/{}/manifests/{}", repository_name, reference), + ) + .header("Docker-Content-Digest", content_digest) + .body(Body::empty()) + .unwrap()) +} + +/// Ensure that the accessed repository exists +/// and the user is allowed to access it. +/// Returns the associated bot. +fn check_access( + repository_name: &str, + auth: &RegistryAuth, + db_conn: &DatabaseConnection, +) -> Result<db::bots::Bot, StatusCode> { + use diesel::OptionalExtension; + + // TODO: it would be nice to provide the found repository + // to the route handlers + let bot = db::bots::find_bot_by_name(repository_name, db_conn) + .optional() + .expect("could not run query") + .ok_or(StatusCode::NOT_FOUND)?; + + match &auth { + RegistryAuth::Admin => Ok(bot), + RegistryAuth::User(user) => { + if bot.owner_id == Some(user.id) { + Ok(bot) + } else { + Err(StatusCode::FORBIDDEN) + } + } + } +} |