1use std::{net::IpAddr, sync::Arc, time::Duration};
8
9use governor::{RateLimiter, clock::QuantaClock, state::keyed::DashMapStateStore};
10use mas_config::RateLimitingConfig;
11use mas_data_model::{User, UserEmailAuthentication};
12use ulid::Ulid;
13
14#[derive(Debug, Clone, thiserror::Error)]
15pub enum AccountRecoveryLimitedError {
16    #[error("Too many account recovery requests for requester {0}")]
17    Requester(RequesterFingerprint),
18
19    #[error("Too many account recovery requests for e-mail {0}")]
20    Email(String),
21}
22
23#[derive(Debug, Clone, Copy, thiserror::Error)]
24pub enum PasswordCheckLimitedError {
25    #[error("Too many password checks for requester {0}")]
26    Requester(RequesterFingerprint),
27
28    #[error("Too many password checks for user {0}")]
29    User(Ulid),
30}
31
32#[derive(Debug, Clone, thiserror::Error)]
33pub enum RegistrationLimitedError {
34    #[error("Too many account registration requests for requester {0}")]
35    Requester(RequesterFingerprint),
36}
37
38#[derive(Debug, Clone, thiserror::Error)]
39pub enum EmailAuthenticationLimitedError {
40    #[error("Too many email authentication requests for requester {0}")]
41    Requester(RequesterFingerprint),
42
43    #[error("Too many email authentication requests for authentication session {0}")]
44    Authentication(Ulid),
45
46    #[error("Too many email authentication requests for email {0}")]
47    Email(String),
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct RequesterFingerprint {
53    ip: Option<IpAddr>,
54}
55
56impl std::fmt::Display for RequesterFingerprint {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        if let Some(ip) = self.ip {
59            write!(f, "{ip}")
60        } else {
61            write!(f, "(NO CLIENT IP)")
62        }
63    }
64}
65
66impl RequesterFingerprint {
67    pub const EMPTY: Self = Self { ip: None };
70
71    #[must_use]
73    pub const fn new(ip: IpAddr) -> Self {
74        Self { ip: Some(ip) }
75    }
76}
77
78#[derive(Debug, Clone)]
80pub struct Limiter {
81    inner: Arc<LimiterInner>,
82}
83
84type KeyedRateLimiter<K> = RateLimiter<K, DashMapStateStore<K>, QuantaClock>;
85
86#[derive(Debug)]
87struct LimiterInner {
88    account_recovery_per_requester: KeyedRateLimiter<RequesterFingerprint>,
89    account_recovery_per_email: KeyedRateLimiter<String>,
90    password_check_for_requester: KeyedRateLimiter<RequesterFingerprint>,
91    password_check_for_user: KeyedRateLimiter<Ulid>,
92    registration_per_requester: KeyedRateLimiter<RequesterFingerprint>,
93    email_authentication_per_requester: KeyedRateLimiter<RequesterFingerprint>,
94    email_authentication_per_email: KeyedRateLimiter<String>,
95    email_authentication_emails_per_session: KeyedRateLimiter<Ulid>,
96    email_authentication_attempt_per_session: KeyedRateLimiter<Ulid>,
97}
98
99impl LimiterInner {
100    fn new(config: &RateLimitingConfig) -> Option<Self> {
101        Some(Self {
102            account_recovery_per_requester: RateLimiter::keyed(
103                config.account_recovery.per_ip.to_quota()?,
104            ),
105            account_recovery_per_email: RateLimiter::keyed(
106                config.account_recovery.per_address.to_quota()?,
107            ),
108            password_check_for_requester: RateLimiter::keyed(config.login.per_ip.to_quota()?),
109            password_check_for_user: RateLimiter::keyed(config.login.per_account.to_quota()?),
110            registration_per_requester: RateLimiter::keyed(config.registration.to_quota()?),
111            email_authentication_per_email: RateLimiter::keyed(
112                config.email_authentication.per_address.to_quota()?,
113            ),
114            email_authentication_per_requester: RateLimiter::keyed(
115                config.email_authentication.per_ip.to_quota()?,
116            ),
117            email_authentication_emails_per_session: RateLimiter::keyed(
118                config.email_authentication.emails_per_session.to_quota()?,
119            ),
120            email_authentication_attempt_per_session: RateLimiter::keyed(
121                config.email_authentication.attempt_per_session.to_quota()?,
122            ),
123        })
124    }
125}
126
127impl Limiter {
128    #[must_use]
133    pub fn new(config: &RateLimitingConfig) -> Option<Self> {
134        Some(Self {
135            inner: Arc::new(LimiterInner::new(config)?),
136        })
137    }
138
139    pub fn start(&self) {
144        let this = self.clone();
146        tokio::spawn(async move {
147            let mut interval = tokio::time::interval(Duration::from_secs(60));
149            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
150
151            loop {
152                this.inner.account_recovery_per_email.retain_recent();
154                this.inner.account_recovery_per_requester.retain_recent();
155                this.inner.password_check_for_requester.retain_recent();
156                this.inner.password_check_for_user.retain_recent();
157                this.inner.registration_per_requester.retain_recent();
158                this.inner.email_authentication_per_email.retain_recent();
159                this.inner
160                    .email_authentication_per_requester
161                    .retain_recent();
162                this.inner
163                    .email_authentication_emails_per_session
164                    .retain_recent();
165                this.inner
166                    .email_authentication_attempt_per_session
167                    .retain_recent();
168
169                interval.tick().await;
170            }
171        });
172    }
173
174    pub fn check_account_recovery(
180        &self,
181        requester: RequesterFingerprint,
182        email_address: &str,
183    ) -> Result<(), AccountRecoveryLimitedError> {
184        self.inner
185            .account_recovery_per_requester
186            .check_key(&requester)
187            .map_err(|_| AccountRecoveryLimitedError::Requester(requester))?;
188
189        let canonical_email = email_address.to_lowercase();
193        self.inner
194            .account_recovery_per_email
195            .check_key(&canonical_email)
196            .map_err(|_| AccountRecoveryLimitedError::Email(canonical_email))?;
197
198        Ok(())
199    }
200
201    pub fn check_password(
207        &self,
208        key: RequesterFingerprint,
209        user: &User,
210    ) -> Result<(), PasswordCheckLimitedError> {
211        self.inner
212            .password_check_for_requester
213            .check_key(&key)
214            .map_err(|_| PasswordCheckLimitedError::Requester(key))?;
215
216        self.inner
217            .password_check_for_user
218            .check_key(&user.id)
219            .map_err(|_| PasswordCheckLimitedError::User(user.id))?;
220
221        Ok(())
222    }
223
224    pub fn check_registration(
230        &self,
231        requester: RequesterFingerprint,
232    ) -> Result<(), RegistrationLimitedError> {
233        self.inner
234            .registration_per_requester
235            .check_key(&requester)
236            .map_err(|_| RegistrationLimitedError::Requester(requester))?;
237
238        Ok(())
239    }
240
241    pub fn check_email_authentication_email(
248        &self,
249        requester: RequesterFingerprint,
250        email: &str,
251    ) -> Result<(), EmailAuthenticationLimitedError> {
252        self.inner
253            .email_authentication_per_requester
254            .check_key(&requester)
255            .map_err(|_| EmailAuthenticationLimitedError::Requester(requester))?;
256
257        let canonical_email = email.to_lowercase();
261        self.inner
262            .email_authentication_per_email
263            .check_key(&canonical_email)
264            .map_err(|_| EmailAuthenticationLimitedError::Email(email.to_owned()))?;
265        Ok(())
266    }
267
268    pub fn check_email_authentication_attempt(
274        &self,
275        authentication: &UserEmailAuthentication,
276    ) -> Result<(), EmailAuthenticationLimitedError> {
277        self.inner
278            .email_authentication_attempt_per_session
279            .check_key(&authentication.id)
280            .map_err(|_| EmailAuthenticationLimitedError::Authentication(authentication.id))
281    }
282
283    pub fn check_email_authentication_send_code(
290        &self,
291        requester: RequesterFingerprint,
292        authentication: &UserEmailAuthentication,
293    ) -> Result<(), EmailAuthenticationLimitedError> {
294        self.check_email_authentication_email(requester, &authentication.email)?;
295        self.inner
296            .email_authentication_emails_per_session
297            .check_key(&authentication.id)
298            .map_err(|_| EmailAuthenticationLimitedError::Authentication(authentication.id))
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use mas_data_model::User;
305    use mas_storage::{Clock, clock::MockClock};
306    use rand::SeedableRng;
307
308    use super::*;
309
310    #[test]
311    fn test_password_check_limiter() {
312        let now = MockClock::default().now();
313        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
314
315        let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();
316
317        let requesters: [_; 768] = (0..=255)
319            .flat_map(|a| (0..3).map(move |b| RequesterFingerprint::new([a, a, b, b].into())))
320            .collect::<Vec<_>>()
321            .try_into()
322            .unwrap();
323
324        let alice = User {
325            id: Ulid::from_datetime_with_source(now.into(), &mut rng),
326            username: "alice".to_owned(),
327            sub: "123-456".to_owned(),
328            created_at: now,
329            locked_at: None,
330            deactivated_at: None,
331            can_request_admin: false,
332        };
333
334        let bob = User {
335            id: Ulid::from_datetime_with_source(now.into(), &mut rng),
336            username: "bob".to_owned(),
337            sub: "123-456".to_owned(),
338            created_at: now,
339            locked_at: None,
340            deactivated_at: None,
341            can_request_admin: false,
342        };
343
344        assert!(limiter.check_password(requesters[0], &alice).is_ok());
346        assert!(limiter.check_password(requesters[0], &alice).is_ok());
347        assert!(limiter.check_password(requesters[0], &alice).is_ok());
348
349        assert!(limiter.check_password(requesters[0], &alice).is_err());
351        assert!(limiter.check_password(requesters[0], &bob).is_err());
353
354        assert!(limiter.check_password(requesters[1], &alice).is_ok());
356
357        for requester in requesters.iter().skip(2).take(598) {
360            assert!(limiter.check_password(*requester, &alice).is_ok());
361            assert!(limiter.check_password(*requester, &alice).is_ok());
362            assert!(limiter.check_password(*requester, &alice).is_ok());
363            assert!(limiter.check_password(*requester, &alice).is_err());
364        }
365
366        assert!(limiter.check_password(requesters[600], &alice).is_ok());
369        assert!(limiter.check_password(requesters[601], &alice).is_ok());
370        assert!(limiter.check_password(requesters[602], &alice).is_err());
371
372        assert!(limiter.check_password(requesters[603], &bob).is_ok());
374    }
375}