diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/errors.rs | 10 | ||||
-rw-r--r-- | src/handlers/auth.rs | 54 | ||||
-rw-r--r-- | src/handlers/mod.rs | 2 | ||||
-rw-r--r-- | src/handlers/planets.rs | 47 | ||||
-rw-r--r-- | src/main.rs | 37 |
5 files changed, 99 insertions, 51 deletions
diff --git a/src/errors.rs b/src/errors.rs index f6e00e2..e32c6d5 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -2,9 +2,9 @@ use hex::FromHexError; use ring::error::KeyRejected; use thiserror::Error; -use axum::response::{Response, IntoResponse}; -use axum::http::StatusCode; use axum::body; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; use axum::Json; use ring::error::Unspecified as RingUnspecified; @@ -56,7 +56,6 @@ pub type NoneResult = std::result::Result<(), ServiceError>; impl IntoResponse for ServiceError { fn into_response(self) -> Response { - let body = body::boxed(body::Full::from(self.to_string())); let status = match self { @@ -64,9 +63,6 @@ impl IntoResponse for ServiceError { ServiceError::NotAuthorized => StatusCode::UNAUTHORIZED, _ => StatusCode::INTERNAL_SERVER_ERROR, }; - Response::builder() - .status(status) - .body(body) - .unwrap() + Response::builder().status(status).body(body).unwrap() } } diff --git a/src/handlers/auth.rs b/src/handlers/auth.rs index 8d65b05..cafaeb8 100644 --- a/src/handlers/auth.rs +++ b/src/handlers/auth.rs @@ -1,15 +1,22 @@ -use std::{collections::HashMap, fs::{self, File}}; +use std::{ + collections::HashMap, + fs::{self, File}, + sync::Arc, +}; -use axum::{extract::Query, Extension}; +use axum::{extract::Query, middleware::Next, response::Response, Extension}; use axum_macros::debug_handler; -use chrono::{Utc, TimeZone, Datelike}; +use chrono::{Datelike, TimeZone, Utc}; +use hyper::Request; use ring::{rand::SystemRandom, signature::Ed25519KeyPair}; use uuid::Uuid; use std::io::Write; -use crate::{errors::{NoneResult, ServiceError, StringResult, TokenResult}, State}; - +use crate::{ + errors::{NoneResult, ServiceError, StringResult, TokenResult}, + State, +}; /** * Takes in a request to create a new token with a secret key that gets printed @@ -17,7 +24,10 @@ use crate::{errors::{NoneResult, ServiceError, StringResult, TokenResult}, State * for future authentication */ #[debug_handler] -pub async fn begin(Query(params): Query<HashMap<String, String>>, Extension(state): Extension<State>) -> TokenResult { +pub async fn begin( + Query(params): Query<HashMap<String, String>>, + Extension(state): Extension<Arc<State>>, +) -> TokenResult { if let Some(k) = params.get("key") { if k == &state.gen_key { let dt = Utc::now(); @@ -33,15 +43,17 @@ pub async fn begin(Query(params): Query<HashMap<String, String>>, Extension(stat .set_issuer("solard") .set_audience("solard") .set_not_before(&Utc::now()) - .build() { - Ok(token) => token, - Err(_) => { - return Err(ServiceError::Generic(String::from("could not generate paseto key"))); - } - }; + .build() + { + Ok(token) => token, + Err(_) => { + return Err(ServiceError::Generic(String::from( + "could not generate paseto key", + ))); + } + }; return Ok(token.to_string()); - } else { return Err(ServiceError::NotAuthorized); } @@ -50,6 +62,17 @@ pub async fn begin(Query(params): Query<HashMap<String, String>>, Extension(stat } } +pub async fn requires_auth<B>(req: Request<B>, next: Next<B>) -> Result<Response, ServiceError> { + let auth_header = req + .headers() + .get(axum::http::header::AUTHORIZATION) + .and_then(|h| h.to_str().ok()); + + match auth_header { + Some(h) => Ok(next.run(req).await), + None => Err(ServiceError::NotAuthorized), + } +} fn load_or_gen_keypair() -> Result<Ed25519KeyPair, ServiceError> { let kp: Ed25519KeyPair; @@ -59,7 +82,10 @@ fn load_or_gen_keypair() -> Result<Ed25519KeyPair, ServiceError> { let srand = SystemRandom::new(); let pkcs8 = Ed25519KeyPair::generate_pkcs8(&srand)?; - let mut file = File::open(".keypair").unwrap(); + let mut file = match File::open(".keypair") { + Ok(f) => f, + Err(_) => File::create(".keypair").unwrap(), + }; file.write(pkcs8.as_ref()); kp = Ed25519KeyPair::from_pkcs8(pkcs8.as_ref())?; diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 8f8224e..ddd4006 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -1,2 +1,2 @@ -pub mod planets; pub mod auth; +pub mod planets; diff --git a/src/handlers/planets.rs b/src/handlers/planets.rs index 2c66324..8593026 100644 --- a/src/handlers/planets.rs +++ b/src/handlers/planets.rs @@ -1,14 +1,14 @@ use axum::Extension; +use axum::{extract::Path, response::IntoResponse, Json}; use axum_macros::debug_handler; -use axum::{response::IntoResponse, Json, extract::Path}; use solarlib::ship::{DbShip, Ship}; use tokio::process::Command; use tracing::{error, instrument}; -use solarlib::star::{Star, NewPlanet}; -use solarlib::planet::Planet; use solarlib::errors::Error as SolarlibError; +use solarlib::planet::Planet; +use solarlib::star::{NewPlanet, Star}; use crate::{errors::*, get_star, State}; use std::sync::Arc; @@ -16,14 +16,13 @@ use std::sync::Arc; pub async fn list() -> JsonResult<Json<Vec<Planet>>> { let con_url = std::env::var("QEMU_URL").unwrap_or("qemu:///system".to_string()); let mut star = Star::new(con_url)?; - + let inhabitants = star.inhabitants()?; Ok(Json(inhabitants)) } pub async fn get(Path(uuid): Path<String>) -> JsonResult<Json<Planet>> { - let con_url = std::env::var("QEMU_URL").unwrap_or("qemu:///system".to_string()); let mut star = Star::new(con_url)?; @@ -79,7 +78,7 @@ pub async fn reboot(Path(uuid): Path<String>) -> NoneResult { } else { return Err(ServiceError::NotFound); } - + Ok(()) } @@ -113,7 +112,10 @@ pub async fn new_planet(Json(new_planet): Json<NewPlanet>) -> JsonResult<Json<Pl let ship_shasum = new_planet.clone().ship; - let res: DbShip = reqwest::get(format!("http://{}/ships/get/{}", hw_url, ship_shasum)).await?.json().await?; + let res: DbShip = reqwest::get(format!("http://{}/ships/get/{}", hw_url, ship_shasum)) + .await? + .json() + .await?; let ship: Ship = res.into(); let mut s = get_star()?; @@ -122,17 +124,27 @@ pub async fn new_planet(Json(new_planet): Json<NewPlanet>) -> JsonResult<Json<Pl // Try to create right away, if the Ship already exists on the system, it'll go through. If // not, we can download it by using the shasum - let r = s.planet(new_planet.clone().name, new_planet.max_mem, new_planet.max_cpus, new_planet.disk_size_mb, ship.clone()); + let r = s.planet( + new_planet.clone().name, + new_planet.max_mem, + new_planet.max_cpus, + new_planet.disk_size_mb, + ship.clone(), + ); match r { - Err(e) => { - match e { - SolarlibError::MissingImage(_) => { - ship.download(&s)?; - return Ok(Json(s.planet(new_planet.name, new_planet.max_mem, new_planet.max_cpus, new_planet.disk_size_mb, ship.clone())?)); - }, - _ => { - return Err(ServiceError::Solarlib(e)); - } + Err(e) => match e { + SolarlibError::MissingImage(_) => { + ship.download(&s)?; + return Ok(Json(s.planet( + new_planet.name, + new_planet.max_mem, + new_planet.max_cpus, + new_planet.disk_size_mb, + ship.clone(), + )?)); + } + _ => { + return Err(ServiceError::Solarlib(e)); } }, Ok(r) => { @@ -142,7 +154,6 @@ pub async fn new_planet(Json(new_planet): Json<NewPlanet>) -> JsonResult<Json<Pl }; } - pub async fn no_planet(Path(uuid): Path<String>) -> NoneResult { let mut s = get_star()?; diff --git a/src/main.rs b/src/main.rs index 3f889cd..97050e8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,17 @@ use axum::{ error_handling::HandleErrorLayer, + handler::Handler, http::StatusCode, + middleware, response::IntoResponse, routing::{get, post}, - handler::Handler, - Json, Router, Extension + Extension, Json, Router, }; -use rand::{thread_rng, Rng, distributions::Alphanumeric}; +use rand::{distributions::Alphanumeric, thread_rng, Rng}; use serde::{Deserialize, Serialize}; use solarlib::star::Star; -use std::{net::SocketAddr, time::Duration, str::FromStr, sync::Arc}; +use std::{net::SocketAddr, str::FromStr, sync::Arc, time::Duration}; use tower::{BoxError, ServiceBuilder}; use tower_http::trace::TraceLayer; @@ -32,10 +33,9 @@ pub struct State { async fn main() { kankyo::init(); color_eyre::install().unwrap(); - tracing_subscriber::registry() + tracing_subscriber::registry() .with(tracing_subscriber::EnvFilter::new( - std::env::var("RUST_LOG") - .unwrap_or_else(|_| "solard=info,tower_http=debug".into()), + std::env::var("RUST_LOG").unwrap_or_else(|_| "solard=info,tower_http=debug".into()), )) .with(tracing_subscriber::fmt::layer()) .init(); @@ -63,17 +63,32 @@ async fn main() { .route("/health", get(health_check)) .route("/planets/list", get(handlers::planets::list)) .route("/planets/new", post(handlers::planets::new_planet)) + .route_layer(middleware::from_fn(handlers::auth::requires_auth)) .route("/planets/:uuid", get(handlers::planets::get)) .route("/planets/:uuid/shutdown", post(handlers::planets::shutdown)) - .route("/planets/:uuid/shutdown/hard", post(handlers::planets::force_shutdown)) + .route_layer(middleware::from_fn(handlers::auth::requires_auth)) + .route( + "/planets/:uuid/shutdown/hard", + post(handlers::planets::force_shutdown), + ) + .route_layer(middleware::from_fn(handlers::auth::requires_auth)) .route("/planets/:uuid/start", post(handlers::planets::start)) + .route_layer(middleware::from_fn(handlers::auth::requires_auth)) .route("/planets/:uuid/pause", post(handlers::planets::pause)) + .route_layer(middleware::from_fn(handlers::auth::requires_auth)) .route("/planets/:uuid/reboot", post(handlers::planets::reboot)) - .route("/planets/:uuid/reboot/hard", post(handlers::planets::force_reboot)) + .route_layer(middleware::from_fn(handlers::auth::requires_auth)) + .route( + "/planets/:uuid/reboot/hard", + post(handlers::planets::force_reboot), + ) + .route_layer(middleware::from_fn(handlers::auth::requires_auth)) .route("/planets/:uuid/destroy", post(handlers::planets::no_planet)) + .route_layer(middleware::from_fn(handlers::auth::requires_auth)) // Authentication .route("/auth/begin", post(handlers::auth::begin)) - .layer( ServiceBuilder::new() + .layer( + ServiceBuilder::new() .layer(HandleErrorLayer::new(|error: BoxError| async move { if error.is::<tower::timeout::error::Elapsed>() { Ok(StatusCode::REQUEST_TIMEOUT) @@ -106,7 +121,7 @@ async fn health_check() -> &'static str { fn get_star() -> Result<Star, errors::ServiceError> { let con_url = std::env::var("QEMU_URL").unwrap_or("qemu:///system".to_string()); - + Ok(Star::new(con_url)?) } |