databeam/
sql.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
//! Shared SQL types
use serde::{Serialize, Deserialize};

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct DatabaseOpts {
    /// The type of the database
    pub r#type: Option<String>,
    /// The host to connect to (`None` for SQLite)
    pub host: Option<String>,
    /// The database user's name
    pub user: String,
    /// The database user's password
    pub pass: String,
    /// The name of the database
    #[serde(default = "default_database_name")]
    pub name: String,
}

fn default_database_name() -> String {
    "main".to_string()
}

impl Default for DatabaseOpts {
    fn default() -> Self {
        Self {
            r#type: Some("sqlite".to_string()),
            host: None,
            user: String::new(),
            pass: String::new(),
            name: default_database_name(),
        }
    }
}

// ...
#[derive(Clone)]
pub struct Database<T> {
    pub client: T,
    pub r#type: String,
}

// ...
#[cfg(feature = "mysql")]
/// Create a new "mysql" database
pub async fn create_db(options: DatabaseOpts) -> Database<sqlx::MySqlPool> {
    // mysql
    let opts = sqlx::mysql::MySqlPoolOptions::new()
        .max_connections(25)
        .acquire_timeout(std::time::Duration::from_millis(2000))
        .idle_timeout(Some(std::time::Duration::from_secs(60 * 5)));

    let client = opts
        .connect(&format!(
            "mysql://{}:{}@{}/{}",
            options.user,
            options.pass,
            if options.host.is_some() {
                options.host.unwrap()
            } else {
                "localhost".to_string()
            },
            options.name
        ))
        .await;

    if client.is_err() {
        panic!("failed to connect to database: {}", client.err().unwrap());
    }

    return Database {
        client: client.unwrap(),
        r#type: String::from("mysql"),
    };
}

#[cfg(feature = "postgres")]
/// Create a new "postgres" database
pub async fn create_db(options: DatabaseOpts) -> Database<sqlx::PgPool> {
    // postgres
    let opts = sqlx::postgres::PgPoolOptions::new()
        .max_connections(25)
        .acquire_timeout(std::time::Duration::from_millis(2000))
        .idle_timeout(Some(std::time::Duration::from_secs(60 * 5)));

    let client = opts
        .connect(&format!(
            "postgres://{}:{}@{}/{}",
            options.user,
            options.pass,
            if options.host.is_some() {
                options.host.unwrap()
            } else {
                "localhost".to_string()
            },
            options.name
        ))
        .await;

    if client.is_err() {
        panic!("failed to connect to database: {}", client.err().unwrap());
    }

    return Database {
        client: client.unwrap(),
        r#type: String::from("postgres"),
    };
}

#[cfg(feature = "sqlite")]
/// Create a new "sqlite" database (named "main.db")
pub async fn create_db(options: DatabaseOpts) -> Database<sqlx::SqlitePool> {
    // sqlite
    let client = sqlx::SqlitePool::connect(&format!("sqlite://{}.db", options.name)).await;

    if client.is_err() {
        panic!("Failed to connect to database!");
    }

    return Database {
        client: client.unwrap(),
        r#type: String::from("sqlite"),
    };
}