aboutsummaryrefslogblamecommitdiff
path: root/src/main.rs
blob: 6ebe145c57da2dfeabb4630d8ba9417b263d2525 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12

           

             
          






                                                                                        
                                                             
                                     

                                       

                            


                                            
                                         

                                   

                                           































                                                                                 

                                             


                                            
                                                     
















                                                                           

                                       





















                                                       





                                                                                   



                                                                                            
                             
                                                                       






                                              





                                                                                                                  

                                                                              
                                                                   

















                                                                
































                                                                              
                                                    
mod config;
mod errors;
mod handlers;
mod models;
mod utils;

use std::{net::SocketAddr, str::FromStr, sync::Arc, time::Duration};

use axum::body;
use axum::extract::Path;
use axum::{error_handling::HandleErrorLayer, routing::get, BoxError, Extension, Router};
use axum::response::{Html, IntoResponse, Response};
use axum_extra::extract::{PrivateCookieJar, SignedCookieJar};
use axum_extra::extract::cookie::Key;
use errors::{StringResult, HtmlResult};
use hyper::StatusCode;
use models::{Peer, Network};
use sqlx::query_as;
use sqlx::{PgPool, postgres::PgPoolOptions};
use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;
use tracing::{error, info, debug, trace};
use crate::errors::ServiceError;
use tracing_subscriber::prelude::*;
use crate::models::DbUser;
use crate::handlers::auth::get_user_or_403;

pub struct State {
    pub config: config::Config,
    pub conn: PgPool
}

#[tokio::main]
async fn main() {
    color_eyre::install().unwrap();
    tracing_subscriber::registry()
        .with(tracing_subscriber::EnvFilter::new("debug"))
        .with(tracing_subscriber::fmt::layer())
        .init();
    let config = match config::Config::init("/etc/nccd/config.toml".to_owned()) {
        Ok(c) => c,
        Err(e) => {
            error!("Config Error: {:?}", e);
            std::process::exit(1);
        }
    };

    let bind_addr = config.server.bind_addr.clone();

    let db_config = config.database.clone();

    let conn = PgPoolOptions::new()
        .max_connections(db_config.max_connections)
        .connect(&db_config.postgres_url)
        .await.unwrap();

    let shared_state = Arc::new(State { config, conn });

    let key = load_or_gen_keypair().unwrap();

    let app = Router::new()
        .route("/health", get(health_check))
        .route("/", get(index))
        .nest("/dash", handlers::gen_routers().await)
        .route("/static/:name", get(statics))
        .layer(
            ServiceBuilder::new()
                .layer(HandleErrorLayer::new(|error: BoxError| async move {
                    if error.is::<tower::timeout::error::Elapsed>() {
                        Ok(StatusCode::REQUEST_TIMEOUT)
                    } else {
                        Err((
                            StatusCode::INTERNAL_SERVER_ERROR,
                            format!("Unhandled internal error: {}", error),
                        ))
                    }
                }))
                .timeout(Duration::from_secs(10))
                .layer(TraceLayer::new_for_http())
                .into_inner(),
        )
        .layer(Extension(shared_state))
        .layer(Extension(key));

    let addr = match SocketAddr::from_str(&bind_addr) {
        Ok(a) => a,
        Err(e) => {
            error!("Invalid bind addr: {:?}", e);
            std::process::exit(1);
        }
    };

    info!("Listening on {}", addr);

    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn health_check() -> &'static str {
    "OK"
}

#[axum_macros::debug_handler]
async fn index(state: Extension<Arc<State>>, jar: PrivateCookieJar) -> HtmlResult {
    let mut conn = state.conn.acquire().await?;
    let user: Option<DbUser> = match get_user_or_403(jar, &mut conn).await {
        Ok(u) => Some(u),
        Err(_) => None,
    };
    
    let peers: Vec<Peer> = query_as("SELECT * FROM peers").fetch_all(&mut conn).await?;
    let nets: Vec<Network> = query_as("SELECT * FROM networks").fetch_all(&mut conn).await?;

    let mut buf = Vec::new();
    crate::templates::index_html(&mut buf, user, peers, nets).unwrap();

    match String::from_utf8(buf) {
        Ok(s) => Ok(Html(s)),
        Err(_) => Err(ServiceError::NotFound),
    }
}

async fn test_email() -> Result<(), ServiceError> {
    utils::send_email("csalter@carathe.dev".to_string(), "Test Email".to_string(), "Hi, test".to_string()).await?;

    Ok(())
}

async fn statics(Path(name): Path<String>) -> Result<Response, ServiceError> {
    for s in templates::statics::STATICS {
        trace!("Name: {}\nContents:\n{:?}\n\n", s.name, s.content);
    }

    match templates::statics::StaticFile::get(&name) {
        Some(s) => match String::from_utf8(s.content.to_vec()) {
            Ok(c) => {
                let body = body::boxed(body::Full::from(c));
                
                Ok(Response::builder()
                    .header("Content-Type", "text/css")
                    .status(StatusCode::OK)
                    .body(body).unwrap())
            },
            Err(_) => Err(ServiceError::NotFound),
        },
        None => Err(ServiceError::NotFound),
    }
}

use std::fs::{self, File};
fn load_or_gen_keypair() -> Result<Key, ServiceError> {
    let kp: Key;
        let mut file = match File::open(".keypair") {
            Ok(f) => f,
            Err(_) => {
                debug!("File does not exist, creating at .keypair");
                File::create(".keypair").unwrap()
            }
        };
    if let Ok(c) = fs::read(".keypair") {
        if c.len() == 0 {
            debug!("No keypair found. Generating...");
            let key = Key::generate();
            fs::write(".keypair", key.master().as_ref()).unwrap(); 
            debug!("Written keypair {:?} to .keypair", key.master().as_ref());
            kp = key;
        } else {
            debug!("Found keypair file, contents: {:?}", c);
            kp = Key::from(&c);
            debug!("Loaded keypair from file");
        }
    } else {
        debug!("Generating new keypair");
            let key = Key::generate();
            fs::write(".keypair", key.master().as_ref()).unwrap(); 
            debug!("Written keypair {:?} to .keypair", key.master().as_ref());
            kp = key;
   }
    Ok(kp)
}


include!(concat!(env!("OUT_DIR"), "/templates.rs"));