From f42bf29ba97fc808433ae4217fd6b00469a12fae Mon Sep 17 00:00:00 2001 From: Cara Salter Date: Wed, 20 Jul 2022 07:59:44 -0400 Subject: Cookies --- .gitignore | 1 + src/handlers/auth.rs | 34 ++++++++++++++++++++++++++++----- src/handlers/mod.rs | 2 +- src/main.rs | 51 ++++++++++++++++++++++++++++++++++++++++++++++--- templates/index.rs.html | 8 +++++++- 5 files changed, 86 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 0133534..c159c44 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target design/ +.keypair diff --git a/src/handlers/auth.rs b/src/handlers/auth.rs index c4672aa..7e2642c 100644 --- a/src/handlers/auth.rs +++ b/src/handlers/auth.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use axum::{response::{IntoResponse, Html, Redirect}, Form, Extension}; use axum_extra::extract::{PrivateCookieJar, cookie::Cookie}; use serde::Deserialize; -use sqlx::{query, query_as}; -use tracing::debug; +use sqlx::{query, query_as, pool::PoolConnection, Postgres}; +use tracing::{debug, instrument}; use crate::{errors::ServiceError, State, models::DbUser}; use chrono::prelude::*; @@ -39,15 +39,17 @@ pub async fn login_post(Form(login): Form, state: Extension impl IntoResponse { @@ -77,3 +79,25 @@ pub async fn register_post(Form(reg): Form, state: Extension) -> Result { + debug!("Starting middleware get_user_or_403"); + debug!("Displaying all cookies"); + for c in jar.iter() { + debug!("{}={}", c.name(), c.value()); + } + if let Some(id) = jar.get("user-id") { + debug!("Found user {}", id); + + let user: DbUser = query_as("SELECT * FROM users WHERE id=$1").bind(id.value()) + .fetch_one(conn) + .await?; + + Ok(user) + + } else { + debug!("No user found"); + Err(ServiceError::NotAuthorized) + } +} diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 4076e68..b83d83c 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -1,6 +1,6 @@ use axum::{Router, routing::get}; -mod auth; +pub mod auth; pub async fn gen_routers() -> Router { diff --git a/src/main.rs b/src/main.rs index ed943c7..acc9f9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,8 @@ use axum::body; use axum::extract::Path; use axum::{error_handling::HandleErrorLayer, routing::get, BoxError, Extension, Router}; use axum::response::{Html, IntoResponse, Response}; +use axum_extra::extract::PrivateCookieJar; +use axum_extra::extract::cookie::Key; use errors::{StringResult, HtmlResult}; use hyper::StatusCode; use sqlx::{PgPool, postgres::PgPoolOptions}; @@ -17,6 +19,8 @@ use tower_http::trace::TraceLayer; use tracing::{error, info, debug}; use crate::errors::ServiceError; use tracing_subscriber::prelude::*; +use crate::models::DbUser; +use crate::handlers::auth::get_user_or_403; pub struct State { pub config: config::Config, @@ -49,6 +53,8 @@ async fn main() { let shared_state = Arc::new(State { config, conn }); + let key = load_or_gen_keypair().unwrap(); + let app = Router::new() .route("/health", get(health_check)) .route("/", get(index)) @@ -70,7 +76,8 @@ async fn main() { .layer(TraceLayer::new_for_http()) .into_inner(), ) - .layer(Extension(shared_state)); + .layer(Extension(shared_state)) + .layer(Extension(key)); let addr = match SocketAddr::from_str(&bind_addr) { Ok(a) => a, @@ -93,9 +100,14 @@ async fn health_check() -> &'static str { } #[axum_macros::debug_handler] -async fn index() -> HtmlResult { +async fn index(state: Extension>, jar: PrivateCookieJar) -> HtmlResult { + let mut conn = state.conn.acquire().await?; + let user: Option = match get_user_or_403(jar, &mut conn).await { + Ok(u) => Some(u), + Err(_) => None, + }; let mut buf = Vec::new(); - crate::templates::index_html(&mut buf).unwrap(); + crate::templates::index_html(&mut buf, user).unwrap(); match String::from_utf8(buf) { Ok(s) => Ok(Html(s)), @@ -124,4 +136,37 @@ async fn statics(Path(name): Path) -> Result { } } +use std::fs::{self, File}; +fn load_or_gen_keypair() -> Result { + let kp: Key; + let mut file = match File::open(".keypair") { + Ok(f) => f, + Err(_) => { + debug!("File does not exist, creating at .keypair"); + File::create(".keypair").unwrap() + } + }; + if let Ok(c) = fs::read(".keypair") { + if c.len() == 0 { + debug!("No keypair found. Generating..."); + let key = Key::generate(); + fs::write(".keypair", key.master().as_ref()).unwrap(); + debug!("Written keypair {:?} to .keypair", key.master().as_ref()); + kp = key; + } else { + debug!("Found keypair file, contents: {:?}", c); + kp = Key::from(&c); + debug!("Loaded keypair from file"); + } + } else { + debug!("Generating new keypair"); + let key = Key::generate(); + fs::write(".keypair", key.master().as_ref()).unwrap(); + debug!("Written keypair {:?} to .keypair", key.master().as_ref()); + kp = key; + } + Ok(kp) +} + + include!(concat!(env!("OUT_DIR"), "/templates.rs")); diff --git a/templates/index.rs.html b/templates/index.rs.html index 4fb701c..d0196df 100644 --- a/templates/index.rs.html +++ b/templates/index.rs.html @@ -1,8 +1,14 @@ @use super::{header_html, footer_html}; +@use crate::models::DbUser; -@() +@(user: Option) @:header_html()

NCCd (Network Communications Control Daemon)

+ @if user.is_some() { +

Welcome @user.unwrap().pref_name

+ } else { +

Please Log in

+ } @:footer_html() -- cgit v1.2.3