From ecca0e8cef755899049e7f541c86b9d06b2438a4 Mon Sep 17 00:00:00 2001 From: Cara Salter Date: Mon, 30 May 2022 00:24:40 -0400 Subject: meta: Add Bearer authentication Makes use of a pre-shared key. Do not expose this to the internet! --- Cargo.lock | 12 ++++++++++++ Cargo.toml | 3 +++ src/errors.rs | 4 ++++ src/handlers/ships.rs | 24 +++++++++++++++++++++--- src/main.rs | 2 +- 5 files changed, 41 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3cab99f..d43b979 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -94,6 +94,17 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-auth" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a78cc399f2af2dd7adf88e0fcc0e21dbf730258c1b34785f47816ff224238f74" +dependencies = [ + "axum", + "base64", + "http", +] + [[package]] name = "axum-core" version = "0.2.4" @@ -499,6 +510,7 @@ name = "homeworld" version = "0.1.0" dependencies = [ "axum", + "axum-auth", "color-eyre", "eyre", "hyper", diff --git a/Cargo.toml b/Cargo.toml index 8290111..afb3d62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,9 @@ thiserror = "1" kankyo = "0.3" +# Middleware +axum-auth = "0.2" + [dependencies.solarlib] git = "https://git.carathe.dev/solard/solarlib" branch = "master" diff --git a/src/errors.rs b/src/errors.rs index c1d6672..d4b365f 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -15,6 +15,9 @@ pub enum ServiceError { #[error("Not Found")] NotFound, + + #[error("Not Authorized")] + NotAuthorized, } pub type StringResult = Result; @@ -27,6 +30,7 @@ impl IntoResponse for ServiceError { let status = match self { ServiceError::NotFound => StatusCode::NOT_FOUND, + ServiceError::NotAuthorized => StatusCode::UNAUTHORIZED, _ => StatusCode::INTERNAL_SERVER_ERROR, }; diff --git a/src/handlers/ships.rs b/src/handlers/ships.rs index 75f5dc5..42367ea 100644 --- a/src/handlers/ships.rs +++ b/src/handlers/ships.rs @@ -4,13 +4,13 @@ use axum::{Json, extract::Path, Extension}; use hyper::StatusCode; use solarlib::ship::{Ship, DbShip, Sha256}; use sqlx::{query_as, query, Error as SqlxError}; +use axum_auth::AuthBearer; +use tracing::log::warn; use crate::{errors::{JsonResult, StringResult, ServiceError}, State}; pub async fn list(state: Extension>) -> JsonResult>> { - let mut result: Vec = Vec::new(); - let mut conn = state.conn.acquire().await?; let db_ships = query_as!(DbShip, "SELECT * FROM ships").fetch_all(&mut conn).await?; @@ -20,7 +20,8 @@ pub async fn list(state: Extension>) -> JsonResult>> { Ok(Json(ships)) } -pub async fn new(Json(new_ship): Json, state: Extension>) -> StringResult { +pub async fn new(Json(new_ship): Json, state: Extension>, AuthBearer(token): AuthBearer) -> StringResult { + check_bearer(token)?; let mut conn = state.conn.acquire().await?; query!("INSERT INTO ships (name, shasum, download_url, version) VALUES ($1, $2, $3, $4)", new_ship.name, new_ship.shasum.to_string(), new_ship.download_url, new_ship.version).execute(&mut conn).await?; @@ -29,6 +30,7 @@ pub async fn new(Json(new_ship): Json, state: Extension>) -> St } pub async fn delete(Path(shasum): Path, state: Extension>) -> StringResult { + check_bearer(token)?; let mut conn = state.conn.acquire().await?; query!("DELETE FROM ships WHERE shasum=$1", shasum.to_string()).execute(&mut conn).await?; @@ -55,3 +57,19 @@ pub async fn get(Path(shasum): Path, state: Extension>) -> Js Ok(Json(db_ship)) } + +fn check_bearer(token: String) -> Result<(), ServiceError> { + let expected_token = match std::env::var("SHARED_KEY") { + Ok(t) => t, + Err(_) => { + warn!("No pre-shared key set in environment. This is not secure!"); + "bad-key".into() + } + }; + + if token != expected_token { + Err(ServiceError::NotAuthorized) + } else { + Ok(()) + } +} diff --git a/src/main.rs b/src/main.rs index 355cc88..590d714 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,7 +31,7 @@ async fn main() { tracing_subscriber::registry() .with(tracing_subscriber::EnvFilter::new( std::env::var("RUST_LOG") - .unwrap_or_else(|_| "waifud=info,tower_http=debug".into()), + .unwrap_or_else(|_| "homeworld=info,tower_http=debug".into()), )) .with(tracing_subscriber::fmt::layer()) .init(); -- cgit v1.2.3