rainbeam/
main.rs

1#![doc = include_str!("../../../README.md")]
2#![doc(issue_tracker_base_url = "https://github.com/swmff/rainbeam/issues")]
3#![doc(html_favicon_url = "https://rainbeam.net/static/favicon.svg")]
4#![doc(html_logo_url = "https://rainbeam.net/static/favicon.svg")]
5use axum::routing::{get, get_service};
6use axum::Router;
7
8use tower_http::trace::{self, TraceLayer};
9use tracing::{info, Level};
10
11use authbeam::{api as AuthApi, Database as AuthDatabase};
12use databeam::config::Config as DataConf;
13use rainbeam_shared::fs;
14use pathbufd::{PathBufD, pathd};
15
16pub use rb::database;
17pub use rb::config;
18pub use rb::model;
19pub use rb::routing;
20
21use std::env::var;
22
23// mimalloc
24#[cfg(feature = "mimalloc")]
25use mimalloc::MiMalloc;
26
27#[cfg(feature = "mimalloc")]
28#[global_allocator]
29static GLOBAL: MiMalloc = MiMalloc;
30
31/// Main server process
32#[tokio::main(flavor = "multi_thread")]
33pub async fn main() {
34    let mut config = config::Config::get_config();
35
36    let here = PathBufD::current();
37    let static_dir = here.join(".config").join("static");
38    let well_known_dir = here.join(".config").join(".well-known");
39    config.static_dir = static_dir.clone();
40
41    tracing_subscriber::fmt()
42        .with_target(false)
43        .compact()
44        .init();
45
46    // make sure media dir is created
47    // TODO: implement `.is_empty()` on `PathBufD`
48    if !config.media_dir.to_string().is_empty() {
49        fs::mkdir(&config.media_dir).expect("failed to create media dir");
50        fs::mkdir(pathd!("{}/avatars", config.media_dir)).expect("failed to create avatars dir");
51        fs::mkdir(pathd!("{}/banners", config.media_dir)).expect("failed to create banners dir");
52        fs::mkdir(pathd!("{}/carpgraph", config.media_dir))
53            .expect("failed to create carpgraph dir");
54    }
55
56    // load plugins
57    let plugins = rainbeam_plugins::config::get_plugins();
58
59    for plugin in plugins {
60        match rainbeam_plugins::run(plugin) {
61            Ok(_) => todo!(),
62            Err(e) => panic!("plugin error: {e}"),
63        };
64    }
65
66    // create databases
67    let auth_database = AuthDatabase::new(
68        DataConf::get_config().connection, // pull connection config from config file
69        authbeam::ServerOptions {
70            captcha: config.captcha.clone(),
71            registration_enabled: config.registration_enabled,
72            real_ip_header: config.real_ip_header.clone(),
73            static_dir: config.static_dir.clone(),
74            media_dir: config.media_dir.clone(),
75            host: config.host.clone(),
76            snowflake_server_id: config.snowflake_server_id.clone(),
77            blocked_hosts: config.blocked_hosts.clone(),
78        },
79    )
80    .await;
81    auth_database.init().await;
82
83    let database = database::Database::new(
84        DataConf::get_config().connection,
85        auth_database.clone(),
86        config.clone(),
87    )
88    .await;
89    database.init().await;
90
91    // create app
92    let app = Router::new()
93        // api
94        .nest_service("/api/v0/auth", AuthApi::routes(auth_database.clone()))
95        .nest("/api/v0/util", routing::api::util::routes(database.clone()))
96        .nest("/api/v1", routing::api::routes(database.clone()))
97        // pages
98        .merge(routing::pages::routes(database.clone()).await)
99        // ...
100        .nest_service(
101            "/.well-known",
102            get_service(tower_http::services::ServeDir::new(&well_known_dir)),
103        )
104        .nest_service(
105            "/static",
106            get_service(tower_http::services::ServeDir::new(&static_dir)),
107        )
108        .nest_service(
109            "/manifest.json",
110            get_service(tower_http::services::ServeFile::new(format!(
111                "{static_dir}/manifest.json"
112            ))),
113        )
114        .fallback_service(get(routing::pages::not_found).with_state(database.clone()))
115        .layer(axum::extract::DefaultBodyLimit::max(
116            var("MAX_BODY_LIMIT")
117                .unwrap_or("8388608".to_string())
118                .parse::<usize>()
119                .unwrap(),
120        ))
121        .layer(
122            TraceLayer::new_for_http()
123                .make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))
124                .on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
125        );
126
127    let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", config.port))
128        .await
129        .unwrap();
130
131    info!("🌈 Starting server at: http://localhost:{}!", config.port);
132    axum::serve(listener, app).await.unwrap();
133}