diff options
Diffstat (limited to 'planetwars-server/src/modules/registry.rs')
-rw-r--r-- | planetwars-server/src/modules/registry.rs | 106 |
1 files changed, 75 insertions, 31 deletions
diff --git a/planetwars-server/src/modules/registry.rs b/planetwars-server/src/modules/registry.rs index a866dce..8bc3a7d 100644 --- a/planetwars-server/src/modules/registry.rs +++ b/planetwars-server/src/modules/registry.rs @@ -6,6 +6,7 @@ use axum::headers::Authorization; use axum::response::{IntoResponse, Response}; use axum::routing::{get, head, post, put}; use axum::{async_trait, Router}; +use futures::StreamExt; use hyper::StatusCode; use serde::Serialize; use sha2::{Digest, Sha256}; @@ -14,7 +15,7 @@ use tokio::io::AsyncWriteExt; use tokio_util::io::ReaderStream; use crate::util::gen_alphanumeric; -use crate::DatabaseConnection; +use crate::{db, DatabaseConnection}; use crate::db::users::{authenticate_user, Credentials, User}; @@ -133,22 +134,28 @@ pub struct RegistryError { } async fn check_blob_exists( - _auth: RegistryAuth, - Path((_repository_name, raw_digest)): Path<(String, String)>, -) -> impl IntoResponse { + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, raw_digest)): Path<(String, String)>, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + let digest = raw_digest.strip_prefix("sha256:").unwrap(); let blob_path = PathBuf::from(REGISTRY_PATH).join("sha256").join(&digest); if blob_path.exists() { - StatusCode::OK + Ok(StatusCode::OK) } else { - StatusCode::NOT_FOUND + Err(StatusCode::NOT_FOUND) } } async fn get_blob( - _auth: RegistryAuth, - Path((_repository_name, raw_digest)): Path<(String, String)>, -) -> impl IntoResponse { + db_conn: DatabaseConnection, + auth: RegistryAuth, + Path((repository_name, raw_digest)): Path<(String, String)>, +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + let digest = raw_digest.strip_prefix("sha256:").unwrap(); let blob_path = PathBuf::from(REGISTRY_PATH).join("sha256").join(&digest); if !blob_path.exists() { @@ -161,15 +168,18 @@ async fn get_blob( } async fn create_upload( - _auth: RegistryAuth, + db_conn: DatabaseConnection, + auth: RegistryAuth, Path(repository_name): Path<String>, -) -> impl IntoResponse { +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + let uuid = gen_alphanumeric(16); tokio::fs::File::create(PathBuf::from(REGISTRY_PATH).join("uploads").join(&uuid)) .await .unwrap(); - Response::builder() + Ok(Response::builder() .status(StatusCode::ACCEPTED) .header( "Location", @@ -178,16 +188,17 @@ async fn create_upload( .header("Docker-Upload-UUID", uuid) .header("Range", "bytes=0-0") .body(Body::empty()) - .unwrap() + .unwrap()) } -use futures::StreamExt; - async fn patch_upload( - _auth: RegistryAuth, + db_conn: DatabaseConnection, + auth: RegistryAuth, Path((repository_name, uuid)): Path<(String, String)>, mut stream: BodyStream, -) -> impl IntoResponse { +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + // let content_length = headers.get("Content-Length").unwrap(); // let content_range = headers.get("Content-Range").unwrap(); // let content_type = headers.get("Content-Type").unwrap(); @@ -207,7 +218,7 @@ async fn patch_upload( len += n_bytes; } - Response::builder() + Ok(Response::builder() .status(StatusCode::ACCEPTED) .header( "Location", @@ -216,7 +227,7 @@ async fn patch_upload( .header("Docker-Upload-UUID", uuid) .header("Range", format!("0-{}", len)) .body(Body::empty()) - .unwrap() + .unwrap()) } use serde::Deserialize; @@ -226,11 +237,14 @@ struct UploadParams { } async fn put_upload( - _auth: RegistryAuth, + db_conn: DatabaseConnection, + auth: RegistryAuth, Path((repository_name, uuid)): Path<(String, String)>, Query(params): Query<UploadParams>, mut stream: BodyStream, -) -> impl IntoResponse { +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + let mut _len = 0; let upload_path = PathBuf::from(REGISTRY_PATH).join("uploads").join(&uuid); let mut file = tokio::fs::OpenOptions::new() @@ -251,7 +265,7 @@ async fn put_upload( let target_path = PathBuf::from(REGISTRY_PATH).join("sha256").join(&digest); tokio::fs::rename(&upload_path, &target_path).await.unwrap(); - Response::builder() + Ok(Response::builder() .status(StatusCode::CREATED) .header( "Location", @@ -261,13 +275,16 @@ async fn put_upload( // .header("Range", format!("0-{}", len)) .header("Docker-Content-Digest", digest) .body(Body::empty()) - .unwrap() + .unwrap()) } async fn get_manifest( - _auth: RegistryAuth, + db_conn: DatabaseConnection, + auth: RegistryAuth, Path((repository_name, reference)): Path<(String, String)>, -) -> impl IntoResponse { +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + let manifest_path = PathBuf::from(REGISTRY_PATH) .join("manifests") .join(&repository_name) @@ -278,18 +295,21 @@ async fn get_manifest( let manifest: serde_json::Map<String, serde_json::Value> = serde_json::from_slice(&data).unwrap(); let media_type = manifest.get("mediaType").unwrap().as_str().unwrap(); - Response::builder() + Ok(Response::builder() .status(StatusCode::OK) .header("Content-Type", media_type) .body(axum::body::Full::from(data)) - .unwrap() + .unwrap()) } async fn put_manifest( - _auth: RegistryAuth, + db_conn: DatabaseConnection, + auth: RegistryAuth, Path((repository_name, reference)): Path<(String, String)>, mut stream: BodyStream, -) -> impl IntoResponse { +) -> Result<impl IntoResponse, StatusCode> { + check_access(&repository_name, &auth, &db_conn)?; + let repository_dir = PathBuf::from(REGISTRY_PATH) .join("manifests") .join(&repository_name); @@ -317,7 +337,7 @@ async fn put_manifest( let digest_path = repository_dir.join(&content_digest).with_extension("json"); tokio::fs::copy(manifest_path, digest_path).await.unwrap(); - Response::builder() + Ok(Response::builder() .status(StatusCode::CREATED) .header( "Location", @@ -325,5 +345,29 @@ async fn put_manifest( ) .header("Docker-Content-Digest", content_digest) .body(Body::empty()) - .unwrap() + .unwrap()) +} + +fn check_access( + repository_name: &str, + auth: &RegistryAuth, + db_conn: &DatabaseConnection, +) -> Result<(), StatusCode> { + use diesel::OptionalExtension; + + let res = db::bots::find_bot_by_name(repository_name, db_conn) + .optional() + .expect("could not run query"); + + match res { + None => Ok(()), // name has not been claimed yet (TODO: verify its validity) + Some(existing_bot) => { + let RegistryAuth::User(user) = auth; + if existing_bot.owner_id == Some(user.id) { + Ok(()) + } else { + Err(StatusCode::FORBIDDEN) + } + } + } } |