aboutsummaryrefslogtreecommitdiff
path: root/planetwars-server/src/modules/registry.rs
diff options
context:
space:
mode:
authorIlion Beyst <ilion.beyst@gmail.com>2022-07-18 21:03:34 +0200
committerIlion Beyst <ilion.beyst@gmail.com>2022-07-18 21:03:34 +0200
commit7daf8f643798ce76733006f8469890bf1a3fd05e (patch)
treed755c4f0979f56792684b109d263624aef4baaad /planetwars-server/src/modules/registry.rs
parent608d05bc167c57d190d3c06f250b5e4a5662e77e (diff)
parentd092f5d89c0fda5cc67349d5489b4ef1b294e053 (diff)
downloadplanetwars.dev-7daf8f643798ce76733006f8469890bf1a3fd05e.tar.xz
planetwars.dev-7daf8f643798ce76733006f8469890bf1a3fd05e.zip
Merge branch 'next'
Diffstat (limited to 'planetwars-server/src/modules/registry.rs')
-rw-r--r--planetwars-server/src/modules/registry.rs440
1 files changed, 440 insertions, 0 deletions
diff --git a/planetwars-server/src/modules/registry.rs b/planetwars-server/src/modules/registry.rs
new file mode 100644
index 0000000..3f6dad2
--- /dev/null
+++ b/planetwars-server/src/modules/registry.rs
@@ -0,0 +1,440 @@
+// TODO: this module is functional, but it needs a good refactor for proper error handling.
+
+use axum::body::{Body, StreamBody};
+use axum::extract::{BodyStream, FromRequest, Path, Query, RequestParts, TypedHeader};
+use axum::headers::authorization::Basic;
+use axum::headers::Authorization;
+use axum::response::{IntoResponse, Response};
+use axum::routing::{get, head, post, put};
+use axum::{async_trait, Extension, Router};
+use futures::StreamExt;
+use hyper::StatusCode;
+use serde::Serialize;
+use sha2::{Digest, Sha256};
+use std::path::PathBuf;
+use std::sync::Arc;
+use tokio::io::AsyncWriteExt;
+use tokio_util::io::ReaderStream;
+
+use crate::db::bots::NewBotVersion;
+use crate::util::gen_alphanumeric;
+use crate::{db, DatabaseConnection, GlobalConfig};
+
+use crate::db::users::{authenticate_user, Credentials, User};
+
+pub fn registry_service() -> Router {
+ Router::new()
+ // The docker API requires this trailing slash
+ .nest("/v2/", registry_api_v2())
+}
+
+fn registry_api_v2() -> Router {
+ Router::new()
+ .route("/", get(get_root))
+ .route(
+ "/:name/manifests/:reference",
+ get(get_manifest).put(put_manifest),
+ )
+ .route(
+ "/:name/blobs/:digest",
+ head(check_blob_exists).get(get_blob),
+ )
+ .route("/:name/blobs/uploads/", post(create_upload))
+ .route(
+ "/:name/blobs/uploads/:uuid",
+ put(put_upload).patch(patch_upload),
+ )
+}
+
+const ADMIN_USERNAME: &str = "admin";
+
+type AuthorizationHeader = TypedHeader<Authorization<Basic>>;
+
+enum RegistryAuth {
+ User(User),
+ Admin,
+}
+
+enum RegistryAuthError {
+ NoAuthHeader,
+ InvalidCredentials,
+}
+
+impl IntoResponse for RegistryAuthError {
+ fn into_response(self) -> Response {
+ // TODO: create enum for registry errors
+ let err = RegistryErrors {
+ errors: vec![RegistryError {
+ code: "UNAUTHORIZED".to_string(),
+ message: "please log in".to_string(),
+ detail: serde_json::Value::Null,
+ }],
+ };
+
+ (
+ StatusCode::UNAUTHORIZED,
+ [
+ ("Docker-Distribution-API-Version", "registry/2.0"),
+ ("WWW-Authenticate", "Basic"),
+ ],
+ serde_json::to_string(&err).unwrap(),
+ )
+ .into_response()
+ }
+}
+
+#[async_trait]
+impl<B> FromRequest<B> for RegistryAuth
+where
+ B: Send,
+{
+ type Rejection = RegistryAuthError;
+
+ async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+ let TypedHeader(Authorization(basic)) = AuthorizationHeader::from_request(req)
+ .await
+ .map_err(|_| RegistryAuthError::NoAuthHeader)?;
+
+ // TODO: Into<Credentials> would be nice
+ let credentials = Credentials {
+ username: basic.username(),
+ password: basic.password(),
+ };
+
+ let Extension(config) = Extension::<Arc<GlobalConfig>>::from_request(req)
+ .await
+ .unwrap();
+
+ if credentials.username == ADMIN_USERNAME {
+ if credentials.password == config.registry_admin_password {
+ Ok(RegistryAuth::Admin)
+ } else {
+ Err(RegistryAuthError::InvalidCredentials)
+ }
+ } else {
+ let db_conn = DatabaseConnection::from_request(req).await.unwrap();
+ let user = authenticate_user(&credentials, &db_conn)
+ .ok_or(RegistryAuthError::InvalidCredentials)?;
+
+ Ok(RegistryAuth::User(user))
+ }
+ }
+}
+
+// Since async file io just calls spawn_blocking internally, it does not really make sense
+// to make this an async function
+fn file_sha256_digest(path: &std::path::Path) -> std::io::Result<String> {
+ let mut file = std::fs::File::open(path)?;
+ let mut hasher = Sha256::new();
+ let _n = std::io::copy(&mut file, &mut hasher)?;
+ Ok(format!("{:x}", hasher.finalize()))
+}
+
+/// Get the index of the last byte in a file
+async fn last_byte_pos(file: &tokio::fs::File) -> std::io::Result<u64> {
+ let n_bytes = file.metadata().await?.len();
+ let pos = if n_bytes == 0 { 0 } else { n_bytes - 1 };
+ Ok(pos)
+}
+
+async fn get_root(_auth: RegistryAuth) -> impl IntoResponse {
+ // root should return 200 OK to confirm api compliance
+ Response::builder()
+ .status(StatusCode::OK)
+ .header("Docker-Distribution-API-Version", "registry/2.0")
+ .body(Body::empty())
+ .unwrap()
+}
+
+#[derive(Serialize)]
+pub struct RegistryErrors {
+ errors: Vec<RegistryError>,
+}
+
+#[derive(Serialize)]
+pub struct RegistryError {
+ code: String,
+ message: String,
+ detail: serde_json::Value,
+}
+
+async fn check_blob_exists(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, raw_digest)): Path<(String, String)>,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let digest = raw_digest.strip_prefix("sha256:").unwrap();
+ let blob_path = PathBuf::from(&config.registry_directory)
+ .join("sha256")
+ .join(&digest);
+ if blob_path.exists() {
+ let metadata = std::fs::metadata(&blob_path).unwrap();
+ Ok((StatusCode::OK, [("Content-Length", metadata.len())]))
+ } else {
+ Err(StatusCode::NOT_FOUND)
+ }
+}
+
+async fn get_blob(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, raw_digest)): Path<(String, String)>,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let digest = raw_digest.strip_prefix("sha256:").unwrap();
+ let blob_path = PathBuf::from(&config.registry_directory)
+ .join("sha256")
+ .join(&digest);
+ if !blob_path.exists() {
+ return Err(StatusCode::NOT_FOUND);
+ }
+ let file = tokio::fs::File::open(&blob_path).await.unwrap();
+ let reader_stream = ReaderStream::new(file);
+ let stream_body = StreamBody::new(reader_stream);
+ Ok(stream_body)
+}
+
+async fn create_upload(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path(repository_name): Path<String>,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let uuid = gen_alphanumeric(16);
+ tokio::fs::File::create(
+ PathBuf::from(&config.registry_directory)
+ .join("uploads")
+ .join(&uuid),
+ )
+ .await
+ .unwrap();
+
+ Ok(Response::builder()
+ .status(StatusCode::ACCEPTED)
+ .header(
+ "Location",
+ format!("/v2/{}/blobs/uploads/{}", repository_name, uuid),
+ )
+ .header("Docker-Upload-UUID", uuid)
+ .header("Range", "bytes=0-0")
+ .body(Body::empty())
+ .unwrap())
+}
+
+async fn patch_upload(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, uuid)): Path<(String, String)>,
+ mut stream: BodyStream,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ // TODO: support content range header in request
+ let upload_path = PathBuf::from(&config.registry_directory)
+ .join("uploads")
+ .join(&uuid);
+ let mut file = tokio::fs::OpenOptions::new()
+ .read(false)
+ .write(true)
+ .append(true)
+ .create(false)
+ .open(upload_path)
+ .await
+ .unwrap();
+ while let Some(Ok(chunk)) = stream.next().await {
+ file.write_all(&chunk).await.unwrap();
+ }
+
+ let last_byte = last_byte_pos(&file).await.unwrap();
+
+ Ok(Response::builder()
+ .status(StatusCode::ACCEPTED)
+ .header(
+ "Location",
+ format!("/v2/{}/blobs/uploads/{}", repository_name, uuid),
+ )
+ .header("Docker-Upload-UUID", uuid)
+ // range indicating current progress of the upload
+ .header("Range", format!("0-{}", last_byte))
+ .body(Body::empty())
+ .unwrap())
+}
+
+use serde::Deserialize;
+#[derive(Deserialize)]
+struct UploadParams {
+ digest: String,
+}
+
+async fn put_upload(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, uuid)): Path<(String, String)>,
+ Query(params): Query<UploadParams>,
+ mut stream: BodyStream,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let upload_path = PathBuf::from(&config.registry_directory)
+ .join("uploads")
+ .join(&uuid);
+ let mut file = tokio::fs::OpenOptions::new()
+ .read(false)
+ .write(true)
+ .append(true)
+ .create(false)
+ .open(&upload_path)
+ .await
+ .unwrap();
+
+ let range_begin = last_byte_pos(&file).await.unwrap();
+ while let Some(Ok(chunk)) = stream.next().await {
+ file.write_all(&chunk).await.unwrap();
+ }
+ file.flush().await.unwrap();
+ let range_end = last_byte_pos(&file).await.unwrap();
+
+ let expected_digest = params.digest.strip_prefix("sha256:").unwrap();
+ let digest = file_sha256_digest(&upload_path).unwrap();
+ if digest != expected_digest {
+ // TODO: return a docker error body
+ return Err(StatusCode::BAD_REQUEST);
+ }
+
+ let target_path = PathBuf::from(&config.registry_directory)
+ .join("sha256")
+ .join(&digest);
+ tokio::fs::rename(&upload_path, &target_path).await.unwrap();
+
+ Ok(Response::builder()
+ .status(StatusCode::CREATED)
+ .header(
+ "Location",
+ format!("/v2/{}/blobs/{}", repository_name, digest),
+ )
+ .header("Docker-Upload-UUID", uuid)
+ // content range for bytes that were in the body of this request
+ .header("Content-Range", format!("{}-{}", range_begin, range_end))
+ .header("Docker-Content-Digest", params.digest)
+ .body(Body::empty())
+ .unwrap())
+}
+
+async fn get_manifest(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, reference)): Path<(String, String)>,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ check_access(&repository_name, &auth, &db_conn)?;
+
+ let manifest_path = PathBuf::from(&config.registry_directory)
+ .join("manifests")
+ .join(&repository_name)
+ .join(&reference)
+ .with_extension("json");
+ let data = tokio::fs::read(&manifest_path).await.unwrap();
+
+ let manifest: serde_json::Map<String, serde_json::Value> =
+ serde_json::from_slice(&data).unwrap();
+ let media_type = manifest.get("mediaType").unwrap().as_str().unwrap();
+ Ok(Response::builder()
+ .status(StatusCode::OK)
+ .header("Content-Type", media_type)
+ .body(axum::body::Full::from(data))
+ .unwrap())
+}
+
+async fn put_manifest(
+ db_conn: DatabaseConnection,
+ auth: RegistryAuth,
+ Path((repository_name, reference)): Path<(String, String)>,
+ mut stream: BodyStream,
+ Extension(config): Extension<Arc<GlobalConfig>>,
+) -> Result<impl IntoResponse, StatusCode> {
+ let bot = check_access(&repository_name, &auth, &db_conn)?;
+
+ let repository_dir = PathBuf::from(&config.registry_directory)
+ .join("manifests")
+ .join(&repository_name);
+
+ tokio::fs::create_dir_all(&repository_dir).await.unwrap();
+
+ let mut hasher = Sha256::new();
+ let manifest_path = repository_dir.join(&reference).with_extension("json");
+ {
+ let mut file = tokio::fs::OpenOptions::new()
+ .write(true)
+ .create(true)
+ .truncate(true)
+ .open(&manifest_path)
+ .await
+ .unwrap();
+ while let Some(Ok(chunk)) = stream.next().await {
+ hasher.update(&chunk);
+ file.write_all(&chunk).await.unwrap();
+ }
+ }
+ let digest = hasher.finalize();
+ // TODO: store content-adressable manifests separately
+ let content_digest = format!("sha256:{:x}", digest);
+ let digest_path = repository_dir.join(&content_digest).with_extension("json");
+ tokio::fs::copy(manifest_path, digest_path).await.unwrap();
+
+ // Register the new image as a bot version
+ // TODO: how should tags be handled?
+ let new_version = NewBotVersion {
+ bot_id: Some(bot.id),
+ code_bundle_path: None,
+ container_digest: Some(&content_digest),
+ };
+ db::bots::create_bot_version(&new_version, &db_conn).expect("could not save bot version");
+
+ Ok(Response::builder()
+ .status(StatusCode::CREATED)
+ .header(
+ "Location",
+ format!("/v2/{}/manifests/{}", repository_name, reference),
+ )
+ .header("Docker-Content-Digest", content_digest)
+ .body(Body::empty())
+ .unwrap())
+}
+
+/// Ensure that the accessed repository exists
+/// and the user is allowed to access it.
+/// Returns the associated bot.
+fn check_access(
+ repository_name: &str,
+ auth: &RegistryAuth,
+ db_conn: &DatabaseConnection,
+) -> Result<db::bots::Bot, StatusCode> {
+ use diesel::OptionalExtension;
+
+ // TODO: it would be nice to provide the found repository
+ // to the route handlers
+ let bot = db::bots::find_bot_by_name(repository_name, db_conn)
+ .optional()
+ .expect("could not run query")
+ .ok_or(StatusCode::NOT_FOUND)?;
+
+ match &auth {
+ RegistryAuth::Admin => Ok(bot),
+ RegistryAuth::User(user) => {
+ if bot.owner_id == Some(user.id) {
+ Ok(bot)
+ } else {
+ Err(StatusCode::FORBIDDEN)
+ }
+ }
+ }
+}