diff options
| author | Cara Salter <cara@devcara.com> | 2022-07-20 07:59:44 -0400 | 
|---|---|---|
| committer | Cara Salter <cara@devcara.com> | 2022-07-20 07:59:44 -0400 | 
| commit | f42bf29ba97fc808433ae4217fd6b00469a12fae (patch) | |
| tree | 0e04421da46d41f20d19327a0e1280509957a1e7 | |
| parent | c742b752140ab0eee6e353c779bd897042ba6739 (diff) | |
| download | nccd-f42bf29ba97fc808433ae4217fd6b00469a12fae.tar.gz nccd-f42bf29ba97fc808433ae4217fd6b00469a12fae.zip  | |
Cookies
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | src/handlers/auth.rs | 34 | ||||
| -rw-r--r-- | src/handlers/mod.rs | 2 | ||||
| -rw-r--r-- | src/main.rs | 51 | ||||
| -rw-r--r-- | templates/index.rs.html | 8 | 
5 files changed, 86 insertions, 10 deletions
@@ -1,2 +1,3 @@  /target  design/ +.keypair diff --git a/src/handlers/auth.rs b/src/handlers/auth.rs index c4672aa..7e2642c 100644 --- a/src/handlers/auth.rs +++ b/src/handlers/auth.rs @@ -3,8 +3,8 @@ use std::sync::Arc;  use axum::{response::{IntoResponse, Html, Redirect}, Form, Extension};  use axum_extra::extract::{PrivateCookieJar, cookie::Cookie};  use serde::Deserialize; -use sqlx::{query, query_as}; -use tracing::debug; +use sqlx::{query, query_as, pool::PoolConnection, Postgres}; +use tracing::{debug, instrument};  use crate::{errors::ServiceError, State, models::DbUser};  use chrono::prelude::*; @@ -39,15 +39,17 @@ pub async fn login_post(Form(login): Form<LoginForm>, state: Extension<Arc<State      if bcrypt::verify(login.password, &user.pw_hash)? {          debug!("Logged in ID {} (email {})", user.id, user.email); -        query("UPDATE users SET last_login=$1 WHERE id=$2").bind(Utc::now()).bind(user.id) +        query("UPDATE users SET last_login=$1 WHERE id=$2").bind(Utc::now()).bind(user.id.clone())              .execute(&mut conn)              .await?;          let updated_jar = jar.add(Cookie::new("user-id", user.id.clone())); -    } else { +        Ok((updated_jar, Redirect::to("/"))) +    } else { +        let updated_jar = jar; +        Ok((updated_jar, Redirect::to("/dash/auth/login")))      } -    Ok((updated_jar, Redirect::to("/")))  }  pub async fn register() -> impl IntoResponse { @@ -77,3 +79,25 @@ pub async fn register_post(Form(reg): Form<RegisterForm>, state: Extension<Arc<S      Ok(Redirect::to("/dash/auth/login"))  } + +#[instrument] +pub async fn get_user_or_403(jar: PrivateCookieJar, conn: &mut PoolConnection<Postgres>) -> Result<DbUser, ServiceError> { +    debug!("Starting middleware get_user_or_403"); +    debug!("Displaying all cookies"); +    for c in jar.iter() { +        debug!("{}={}", c.name(), c.value()); +    } +    if let Some(id) = jar.get("user-id") { +        debug!("Found user {}", id); + +        let user: DbUser = query_as("SELECT * FROM users WHERE id=$1").bind(id.value()) +            .fetch_one(conn) +            .await?; + +        Ok(user) + +    } else { +        debug!("No user found"); +        Err(ServiceError::NotAuthorized) +    } +} diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 4076e68..b83d83c 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -1,6 +1,6 @@  use axum::{Router, routing::get}; -mod auth; +pub mod auth;  pub async fn gen_routers() -> Router { diff --git a/src/main.rs b/src/main.rs index ed943c7..acc9f9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,8 @@ use axum::body;  use axum::extract::Path;  use axum::{error_handling::HandleErrorLayer, routing::get, BoxError, Extension, Router};  use axum::response::{Html, IntoResponse, Response}; +use axum_extra::extract::PrivateCookieJar; +use axum_extra::extract::cookie::Key;  use errors::{StringResult, HtmlResult};  use hyper::StatusCode;  use sqlx::{PgPool, postgres::PgPoolOptions}; @@ -17,6 +19,8 @@ use tower_http::trace::TraceLayer;  use tracing::{error, info, debug};  use crate::errors::ServiceError;  use tracing_subscriber::prelude::*; +use crate::models::DbUser; +use crate::handlers::auth::get_user_or_403;  pub struct State {      pub config: config::Config, @@ -49,6 +53,8 @@ async fn main() {      let shared_state = Arc::new(State { config, conn }); +    let key = load_or_gen_keypair().unwrap(); +      let app = Router::new()          .route("/health", get(health_check))          .route("/", get(index)) @@ -70,7 +76,8 @@ async fn main() {                  .layer(TraceLayer::new_for_http())                  .into_inner(),          ) -        .layer(Extension(shared_state)); +        .layer(Extension(shared_state)) +        .layer(Extension(key));      let addr = match SocketAddr::from_str(&bind_addr) {          Ok(a) => a, @@ -93,9 +100,14 @@ async fn health_check() -> &'static str {  }  #[axum_macros::debug_handler] -async fn index() -> HtmlResult { +async fn index(state: Extension<Arc<State>>, jar: PrivateCookieJar) -> HtmlResult { +    let mut conn = state.conn.acquire().await?; +    let user: Option<DbUser> = match get_user_or_403(jar, &mut conn).await { +        Ok(u) => Some(u), +        Err(_) => None, +    };      let mut buf = Vec::new(); -    crate::templates::index_html(&mut buf).unwrap(); +    crate::templates::index_html(&mut buf, user).unwrap();      match String::from_utf8(buf) {          Ok(s) => Ok(Html(s)), @@ -124,4 +136,37 @@ async fn statics(Path(name): Path<String>) -> Result<Response, ServiceError> {      }  } +use std::fs::{self, File}; +fn load_or_gen_keypair() -> Result<Key, ServiceError> { +    let kp: Key; +        let mut file = match File::open(".keypair") { +            Ok(f) => f, +            Err(_) => { +                debug!("File does not exist, creating at .keypair"); +                File::create(".keypair").unwrap() +            } +        }; +    if let Ok(c) = fs::read(".keypair") { +        if c.len() == 0 { +            debug!("No keypair found. Generating..."); +            let key = Key::generate(); +            fs::write(".keypair", key.master().as_ref()).unwrap();  +            debug!("Written keypair {:?} to .keypair", key.master().as_ref()); +            kp = key; +        } else { +            debug!("Found keypair file, contents: {:?}", c); +            kp = Key::from(&c); +            debug!("Loaded keypair from file"); +        } +    } else { +        debug!("Generating new keypair"); +            let key = Key::generate(); +            fs::write(".keypair", key.master().as_ref()).unwrap();  +            debug!("Written keypair {:?} to .keypair", key.master().as_ref()); +            kp = key; +   } +    Ok(kp) +} + +  include!(concat!(env!("OUT_DIR"), "/templates.rs")); diff --git a/templates/index.rs.html b/templates/index.rs.html index 4fb701c..d0196df 100644 --- a/templates/index.rs.html +++ b/templates/index.rs.html @@ -1,8 +1,14 @@  @use super::{header_html, footer_html}; +@use crate::models::DbUser; -@() +@(user: Option<DbUser>)  @:header_html()              <h1>NCCd (Network Communications Control Daemon)</h1> +            @if user.is_some() { +                <h3>Welcome @user.unwrap().pref_name</h3> +            } else { +              <h3>Please <a href="/dash/auth/login">Log in</a></h3> +            }  @:footer_html()   | 
