Advertisement
SpiderLordCoder1st

Untitled

Jun 8th, 2025
23
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 29.58 KB | None | 0 0
  1. use std::collections::HashMap;
  2. use std::convert::Infallible;
  3. use std::fs;
  4. use std::{net::SocketAddr, path::Path, sync::Arc};
  5. //
  6. use axum::extract::Query;
  7. use axum::http::header::AUTHORIZATION;
  8. use axum::http::Uri;
  9. use axum::response::Redirect;
  10. use axum::{Extension, Form};
  11. use axum_login::tower_sessions::{MemoryStore, SessionManagerLayer};
  12. use axum_login::AuthUser;
  13. use axum_login::{login_required, AuthManagerLayerBuilder, AuthnBackend, UserId};
  14. use axum::middleware::{self, Next};
  15. use axum::{
  16. body::Body,
  17. extract::{ws::{Message, WebSocket, WebSocketUpgrade}, Request, State},
  18. http::{self, Method, Response, StatusCode},
  19. response::{Html, IntoResponse, Json},
  20. routing::{get, post},
  21. Router,
  22. };
  23. use futures_util::{sink::SinkExt, stream::StreamExt};
  24. // use k8s_openapi::chrono;
  25.  
  26. use mime_guess::from_path;
  27. use serde::{Deserialize, Serialize};
  28. use serde_json::json;
  29. use tokio::io::{AsyncBufReadExt, BufReader};
  30. use tokio::{
  31. fs as tokio_fs,
  32. io::{AsyncReadExt, AsyncWriteExt},
  33. net::{TcpListener, TcpStream},
  34. sync::{broadcast, mpsc, Mutex},
  35. time::{timeout, Duration},
  36. };
  37. use tower_http::cors::{Any, CorsLayer};
  38. use bcrypt::{hash, verify, DEFAULT_COST};
  39. use chrono::{Utc, Duration as OtherDuration};
  40. use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, TokenData, Validation};
  41. use bcrypt::BcryptError;
  42. use async_trait::async_trait;
  43. use tower_http::add_extension::AddExtensionLayer;
  44. use axum::extract::FromRequest;
  45.  
  46. mod database;
  47. use database::User;
  48. use database::CreateUserData;
  49. use database::RemoveUserData;
  50.  
  51. #[cfg(feature = "full-stack")]
  52. mod docker;
  53.  
  54. #[cfg(feature = "full-stack")]
  55. use kube::Client;
  56.  
  57. #[cfg(feature = "full-stack")]
  58. mod kubernetes;
  59.  
  60. #[cfg(not(feature = "full-stack"))]
  61. mod docker {
  62. pub async fn build_docker_image() -> Result<(), Box<dyn std::error::Error + Send + Sync>>{
  63. Err("This should not be running".into())
  64. }
  65. }
  66. #[cfg(not(feature = "full-stack"))]
  67. mod kubernetes {
  68. pub async fn create_k8s_deployment(_: &crate::Client) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
  69. Err("This should not be running".into())
  70. }
  71. pub async fn list_node_names(_: crate::Client) -> Result<Vec<String>, Box<dyn std::error::Error>> {
  72. Err("This should not be running".into())
  73. }
  74. }
  75. #[cfg(not(feature = "full-stack"))]
  76. static TcpUrl: &str = "0.0.0.0:8082";
  77.  
  78. #[cfg(not(feature = "full-stack"))]
  79. static LocalUrl: &str = "0.0.0.0:8081";
  80.  
  81. #[cfg(not(feature = "full-stack"))]
  82. static K8S_WORKS: bool = false;
  83.  
  84. #[cfg(feature = "full-stack")]
  85. static TcpUrl: &str = "gameserver-service:8080";
  86.  
  87. #[cfg(feature = "full-stack")]
  88. static LocalUrl: &str = "0.0.0.0:8080";
  89.  
  90. #[cfg(feature = "full-stack")]
  91. static K8S_WORKS: bool = true;
  92.  
  93. #[cfg(not(feature = "full-stack"))]
  94. #[derive(Clone)]
  95. struct Client;
  96.  
  97. #[cfg(not(feature = "full-stack"))]
  98. impl Client {
  99. async fn try_default() -> Result<Self, Box<dyn std::error::Error + Send + Sync>>{
  100. Err("This should not be running".into())
  101. }
  102. }
  103.  
  104. const CONNECTION_RETRY_DELAY: Duration = Duration::from_secs(2);
  105. const CONNECTION_TIMEOUT: Duration = Duration::from_secs(3);
  106. const CHANNEL_BUFFER_SIZE: usize = 32;
  107.  
  108. #[derive(Debug, Serialize, Deserialize)]
  109. struct MessagePayload {
  110. r#type: String,
  111. message: String,
  112. authcode: String,
  113. }
  114.  
  115. #[derive(Debug, Deserialize)]
  116. struct ConsoleMessage {
  117. data: String,
  118. #[serde(rename = "type")]
  119. message_type: String,
  120. authcode: String,
  121. }
  122.  
  123. #[derive(Debug, Deserialize, Serialize)]
  124. struct InnerData {
  125. data: String,
  126. #[serde(rename = "type")]
  127. message_type: String,
  128. authcode: String,
  129. }
  130.  
  131. #[derive(Debug, Serialize)]
  132. struct ResponseMessage {
  133. response: String,
  134. }
  135.  
  136. #[derive(Debug, Serialize, Deserialize)]
  137. struct List {
  138. list: ApiCalls,
  139. }
  140.  
  141. enum WebErrors {
  142. AuthError {
  143. message: String,
  144. status_code: StatusCode,
  145. }
  146. }
  147. impl IntoResponse for WebErrors {
  148. fn into_response(self) -> Response<Body> {
  149. let (status_code, message) = match self {
  150. WebErrors::AuthError { message, status_code } => (status_code, message),
  151. };
  152.  
  153. Response::builder()
  154. .status(status_code)
  155. .header("content-type", "application/json")
  156. .body(Body::from(serde_json::to_string(&json!({ "error": message })).unwrap()))
  157. .unwrap()
  158. }
  159. }
  160.  
  161.  
  162. // impl IntoResponse for WebErrors {
  163. // fn into_response(self) -> Response<axum::body::Body> {
  164. // match self {
  165. // WebErrors::AuthError { message, status_code } => {
  166. // Response::builder()
  167. // .status(status_code)
  168. // .header("content-type", "text/plain")
  169. // .body(message.into())
  170. // .unwrap()
  171. // }
  172. // }
  173. // }
  174. // }
  175.  
  176.  
  177. #[derive(Debug, Deserialize, Serialize, Clone)]
  178. struct IncomingMessage {
  179. message: String,
  180. #[serde(rename = "type")]
  181. message_type: String,
  182. authcode: String,
  183. }
  184.  
  185. #[derive(Debug, Deserialize, Serialize, Clone)]
  186. #[serde(tag = "kind", content = "data")]
  187. enum ApiCalls {
  188. None,
  189. Capabilities(Vec<String>),
  190. NodeList(Vec<String>),
  191. UserList(Vec<User>),
  192. // CreateUserData(CreateUserData),
  193. // LoginData(LoginData),
  194. UserData(LoginData),
  195. IncomingMessage(IncomingMessage),
  196. }
  197.  
  198. async fn attempt_connection() -> Result<TcpStream, Box<dyn std::error::Error + Send + Sync>> {
  199. timeout(CONNECTION_TIMEOUT, TcpStream::connect(TcpUrl)).await?.map_err(Into::into)
  200. }
  201.  
  202. async fn handle_server_data(
  203. data: &[u8],
  204. ws_tx: &broadcast::Sender<String>,
  205. ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
  206. if let Ok(text) = String::from_utf8(data.to_vec()) {
  207. println!("Raw message from server: {}", text);
  208.  
  209. if let Ok(outer_msg) = serde_json::from_str::<InnerData>(&text) {
  210. let inner_data_str = outer_msg.data.as_str();
  211. if let Ok(inner_data) = serde_json::from_str::<serde_json::Value>(inner_data_str) {
  212. if let Some(message_content) = inner_data["data"].as_str() {
  213. println!("Extracted message: {}", message_content);
  214. let _ = ws_tx.send(message_content.to_string());
  215. }
  216. } else {
  217. println!("Sending raw inner data: {}", inner_data_str);
  218. let _ = ws_tx.send(inner_data_str.to_string());
  219. }
  220. } else if let Ok(_) = serde_json::from_str::<MessagePayload>(&text) {
  221. todo!()
  222. } else if let Ok(_) = serde_json::from_str::<List>(&text) {
  223. todo!()
  224. } else {
  225. println!("Sending raw text: {}", text);
  226. let _ = ws_tx.send(text);
  227. }
  228. }
  229. Ok(())
  230. }
  231.  
  232. async fn handle_stream(
  233. rx: Arc<Mutex<mpsc::Receiver<Vec<u8>>>>,
  234. stream: &mut TcpStream,
  235. ws_tx: broadcast::Sender<String>,
  236. ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
  237. let (mut reader, mut writer) = stream.split();
  238. let mut buf = vec![0u8; 1024];
  239. let mut buf_reader = BufReader::new(reader);
  240. let mut line = String::new();
  241.  
  242. let msg = serde_json::to_string(
  243. &List {
  244. list: ApiCalls::Capabilities(vec!["all".to_string()])
  245. // {
  246. // //let items: Vec<ApiCalls> = vec!["all".to_string()].iter().map(|item| ApiCalls::Capabilities(item.to_string())).collect();
  247. // let items: Vec<ApiCalls> = ApiCalls::Capabilities(vec!["all".to_string()]);
  248. // items
  249. // },
  250. }
  251. )? + "\n";
  252.  
  253. writer.write_all(msg.as_bytes()).await?;
  254. match buf_reader.read_line(&mut line).await {
  255. Ok(0) => {
  256. println!("Error, possibly connection closed");
  257. }
  258. Ok(_) => {
  259. println!("Received line: {}", line.trim_end());
  260. }
  261. Err(e) => {
  262. println!("Error reading line: {:?}", e);
  263. return Err(e.into());
  264. }
  265. }
  266. reader = buf_reader.into_inner();
  267.  
  268. loop {
  269. let mut rx_guard = rx.lock().await;
  270. tokio::select! {
  271. result = reader.read(&mut buf) => match result {
  272. Ok(0) => return Ok(()),
  273. Ok(n) => handle_server_data(&buf[..n], &ws_tx).await?,
  274. Err(e) => return Err(e.into()),
  275. },
  276. result = rx_guard.recv() => if let Some(data) = result {
  277. writer.write_all(&data).await?;
  278. writer.write_all(b"\n").await?;
  279. writer.flush().await?;
  280. } else {
  281. return Ok(());
  282. }
  283. }
  284. }
  285. }
  286.  
  287. async fn connect_to_server(
  288. rx: Arc<Mutex<mpsc::Receiver<Vec<u8>>>>,
  289. ws_tx: broadcast::Sender<String>,
  290. ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
  291. loop {
  292. println!("→ trying to connect to {}…", TcpUrl);
  293. match timeout(CONNECTION_TIMEOUT, TcpStream::connect(TcpUrl)).await {
  294. Ok(Ok(mut stream)) => {
  295. println!("connected!");
  296. handle_stream(rx.clone(), &mut stream, ws_tx.clone()).await?
  297. }
  298. Ok(Err(e)) => {
  299. eprintln!("connect error: {}", e);
  300. tokio::time::sleep(CONNECTION_RETRY_DELAY).await;
  301. }
  302. Err(_) => {
  303. eprintln!("connect timed out after {:?}", CONNECTION_TIMEOUT);
  304. tokio::time::sleep(CONNECTION_RETRY_DELAY).await;
  305. }
  306. }
  307. }
  308. }
  309.  
  310. async fn try_initial_connection(
  311. ws_tx: broadcast::Sender<String>,
  312. tcp_tx: Arc<Mutex<mpsc::Sender<Vec<u8>>>>,
  313. ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
  314. match attempt_connection().await {
  315. Ok(mut stream) => {
  316. println!("Initial connection succeeded!");
  317.  
  318. let (temp_tx, temp_rx) = mpsc::channel::<Vec<u8>>(CHANNEL_BUFFER_SIZE);
  319.  
  320. {
  321. let mut guard = tcp_tx.lock().await;
  322. *guard = temp_tx;
  323. }
  324. handle_stream(Arc::new(Mutex::new(temp_rx)), &mut stream, ws_tx).await
  325. }
  326. Err(e) => {
  327. eprintln!("Initial connection failed: {}", e);
  328. Err(e)
  329. }
  330. }
  331. }
  332.  
  333. #[derive(Clone)]
  334. struct AppState {
  335. tcp_tx: Arc<Mutex<mpsc::Sender<Vec<u8>>>>,
  336. tcp_rx: Arc<Mutex<mpsc::Receiver<Vec<u8>>>>,
  337.  
  338. ws_tx: broadcast::Sender<String>,
  339.  
  340. base_path: String,
  341. client: Option<Client>,
  342. database: database::Postgres
  343. }
  344.  
  345. #[tokio::main]
  346. async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
  347. println!("Starting server...");
  348.  
  349. let db_user = std::env::var("POSTGRES_USER").unwrap_or("gameserver".to_string());
  350. let db_password = std::env::var("POSTGRES_PASSWORD").unwrap_or("gameserverpass".to_string());
  351. let db = std::env::var("POSTGRES_DB").unwrap_or("gameserver_db".to_string());
  352. let db_port = std::env::var("POSTGRES_PORT").unwrap_or("5432".to_string());
  353. let db_host = std::env::var("POSTGRES_HOST").unwrap_or("gameserver-postgres".to_string());
  354.  
  355. let conn = sqlx::postgres::PgPool::connect(&format!("postgres://{}:{}@{}:{}/{}", db_user, db_password, db_host, db_port, db)).await.unwrap();
  356. let database = database::Postgres::new(conn);
  357.  
  358. let verbose = std::env::var("VERBOSE").is_ok();
  359. let base_path = std::env::var("SITE_URL")
  360. .map(|s| {
  361. let mut s = s.trim().to_string();
  362. if !s.is_empty() {
  363. if !s.starts_with('/') { s.insert(0, '/'); }
  364. if s.ends_with('/') && s != "/" { s.pop(); }
  365. }
  366. s
  367. })
  368. .unwrap_or_default();
  369.  
  370. const ENABLE_K8S_CLIENT: bool = true;
  371. const ENABLE_INITIAL_CONNECTION: bool = false;
  372. const FORCE_REBUILD: bool = false;
  373. const BUILD_DOCKER_IMAGE: bool = true;
  374. const BUILD_DEPLOYMENT: bool = true;
  375.  
  376. let (ws_tx, _) = broadcast::channel::<String>(CHANNEL_BUFFER_SIZE);
  377. let (tcp_tx, tcp_rx) = mpsc::channel::<Vec<u8>>(CHANNEL_BUFFER_SIZE);
  378.  
  379. let mut client: Option<Client> = None;
  380. if ENABLE_K8S_CLIENT && K8S_WORKS {
  381. client = Some(Client::try_default().await?);
  382. }
  383.  
  384. let state = AppState {
  385. tcp_tx: Arc::new(Mutex::new(tcp_tx)),
  386. tcp_rx: Arc::new(Mutex::new(tcp_rx)),
  387. ws_tx: ws_tx.clone(),
  388. base_path: base_path.clone(),
  389. database,
  390. client,
  391. };
  392.  
  393. if ENABLE_INITIAL_CONNECTION && state.client.is_some() {
  394. println!("Trying initial connection...");
  395. if try_initial_connection(ws_tx.clone(), state.tcp_tx.clone()).await.is_err() || FORCE_REBUILD {
  396. eprintln!("Initial connection failed or force rebuild enabled");
  397. if BUILD_DOCKER_IMAGE {
  398. docker::build_docker_image().await?;
  399. }
  400. if BUILD_DEPLOYMENT {
  401. kubernetes::create_k8s_deployment(state.client.as_ref().unwrap()).await?;
  402. }
  403. }
  404. }
  405.  
  406. let bridge_rx = state.tcp_rx.clone();
  407. let bridge_tx = state.ws_tx.clone();
  408. tokio::spawn(async move {
  409. if let Err(e) = connect_to_server(bridge_rx, bridge_tx).await {
  410. eprintln!("Connection task failed: {}", e);
  411. }
  412. });
  413.  
  414. let cors = CorsLayer::new()
  415. .allow_origin(Any)
  416. .allow_methods([Method::GET, Method::POST])
  417. .allow_headers(Any);
  418.  
  419. let fallback_router = routes_static(state.clone().into());
  420.  
  421. let inner = Router::new()
  422. .route("/api/message", get(get_message))
  423. .route("/api/nodes", get(get_nodes))
  424. .route("/api/ws", get(ws_handler))
  425. .route("/api/users", get(users))
  426. .route("/api/send", post(receive_message))
  427. .route("/api/general", post(process_general))
  428. .route("/api/signin", post(sign_in))
  429. .route("/api/createuser", post(create_user))
  430. .route("/api/deleteuser", post(delete_user))
  431. .merge(fallback_router)
  432. .with_state(state.clone());
  433.  
  434.  
  435. let app = if base_path.is_empty() || base_path == "/" {
  436. inner.layer(cors)
  437. } else {
  438. Router::new().nest(&base_path, inner).layer(cors)
  439. };
  440.  
  441. let addr: SocketAddr = LocalUrl.parse().unwrap();
  442. println!("Listening on http://{}{}", addr, base_path);
  443.  
  444. // let addr: SocketAddr = "0.0.0.0:8081".parse().unwrap();
  445. // println!("Listening on http://{}{}", addr, base_path);
  446. // axum::serve(TcpListener::bind(addr).await?, app).await?;
  447.  
  448. // Updated server start:
  449. let listener = TcpListener::bind(addr).await?;
  450. axum::serve(listener, app.into_make_service())
  451. .await?;
  452.  
  453. Ok(())
  454. }
  455.  
  456. async fn create_user(
  457. State(state): State<AppState>,
  458. Json(request): Json<CreateUserData>
  459. ) -> impl IntoResponse {
  460. let result = state.database.create_user_in_db(request).await;
  461. StatusCode::CREATED
  462. }
  463.  
  464. async fn delete_user(
  465. State(state): State<AppState>,
  466. Json(request): Json<RemoveUserData>
  467. ) -> impl IntoResponse {
  468. let result = state.database.remove_user_in_db(request).await;
  469. StatusCode::CREATED
  470. }
  471.  
  472. async fn capabilities(
  473. State(_): State<AppState>,
  474. //Form(request): Form<LoginData>
  475. ) -> impl IntoResponse {
  476.  
  477. }
  478. fn routes_static(state: Arc<AppState>) -> Router<AppState> {
  479. let session_store = MemoryStore::default();
  480. let session_layer = SessionManagerLayer::new(session_store);
  481.  
  482. let backend = Backend::default();
  483. let auth_layer = AuthManagerLayerBuilder::new(backend, session_layer).build();
  484.  
  485. let public = Router::new()
  486. // .route("/login", get(login_page).post(sign_in))
  487. .route("/", get(handle_static_request))
  488. .route("/authenticate", get(authenticate_route))
  489. .route("/index.html", get(handle_static_request))
  490. .layer(AddExtensionLayer::new(state.clone()));
  491.  
  492. let protected = Router::new()
  493. .route("/{*wildcard}", get(handle_static_request))
  494. .layer(AddExtensionLayer::new(state.clone()))
  495. .route_layer(login_required!(Backend, login_url = "/index.html"));
  496.  
  497. public.merge(protected).route_layer(auth_layer)
  498. }
  499.  
  500.  
  501.  
  502. async fn ws_handler(
  503. ws: WebSocketUpgrade,
  504. State(state): State<AppState>,
  505. ) -> impl IntoResponse {
  506. ws.on_upgrade(move |socket| handle_socket(socket, state))
  507. }
  508.  
  509. async fn handle_socket(socket: WebSocket, state: AppState) {
  510. println!("WebSocket connected");
  511. let (mut sender, mut receiver) = socket.split();
  512.  
  513. let mut broadcast_rx = state.ws_tx.subscribe();
  514. tokio::spawn(async move {
  515. while let Ok(msg) = broadcast_rx.recv().await {
  516. println!("Forwarding: {}", msg);
  517. if sender.send(Message::Text(msg.into())).await.is_err() {
  518. break;
  519. }
  520. }
  521. });
  522.  
  523. while let Some(Ok(message)) = receiver.next().await {
  524. if let Message::Text(text) = message {
  525. println!("Got from client: {}", text);
  526. let payload = serde_json::from_str::<MessagePayload>(&text).unwrap_or(MessagePayload {
  527. r#type: "console".into(),
  528. message: text.to_string(),
  529. authcode: "0".into(),
  530. });
  531.  
  532. if let Ok(mut bytes) = serde_json::to_vec(&payload) {
  533. bytes.push(b'\n');
  534. let _ = state.tcp_tx.lock().await.send(bytes).await;
  535. }
  536. }
  537. }
  538.  
  539. println!("WebSocket disconnected");
  540. }
  541.  
  542. async fn process_general(
  543. State(state): State<AppState>,
  544. Json(res): Json<ApiCalls>,
  545. ) -> Result<Json<ResponseMessage>, (StatusCode, String)> {
  546. //let payload = res.standard_response().unwrap();
  547. if let ApiCalls::IncomingMessage(payload) = res {
  548. println!("Processing general message: {:?}", payload);
  549.  
  550. let json_payload = MessagePayload {
  551. r#type: payload.message_type.clone(),
  552. message: payload.message.clone(),
  553. authcode: payload.authcode.clone(),
  554. };
  555.  
  556. match serde_json::to_vec(&json_payload) {
  557. Ok(mut json_bytes) => {
  558. json_bytes.push(b'\n');
  559.  
  560. let tx = state.tcp_tx.clone();
  561. let tx_guard = tx.lock().await;
  562.  
  563. match tx_guard.send(json_bytes).await {
  564. Ok(_) => {
  565. println!("Successfully forwarded message to TCP server");
  566. Ok(Json(ResponseMessage {
  567. response: format!("Processed: {}", payload.message),
  568. }))
  569. },
  570. Err(e) => {
  571. eprintln!("Failed to send message to TCP channel: {}", e);
  572. Err((StatusCode::INTERNAL_SERVER_ERROR,
  573. "Failed to forward message to server".to_string()))
  574. }
  575. }
  576. }
  577. Err(e) => {
  578. eprintln!("Serialization error: {}", e);
  579. Err((StatusCode::BAD_REQUEST,
  580. "Invalid message format".to_string()))
  581. }
  582. }
  583. } else {
  584. Err((StatusCode::INTERNAL_SERVER_ERROR,
  585. "Failed to forward message to server".to_string()))
  586. }
  587. }
  588. async fn users(State(state): State<AppState>) -> Result<impl IntoResponse, StatusCode> {
  589. let users = state.database
  590. .fetch_all("users")
  591. .await
  592. .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
  593.  
  594. Ok(Json(List { list: ApiCalls::UserList(users) }))
  595. }
  596.  
  597. async fn get_nodes(State(state): State<AppState>) -> impl IntoResponse {
  598. if state.client.is_some() {
  599. match kubernetes::list_node_names(state.client.unwrap()).await {
  600. Ok(nodes) => {
  601. //let items: Json<List<Vec<ApiCalls>>> = Json(List { list: nodes.iter().map(|node| ApiCalls::NodeList(node.to_string())).collect() });
  602. // let items: Json<List> = Json(List { list: nodes.iter().map(|node| ApiCalls::NodeList(node.to_string())).collect() });
  603. let items: Json<List> = Json(List { list: ApiCalls::NodeList(nodes) });
  604. items
  605. },
  606. Err(err) => {
  607. eprintln!("Error listing nodes: {}", err);
  608. Json(List { list: ApiCalls::None })
  609. },
  610. }
  611. } else {
  612. Json(List { list: ApiCalls::None })
  613. }
  614. }
  615.  
  616.  
  617. async fn receive_message(
  618. State(state): State<AppState>,
  619. Json(res): Json<ApiCalls>,
  620. ) -> Result<Json<ResponseMessage>, (StatusCode, String)> {
  621. //let payload = res.standard_response().unwrap();
  622. if let ApiCalls::IncomingMessage(payload) = res {
  623. let json_payload = MessagePayload {
  624. r#type: payload.message_type.clone(),
  625. message: payload.message.clone(),
  626. authcode: payload.authcode.clone(),
  627. };
  628.  
  629. match serde_json::to_vec(&json_payload) {
  630. Ok(mut json_bytes) => {
  631. json_bytes.push(b'\n');
  632.  
  633. let tx_guard = state.tcp_tx.lock().await;
  634. match tx_guard.send(json_bytes).await {
  635. Ok(_) => Ok(Json(ResponseMessage {
  636. response: format!("Successfully sent message: {}", payload.message),
  637. })),
  638. Err(e) => {
  639. eprintln!("Failed to send message to TCP channel: {}", e);
  640. Err((StatusCode::INTERNAL_SERVER_ERROR,
  641. "Failed to forward message to server".to_string()))
  642. }
  643. }
  644. }
  645. Err(e) => {
  646. eprintln!("Serialization error: {}", e);
  647. Err((StatusCode::BAD_REQUEST,
  648. "Invalid message format".to_string()))
  649. }
  650. }
  651. } else {
  652. Err((StatusCode::BAD_REQUEST,
  653. "Invalid message format".to_string()))
  654. }
  655. }
  656.  
  657. pub type AuthSession = axum_login::AuthSession<Backend>;
  658.  
  659. #[derive(Deserialize, Serialize, Clone)]
  660. pub struct Claims {
  661. pub exp: usize,
  662. pub iat: usize,
  663. pub user: String,
  664. }
  665.  
  666.  
  667.  
  668. #[derive(Clone, Default)]
  669. pub struct Backend {
  670. pub users: HashMap<String, User>,
  671. }
  672.  
  673. #[async_trait]
  674. impl AuthnBackend for Backend {
  675. type User = User;
  676. type Credentials = String;
  677. type Error = Infallible;
  678.  
  679. async fn authenticate(&self, token: Self::Credentials) -> Result<Option<Self::User>, Self::Error> {
  680. let user = resolve_jwt(&token).ok().map(|data| User {
  681. username: data.claims.user,
  682. password_hash: None,
  683. });
  684. Ok(user)
  685. }
  686.  
  687. async fn get_user(&self, user_id: &String) -> Result<Option<Self::User>, Self::Error> {
  688. Ok(Some(User {
  689. username: user_id.clone(),
  690. password_hash: None,
  691. }))
  692. }
  693. }
  694.  
  695. // const SECRET: &str = "randomString";
  696.  
  697. fn resolve_jwt(token: &str) -> Result<TokenData<Claims>, StatusCode> {
  698. let secret = std::env::var("SECRET").unwrap_or_else(|_| {
  699. panic!("Need to specify a secret");
  700. });
  701. decode::<Claims>(
  702. token,
  703. &DecodingKey::from_secret(secret.as_bytes()),
  704. &Validation::default(),
  705. ).map_err(|_| StatusCode::UNAUTHORIZED)
  706. }
  707.  
  708. fn encode_token(user: String) -> Result<String, StatusCode> {
  709. let now = Utc::now();
  710. let exp = (now + chrono::Duration::hours(24)).timestamp() as usize;
  711. let iat = now.timestamp() as usize;
  712. let claims = Claims { exp, iat, user };
  713.  
  714. let secret = std::env::var("SECRET").unwrap_or_else(|_| {
  715. panic!("Need to specify a secret");
  716. });
  717.  
  718. encode(
  719. &Header::default(),
  720. &claims,
  721. &EncodingKey::from_secret(secret.as_bytes()),
  722. ).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
  723. }
  724.  
  725. #[derive(Debug, Deserialize, Serialize, Clone)]
  726. pub struct LoginData {
  727. pub user: String,
  728. pub password: String,
  729. }
  730.  
  731.  
  732.  
  733. #[derive(Deserialize)]
  734. pub struct JwtTokenForm {
  735. pub token: String,
  736. }
  737.  
  738.  
  739. impl AuthUser for User {
  740. type Id = String;
  741.  
  742. fn id(&self) -> Self::Id {
  743. self.username.clone()
  744. }
  745.  
  746. fn session_auth_hash(&self) -> &[u8] {
  747. self.username.as_bytes()
  748. }
  749. }
  750.  
  751. #[axum::debug_handler]
  752. pub async fn sign_in(
  753. State(state): State<AppState>,
  754. Form(request): Form<LoginData>
  755. ) -> Result<Json<ResponseMessage>, StatusCode> {
  756. let user = state.database.retrive_user(request.user.clone()).await.ok_or(StatusCode::UNAUTHORIZED)?;
  757.  
  758. let password_valid = verify_password(request.password, user.password_hash.unwrap())
  759. .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
  760.  
  761. if !password_valid {
  762. return Err(StatusCode::UNAUTHORIZED);
  763. }
  764.  
  765. let token = encode_token(user.username)?;
  766. Ok(Json(ResponseMessage { response: token }))
  767. }
  768.  
  769. // pub fn retrive_user(username: String) -> Option<User> {
  770. // if username == "testuser" {
  771. // let password_hash = bcrypt::hash("password123", bcrypt::DEFAULT_COST).ok();
  772. // Some(User {
  773. // username,
  774. // password_hash,
  775. // })
  776. // } else {
  777. // None
  778. // }
  779. // }
  780.  
  781. pub fn verify_password(password: String, hash: String) -> Result<bool, bcrypt::BcryptError> {
  782. bcrypt::verify(password, &hash)
  783. }
  784.  
  785. async fn serve_html_with_replacement(
  786. file: &str,
  787. state: &AppState,
  788. ) -> Result<Response<Body>, StatusCode> {
  789. let path = Path::new("src/svelte/dist").join(file);
  790.  
  791. if path.extension().and_then(|e| e.to_str()) == Some("html") {
  792. let html = tokio_fs::read_to_string(&path)
  793. .await
  794. .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
  795. let replaced = html.replace("[[SITE_URL]]", &state.base_path);
  796. return Ok(Html(replaced).into_response());
  797. }
  798.  
  799. let bytes = tokio_fs::read(&path)
  800. .await
  801. .map_err(|_| StatusCode::NOT_FOUND)?;
  802. let content_type = from_path(&path).first_or_octet_stream().to_string();
  803.  
  804. Ok(Response::builder()
  805. .header("Content-Type", content_type)
  806. .body(Body::from(bytes))
  807. .unwrap())
  808. }
  809.  
  810. async fn handle_static_request(
  811. Extension(state): Extension<Arc<AppState>>,
  812. req: Request<Body>,
  813. ) -> Result<Response<Body>, StatusCode> {
  814.  
  815. let path = req.uri().path();
  816.  
  817. let file = if path == "/" || path.is_empty() {
  818. "index.html"
  819. } else {
  820. &path[1..]
  821. };
  822.  
  823. match serve_html_with_replacement(file, &state).await {
  824. Ok(res) => Ok(res),
  825. Err(status) => Ok(Response::builder()
  826. .status(status)
  827. .header("content-type", "text/plain")
  828. .body(format!("Error serving `{}`", file).into())
  829. .unwrap()),
  830. }
  831. }
  832.  
  833. #[derive(Deserialize)]
  834. pub struct AuthenticateParams {
  835. next: String,
  836. jwk: String,
  837. }
  838.  
  839. async fn authenticate_route(
  840. State(_state): State<AppState>,
  841. Query(params): Query<AuthenticateParams>,
  842. mut auth_session: AuthSession,
  843. ) -> impl IntoResponse {
  844. match resolve_jwt(&params.jwk) {
  845. Ok(token_data) => {
  846. let user = User {
  847. username: token_data.claims.user,
  848. password_hash: None,
  849. };
  850.  
  851. if let Err(e) = auth_session.login(&user).await {
  852. eprintln!("Failed to log in user: {:?}", e);
  853. return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to log in").into_response();
  854. }
  855.  
  856. if params.next.starts_with('/') {
  857. if let Ok(uri) = params.next.parse::<Uri>() {
  858. return Redirect::to(&uri.to_string()).into_response();
  859. } else {
  860. return (StatusCode::BAD_REQUEST, "Invalid next parameter: unable to parse URI").into_response();
  861. }
  862. } else {
  863. return (StatusCode::BAD_REQUEST, "Invalid next parameter: must start with '/'").into_response();
  864. }
  865. }
  866. Err(_) => {
  867. (StatusCode::UNAUTHORIZED, "Invalid token").into_response()
  868. }
  869. }
  870. }
  871.  
  872.  
  873.  
  874.  
  875. async fn get_message(
  876. State(state): State<AppState>,
  877. ) -> Result<Json<MessagePayload>, (StatusCode, String)> {
  878. let request = MessagePayload {
  879. r#type: "request".to_string(),
  880. message: "get_message".to_string(),
  881. authcode: "0".to_owned(),
  882. };
  883.  
  884. let mut json_bytes = match serde_json::to_vec(&request) {
  885. Ok(mut v) => { v.push(b'\n'); v }
  886. Err(e) => {
  887. eprintln!("Serialization error: {}", e);
  888. return Err((
  889. StatusCode::INTERNAL_SERVER_ERROR,
  890. "Failed to serialize request".into(),
  891. ));
  892. }
  893. };
  894.  
  895. let tx_guard = state.tcp_tx.lock().await;
  896. if let Err(e) = tx_guard.send(json_bytes).await {
  897. eprintln!("Failed to send request: {}", e);
  898. return Err((
  899. StatusCode::INTERNAL_SERVER_ERROR,
  900. "Failed to send request to server".into(),
  901. ));
  902. }
  903. drop(tx_guard);
  904.  
  905. let mut rx_guard = state.tcp_rx.lock().await;
  906. match rx_guard.recv().await {
  907. Some(response_bytes) => {
  908. match serde_json::from_slice::<MessagePayload>(&response_bytes) {
  909. Ok(msg) => Ok(Json(msg)),
  910. Err(e) => {
  911. eprintln!("Deserialization error: {}", e);
  912. Err((
  913. StatusCode::INTERNAL_SERVER_ERROR,
  914. "Failed to parse server response".into(),
  915. ))
  916. }
  917. }
  918. }
  919. None => {
  920. eprintln!("No response received");
  921. Err((
  922. StatusCode::INTERNAL_SERVER_ERROR,
  923. "No response from server".into(),
  924. ))
  925. }
  926. }
  927. }
  928.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement