From 531cbb4458534b5c17d1e87260e085c805b21298 Mon Sep 17 00:00:00 2001 From: Tony Rewin Date: Fri, 6 Oct 2023 17:57:54 +0300 Subject: [PATCH] disconnect-handler --- src/main.rs | 57 ++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/src/main.rs b/src/main.rs index fcfa937..ad2c2f9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,19 +5,27 @@ use std::env; use futures::StreamExt; use tokio::sync::broadcast; use actix_web::error::{ErrorUnauthorized, ErrorInternalServerError as ServerError}; +use std::sync::{Arc, Mutex}; +use tokio::task::JoinHandle; mod data; -async fn sse_handler( +#[derive(Clone)] +struct AppState { + tasks: Arc>>>, + redis: Client, +} + +async fn connect_handler( token: web::Path, - redis: web::Data, + state: web::Data, ) -> Result { let listener_id = data::get_auth_id(&token).await.map_err(|e| { eprintln!("TOKEN check failed: {}", e); ErrorUnauthorized("Unauthorized") })?; - let mut con = redis.get_async_connection().await.map_err(|e| { + let mut con = state.redis.get_async_connection().await.map_err(|e| { eprintln!("Failed to get async connection: {}", e); ServerError("Internal Server Error") })?; @@ -33,8 +41,8 @@ async fn sse_handler( })?; let (tx, mut rx) = broadcast::channel(100); - let _handle = tokio::spawn(async move { - let conn = redis.get_async_connection().await.unwrap(); + let handle = tokio::spawn(async move { + let conn = state.redis.get_async_connection().await.unwrap(); let mut pubsub = conn.into_pubsub(); pubsub.subscribe("new_follower").await.unwrap(); @@ -57,6 +65,10 @@ async fn sse_handler( }; } }); + state.tasks + .lock() + .unwrap() + .insert(format!("{}", listener_id.clone()), handle); let server_event = rx.recv().await.map_err(|e| { eprintln!("Failed to receive server event: {}", e); @@ -70,17 +82,44 @@ async fn sse_handler( .streaming(server_event_stream)) } + +async fn disconnect_handler( + token: web::Path, + state: web::Data, +) -> Result { + let listener_id = data::get_auth_id(&token).await.map_err(|e| { + eprintln!("TOKEN check failed: {}", e); + ErrorUnauthorized("Unauthorized") + })?; + if let Some(handle) = state.tasks.lock().unwrap().remove(&format!("{}", listener_id)) { + handle.abort(); + let mut con = state.redis.get_async_connection().await.map_err(|e| { + eprintln!("Failed to get async connection: {}", e); + ServerError("Internal Server Error") + })?; + con.srem::<&str, &i32, usize>("authors-online", &listener_id).await.map_err(|e| { + eprintln!("Failed to remove author from online list: {}", e); + ServerError("Internal Server Error") + })?; + } + Ok(HttpResponse::Ok().finish()) +} + #[actix_web::main] async fn main() -> std::io::Result<()> { let redis_url = env::var("REDIS_URL").unwrap_or_else(|_| String::from("redis://127.0.0.1/")); let client = redis::Client::open(redis_url.clone()).unwrap(); - + let tasks = Arc::new(Mutex::new(HashMap::new())); + let state = AppState { + tasks: tasks.clone(), + redis: client.clone(), + }; println!("Connecting to Redis: {}", redis_url); HttpServer::new(move || { App::new() - .app_data(web::Data::new(client.clone())) - .route("/connect", web::get().to(sse_handler)) - .route("/disconnect", web::get().to(sse_handler)) + .app_data(web::Data::new(state.clone())) + .route("/connect", web::get().to(connect_handler)) + .route("/disconnect", web::get().to(disconnect_handler)) }) .bind("127.0.0.1:8080")? .run()