aboutsummaryrefslogtreecommitdiff
path: root/planetwars-server/src/modules
diff options
context:
space:
mode:
authorIlion Beyst <ilion.beyst@gmail.com>2022-07-18 21:03:34 +0200
committerIlion Beyst <ilion.beyst@gmail.com>2022-07-18 21:03:34 +0200
commit7daf8f643798ce76733006f8469890bf1a3fd05e (patch)
treed755c4f0979f56792684b109d263624aef4baaad /planetwars-server/src/modules
parent608d05bc167c57d190d3c06f250b5e4a5662e77e (diff)
parentd092f5d89c0fda5cc67349d5489b4ef1b294e053 (diff)
downloadplanetwars.dev-7daf8f643798ce76733006f8469890bf1a3fd05e.tar.xz
planetwars.dev-7daf8f643798ce76733006f8469890bf1a3fd05e.zip
Merge branch 'next'
Diffstat (limited to 'planetwars-server/src/modules')
-rw-r--r--planetwars-server/src/modules/bot_api.rs283
-rw-r--r--planetwars-server/src/modules/bots.rs17
-rw-r--r--planetwars-server/src/modules/matches.rs123
-rw-r--r--planetwars-server/src/modules/mod.rs2
-rw-r--r--planetwars-server/src/modules/ranking.rs37
-rw-r--r--planetwars-server/src/modules/registry.rs440
6 files changed, 840 insertions, 62 deletions
diff --git a/planetwars-server/src/modules/bot_api.rs b/planetwars-server/src/modules/bot_api.rs
new file mode 100644
index 0000000..33f5d87
--- /dev/null
+++ b/planetwars-server/src/modules/bot_api.rs
@@ -0,0 +1,283 @@
+pub mod pb {
+ tonic::include_proto!("grpc.planetwars.bot_api");
+}
+
+use std::collections::HashMap;
+use std::net::SocketAddr;
+use std::sync::{Arc, Mutex};
+use std::time::Duration;
+
+use runner::match_context::{EventBus, PlayerHandle, RequestError, RequestMessage};
+use runner::match_log::MatchLogger;
+use tokio::sync::{mpsc, oneshot};
+use tokio_stream::wrappers::UnboundedReceiverStream;
+use tonic;
+use tonic::transport::Server;
+use tonic::{Request, Response, Status, Streaming};
+
+use planetwars_matchrunner as runner;
+
+use crate::db;
+use crate::util::gen_alphanumeric;
+use crate::ConnectionPool;
+use crate::GlobalConfig;
+
+use super::matches::{MatchPlayer, RunMatch};
+
+pub struct BotApiServer {
+ conn_pool: ConnectionPool,
+ runner_config: Arc<GlobalConfig>,
+ router: PlayerRouter,
+}
+
+/// Routes players to their handler
+#[derive(Clone)]
+struct PlayerRouter {
+ routing_table: Arc<Mutex<HashMap<String, SyncThingData>>>,
+}
+
+impl PlayerRouter {
+ pub fn new() -> Self {
+ PlayerRouter {
+ routing_table: Arc::new(Mutex::new(HashMap::new())),
+ }
+ }
+}
+
+impl Default for PlayerRouter {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+// TODO: implement a way to expire entries
+impl PlayerRouter {
+ fn put(&self, player_key: String, entry: SyncThingData) {
+ let mut routing_table = self.routing_table.lock().unwrap();
+ routing_table.insert(player_key, entry);
+ }
+
+ fn take(&self, player_key: &str) -> Option<SyncThingData> {
+ // TODO: this design does not allow for reconnects. Is this desired?
+ let mut routing_table = self.routing_table.lock().unwrap();
+ routing_table.remove(player_key)
+ }
+}
+
+#[tonic::async_trait]
+impl pb::bot_api_service_server::BotApiService for BotApiServer {
+ type ConnectBotStream = UnboundedReceiverStream<Result<pb::PlayerRequest, Status>>;
+
+ async fn connect_bot(
+ &self,
+ req: Request<Streaming<pb::PlayerRequestResponse>>,
+ ) -> Result<Response<Self::ConnectBotStream>, Status> {
+ // TODO: clean up errors
+ let player_key = req
+ .metadata()
+ .get("player_key")
+ .ok_or_else(|| Status::unauthenticated("no player_key provided"))?;
+
+ let player_key_str = player_key
+ .to_str()
+ .map_err(|_| Status::invalid_argument("unreadable string"))?;
+
+ let sync_data = self
+ .router
+ .take(player_key_str)
+ .ok_or_else(|| Status::not_found("player_key not found"))?;
+
+ let stream = req.into_inner();
+
+ sync_data.tx.send(stream).unwrap();
+ Ok(Response::new(UnboundedReceiverStream::new(
+ sync_data.server_messages,
+ )))
+ }
+
+ async fn create_match(
+ &self,
+ req: Request<pb::MatchRequest>,
+ ) -> Result<Response<pb::CreatedMatch>, Status> {
+ // TODO: unify with matchrunner module
+ let conn = self.conn_pool.get().await.unwrap();
+
+ let match_request = req.get_ref();
+
+ let opponent_bot = db::bots::find_bot_by_name(&match_request.opponent_name, &conn)
+ .map_err(|_| Status::not_found("opponent not found"))?;
+ let opponent_bot_version = db::bots::active_bot_version(opponent_bot.id, &conn)
+ .map_err(|_| Status::not_found("no opponent version found"))?;
+
+ let player_key = gen_alphanumeric(32);
+
+ let remote_bot_spec = Box::new(RemoteBotSpec {
+ player_key: player_key.clone(),
+ router: self.router.clone(),
+ });
+ let run_match = RunMatch::from_players(
+ self.runner_config.clone(),
+ vec![
+ MatchPlayer::BotSpec {
+ spec: remote_bot_spec,
+ },
+ MatchPlayer::BotVersion {
+ bot: Some(opponent_bot),
+ version: opponent_bot_version,
+ },
+ ],
+ );
+ let (created_match, _) = run_match
+ .run(self.conn_pool.clone())
+ .await
+ .expect("failed to create match");
+
+ Ok(Response::new(pb::CreatedMatch {
+ match_id: created_match.base.id,
+ player_key,
+ }))
+ }
+}
+
+// TODO: please rename me
+struct SyncThingData {
+ tx: oneshot::Sender<Streaming<pb::PlayerRequestResponse>>,
+ server_messages: mpsc::UnboundedReceiver<Result<pb::PlayerRequest, Status>>,
+}
+
+struct RemoteBotSpec {
+ player_key: String,
+ router: PlayerRouter,
+}
+
+#[tonic::async_trait]
+impl runner::BotSpec for RemoteBotSpec {
+ async fn run_bot(
+ &self,
+ player_id: u32,
+ event_bus: Arc<Mutex<EventBus>>,
+ _match_logger: MatchLogger,
+ ) -> Box<dyn PlayerHandle> {
+ let (tx, rx) = oneshot::channel();
+ let (server_msg_snd, server_msg_recv) = mpsc::unbounded_channel();
+ self.router.put(
+ self.player_key.clone(),
+ SyncThingData {
+ tx,
+ server_messages: server_msg_recv,
+ },
+ );
+
+ let fut = tokio::time::timeout(Duration::from_secs(10), rx);
+ match fut.await {
+ Ok(Ok(client_messages)) => {
+ // let client_messages = rx.await.unwrap();
+ tokio::spawn(handle_bot_messages(
+ player_id,
+ event_bus.clone(),
+ client_messages,
+ ));
+ }
+ _ => {
+ // ensure router cleanup
+ self.router.take(&self.player_key);
+ }
+ };
+
+ // If the player did not connect, the receiving half of `sender`
+ // will be dropped here, resulting in a time-out for every turn.
+ // This is fine for now, but
+ // TODO: provide a formal mechanism for player startup failure
+ Box::new(RemoteBotHandle {
+ sender: server_msg_snd,
+ player_id,
+ event_bus,
+ })
+ }
+}
+
+async fn handle_bot_messages(
+ player_id: u32,
+ event_bus: Arc<Mutex<EventBus>>,
+ mut messages: Streaming<pb::PlayerRequestResponse>,
+) {
+ while let Some(message) = messages.message().await.unwrap() {
+ let request_id = (player_id, message.request_id as u32);
+ event_bus
+ .lock()
+ .unwrap()
+ .resolve_request(request_id, Ok(message.content));
+ }
+}
+
+struct RemoteBotHandle {
+ sender: mpsc::UnboundedSender<Result<pb::PlayerRequest, Status>>,
+ player_id: u32,
+ event_bus: Arc<Mutex<EventBus>>,
+}
+
+impl PlayerHandle for RemoteBotHandle {
+ fn send_request(&mut self, r: RequestMessage) {
+ let res = self.sender.send(Ok(pb::PlayerRequest {
+ request_id: r.request_id as i32,
+ content: r.content,
+ }));
+ match res {
+ Ok(()) => {
+ // schedule a timeout. See comments at method implementation
+ tokio::spawn(schedule_timeout(
+ (self.player_id, r.request_id),
+ r.timeout,
+ self.event_bus.clone(),
+ ));
+ }
+ Err(_send_error) => {
+ // cannot contact the remote bot anymore;
+ // directly mark all requests as timed out.
+ // TODO: create a dedicated error type for this.
+ // should it be logged?
+ println!("send error: {:?}", _send_error);
+ self.event_bus
+ .lock()
+ .unwrap()
+ .resolve_request((self.player_id, r.request_id), Err(RequestError::Timeout));
+ }
+ }
+ }
+}
+
+// TODO: this will spawn a task for every request, which might not be ideal.
+// Some alternatives:
+// - create a single task that manages all time-outs.
+// - intersperse timeouts with incoming client messages
+// - push timeouts upwards, into the matchrunner logic (before we hit the playerhandle).
+// This was initially not done to allow timer start to be delayed until the message actually arrived
+// with the player. Is this still needed, or is there a different way to do this?
+//
+async fn schedule_timeout(
+ request_id: (u32, u32),
+ duration: Duration,
+ event_bus: Arc<Mutex<EventBus>>,
+) {
+ tokio::time::sleep(duration).await;
+ event_bus
+ .lock()
+ .unwrap()
+ .resolve_request(request_id, Err(RequestError::Timeout));
+}
+
+pub async fn run_bot_api(runner_config: Arc<GlobalConfig>, pool: ConnectionPool) {
+ let router = PlayerRouter::new();
+ let server = BotApiServer {
+ router,
+ conn_pool: pool,
+ runner_config,
+ };
+
+ let addr = SocketAddr::from(([127, 0, 0, 1], 50051));
+ Server::builder()
+ .add_service(pb::bot_api_service_server::BotApiServiceServer::new(server))
+ .serve(addr)
+ .await
+ .unwrap()
+}
diff --git a/planetwars-server/src/modules/bots.rs b/planetwars-server/src/modules/bots.rs
index 843e48d..5513539 100644
--- a/planetwars-server/src/modules/bots.rs
+++ b/planetwars-server/src/modules/bots.rs
@@ -2,22 +2,25 @@ use std::path::PathBuf;
use diesel::{PgConnection, QueryResult};
-use crate::{db, util::gen_alphanumeric, BOTS_DIR};
+use crate::{db, util::gen_alphanumeric, GlobalConfig};
-pub fn save_code_bundle(
+/// Save a string containing bot code as a code bundle.
+pub fn save_code_string(
bot_code: &str,
bot_id: Option<i32>,
conn: &PgConnection,
-) -> QueryResult<db::bots::CodeBundle> {
+ config: &GlobalConfig,
+) -> QueryResult<db::bots::BotVersion> {
let bundle_name = gen_alphanumeric(16);
- let code_bundle_dir = PathBuf::from(BOTS_DIR).join(&bundle_name);
+ let code_bundle_dir = PathBuf::from(&config.bots_directory).join(&bundle_name);
std::fs::create_dir(&code_bundle_dir).unwrap();
std::fs::write(code_bundle_dir.join("bot.py"), bot_code).unwrap();
- let new_code_bundle = db::bots::NewCodeBundle {
+ let new_code_bundle = db::bots::NewBotVersion {
bot_id,
- path: &bundle_name,
+ code_bundle_path: Some(&bundle_name),
+ container_digest: None,
};
- db::bots::create_code_bundle(&new_code_bundle, conn)
+ db::bots::create_bot_version(&new_code_bundle, conn)
}
diff --git a/planetwars-server/src/modules/matches.rs b/planetwars-server/src/modules/matches.rs
index a254bac..a1fe63d 100644
--- a/planetwars-server/src/modules/matches.rs
+++ b/planetwars-server/src/modules/matches.rs
@@ -1,4 +1,4 @@
-use std::path::PathBuf;
+use std::{path::PathBuf, sync::Arc};
use diesel::{PgConnection, QueryResult};
use planetwars_matchrunner::{self as runner, docker_runner::DockerBotSpec, BotSpec, MatchConfig};
@@ -11,77 +11,126 @@ use crate::{
matches::{MatchData, MatchResult},
},
util::gen_alphanumeric,
- ConnectionPool, BOTS_DIR, MAPS_DIR, MATCHES_DIR,
+ ConnectionPool, GlobalConfig,
};
-const PYTHON_IMAGE: &str = "python:3.10-slim-buster";
-
-pub struct RunMatch<'a> {
+pub struct RunMatch {
log_file_name: String,
- player_code_bundles: Vec<&'a db::bots::CodeBundle>,
- match_id: Option<i32>,
+ players: Vec<MatchPlayer>,
+ config: Arc<GlobalConfig>,
+}
+
+pub enum MatchPlayer {
+ BotVersion {
+ bot: Option<db::bots::Bot>,
+ version: db::bots::BotVersion,
+ },
+ BotSpec {
+ spec: Box<dyn BotSpec>,
+ },
}
-impl<'a> RunMatch<'a> {
- pub fn from_players(player_code_bundles: Vec<&'a db::bots::CodeBundle>) -> Self {
+impl RunMatch {
+ pub fn from_players(config: Arc<GlobalConfig>, players: Vec<MatchPlayer>) -> Self {
let log_file_name = format!("{}.log", gen_alphanumeric(16));
RunMatch {
+ config,
log_file_name,
- player_code_bundles,
- match_id: None,
+ players,
}
}
- pub fn runner_config(&self) -> runner::MatchConfig {
+ fn into_runner_config(self) -> runner::MatchConfig {
runner::MatchConfig {
- map_path: PathBuf::from(MAPS_DIR).join("hex.json"),
+ map_path: PathBuf::from(&self.config.maps_directory).join("hex.json"),
map_name: "hex".to_string(),
- log_path: PathBuf::from(MATCHES_DIR).join(&self.log_file_name),
+ log_path: PathBuf::from(&self.config.match_logs_directory).join(&self.log_file_name),
players: self
- .player_code_bundles
- .iter()
- .map(|b| runner::MatchPlayer {
- bot_spec: code_bundle_to_botspec(b),
+ .players
+ .into_iter()
+ .map(|player| runner::MatchPlayer {
+ bot_spec: match player {
+ MatchPlayer::BotVersion { bot, version } => {
+ bot_version_to_botspec(&self.config, bot.as_ref(), &version)
+ }
+ MatchPlayer::BotSpec { spec } => spec,
+ },
})
.collect(),
}
}
- pub fn store_in_database(&mut self, db_conn: &PgConnection) -> QueryResult<MatchData> {
- // don't store the same match twice
- assert!(self.match_id.is_none());
+ pub async fn run(
+ self,
+ conn_pool: ConnectionPool,
+ ) -> QueryResult<(MatchData, JoinHandle<MatchOutcome>)> {
+ let match_data = {
+ // TODO: it would be nice to get an already-open connection here when possible.
+ // Maybe we need an additional abstraction, bundling a connection and connection pool?
+ let db_conn = conn_pool.get().await.expect("could not get a connection");
+ self.store_in_database(&db_conn)?
+ };
+
+ let runner_config = self.into_runner_config();
+ let handle = tokio::spawn(run_match_task(conn_pool, runner_config, match_data.base.id));
+ Ok((match_data, handle))
+ }
+
+ fn store_in_database(&self, db_conn: &PgConnection) -> QueryResult<MatchData> {
let new_match_data = db::matches::NewMatch {
state: db::matches::MatchState::Playing,
log_path: &self.log_file_name,
};
let new_match_players = self
- .player_code_bundles
+ .players
.iter()
- .map(|b| db::matches::MatchPlayerData {
- code_bundle_id: b.id,
+ .map(|p| db::matches::MatchPlayerData {
+ code_bundle_id: match p {
+ MatchPlayer::BotVersion { version, .. } => Some(version.id),
+ MatchPlayer::BotSpec { .. } => None,
+ },
})
.collect::<Vec<_>>();
- let match_data = db::matches::create_match(&new_match_data, &new_match_players, &db_conn)?;
- self.match_id = Some(match_data.base.id);
- Ok(match_data)
+ db::matches::create_match(&new_match_data, &new_match_players, db_conn)
}
+}
- pub fn spawn(self, pool: ConnectionPool) -> JoinHandle<MatchOutcome> {
- let match_id = self.match_id.expect("match must be saved before running");
- let runner_config = self.runner_config();
- tokio::spawn(run_match_task(pool, runner_config, match_id))
+pub fn bot_version_to_botspec(
+ runner_config: &GlobalConfig,
+ bot: Option<&db::bots::Bot>,
+ bot_version: &db::bots::BotVersion,
+) -> Box<dyn BotSpec> {
+ if let Some(code_bundle_path) = &bot_version.code_bundle_path {
+ python_docker_bot_spec(runner_config, code_bundle_path)
+ } else if let (Some(container_digest), Some(bot)) = (&bot_version.container_digest, bot) {
+ Box::new(DockerBotSpec {
+ image: format!(
+ "{}/{}@{}",
+ runner_config.container_registry_url, bot.name, container_digest
+ ),
+ binds: None,
+ argv: None,
+ working_dir: None,
+ })
+ } else {
+ // TODO: ideally this would not be possible
+ panic!("bad bot version")
}
}
-pub fn code_bundle_to_botspec(code_bundle: &db::bots::CodeBundle) -> Box<dyn BotSpec> {
- let bundle_path = PathBuf::from(BOTS_DIR).join(&code_bundle.path);
+fn python_docker_bot_spec(config: &GlobalConfig, code_bundle_path: &str) -> Box<dyn BotSpec> {
+ let code_bundle_rel_path = PathBuf::from(&config.bots_directory).join(code_bundle_path);
+ let code_bundle_abs_path = std::fs::canonicalize(&code_bundle_rel_path).unwrap();
+ let code_bundle_path_str = code_bundle_abs_path.as_os_str().to_str().unwrap();
+ // TODO: it would be good to simplify this configuration
Box::new(DockerBotSpec {
- code_path: bundle_path,
- image: PYTHON_IMAGE.to_string(),
- argv: vec!["python".to_string(), "bot.py".to_string()],
+ image: config.python_runner_image.clone(),
+ binds: Some(vec![format!("{}:{}", code_bundle_path_str, "/workdir")]),
+ argv: Some(vec!["python".to_string(), "bot.py".to_string()]),
+ working_dir: Some("/workdir".to_string()),
})
}
@@ -104,5 +153,5 @@ async fn run_match_task(
db::matches::save_match_result(match_id, result, &conn).expect("could not save match result");
- return outcome;
+ outcome
}
diff --git a/planetwars-server/src/modules/mod.rs b/planetwars-server/src/modules/mod.rs
index bea28e0..1200f9d 100644
--- a/planetwars-server/src/modules/mod.rs
+++ b/planetwars-server/src/modules/mod.rs
@@ -1,5 +1,7 @@
// This module implements general domain logic, not directly
// tied to the database or API layers.
+pub mod bot_api;
pub mod bots;
pub mod matches;
pub mod ranking;
+pub mod registry;
diff --git a/planetwars-server/src/modules/ranking.rs b/planetwars-server/src/modules/ranking.rs
index 5d496d7..a9f6419 100644
--- a/planetwars-server/src/modules/ranking.rs
+++ b/planetwars-server/src/modules/ranking.rs
@@ -1,17 +1,18 @@
-use crate::{db::bots::Bot, DbPool};
+use crate::{db::bots::Bot, DbPool, GlobalConfig};
use crate::db;
-use crate::modules::matches::RunMatch;
+use crate::modules::matches::{MatchPlayer, RunMatch};
use diesel::{PgConnection, QueryResult};
use rand::seq::SliceRandom;
use std::collections::HashMap;
use std::mem;
+use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio;
const RANKER_INTERVAL: u64 = 60;
-pub async fn run_ranker(db_pool: DbPool) {
+pub async fn run_ranker(config: Arc<GlobalConfig>, db_pool: DbPool) {
// TODO: make this configurable
// play at most one match every n seconds
let mut interval = tokio::time::interval(Duration::from_secs(RANKER_INTERVAL));
@@ -30,30 +31,30 @@ pub async fn run_ranker(db_pool: DbPool) {
let mut rng = &mut rand::thread_rng();
bots.choose_multiple(&mut rng, 2).cloned().collect()
};
- play_ranking_match(selected_bots, db_pool.clone()).await;
+ play_ranking_match(config.clone(), selected_bots, db_pool.clone()).await;
recalculate_ratings(&db_conn).expect("could not recalculate ratings");
}
}
-async fn play_ranking_match(selected_bots: Vec<Bot>, db_pool: DbPool) {
+async fn play_ranking_match(config: Arc<GlobalConfig>, selected_bots: Vec<Bot>, db_pool: DbPool) {
let db_conn = db_pool.get().await.expect("could not get db pool");
- let mut code_bundles = Vec::new();
+ let mut players = Vec::new();
for bot in &selected_bots {
- let code_bundle = db::bots::active_code_bundle(bot.id, &db_conn)
- .expect("could not get active code bundle");
- code_bundles.push(code_bundle);
+ let version = db::bots::active_bot_version(bot.id, &db_conn)
+ .expect("could not get active bot version");
+ let player = MatchPlayer::BotVersion {
+ bot: Some(bot.clone()),
+ version,
+ };
+ players.push(player);
}
- let code_bundle_refs = code_bundles.iter().collect::<Vec<_>>();
-
- let mut run_match = RunMatch::from_players(code_bundle_refs);
- run_match
- .store_in_database(&db_conn)
- .expect("could not store match in db");
- run_match
- .spawn(db_pool.clone())
+ let (_, handle) = RunMatch::from_players(config, players)
+ .run(db_pool.clone())
.await
- .expect("running match failed");
+ .expect("failed to run match");
+ // wait for match to complete, so that only one ranking match can be running
+ let _outcome = handle.await;
}
fn recalculate_ratings(db_conn: &PgConnection) -> QueryResult<()> {
diff --git a/planetwars-server/src/modules/registry.rs b/planetwars-server/src/modules/registry.rs
new file mode 100644
index 0000000..3f6dad2
--- /dev/null
+++ b/planetwars-server/src/modules/registry.rs
@@ -0,0 +1,440 @@
+// TODO: this module is functional, but it needs a good refactor for proper error handling.
+
+use axum::body::{Body, StreamBody};
+use axum::extract::{BodyStream, FromRequest, Path, Query, RequestParts, TypedHeader};
+use axum::headers::authorization::Basic;
+use axum::headers::Authorization;
+use axum::response::{IntoResponse, Response};
+use axum::routing::{get, head, post, put};
+use axum::{async_trait, Extension, Router};
+use futures::StreamExt;
+use hyper::StatusCode;
+use serde::Serialize;
+use sha2::{Digest, Sha256};
+use std::path::PathBuf;
+use std::sync::Arc;
+use tokio::io::AsyncWriteExt;
+use tokio_util::io::ReaderStream;
+
+use crate::db::bots::NewBotVersion;
+use crate::util::gen_alphanumeric;
+use crate::{db, DatabaseConnection, GlobalConfig};
+
+use crate::db::users::{authenticate_user, Credentials, User};
+
+pub fn registry_service() -> Router {
+ Router::new()
+ // The docker API requires this trailing slash
+ .nest("/v2/", registry_api_v2())
+}
+
+fn registry_api_v2() -> Router {
+ Router::new()
+ .route("/", get(get_root))
+ .route(
+ "/:name/manifests/:reference",
+ get(get_manifest).put(put_manifest),
+ )
+ .route(
+ "/:name/blobs/:digest",
+ head(check_blob_exists).get(get_blob),
+ )
+ .route("/:name/blobs/uploads/", post(create_upload))
+ .route(
+ "/:name/blobs/uploads/:uuid",
+ put(put_upload).patch(patch_upload),
+ )
+}
+
+const ADMIN_USERNAME: &str = "admin";
+
+type AuthorizationHeader = TypedHeader<Authorization<Basic>>;
+
+enum RegistryAuth {
+ User(User),
+ Admin,
+}
+
+enum RegistryAuthError {
+ NoAuthHeader,
+ InvalidCredentials,
+}
+
+impl IntoResponse for RegistryAuthError {
+ fn into_response(self) -> Response {
+ // TODO: create enum for registry errors
+ let err = RegistryErrors {
+ errors: vec![RegistryError {
+ code: "UNAUTHORIZED".to_string(),
+ message: "please log in".to_string(),
+ detail: serde_json::Value::Null,
+ }],
+ };
+
+ (
+ StatusCode::UNAUTHORIZED,
+ [
+ ("Docker-Distribution-API-Version", "registry/2.0"),
+ ("WWW-Authenticate", "Basic"),
+ ],
+ serde_json::to_string(&err).unwrap(),
+ )
+ .into_response()
+ }
+}
+
+#[async_trait]
+impl<B> FromRequest<B> for RegistryAuth
+where
+ B: Send,
+{
+ type Rejection = RegistryAuthError;
+
+ async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+ let TypedHeader(Authorization(basic)) = AuthorizationHeader::from_request(req)
+ .await
+ .map_err(|_| RegistryAuthError::NoAuthHeader)?;
+
+ // TODO: Into<Credentials> would be nice
+ let credentials = Credentials {
+ username: basic.username(),
+ password: basic.password(),
+ };
+
+ let Extension(config) = Extension::<Arc<GlobalConfig>>::from_request(req)
+ .await
+ .unwrap();
+
+ if credentials.username == ADMIN_USERNAME {
+ if credentials.password == config.registry_admin_password {
+ Ok(RegistryAuth::Admin)
+ } else {
+ Err(RegistryAuthError::InvalidCredentials)
+ }
+ } else {
+ let db_conn = DatabaseConnection::from_request(req).await.unwrap();
+ let user = authenticate_user(&credentials, &db_conn)
+ .ok_or(RegistryAuthError::InvalidCredentials)?;
+
+ Ok(RegistryAuth::User(user))
+ }
+ }
+}
+
+// Since async file io just calls spawn_blocking internally, it does not really make sense
+// to make this an async function
+fn file_sha256_digest(path: &std::path::Path) -> std::io::Result<String> {
+ let mut file = std::fs::File::open(path)?;
+ let mut hasher = Sha256::new();
+ let _n = std::io::copy(&mut file, &mut hasher)?;
+ Ok(format!("{:x}", hasher.finalize()))
+}
+
+/// Get the index of the last byte in a file
+async fn last_byte_pos(file: &tokio::fs::File) -> std::io::Result<u64> {
+ let n_bytes = file.metadata().await?.len();
+ let pos = if n_bytes == 0 { 0 } else { n_bytes - 1 };
+ Ok(pos)
+}
+
+async fn get_root(_auth: RegistryAuth) -> impl IntoResponse {
+ // root should return 200 OK to confirm api compliance
+ Response::builder()
+ .status(StatusCode::OK)
+ .header("Docker-Distribution-API-Version", "registry/2.0")
+ .body(Body::empty())
+ .unwrap()
+}
+
+#[derive(Serialize)]
+pub struct RegistryErrors {
+ errors: Vec<RegistryError>,
+}
+
+#[derive(Serialize)]
+pub struct RegistryError {
+ code: String,
+ message: String,
+ detail: serde_json::Value,
+}
+
+async fn check_blob_exists(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, raw_digest)): Path<(String, String)>,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let digest = raw_digest.strip_prefix("sha256:").unwrap();
+ let blob_path = PathBuf::from(&config.registry_directory)
+ .join("sha256")
+ .join(&digest);
+ if blob_path.exists() {
+ let metadata = std::fs::metadata(&blob_path).unwrap();
+ Ok((StatusCode::OK, [("Content-Length", metadata.len())]))
+ } else {
+ Err(StatusCode::NOT_FOUND)
+ }
+}
+
+async fn get_blob(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, raw_digest)): Path<(String, String)>,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let digest = raw_digest.strip_prefix("sha256:").unwrap();
+ let blob_path = PathBuf::from(&config.registry_directory)
+ .join("sha256")
+ .join(&digest);
+ if !blob_path.exists() {
+ return Err(StatusCode::NOT_FOUND);
+ }
+ let file = tokio::fs::File::open(&blob_path).await.unwrap();
+ let reader_stream = ReaderStream::new(file);
+ let stream_body = StreamBody::new(reader_stream);
+ Ok(stream_body)
+}
+
+async fn create_upload(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path(repository_name): Path<String>,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let uuid = gen_alphanumeric(16);
+ tokio::fs::File::create(
+ PathBuf::from(&config.registry_directory)
+ .join("uploads")
+ .join(&uuid),
+ )
+ .await
+ .unwrap();
+
+ Ok(Response::builder()
+ .status(StatusCode::ACCEPTED)
+ .header(
+ "Location",
+ format!("/v2/{}/blobs/uploads/{}", repository_name, uuid),
+ )
+ .header("Docker-Upload-UUID", uuid)
+ .header("Range", "bytes=0-0")
+ .body(Body::empty())
+ .unwrap())
+}
+
+async fn patch_upload(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, uuid)): Path<(String, String)>,
+ mut stream: BodyStream,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ // TODO: support content range header in request
+ let upload_path = PathBuf::from(&config.registry_directory)
+ .join("uploads")
+ .join(&uuid);
+ let mut file = tokio::fs::OpenOptions::new()
+ .read(false)
+ .write(true)
+ .append(true)
+ .create(false)
+ .open(upload_path)
+ .await
+ .unwrap();
+ while let Some(Ok(chunk)) = stream.next().await {
+ file.write_all(&chunk).await.unwrap();
+ }
+
+ let last_byte = last_byte_pos(&file).await.unwrap();
+
+ Ok(Response::builder()
+ .status(StatusCode::ACCEPTED)
+ .header(
+ "Location",
+ format!("/v2/{}/blobs/uploads/{}", repository_name, uuid),
+ )
+ .header("Docker-Upload-UUID", uuid)
+ // range indicating current progress of the upload
+ .header("Range", format!("0-{}", last_byte))
+ .body(Body::empty())
+ .unwrap())
+}
+
+use serde::Deserialize;
+#[derive(Deserialize)]
+struct UploadParams {
+ digest: String,
+}
+
+async fn put_upload(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, uuid)): Path<(String, String)>,
+ Query(params): Query<UploadParams>,
+ mut stream: BodyStream,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let upload_path = PathBuf::from(&config.registry_directory)
+ .join("uploads")
+ .join(&uuid);
+ let mut file = tokio::fs::OpenOptions::new()
+ .read(false)
+ .write(true)
+ .append(true)
+ .create(false)
+ .open(&upload_path)
+ .await
+ .unwrap();
+
+ let range_begin = last_byte_pos(&file).await.unwrap();
+ while let Some(Ok(chunk)) = stream.next().await {
+ file.write_all(&chunk).await.unwrap();
+ }
+ file.flush().await.unwrap();
+ let range_end = last_byte_pos(&file).await.unwrap();
+
+ let expected_digest = params.digest.strip_prefix("sha256:").unwrap();
+ let digest = file_sha256_digest(&upload_path).unwrap();
+ if digest != expected_digest {
+ // TODO: return a docker error body
+ return Err(StatusCode::BAD_REQUEST);
+ }
+
+ let target_path = PathBuf::from(&config.registry_directory)
+ .join("sha256")
+ .join(&digest);
+ tokio::fs::rename(&upload_path, &target_path).await.unwrap();
+
+ Ok(Response::builder()
+ .status(StatusCode::CREATED)
+ .header(
+ "Location",
+ format!("/v2/{}/blobs/{}", repository_name, digest),
+ )
+ .header("Docker-Upload-UUID", uuid)
+ // content range for bytes that were in the body of this request
+ .header("Content-Range", format!("{}-{}", range_begin, range_end))
+ .header("Docker-Content-Digest", params.digest)
+ .body(Body::empty())
+ .unwrap())
+}
+
+async fn get_manifest(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, reference)): Path<(String, String)>,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let manifest_path = PathBuf::from(&config.registry_directory)
+ .join("manifests")
+ .join(&repository_name)
+ .join(&reference)
+ .with_extension("json");
+ let data = tokio::fs::read(&manifest_path).await.unwrap();
+
+ let manifest: serde_json::Map<String, serde_json::Value> =
+ serde_json::from_slice(&data).unwrap();
+ let media_type = manifest.get("mediaType").unwrap().as_str().unwrap();
+ Ok(Response::builder()
+ .status(StatusCode::OK)
+ .header("Content-Type", media_type)
+ .body(axum::body::Full::from(data))
+ .unwrap())
+}
+
+async fn put_manifest(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, reference)): Path<(String, String)>,
+ mut stream: BodyStream,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ let bot = check_access(&repository_name, &auth, &db_conn)?;
+
+ let repository_dir = PathBuf::from(&config.registry_directory)
+ .join("manifests")
+ .join(&repository_name);
+
+ tokio::fs::create_dir_all(&repository_dir).await.unwrap();
+
+ let mut hasher = Sha256::new();
+ let manifest_path = repository_dir.join(&reference).with_extension("json");
+ {
+ let mut file = tokio::fs::OpenOptions::new()
+ .write(true)
+ .create(true)
+ .truncate(true)
+ .open(&manifest_path)
+ .await
+ .unwrap();
+ while let Some(Ok(chunk)) = stream.next().await {
+ hasher.update(&chunk);
+ file.write_all(&chunk).await.unwrap();
+ }
+ }
+ let digest = hasher.finalize();
+ // TODO: store content-adressable manifests separately
+ let content_digest = format!("sha256:{:x}", digest);
+ let digest_path = repository_dir.join(&content_digest).with_extension("json");
+ tokio::fs::copy(manifest_path, digest_path).await.unwrap();
+
+ // Register the new image as a bot version
+ // TODO: how should tags be handled?
+ let new_version = NewBotVersion {
+ bot_id: Some(bot.id),
+ code_bundle_path: None,
+ container_digest: Some(&content_digest),
+ };
+ db::bots::create_bot_version(&new_version, &db_conn).expect("could not save bot version");
+
+ Ok(Response::builder()
+ .status(StatusCode::CREATED)
+ .header(
+ "Location",
+ format!("/v2/{}/manifests/{}", repository_name, reference),
+ )
+ .header("Docker-Content-Digest", content_digest)
+ .body(Body::empty())
+ .unwrap())
+}
+
+/// Ensure that the accessed repository exists
+/// and the user is allowed to access it.
+/// Returns the associated bot.
+fn check_access(
+ repository_name: &str,
+ auth: &RegistryAuth,
+ db_conn: &DatabaseConnection,
+) -> Result<db::bots::Bot, StatusCode> {
+ use diesel::OptionalExtension;
+
+ // TODO: it would be nice to provide the found repository
+ // to the route handlers
+ let bot = db::bots::find_bot_by_name(repository_name, db_conn)
+ .optional()
+ .expect("could not run query")
+ .ok_or(StatusCode::NOT_FOUND)?;
+
+ match &auth {
+ RegistryAuth::Admin => Ok(bot),
+ RegistryAuth::User(user) => {
+ if bot.owner_id == Some(user.id) {
+ Ok(bot)
+ } else {
+ Err(StatusCode::FORBIDDEN)
+ }
+ }
+ }
+}