diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/errors.rs | 3 | ||||
-rw-r--r-- | src/handlers/ships.rs | 43 | ||||
-rw-r--r-- | src/main.rs | 28 |
3 files changed, 57 insertions, 17 deletions
diff --git a/src/errors.rs b/src/errors.rs index 70b1a43..42477c0 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -9,6 +9,9 @@ pub enum ServiceError { #[error("Axum error: {0}")] Axum(#[from] axum::Error), + + #[error("SQL error: {0}")] + Sql(#[from] sqlx::Error), } pub type StringResult<T = &'static str> = Result<T, ServiceError>; diff --git a/src/handlers/ships.rs b/src/handlers/ships.rs index 5e8c000..ddb81a6 100644 --- a/src/handlers/ships.rs +++ b/src/handlers/ships.rs @@ -1,23 +1,44 @@ -use axum::{Json, extract::Path}; -use solarlib::ship::{Ship, Sha256}; +use std::sync::Arc; -use crate::errors::{JsonResult, StringResult}; +use axum::{Json, extract::Path, Extension}; +use solarlib::ship::{Ship, DbShip, Sha256}; +use sqlx::{query_as, query}; +use crate::{errors::{JsonResult, StringResult}, State}; -pub async fn list() -> JsonResult<Json<Vec<Ship>>> { + +pub async fn list(state: Extension<Arc<State>>) -> JsonResult<Json<Vec<Ship>>> { let mut result: Vec<Ship> = Vec::new(); - Ok(Json(result)) + let mut conn = state.conn.acquire().await?; + + let db_ships = query_as!(DbShip, "SELECT * FROM ships").fetch_all(&mut conn).await?; + + let ships = db_ships.into_iter().map(|d| d.into()).collect::<Vec<Ship>>(); + + Ok(Json(ships)) } -pub async fn new(Json(new_ship): Json<Ship>) -> StringResult { - unimplemented!(); +pub async fn new(Json(new_ship): Json<Ship>, state: Extension<Arc<State>>) -> StringResult { + let mut conn = state.conn.acquire().await?; + + query!("INSERT INTO ships (name, shasum, download_url, version) VALUES ($1, $2, $3, $4)", new_ship.name, new_ship.shasum.to_string(), new_ship.download_url, new_ship.version).execute(&mut conn).await?; + + Ok("OK") } -pub async fn update(Json(new_ship): Json<Ship>) -> StringResult { - todo!(); +pub async fn delete(Path(shasum): Path<Sha256>, state: Extension<Arc<State>>) -> StringResult { + let mut conn = state.conn.acquire().await?; + + query!("DELETE FROM ships WHERE shasum=$1", shasum.to_string()).execute(&mut conn).await?; + + Ok("OK") } -pub async fn delete(Path(shasum): Path<Sha256>) -> StringResult { - todo!(); +pub async fn get(Path(shasum): Path<Sha256>, state: Extension<Arc<State>>) -> JsonResult<Json<DbShip>> { + let mut conn = state.conn.acquire().await?; + + let db_ship = query_as!(DbShip, "SELECT * FROM ships WHERE shasum=$1", shasum.to_string()).fetch_one(&mut conn).await?; + + Ok(Json(db_ship)) } diff --git a/src/main.rs b/src/main.rs index 31271e4..65a32d4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,11 +3,13 @@ use axum::{ http::StatusCode, response::IntoResponse, routing::{get, post, delete}, - Json, Router + Json, Router, Extension }; +use errors::ServiceError; use serde::{Deserialize, Serialize}; -use std::{net::SocketAddr, time::Duration, str::FromStr}; +use sqlx::{Connection, query, PgConnection, PgPool, postgres::PgPoolOptions}; +use std::{net::SocketAddr, time::Duration, str::FromStr, sync::Arc}; use tower::{BoxError, ServiceBuilder}; use tower_http::trace::TraceLayer; @@ -18,9 +20,13 @@ mod errors; mod handlers; +pub struct State { + pub conn: PgPool, +} + #[tokio::main] async fn main() { - kankyo::init(); + kankyo::init().unwrap(); color_eyre::install().unwrap(); tracing_subscriber::registry() .with(tracing_subscriber::EnvFilter::new( @@ -30,12 +36,20 @@ async fn main() { .with(tracing_subscriber::fmt::layer()) .init(); + let mut conn = PgPoolOptions::new() + .max_connections(5) + .connect(&std::env::var("DATABASE_URL").unwrap_or("postgres://postgres@localhost/homeworld".to_string())).await.unwrap(); + + let shared_state = Arc::new(State { + conn + }); + let app = Router::new() .route("/health", get(health_check)) .route("/ships/list", get(handlers::ships::list)) - .route("/ships/new", post(handlers::ships::new)) - .route("/ships/update", post(handlers::ships::update)) + .route("/ships/new", post(handlers::ships::new)) .route("/ships/delete/:shasum", delete(handlers::ships::delete)) + .route("/ships/get/:shasum", get(handlers::ships::get)) .layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|error: BoxError| async move { if error.is::<tower::timeout::error::Elapsed>() { @@ -50,7 +64,8 @@ async fn main() { .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) .into_inner(), - ); + ) + .layer(Extension(shared_state)); let addr = SocketAddr::from_str(std::env::var("BIND_ADDR").unwrap().as_str().into()).unwrap(); tracing::info!("Listening on {}", addr); @@ -64,3 +79,4 @@ async fn main() { async fn health_check() -> &'static str { "OK" } + |