> ## Documentation Index
> Fetch the complete documentation index at: https://docs.shuttle.dev/llms.txt
> Use this file to discover all available pages before exploring further.

# WebSockets

> Learn how websockets can upgrade your web service by providing live update functionality, using Axum.

## Description

This example shows how to use a WebSocket to show the live status of the Shuttle API on a web page.

There are a few routes available:

* `/` - the homepage route where you can find the `index.html` page.
* `/websocket` - the route that handles websockets.

You can clone the example below by running the following (you'll need `shuttle` CLI installed):

```bash theme={null}
shuttle init --from shuttle-hq/shuttle-examples --subfolder axum/websocket
```

## Code

<CodeGroup>
  ```rust src/main.rs theme={null}
  use std::{sync::Arc, time::Duration};

  use axum::{
      extract::{
          ws::{Message, WebSocket},
          WebSocketUpgrade,
      },
      response::IntoResponse,
      routing::get,
      Extension, Router,
  };
  use chrono::{DateTime, Utc};
  use futures::{SinkExt, StreamExt};
  use serde::Serialize;
  use shuttle_axum::ShuttleAxum;
  use tokio::{
      sync::{watch, Mutex},
      time::sleep,
  };
  use tower_http::services::ServeDir;

  struct State {
      clients_count: usize,
      rx: watch::Receiver<Message>,
  }

  const PAUSE_SECS: u64 = 15;
  const STATUS_URI: &str = "https://api.shuttle.dev/.healthz";

  #[derive(Serialize)]
  struct Response {
      clients_count: usize,
      #[serde(rename = "dateTime")]
      date_time: DateTime<Utc>,
      is_up: bool,
  }

  #[shuttle_runtime::main]
  async fn main() -> ShuttleAxum {
      let (tx, rx) = watch::channel(Message::Text("{}".into()));

      let state = Arc::new(Mutex::new(State {
          clients_count: 0,
          rx,
      }));

      // Spawn a thread to continually check the status of the api
      let state_send = state.clone();
      tokio::spawn(async move {
          let duration = Duration::from_secs(PAUSE_SECS);

          loop {
              let is_up = reqwest::get(STATUS_URI)
                  .await
                  .is_ok_and(|r| r.status().is_success());

              let response = Response {
                  clients_count: state_send.lock().await.clients_count,
                  date_time: Utc::now(),
                  is_up,
              };
              let msg = serde_json::to_string(&response).unwrap();

              if tx.send(Message::Text(msg.into())).is_err() {
                  break;
              }

              sleep(duration).await;
          }
      });

      let router = Router::new()
          .route("/websocket", get(websocket_handler))
          .fallback_service(ServeDir::new("static"))
          .layer(Extension(state));

      Ok(router.into())
  }

  async fn websocket_handler(
      ws: WebSocketUpgrade,
      Extension(state): Extension<Arc<Mutex<State>>>,
  ) -> impl IntoResponse {
      ws.on_upgrade(|socket| websocket(socket, state))
  }

  async fn websocket(stream: WebSocket, state: Arc<Mutex<State>>) {
      // By splitting we can send and receive at the same time.
      let (mut sender, mut receiver) = stream.split();

      let mut rx = {
          let mut state = state.lock().await;
          state.clients_count += 1;
          state.rx.clone()
      };

      // This task will receive watch messages and forward it to this connected client.
      let mut send_task = tokio::spawn(async move {
          while let Ok(()) = rx.changed().await {
              let msg = rx.borrow().clone();

              if sender.send(msg).await.is_err() {
                  break;
              }
          }
      });

      // This task will receive messages from this client.
      let mut recv_task = tokio::spawn(async move {
          while let Some(Ok(Message::Text(text))) = receiver.next().await {
              println!("this example does not read any messages, but got: {text}");
          }
      });

      // If any one of the tasks exit, abort the other.
      tokio::select! {
          _ = (&mut send_task) => recv_task.abort(),
          _ = (&mut recv_task) => send_task.abort(),
      };

      // This client disconnected
      state.lock().await.clients_count -= 1;
  }
  ```

  ```html static/index.html theme={null}
  <!DOCTYPE html>
  <html lang="en" class="bg-gray-600">
    <head>
      <meta charset="utf-8" />
      <meta http-equiv="X-UA-Compatible" content="IE=edge" />
      <meta name="viewport" content="width=device-width, initial-scale=1" />
      <title>Websocket status page</title>
      <script src="https://cdn.tailwindcss.com"></script>
    </head>
    <body
      class="flex justify-around items-center h-screen w-screen m-0 text-center"
    >
      <div
        class="flex max-w-sm flex-col overflow-hidden rounded-lg transition blur-md"
      >
        <div class="flex-shrink-0 bg-gray-800 text-slate-50 p-5">
          Current API status
        </div>
        <div
          id="is_ok"
          class="flex flex-1 flex-col justify-between p-6 bg-gray-500 text-xl font-bold uppercase"
        ></div>
      </div>
      <div
        class="flex max-w-sm flex-col overflow-hidden rounded-lg transition blur-md"
      >
        <div class="flex-shrink-0 bg-gray-800 text-slate-50 p-5">
          Last check time
        </div>
        <div
          id="dateTime"
          class="flex flex-1 flex-col justify-between p-6 bg-gray-500 text-xl font-bold"
        ></div>
      </div>
      <div
        class="flex max-w-sm flex-col overflow-hidden rounded-lg transition blur-md"
      >
        <div class="flex-shrink-0 bg-gray-800 text-slate-50 p-5">
          Clients watching
        </div>
        <div
          id="clients_count"
          class="flex flex-1 flex-col justify-between p-6 bg-gray-500 text-xl font-bold"
        ></div>
      </div>

      <button
        id="open"
        class="absolute text-2xl bg-gray-800 text-slate-50 p-2 rounded shadow-lg shadow-slate-800 hover:shadow-md scale-105 hover:scale-100 transition"
      >
        Open connection
      </button>

      <script type="text/javascript">
        const is_ok = document.querySelector('#is_ok');
        const dateTime = document.querySelector('#dateTime');
        const clients_count = document.querySelector('#clients_count');
        const button = document.querySelector('button');

        function track() {
          const proto = location.protocol.startsWith('https') ? 'wss' : 'ws';
          const websocket = new WebSocket(
            `${proto}://${window.location.host}/websocket`,
          );

          websocket.onopen = () => {
            console.log('connection opened');
            document
              .querySelectorAll('body > div')
              .forEach((e) => e.classList.remove('blur-md'));
            document.querySelector('body > button').classList.add('hidden');
          };

          websocket.onclose = () => {
            console.log('connection closed');
            document
              .querySelectorAll('body > div')
              .forEach((e) => e.classList.add('blur-md'));
            document.querySelector('body > button').classList.remove('hidden');
          };

          websocket.onmessage = (e) => {
            const response = JSON.parse(e.data);

            if (response.is_up) {
              is_ok.textContent = 'up';
              is_ok.classList.add('text-green-600');
              is_ok.classList.remove('text-rose-700');
            } else {
              is_ok.textContent = 'down';
              is_ok.classList.add('text-rose-700');
              is_ok.classList.remove('text-green-600');
            }

            dateTime.textContent = new Date(response.dateTime).toLocaleString();
            clients_count.textContent = response.clients_count;
          };
        }

        track();
        button.addEventListener('click', track);
      </script>
    </body>
  </html>
  ```

  ```toml Cargo.toml theme={null}
  [package]
  [package]
  name = "websocket"
  version = "0.1.0"
  edition = "2021"

  [dependencies]
  axum = { version = "0.8", features = ["ws"] }
  chrono = { version = "0.4", features = ["serde"] }
  futures = "0.3.28"
  reqwest = "0.12"
  serde = { version = "1", features = ["derive"] }
  serde_json = "1"
  shuttle-axum = "0.57.0"
  shuttle-runtime = "0.57.0"
  tokio = "1"
  tower-http = { version = "0.6", features = ["fs"] }
  ```

  ```toml Shuttle.toml theme={null}
  [build]
  assets = [
      "static",
  ]
  ```
</CodeGroup>

## Usage

Once you've cloned the example, launch it locally using `shuttle run` and then go to `http://localhost:8000`. You should be able to see a status page and if you go to your Inspect/Chrome Devtools (depending on what browser you're using), if you go to the Network tab you'll see that your browser received a HTTP status code of 101.

***

<Snippet file="other-frameworks.mdx" />

<Snippet file="check-examples.mdx" />
