/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ mod builder; mod bytesize; mod cache_control; mod clock; mod connection_initializer; mod outcome; mod request_hash; mod store; mod strategy; use self::{ builder::HttpCacheBuilder, store::HttpCacheStore, strategy::{CacheFirst, NetworkFirst}, }; use std::hash::Hash; use viaduct::{Client, Request, Response}; pub use self::builder::HttpCacheBuilderError; pub use self::bytesize::ByteSize; pub use self::outcome::CacheOutcome; pub use self::request_hash::RequestHash; use std::path::Path; use std::time::Duration; pub type HttpCacheSendResult = std::result::Result<(Response, Vec), viaduct::ViaductError>; #[derive(Clone, Copy, Debug)] pub enum CachePolicy { CacheFirst { ttl: Option }, NetworkFirst { ttl: Option }, } impl Default for CachePolicy { fn default() -> Self { Self::CacheFirst { ttl: None } } } pub struct HttpCache { default_ttl: Duration, max_size: ByteSize, store: HttpCacheStore, } impl HttpCache { pub fn builder>(db_path: P) -> HttpCacheBuilder { HttpCacheBuilder::new(db_path.as_ref()) } pub fn clear(&self) -> Result<(), rusqlite::Error> { self.store.clear_all()?; Ok(()) } pub fn invalidate_by_hash(&self, request_hash: &RequestHash) -> Result<(), rusqlite::Error> { self.store.invalidate_by_hash(request_hash)?; Ok(()) } pub fn send_with_policy>( &self, client: &Client, item: T, policy: &CachePolicy, ) -> HttpCacheSendResult { let hash = RequestHash::new(&item); let request = item.into(); let mut outcomes = vec![]; // Clean up expired entries before applying the policy if let Err(e) = self.store.delete_expired_entries() { outcomes.push(CacheOutcome::CleanupFailed(e)); } // Apply the cache policy and collect outcomes let (response, mut strategy_outcomes) = match policy { CachePolicy::CacheFirst { ttl } => CacheFirst { hash, request, ttl: ttl.unwrap_or(self.default_ttl), } .apply(client, &self.store), CachePolicy::NetworkFirst { ttl } => NetworkFirst { hash, request, ttl: ttl.unwrap_or(self.default_ttl), } .apply(client, &self.store), }?; outcomes.append(&mut strategy_outcomes); // Trim the cache to the max size only when something was actually stored if outcomes .iter() .any(|o| matches!(o, CacheOutcome::MissStored)) { if let Err(e) = self.store.trim_to_max_size(&self.max_size) { outcomes.push(CacheOutcome::TrimFailed(e)); } } Ok((response, outcomes)) } } #[cfg(test)] mod tests { use mockito::mock; use std::hash::{Hash, Hasher}; use super::*; use viaduct::ClientSettings; fn make_client() -> Client { Client::new(ClientSettings::default()) } /// Test-only hashable wrapper around Request. /// Hashes method + url for cache key purposes. #[derive(Clone)] struct TestRequest(Request); impl Hash for TestRequest { fn hash(&self, state: &mut H) { self.0.method.as_str().hash(state); self.0.url.as_str().hash(state); } } impl From for Request { fn from(t: TestRequest) -> Self { t.0 } } fn make_post_request() -> TestRequest { let url = format!("{}/ads", mockito::server_url()).parse().unwrap(); TestRequest(Request::post(url).json(&serde_json::json!({"fake":"data"}))) } fn make_cache() -> HttpCache { // Our store opens an in-memory cache for tests. So the name is irrelevant. HttpCache::builder("ignored_in_tests.db") .default_ttl(Duration::from_secs(60)) .max_size(ByteSize::mib(1)) .build() .expect("cache build should succeed") } fn make_cache_with_ttl(secs: u64) -> HttpCache { // In tests our store uses an in-memory DB; filename is irrelevant. HttpCache::builder("ignored_in_tests.db") .default_ttl(Duration::from_secs(secs)) .max_size(ByteSize::mib(4)) .build_for_time_dependent_tests() .expect("cache build should succeed") } #[test] fn test_http_cache_creation() { // Test that HttpCache can be created successfully with test config let cache: Result = HttpCache::builder("test_cache.db").build(); assert!(cache.is_ok()); // Test with custom config let cache_with_config: Result = HttpCache::builder("custom_test.db") .max_size(ByteSize::mib(1)) .default_ttl(Duration::from_secs(60)) .build(); assert!(cache_with_config.is_ok()); } #[test] fn test_clear_cache() { let cache: HttpCache = HttpCache::builder("test_clear.db").build().unwrap(); // Create a test request and response let hash = RequestHash::new(&("Get", "https://example.com/test")); let response = viaduct::Response { request_method: viaduct::Method::Get, url: "https://example.com/test".parse().unwrap(), status: 200, headers: viaduct::Headers::new(), body: b"test response".to_vec(), }; cache .store .store_with_ttl(&hash, &response, &Duration::from_secs(300)) .unwrap(); // Verify it's cached let retrieved = cache.store.lookup(&hash).unwrap(); assert!(retrieved.is_some()); // Clear the cache cache.clear().unwrap(); // Verify it's cleared let retrieved_after_clear = cache.store.lookup(&hash).unwrap(); assert!(retrieved_after_clear.is_none()); } #[test] fn test_default_policy_miss_then_store_then_hit() { viaduct_dev::init_backend_dev(); let body = r#"{"ok":true}"#; let _m = mock("POST", "/ads") .with_status(200) .with_header("content-type", "application/json") .with_body(body) .expect(1) // only the first call should hit the network .create(); let cache = make_cache(); let req = make_post_request(); let client = make_client(); // First call: miss -> store let (_, outcomes) = cache .send_with_policy(&client, req.clone(), &CachePolicy::default()) .unwrap(); assert!(matches!(outcomes.last().unwrap(), CacheOutcome::MissStored)); // Second call: hit (no extra HTTP request due to expect(1)) let (response, outcomes) = cache .send_with_policy(&client, req, &CachePolicy::default()) .unwrap(); assert!(matches!(outcomes.last().unwrap(), CacheOutcome::Hit)); assert_eq!(response.status, 200); } #[test] fn test_refresh_policy_always_uses_network_then_caches() { viaduct_dev::init_backend_dev(); let body1 = r#"{"ok":true,"n":1}"#; let body2 = r#"{"ok":true,"n":2}"#; // Two live responses expected on refresh let _m1 = mock("POST", "/ads") .with_status(200) .with_header("content-type", "application/json") .with_body(body1) .create(); let _m2 = mock("POST", "/ads") .with_status(200) .with_header("content-type", "application/json") .with_body(body2) .create(); let cache = make_cache(); let req = make_post_request(); let client = make_client(); // First refresh: live -> MissStored let (_, outcomes) = cache .send_with_policy( &client, req.clone(), &CachePolicy::NetworkFirst { ttl: None }, ) .unwrap(); assert!(matches!(outcomes.last().unwrap(), CacheOutcome::MissStored)); // Second refresh: live again (different body), still MissStored let (response, outcomes) = cache .send_with_policy(&client, req, &CachePolicy::NetworkFirst { ttl: None }) .unwrap(); assert!(matches!(outcomes.last().unwrap(), CacheOutcome::MissStored)); assert_eq!(response.status, 200); } #[test] fn test_not_cacheable_no_store() { viaduct_dev::init_backend_dev(); let _m = mock("POST", "/ads") .with_status(200) .with_header("content-type", "application/json") .with_header("cache-control", "no-store") // should block caching .with_body(r#"{"ok":true}"#) .expect(1) .create(); let cache = make_cache(); let req = make_post_request(); let client = make_client(); let (_, outcomes) = cache .send_with_policy(&client, req.clone(), &CachePolicy::default()) .unwrap(); assert!(matches!( outcomes.last().unwrap(), CacheOutcome::MissNotCacheable )); // Next call should hit network again (since we didn't cache) let _m2 = mock("POST", "/ads") .with_status(200) .with_header("content-type", "application/json") .with_body(r#"{"ok":true}"#) .expect(1) .create(); let (_, outcomes) = cache .send_with_policy(&client, req, &CachePolicy::default()) .unwrap(); // Either MissStored (if headers differ) or MissNotCacheable if still no-store assert!(matches!( outcomes.last().unwrap(), CacheOutcome::MissStored | CacheOutcome::MissNotCacheable )); } #[test] fn ttl_resolution_min_of_server_request_default() { viaduct_dev::init_backend_dev(); let _m = mockito::mock("POST", "/ads") .with_status(200) .with_header("content-type", "application/json") .with_header("cache-control", "max-age=1") // Set max age to 1 second .with_body(r#"{"ok":true}"#) .expect(1) .create(); let cache = make_cache_with_ttl(300); let req = make_post_request(); let hash = RequestHash::new(&req); let policy = CachePolicy::CacheFirst { ttl: Some(Duration::from_secs(20)), // 20 second ttl specified vs the cache's default of 300s }; let client = make_client(); // Store ttl should resolve to 1s as specified by response headers let (_, outcomes) = cache.send_with_policy(&client, req, &policy).unwrap(); assert!(matches!(outcomes.last().unwrap(), CacheOutcome::MissStored)); // After ~>1s, cleanup should remove it cache.store.get_clock().advance(2); cache.store.delete_expired_entries().unwrap(); assert!(cache.store.lookup(&hash).unwrap().is_none()); } #[test] fn ttl_resolution_request_overrides_default_when_smaller() { viaduct_dev::init_backend_dev(); let _m = mockito::mock("POST", "/ads") .with_status(200) .with_header("content-type", "application/json") .with_body(r#"{"ok":true}"#) .expect(1) .create(); let cache = make_cache_with_ttl(60); let req = make_post_request(); let hash = RequestHash::new(&req); let policy = CachePolicy::CacheFirst { ttl: Some(Duration::from_secs(2)), }; let client = make_client(); // Store with effective TTL = 2s let (_, outcomes) = cache.send_with_policy(&client, req, &policy).unwrap(); assert!(matches!(outcomes.last().unwrap(), CacheOutcome::MissStored)); // Not expired yet at ~1s cache.store.get_clock().advance(1); cache.store.delete_expired_entries().unwrap(); assert!(cache.store.lookup(&hash).unwrap().is_some()); // Expired after ~2s cache.store.get_clock().advance(2); cache.store.delete_expired_entries().unwrap(); assert!(cache.store.lookup(&hash).unwrap().is_none()); } #[test] fn ttl_resolution_uses_default_when_no_server_and_no_request_override() { viaduct_dev::init_backend_dev(); let _m = mockito::mock("POST", "/ads") .with_status(200) // No response policy ttl .with_header("content-type", "application/json") .with_body(r#"{"ok":true}"#) .expect(1) .create(); let cache = make_cache_with_ttl(2); let req = make_post_request(); let hash = RequestHash::new(&req); let client = make_client(); // Store with effective TTL = 2s from client default let (_, outcomes) = cache .send_with_policy(&client, req, &CachePolicy::default()) .unwrap(); assert!(matches!(outcomes.last().unwrap(), CacheOutcome::MissStored)); // Not expired at ~1s cache.store.get_clock().advance(1); cache.store.delete_expired_entries().unwrap(); assert!(cache.store.lookup(&hash).unwrap().is_some()); // Expired after ~3s cache.store.get_clock().advance(3); cache.store.delete_expired_entries().unwrap(); assert!(cache.store.lookup(&hash).unwrap().is_none()); } #[test] fn test_expired_entry_is_a_miss_on_next_send() { viaduct_dev::init_backend_dev(); let _m1 = mockito::mock("POST", "/ads") .with_status(200) .with_header("content-type", "application/json") .with_body(r#"{"ok":true,"n":1}"#) .create(); let _m2 = mockito::mock("POST", "/ads") .with_status(200) .with_header("content-type", "application/json") .with_body(r#"{"ok":true,"n":2}"#) .create(); let cache = make_cache_with_ttl(2); let req = make_post_request(); let client = make_client(); // First call: miss -> store with 2s TTL let (_, outcomes) = cache .send_with_policy(&client, req.clone(), &CachePolicy::default()) .unwrap(); assert!(matches!(outcomes.last().unwrap(), CacheOutcome::MissStored)); // Advance clock past the TTL cache.store.get_clock().advance(3); // Second call: expired entry must be a miss, not a hit let (_, outcomes) = cache .send_with_policy(&client, req, &CachePolicy::default()) .unwrap(); assert!(matches!(outcomes.last().unwrap(), CacheOutcome::MissStored)); } #[test] fn test_invalidate_by_hash() { let cache: HttpCache = HttpCache::builder("test_invalidate.db").build().unwrap(); let hash1 = RequestHash::new(&("Post", "https://example.com/api1")); let hash2 = RequestHash::new(&("Post", "https://example.com/api2")); let response = viaduct::Response { request_method: viaduct::Method::Post, url: "https://example.com/test".parse().unwrap(), status: 200, headers: viaduct::Headers::new(), body: b"test response".to_vec(), }; cache .store .store_with_ttl(&hash1, &response, &Duration::from_secs(300)) .unwrap(); cache .store .store_with_ttl(&hash2, &response, &Duration::from_secs(300)) .unwrap(); assert!(cache.store.lookup(&hash1).unwrap().is_some()); assert!(cache.store.lookup(&hash2).unwrap().is_some()); cache.invalidate_by_hash(&hash1).unwrap(); assert!(cache.store.lookup(&hash1).unwrap().is_none()); assert!(cache.store.lookup(&hash2).unwrap().is_some()); } }