diff options
Diffstat (limited to 'src/handlers/auth.rs')
-rw-r--r-- | src/handlers/auth.rs | 54 |
1 files changed, 40 insertions, 14 deletions
diff --git a/src/handlers/auth.rs b/src/handlers/auth.rs index 8d65b05..cafaeb8 100644 --- a/src/handlers/auth.rs +++ b/src/handlers/auth.rs @@ -1,15 +1,22 @@ -use std::{collections::HashMap, fs::{self, File}}; +use std::{ + collections::HashMap, + fs::{self, File}, + sync::Arc, +}; -use axum::{extract::Query, Extension}; +use axum::{extract::Query, middleware::Next, response::Response, Extension}; use axum_macros::debug_handler; -use chrono::{Utc, TimeZone, Datelike}; +use chrono::{Datelike, TimeZone, Utc}; +use hyper::Request; use ring::{rand::SystemRandom, signature::Ed25519KeyPair}; use uuid::Uuid; use std::io::Write; -use crate::{errors::{NoneResult, ServiceError, StringResult, TokenResult}, State}; - +use crate::{ + errors::{NoneResult, ServiceError, StringResult, TokenResult}, + State, +}; /** * Takes in a request to create a new token with a secret key that gets printed @@ -17,7 +24,10 @@ use crate::{errors::{NoneResult, ServiceError, StringResult, TokenResult}, State * for future authentication */ #[debug_handler] -pub async fn begin(Query(params): Query<HashMap<String, String>>, Extension(state): Extension<State>) -> TokenResult { +pub async fn begin( + Query(params): Query<HashMap<String, String>>, + Extension(state): Extension<Arc<State>>, +) -> TokenResult { if let Some(k) = params.get("key") { if k == &state.gen_key { let dt = Utc::now(); @@ -33,15 +43,17 @@ pub async fn begin(Query(params): Query<HashMap<String, String>>, Extension(stat .set_issuer("solard") .set_audience("solard") .set_not_before(&Utc::now()) - .build() { - Ok(token) => token, - Err(_) => { - return Err(ServiceError::Generic(String::from("could not generate paseto key"))); - } - }; + .build() + { + Ok(token) => token, + Err(_) => { + return Err(ServiceError::Generic(String::from( + "could not generate paseto key", + ))); + } + }; return Ok(token.to_string()); - } else { return Err(ServiceError::NotAuthorized); } @@ -50,6 +62,17 @@ pub async fn begin(Query(params): Query<HashMap<String, String>>, Extension(stat } } +pub async fn requires_auth<B>(req: Request<B>, next: Next<B>) -> Result<Response, ServiceError> { + let auth_header = req + .headers() + .get(axum::http::header::AUTHORIZATION) + .and_then(|h| h.to_str().ok()); + + match auth_header { + Some(h) => Ok(next.run(req).await), + None => Err(ServiceError::NotAuthorized), + } +} fn load_or_gen_keypair() -> Result<Ed25519KeyPair, ServiceError> { let kp: Ed25519KeyPair; @@ -59,7 +82,10 @@ fn load_or_gen_keypair() -> Result<Ed25519KeyPair, ServiceError> { let srand = SystemRandom::new(); let pkcs8 = Ed25519KeyPair::generate_pkcs8(&srand)?; - let mut file = File::open(".keypair").unwrap(); + let mut file = match File::open(".keypair") { + Ok(f) => f, + Err(_) => File::create(".keypair").unwrap(), + }; file.write(pkcs8.as_ref()); kp = Ed25519KeyPair::from_pkcs8(pkcs8.as_ref())?; |