← Back to Tutorials
Python

Build a Distributed Cache

Difficulty: Advanced Est. Time: ~5 hours

Introduction

Distributed caching is essential for scaling modern applications. A distributed cache stores data across multiple servers, providing high availability, fault tolerance, and horizontal scalability.

In this tutorial, we'll build "DistriCache" - a distributed caching system with consistent hashing, data replication, and cluster management.

What You'll Build
  • A distributed cache cluster
  • Consistent hashing for data distribution
  • Multi-node data replication
  • Automatic failover and recovery
  • Client library for cache operations
What You'll Learn
  • Distributed systems fundamentals
  • Consistent hashing algorithms
  • Data replication strategies
  • CAP theorem tradeoffs
  • Network programming in Python

Core Concepts

Let's understand the fundamentals of distributed caching.

Why Distributed Cache?

  • Scalability - Handle more requests by adding nodes
  • High Availability - Survive node failures
  • Low Latency - Data closer to users
  • Fault Tolerance - Replicated data survives failures

The CAP Theorem

Distributed systems can only guarantee two of three: Consistency, Availability, and Partition tolerance. Most distributed caches choose Availability and Partition tolerance (AP), providing eventual consistency.

Consistent Hashing

Consistent hashing minimizes data movement when nodes are added or removed, making the system more stable and scalable.

Project Setup

Bash
# Create project directory
mkdir districache
cd districache

# Create virtual environment
python -m venv venv
source venv/bin/activate

# Install dependencies
pip install redis hashlib asyncio aiohttp

Project Structure

File Structure
districache/
├── districache/
│   ├── __init__.py
│   ├── node.py
│   ├── hashing.py
│   ├── replication.py
│   ├── cluster.py
│   └── client.py
├── server.py
├── requirements.txt
└── README.md

Cache Node Implementation

Let's create the cache node that stores data in memory.

Python
# districache/node.py
import asyncio
import time
from typing import Any, Optional, Dict
import json
import threading

class CacheNode:
    __init__(self, node_id: str, host: str = 'localhost', port: int = 7000):
        self.node_id = node_id
        self.host = host
        self.port = port
        self.store: Dict[str, tuple] = {}
        self.lock = threading.RLock()
        
        self.stats = {
            'hits': 0,
            'misses': 0,
            'sets': 0,
            'deletes': 0
        }
    
    get(self, key: str) -> Optional[Any]:
        with self.lock:
            if key in self.store:
                value, expiry = self.store[key]
                
                if expiry is None or time.time() < expiry:
                    self.stats['hits'] += 1
                    return value
                else:
                    del self.store[key]
            
            self.stats['misses'] += 1
            return None
    
    set(self, key: str, value: Any, ttl: Optional[int] = None):
        with self.lock:
            expiry = None
            if ttl:
                expiry = time.time() + ttl
            
            self.store[key] = (value, expiry)
            self.stats['sets'] += 1
    
    delete(self, key: str) -> bool:
        with self.lock:
            if key in self.store:
                del self.store[key]
                self.stats['deletes'] += 1
                return True
            return False
    
    exists(self, key: str) -> bool:
        with self.lock:
            if key in self.store:
                value, expiry = self.store[key]
                if expiry is None or time.time() < expiry:
                    return True
                del self.store[key]
            return False
    
    get_stats(self) -> Dict:
        return self.stats.copy()
    
    cleanup_expired(self):
        with self.lock:
            now = time.time()
            expired_keys = [
                k for k, (v, exp) in self.store.items()
                if exp is not None and now >= exp
            ]
            for key in expired_keys:
                del self.store[key]
    
    to_dict(self) -> Dict:
        return {
            'node_id': self.node_id,
            'host': self.host,
            'port': self.port,
            'keys_count': len(self.store),
            'stats': self.get_stats()
        }

Consistent Hashing

Let's implement consistent hashing to distribute data across nodes.

Python
# districache/hashing.py
import hashlib
from bisect import bisect_left
from typing import List, Tuple, Optional

class ConsistentHashRing:
    __init__(self, replicas: int = 150):
        self.replicas = replicas
        self.ring: List[int] = []
        self.ring_map: dict = {}
        self.nodes: set = set()
    
    _hash(self, key: str) -> int:
        return int(hashlib.md5(key.encode()).hexdigest(), 16)
    
    _hash_node(self, node: str, replica: int) -> int:
        return self._hash(f"${node}-replica-${replica}")
    
    add_node(self, node: str):
        if node in self.nodes:
            return
        
        self.nodes.add(node)
        
        for i in range(self.replicas):
            hash_value = self._hash_node(node, i)
            self.ring.append(hash_value)
            self.ring_map[hash_value] = node
        
        self.ring.sort()
    
    remove_node(self, node: str):
        if node not in self.nodes:
            return
        
        self.nodes.remove(node)
        
        for i in range(self.replicas):
            hash_value = self._hash_node(node, i)
            idx = bisect_left(self.ring, hash_value)
            if idx < len(self.ring) and self.ring[idx] == hash_value:
                self.ring.pop(idx)
                del self.ring_map[hash_value]
    
    get_node(self, key: str) -> Optional[str]:
        if not self.ring:
            return None
        
        hash_value = self._hash(key)
        idx = bisect_left(self.ring, hash_value)
        
        if idx >= len(self.ring):
            idx = 0
        
        return self.ring_map[self.ring[idx]]
    
    get_nodes(self, key: str, num_nodes: int = 2) -> List[str]:
        if not self.ring:
            return []
        
        hash_value = self._hash(key)
        idx = bisect_left(self.ring, hash_value)
        
        nodes = []
        seen_nodes = set()
        
        for i in range(len(self.ring)):
            curr_idx = (idx + i) % len(self.ring)
            node = self.ring_map[self.ring[curr_idx]]
            
            if node not in seen_nodes:
                nodes.append(node)
                seen_nodes.add(node)
                
                if len(nodes) >= num_nodes:
                    break
        
        return nodes
    
    get_all_nodes(self) -> List[str]:
        return list(self.nodes)

Data Replication

Let's implement data replication across multiple nodes.

Python
# districache/replication.py
import asyncio
import aiohttp
from typing import List, Dict, Any, Optional
import json

class ReplicationManager:
    __init__(self, replication_factor: int = 2):
        self.replication_factor = replication_factor
        self.nodes: Dict[str, dict] = {}
    
    register_node(self, node_id: str, host: str, port: int):
        self.nodes[node_id] = {
            'host': host,
            'port': port,
            'status': 'active'
        }
    
    remove_node(self, node_id: str):
        if node_id in self.nodes:
            del self.nodes[node_id]
    
    get_active_nodes(self) -> List[str]:
        return [
            node_id for node_id, info in self.nodes.items()
            if info['status'] == 'active'
        ]
    
    _get_url(self, node_id: str, path: str) -> str:
        if node_id not in self.nodes:
            raise ValueError(f"Unknown node: ${node_id}")
        
        node = self.nodes[node_id]
        return f"http://${node['host']}:${node['port']}${path}"
    
    async replicate_set(self, key: str, value: Any, ttl: Optional[int], 
                            primary_node: str, replica_nodes: List[str]):
        tasks = []
        
        for node_id in replica_nodes:
            if node_id == primary_node:
                continue
            
            url = self._get_url(node_id, '/cache/set')
            payload = {
                'key': key,
                'value': value,
                'ttl': ttl
            }
            
            tasks.append(self._make_request('POST', url, payload))
        
        if tasks:
            results = await asyncio.gather(*tasks, return_exceptions=True)
            return results
        
        return []
    
    async replicate_delete(self, key: str, primary_node: str, 
                              replica_nodes: List[str]):
        tasks = []
        
        for node_id in replica_nodes:
            if node_id == primary_node:
                continue
            
            url = self._get_url(node_id, f'/cache/delete?key=${key}"
            tasks.append(self._make_request('DELETE', url))
        
        if tasks:
            await asyncio.gather(*tasks, return_exceptions=True)
    
    async _make_request(self, method: str, url: str, data: dict = None):
        try:
            async with aiohttp.ClientSession() as session:
                async with session.request(method, url, json=data, timeout=5) as resp:
                    return await resp.text()
        except Exception as e:
            return str(e)
    
    async get_from_replicas(self, key: str, replica_nodes: List[str]) -> Any:
        for node_id in replica_nodes:
            try:
                url = self._get_url(node_id, f'/cache/get?key=${key}"
                async with aiohttp.ClientSession() as session:
                    async with session.get(url, timeout=5) as resp:
                        if resp.status == 200:
                            data = await resp.json()
                            return data.get('value')
            except Exception:
                continue
        
        return None

Cluster Management

Let's create the cluster manager that coordinates all nodes.

Python
# districache/cluster.py
import asyncio
import aiohttp
from typing import List, Dict, Optional
from .hashing import ConsistentHashRing
from .replication import ReplicationManager
from .node import CacheNode

class CacheCluster:
    __init__(self, replication_factor: int = 2):
        self.hash_ring = ConsistentHashRing()
        self.replication_manager = ReplicationManager(replication_factor)
        self.replication_factor = replication_factor
        self.local_nodes: Dict[str, CacheNode] = {}
        self.node_endpoints: Dict[str, str] = {}
    
    add_node(self, node_id: str, host: str = 'localhost', 
               port: int = 7000, is_local: bool = True):
        self.hash_ring.add_node(node_id)
        self.node_endpoints[node_id] = f"http://${host}:${port}"
        
        self.replication_manager.register_node(node_id, host, port)
        
        if is_local:
            local_node = CacheNode(node_id, host, port)
            self.local_nodes[node_id] = local_node
        
        print(f"Added node: ${node_id} at ${host}:${port}")
    
    remove_node(self, node_id: str):
        self.hash_ring.remove_node(node_id)
        self.replication_manager.remove_node(node_id)
        
        if node_id in self.local_nodes:
            del self.local_nodes[node_id]
        
        if node_id in self.node_endpoints:
            del self.node_endpoints[node_id]
        
        print(f"Removed node: ${node_id}")
    
    get_primary_node(self, key: str) -> Optional[str]:
        return self.hash_ring.get_node(key)
    
    get_replica_nodes(self, key: str) -> List[str]:
        return self.hash_ring.get_nodes(key, self.replication_factor)
    
    get(self, key: str) -> Optional[Any]:
        primary = self.get_primary_node(key)
        
        if primary and primary in self.local_nodes:
            return self.local_nodes[primary].get(key)
        
        return None
    
    set(self, key: str, value: Any, ttl: Optional[int] = None):
        primary = self.get_primary_node(key)
        replicas = self.get_replica_nodes(key)
        
        if primary and primary in self.local_nodes:
            self.local_nodes[primary].set(key, value, ttl)
        
        return primary, replicas
    
    delete(self, key: str):
        primary = self.get_primary_node(key)
        replicas = self.get_replica_nodes(key)
        
        if primary and primary in self.local_nodes:
            self.local_nodes[primary].delete(key)
        
        return True
    
    get_stats(self) -> Dict:
        stats = {
            'total_nodes': len(self.node_endpoints),
            'replication_factor': self.replication_factor,
            'nodes': {}
        }
        
        for node_id, node in self.local_nodes.items():
            stats['nodes'][node_id] = node.get_stats()
        
        return stats
    
    get_all_keys(self) -> set:
        all_keys = set()
        
        for node in self.local_nodes.values():
            all_keys.update(node.store.keys())
        
        return all_keys

Client Library

Let's create an easy-to-use client library.

Python
# districache/client.py
import asyncio
import aiohttp
from typing import Any, Optional, List
import json

class DistriCacheClient:
    __init__(self, cluster_nodes: List[dict], replication_factor: int = 2):
        self.cluster_nodes = cluster_nodes
        self.replication_factor = replication_factor
        self.session: Optional[aiohttp.ClientSession] = None
        self.current_node_idx = 0
    
    async connect(self):
        self.session = aiohttp.ClientSession()
    
    async close(self):
        if self.session:
            await self.session.close()
    
    async _get_url(self, path: str) -> str:
        node = self.cluster_nodes[self.current_node_idx]
        self.current_node_idx = (self.current_node_idx + 1) % len(self.cluster_nodes)
        return f"http://${node['host']}:${node['port']}${path}"
    
    async get(self, key: str) -> Optional[Any]:
        if not self.session:
            await self.connect()
        
        url = await self._get_url(f'/cache/get?key=${key}")
        
        try:
            async with self.session.get(url, timeout=5) as resp:
                if resp.status == 200:
                    data = await resp.json()
                    return data.get('value')
        except Exception as e:
            for _ in range(len(self.cluster_nodes) - 1):
                url = await self._get_url(f'/cache/get?key=${key}")
                try:
                    async with self.session.get(url, timeout=5) as resp:
                        if resp.status == 200:
                            data = await resp.json()
                            return data.get('value')
                except Exception:
                    continue
        
        return None
    
    async set(self, key: str, value: Any, ttl: Optional[int] = None):
        if not self.session:
            await self.connect()
        
        url = await self._get_url('/cache/set')
        payload = {
            'key': key,
            'value': value,
            'ttl': ttl
        }
        
        await self.session.post(url, json=payload)
    
    async delete(self, key: str):
        if not self.session:
            await self.connect()
        
        url = await self._get_url(f'/cache/delete?key=${key}")
        
        await self.session.delete(url)
    
    async __aenter__(self):
        await self.connect()
        return self
    
    async __aexit__(self, exc_type, exc_val, exc_tb):
        await self.close()


# Usage example
# async def main():
#     nodes = [
#         {'host': 'localhost', 'port': 7001},
#         {'host': 'localhost', 'port': 7002},
#         {'host': 'localhost', 'port': 7003}
#     ]
#     
#     async with DistriCacheClient(nodes) as client:
#         await client.set('user:1', {'name': 'John'})
#         user = await client.get('user:1')
#         print(user)

Testing

Python
# test_districache.py
from districache.cluster import CacheCluster

# Create a cluster
cluster = CacheCluster(replication_factor=2)

# Add nodes
cluster.add_node("node1", "localhost", 7001)
cluster.add_node("node2", "localhost", 7002)
cluster.add_node("node3", "localhost", 7003)

# Test data distribution
print("\n--- Testing Data Distribution ---")

for i in range(10):
    key = f"user:${i}"
    primary = cluster.get_primary_node(key)
    replicas = cluster.get_replica_nodes(key)
    print(f"${key} -> Primary: ${primary}, Replicas: ${replicas}")

# Test cache operations
print("\n--- Testing Cache Operations ---")

cluster.set("user:1", {"name": "John", "email": "john@example.com"})
cluster.set("user:2", {"name": "Jane", "email": "jane@example.com"})
cluster.set("product:1", {"name": "Laptop", "price": 999})

print("user:1:", cluster.get("user:1"))
print("user:2:", cluster.get("user:2"))
print("product:1:", cluster.get("product:1"))

# Test deletion
cluster.delete("user:1")
print("After delete, user:1:", cluster.get("user:1"))

# Print stats
print("\n--- Cluster Stats ---")
print(cluster.get_stats())

# Test node removal
print("\n--- Testing Node Removal ---")
cluster.remove_node("node2")

for i in range(5):
    key = f"user:${i}"
    primary = cluster.get_primary_node(key)
    print(f"${key} -> ${primary}")
Testing Checklist
  • Nodes are added to the cluster
  • Data is distributed using consistent hashing
  • Cache operations work correctly
  • Replica nodes are selected
  • Node removal redistributes keys

Summary

Congratulations! You've built a complete distributed caching system.

What You Built

  • Cache Node - In-memory storage with TTL
  • Consistent Hashing - Efficient data distribution
  • Replication Manager - Multi-node replication
  • Cluster Manager - Node coordination
  • Client Library - Easy-to-use API

Next Steps

  • Add automatic node discovery
  • Implement vector clocks for consistency
  • Add persistence to disk
  • Implement cache warming

Continue Learning

Try these tutorials:

  • Build a Message Queue
  • Build a Load Balancer
  • Build an API Gateway