Rate Limiting Strategies in Go
Implement sophisticated rate limiting algorithms and patterns in Go for API protection.
Rate Limiting Fundamentals
Rate limiting is a strategy to control the rate at which a user, service, or system can access a resource or perform operations. Before diving into complex implementations, let’s establish a solid understanding of the core concepts and basic approaches.
Core Concepts and Terminology
Rate limiting involves several key concepts:
- Request Rate: The number of requests per unit time (e.g., 100 requests per second)
- Burst: A temporary spike in request rate
- Quota: The maximum number of requests allowed in a given time window
- Throttling: The act of delaying or rejecting requests that exceed defined limits
- Rate Limiter: The component that enforces rate limits
Simple Counter-Based Rate Limiter
Let’s start with a basic implementation—a fixed window counter rate limiter:
package main
import (
"fmt"
"sync"
"time"
)
// FixedWindowLimiter implements a simple fixed window rate limiter
type FixedWindowLimiter struct {
mu sync.Mutex
requestCount int
windowSize time.Duration
limit int
windowStart time.Time
}
// NewFixedWindowLimiter creates a new fixed window rate limiter
func NewFixedWindowLimiter(limit int, windowSize time.Duration) *FixedWindowLimiter {
return &FixedWindowLimiter{
limit: limit,
windowSize: windowSize,
windowStart: time.Now(),
}
}
// Allow checks if a request should be allowed based on the current rate
func (l *FixedWindowLimiter) Allow() bool {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
// If the window has expired, reset the counter
if now.Sub(l.windowStart) >= l.windowSize {
l.requestCount = 0
l.windowStart = now
}
// Check if we've reached the limit
if l.requestCount >= l.limit {
return false
}
// Increment the counter and allow the request
l.requestCount++
return true
}
func main() {
// Create a rate limiter: 5 requests per second
limiter := NewFixedWindowLimiter(5, time.Second)
// Simulate 10 requests in quick succession
for i := 1; i <= 10; i++ {
allowed := limiter.Allow()
fmt.Printf("Request %d: %v\n", i, allowed)
}
// Wait for the window to reset
fmt.Println("Waiting for window reset...")
time.Sleep(time.Second)
// Try again after the window resets
for i := 11; i <= 15; i++ {
allowed := limiter.Allow()
fmt.Printf("Request %d: %v\n", i, allowed)
}
}
This simple implementation demonstrates the basic concept but has significant limitations:
- Edge Effects: All requests in a window are treated equally, regardless of their distribution within the window. This can lead to “bursts” at window boundaries.
- Memory Efficiency: For high-volume services with many clients, maintaining counters for each client can consume significant memory.
- Precision: The fixed window approach lacks granularity and can allow twice the intended rate at window boundaries.
Fundamentals and Core Concepts
Rate Limiting vs. Throttling
While often used interchangeably, rate limiting and throttling have subtle differences:
- Rate Limiting: Typically involves rejecting requests that exceed a threshold
- Throttling: May involve delaying requests (queuing) rather than rejecting them outright
Here’s a simple throttling implementation that delays requests instead of rejecting them:
package main
import (
"fmt"
"sync"
"time"
)
// ThrottledLimiter implements a throttling rate limiter
type ThrottledLimiter struct {
mu sync.Mutex
lastRequest time.Time
minInterval time.Duration
}
// NewThrottledLimiter creates a new throttling rate limiter
func NewThrottledLimiter(requestsPerSecond int) *ThrottledLimiter {
return &ThrottledLimiter{
minInterval: time.Second / time.Duration(requestsPerSecond),
}
}
// Wait blocks until the request can be processed according to the rate limit
func (l *ThrottledLimiter) Wait() {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
// If this is not the first request, calculate how long to wait
if !l.lastRequest.IsZero() {
elapsed := now.Sub(l.lastRequest)
if elapsed < l.minInterval {
sleepDuration := l.minInterval - elapsed
time.Sleep(sleepDuration)
now = time.Now() // Update now after sleeping
}
}
l.lastRequest = now
}
func main() {
// Create a throttled limiter: 2 requests per second
limiter := NewThrottledLimiter(2)
// Simulate 5 requests
for i := 1; i <= 5; i++ {
start := time.Now()
limiter.Wait()
elapsed := time.Since(start)
fmt.Printf("Request %d: waited %v\n", i, elapsed)
}
}
This throttling approach ensures a consistent rate by introducing delays between requests, which can be useful for scenarios where dropping requests is undesirable.
Advanced Rate Limiting Algorithms
While simple counters can work for basic use cases, more sophisticated algorithms provide better fairness, accuracy, and performance characteristics. Let’s explore the most widely used advanced rate limiting algorithms.
Token Bucket Algorithm
The token bucket algorithm is one of the most popular rate limiting approaches due to its simplicity and effectiveness. It models rate limiting as a bucket that continuously fills with tokens at a fixed rate. Each request consumes a token, and when the bucket is empty, requests are rejected.
Here’s a comprehensive implementation:
package main
import (
"fmt"
"sync"
"time"
)
// TokenBucket implements the token bucket algorithm for rate limiting
type TokenBucket struct {
mu sync.Mutex
tokens float64
maxTokens float64
refillRate float64 // tokens per second
lastRefillTime time.Time
}
// NewTokenBucket creates a new token bucket rate limiter
func NewTokenBucket(maxTokens, refillRate float64) *TokenBucket {
return &TokenBucket{
tokens: maxTokens,
maxTokens: maxTokens,
refillRate: refillRate,
lastRefillTime: time.Now(),
}
}
// refill adds tokens to the bucket based on elapsed time
func (tb *TokenBucket) refill() {
now := time.Now()
elapsed := now.Sub(tb.lastRefillTime).Seconds()
// Calculate tokens to add based on elapsed time
newTokens := elapsed * tb.refillRate
// Update token count, capped at maxTokens
tb.tokens = min(tb.tokens+newTokens, tb.maxTokens)
tb.lastRefillTime = now
}
// Allow checks if a request can proceed and consumes a token if allowed
func (tb *TokenBucket) Allow() bool {
tb.mu.Lock()
defer tb.mu.Unlock()
tb.refill()
if tb.tokens >= 1.0 {
tb.tokens--
return true
}
return false
}
// AllowN checks if N requests can proceed and consumes N tokens if allowed
func (tb *TokenBucket) AllowN(n float64) bool {
tb.mu.Lock()
defer tb.mu.Unlock()
tb.refill()
if tb.tokens >= n {
tb.tokens -= n
return true
}
return false
}
// min returns the minimum of two float64 values
func min(a, b float64) float64 {
if a < b {
return a
}
return b
}
func main() {
// Create a token bucket: 5 max tokens, refill rate of 2 tokens per second
limiter := NewTokenBucket(5, 2)
// Simulate burst of 5 requests (should all succeed due to initial token count)
for i := 1; i <= 5; i++ {
allowed := limiter.Allow()
fmt.Printf("Request %d: %v\n", i, allowed)
}
// The 6th request should be rejected (no tokens left)
allowed := limiter.Allow()
fmt.Printf("Request 6: %v\n", allowed)
// Wait for some tokens to refill
fmt.Println("Waiting for token refill...")
time.Sleep(2 * time.Second) // Wait for ~4 tokens to be refilled
// Try 4 more requests
for i := 7; i <= 10; i++ {
allowed := limiter.Allow()
fmt.Printf("Request %d: %v\n", i, allowed)
}
// Demonstrate AllowN for a request that needs multiple tokens
fmt.Println("Waiting for token refill...")
time.Sleep(3 * time.Second)
// Try to consume 3 tokens at once
allowed = limiter.AllowN(3)
fmt.Printf("Bulk request (3 tokens): %v\n", allowed)
}
Key advantages of the token bucket algorithm:
- Burst Handling: It naturally accommodates bursts up to the bucket capacity
- Smooth Rate Limiting: The continuous refill provides a smoother limiting effect
- Flexibility: Can be adapted for different token costs per request type
Advanced Patterns and Techniques
Leaky Bucket Algorithm
The leaky bucket algorithm models rate limiting as a bucket with a hole in the bottom. Requests fill the bucket, and they “leak” out at a constant rate. If the bucket overflows, new requests are rejected.
package main
import (
"fmt"
"sync"
"time"
)
// LeakyBucket implements the leaky bucket algorithm for rate limiting
type LeakyBucket struct {
mu sync.Mutex
capacity int // Maximum bucket capacity
remaining int // Current bucket level
leakRate float64 // Items per second
lastLeakTime time.Time
}
// NewLeakyBucket creates a new leaky bucket rate limiter
func NewLeakyBucket(capacity int, leakRate float64) *LeakyBucket {
return &LeakyBucket{
capacity: capacity,
remaining: 0,
leakRate: leakRate,
lastLeakTime: time.Now(),
}
}
// leak removes items from the bucket based on elapsed time
func (lb *LeakyBucket) leak() {
now := time.Now()
elapsed := now.Sub(lb.lastLeakTime).Seconds()
// Calculate items to leak based on elapsed time
leakAmount := int(elapsed * lb.leakRate)
// Update bucket level, ensuring it doesn't go below zero
if leakAmount > 0 {
lb.remaining = max(0, lb.remaining-leakAmount)
lb.lastLeakTime = now
}
}
// Allow checks if a request can be added to the bucket
func (lb *LeakyBucket) Allow() bool {
lb.mu.Lock()
defer lb.mu.Unlock()
lb.leak()
// If bucket is full, reject the request
if lb.remaining >= lb.capacity {
return false
}
// Add the request to the bucket
lb.remaining++
return true
}
// max returns the maximum of two integers
func max(a, b int) int {
if a > b {
return a
}
return b
}
func main() {
// Create a leaky bucket: capacity 5, leak rate 2 per second
limiter := NewLeakyBucket(5, 2)
// Simulate 7 requests in quick succession
for i := 1; i <= 7; i++ {
allowed := limiter.Allow()
fmt.Printf("Request %d: %v\n", i, allowed)
}
// Wait for some requests to leak out
fmt.Println("Waiting for leakage...")
time.Sleep(2 * time.Second) // Wait for ~4 requests to leak
// Try 5 more requests
for i := 8; i <= 12; i++ {
allowed := limiter.Allow()
fmt.Printf("Request %d: %v\n", i, allowed)
}
}
The leaky bucket algorithm is particularly useful for:
- Traffic Shaping: It enforces a consistent outflow rate
- Queue Management: It can be extended to queue requests rather than reject them
- Network Traffic: It’s well-suited for network packet shaping
Sliding Window Counter
The sliding window counter algorithm addresses the edge effects of fixed windows by considering a weighted average of the current and previous windows:
package main
import (
"fmt"
"sync"
"time"
)
// SlidingWindowCounter implements a sliding window rate limiter
type SlidingWindowCounter struct {
mu sync.Mutex
limit int
windowSize time.Duration
previousCount int
currentCount int
windowStart time.Time
}
// NewSlidingWindowCounter creates a new sliding window counter rate limiter
func NewSlidingWindowCounter(limit int, windowSize time.Duration) *SlidingWindowCounter {
return &SlidingWindowCounter{
limit: limit,
windowSize: windowSize,
windowStart: time.Now(),
}
}
// Allow checks if a request should be allowed based on the sliding window calculation
func (sw *SlidingWindowCounter) Allow() bool {
sw.mu.Lock()
defer sw.mu.Unlock()
now := time.Now()
elapsed := now.Sub(sw.windowStart)
// If the window has expired, shift the window
if elapsed >= sw.windowSize {
// Calculate how many complete windows have passed
completeWindows := int(elapsed / sw.windowSize)
// If more than one window has passed, reset all counts
if completeWindows > 1 {
sw.previousCount = 0
sw.currentCount = 0
} else {
// Shift window by one
sw.previousCount = sw.currentCount
sw.currentCount = 0
}
// Update window start time
sw.windowStart = sw.windowStart.Add(time.Duration(completeWindows) * sw.windowSize)
elapsed = now.Sub(sw.windowStart)
}
// Calculate the position within the current window (0.0 to 1.0)
windowPosition := float64(elapsed) / float64(sw.windowSize)
// Calculate the weighted rate:
// (previous_count * (1 - position) + current_count)
weightedCount := float64(sw.previousCount) * (1 - windowPosition) + float64(sw.currentCount)
// Check if we've reached the limit
if int(weightedCount) >= sw.limit {
return false
}
// Increment the counter and allow the request
sw.currentCount++
return true
}
func main() {
// Create a sliding window counter: 10 requests per minute
limiter := NewSlidingWindowCounter(10, time.Minute)
// Simulate 8 requests in the first half of the window
for i := 1; i <= 8; i++ {
allowed := limiter.Allow()
fmt.Printf("Request %d: %v\n", i, allowed)
}
// Simulate moving to the next window (30 seconds later)
limiter.windowStart = limiter.windowStart.Add(-30 * time.Second)
fmt.Println("Time elapsed: 30 seconds (50% through window)")
// Try 5 more requests
// With sliding window, we should allow only about 2-3 more requests
// because we're counting 50% of the previous window (8 requests)
for i := 9; i <= 13; i++ {
allowed := limiter.Allow()
fmt.Printf("Request %d: %v\n", i, allowed)
}
}
The sliding window counter provides a more accurate rate limiting approach by:
- Smoothing Boundaries: Eliminating the edge effects of fixed windows
- Accurate Limiting: Providing a more consistent rate limit across window boundaries
- Memory Efficiency: Requiring only two counters per client
Fixed Window Counter with Cell-Based Storage
For systems with many clients, memory usage becomes a concern. A cell-based approach can optimize storage:
package main
import (
"fmt"
"sync"
"time"
)
// Cell represents a time window cell with a request count
type Cell struct {
timestamp time.Time
count int
}
// CellBasedRateLimiter implements a memory-efficient rate limiter
type CellBasedRateLimiter struct {
mu sync.Mutex
clients map[string]*Cell
limit int
windowSize time.Duration
cleanupInterval time.Duration
lastCleanup time.Time
}
// NewCellBasedRateLimiter creates a new cell-based rate limiter
func NewCellBasedRateLimiter(limit int, windowSize time.Duration) *CellBasedRateLimiter {
return &CellBasedRateLimiter{
clients: make(map[string]*Cell),
limit: limit,
windowSize: windowSize,
cleanupInterval: windowSize * 2, // Clean up every 2 window periods
lastCleanup: time.Now(),
}
}
// Allow checks if a request from a specific client should be allowed
func (cb *CellBasedRateLimiter) Allow(clientID string) bool {
cb.mu.Lock()
defer cb.mu.Unlock()
now := time.Now()
// Periodically clean up expired client entries
if now.Sub(cb.lastCleanup) >= cb.cleanupInterval {
cb.cleanup(now)
cb.lastCleanup = now
}
// Get or create client cell
cell, exists := cb.clients[clientID]
if !exists || now.Sub(cell.timestamp) >= cb.windowSize {
// Create new cell or reset expired cell
cb.clients[clientID] = &Cell{
timestamp: now,
count: 1,
}
return true
}
// Check if client has reached the limit
if cell.count >= cb.limit {
return false
}
// Increment the counter and allow the request
cell.count++
return true
}
// cleanup removes expired client entries to free memory
func (cb *CellBasedRateLimiter) cleanup(now time.Time) {
for clientID, cell := range cb.clients {
if now.Sub(cell.timestamp) >= cb.windowSize {
delete(cb.clients, clientID)
}
}
}
func main() {
// Create a cell-based rate limiter: 5 requests per minute per client
limiter := NewCellBasedRateLimiter(5, time.Minute)
// Simulate requests from different clients
clients := []string{"client1", "client2", "client3"}
for _, client := range clients {
fmt.Printf("Testing client: %s\n", client)
// Try 7 requests per client
for i := 1; i <= 7; i++ {
allowed := limiter.Allow(client)
fmt.Printf(" Request %d: %v\n", i, allowed)
}
}
// Simulate time passing and cleanup
fmt.Println("\nSimulating time passing (2 minutes)...")
limiter.lastCleanup = limiter.lastCleanup.Add(-2 * time.Minute)
// Try client1 again after cleanup
fmt.Println("Testing client1 again:")
for i := 1; i <= 3; i++ {
allowed := limiter.Allow("client1")
fmt.Printf(" Request %d: %v\n", i, allowed)
}
}
This implementation is particularly useful for systems with:
- Many Clients: Efficiently handles large numbers of distinct clients
- Memory Constraints: Automatically cleans up expired entries
- High Throughput: Minimizes lock contention with efficient data structures
Implementation Strategies
Distributed Rate Limiting Patterns
In distributed systems, rate limiting becomes more complex as requests may be processed across multiple servers. Let’s explore patterns for coordinating rate limits across a distributed environment.
Redis-Based Distributed Rate Limiter
Redis is commonly used for distributed rate limiting due to its atomic operations and high performance. Here’s an implementation using the Redis-based token bucket algorithm:
package main
import (
"context"
"fmt"
"time"
"github.com/go-redis/redis/v8"
)
// RedisTokenBucket implements a distributed token bucket algorithm using Redis
type RedisTokenBucket struct {
client *redis.Client
keyPrefix string
maxTokens int
refillRate float64 // tokens per second
tokenExpiry time.Duration
}
// NewRedisTokenBucket creates a new Redis-based token bucket rate limiter
func NewRedisTokenBucket(client *redis.Client, keyPrefix string, maxTokens int, refillRate float64) *RedisTokenBucket {
return &RedisTokenBucket{
client: client,
keyPrefix: keyPrefix,
maxTokens: maxTokens,
refillRate: refillRate,
tokenExpiry: time.Hour, // Keys expire after 1 hour of inactivity
}
}
// Allow checks if a request from a specific client should be allowed
func (rtb *RedisTokenBucket) Allow(ctx context.Context, clientID string) (bool, error) {
// Keys for storing the token count and last refill time
tokenKey := fmt.Sprintf("%s:%s:tokens", rtb.keyPrefix, clientID)
timestampKey := fmt.Sprintf("%s:%s:ts", rtb.keyPrefix, clientID)
// Use Redis MULTI/EXEC to ensure atomicity
txf := func(tx *redis.Tx) error {
// Get current token count and last refill timestamp
tokensCmd := tx.Get(ctx, tokenKey)
timestampCmd := tx.Get(ctx, timestampKey)
var tokens float64
var lastRefillTime time.Time
// Handle token count
if tokensCmd.Err() == redis.Nil {
// Key doesn't exist, initialize with max tokens
tokens = float64(rtb.maxTokens)
} else if tokensCmd.Err() != nil {
return tokensCmd.Err()
} else {
// Parse existing token count
var err error
tokens, err = tokensCmd.Float64()
if err != nil {
return err
}
}
// Handle timestamp
now := time.Now()
if timestampCmd.Err() == redis.Nil {
// Key doesn't exist, use current time
lastRefillTime = now
} else if timestampCmd.Err() != nil {
return timestampCmd.Err()
} else {
// Parse existing timestamp
ts, err := timestampCmd.Int64()
if err != nil {
return err
}
lastRefillTime = time.Unix(0, ts)
}
// Calculate token refill
elapsed := now.Sub(lastRefillTime).Seconds()
newTokens := tokens + (elapsed * rtb.refillRate)
if newTokens > float64(rtb.maxTokens) {
newTokens = float64(rtb.maxTokens)
}
// Check if we have enough tokens
if newTokens < 1 {
// Not enough tokens, update values but return false
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, tokenKey, newTokens, rtb.tokenExpiry)
pipe.Set(ctx, timestampKey, now.UnixNano(), rtb.tokenExpiry)
return nil
})
if err != nil {
return err
}
return nil
}
// We have enough tokens, consume one
newTokens--
// Update values in Redis
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, tokenKey, newTokens, rtb.tokenExpiry)
pipe.Set(ctx, timestampKey, now.UnixNano(), rtb.tokenExpiry)
return nil
})
return err
}
// Execute the transaction with optimistic locking
for i := 0; i < 3; i++ { // Retry up to 3 times
err := rtb.client.Watch(ctx, txf, tokenKey, timestampKey)
if err == nil {
return true, nil
}
if err != redis.TxFailedErr {
return false, err
}
// If we got TxFailedErr, retry
}
return false, fmt.Errorf("failed to execute Redis transaction after retries")
}
func main() {
// This is a demonstration - in a real application, you would configure Redis properly
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
})
defer rdb.Close()
// Create a Redis-based token bucket: 10 max tokens, refill rate of 1 token per second
limiter := NewRedisTokenBucket(rdb, "ratelimit", 10, 1)
ctx := context.Background()
// Simulate requests from a client
clientID := "user123"
// Try 12 requests in quick succession
for i := 1; i <= 12; i++ {
allowed, err := limiter.Allow(ctx, clientID)
if err != nil {
fmt.Printf("Error: %v\n", err)
continue
}
fmt.Printf("Request %d: %v\n", i, allowed)
}
// Wait for some tokens to refill
fmt.Println("Waiting for token refill...")
time.Sleep(5 * time.Second)
// Try 5 more requests
for i := 13; i <= 17; i++ {
allowed, err := limiter.Allow(ctx, clientID)
if err != nil {
fmt.Printf("Error: %v\n", err)
continue
}
fmt.Printf("Request %d: %v\n", i, allowed)
}
}
This Redis-based implementation provides several advantages for distributed environments:
- Consistency: All servers share the same rate limit state
- Scalability: Redis can handle high throughput and many clients
- Persistence: Rate limit state can survive service restarts
- Low Latency: Redis operations are typically sub-millisecond
Distributed Rate Limiting with Lua Scripts
For even better performance and atomicity, we can use Redis Lua scripts:
package main
import (
"context"
"fmt"
"time"
"github.com/go-redis/redis/v8"
)
// RedisScriptLimiter implements rate limiting using Redis Lua scripts
type RedisScriptLimiter struct {
client *redis.Client
keyPrefix string
windowSize time.Duration
limit int
luaScript *redis.Script
}
// NewRedisScriptLimiter creates a new Redis-based rate limiter using Lua scripts
func NewRedisScriptLimiter(client *redis.Client, keyPrefix string, limit int, windowSize time.Duration) *RedisScriptLimiter {
// Lua script for sliding window rate limiting
// KEYS[1] - The Redis key to use for this rate limit
// ARGV[1] - The current timestamp in milliseconds
// ARGV[2] - The window size in milliseconds
// ARGV[3] - The maximum number of requests allowed in the window
luaScript := redis.NewScript(`
local key = KEYS[1]
local now = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local limit = tonumber(ARGV[3])
-- Clean up old requests outside the current window
redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
-- Count requests in the current window
local count = redis.call('ZCARD', key)
-- If we're under the limit, add the current request and return allowed
if count < limit then
redis.call('ZADD', key, now, now .. '-' .. math.random())
redis.call('EXPIRE', key, math.ceil(window/1000))
return 1
end
-- We're over the limit
return 0
`)
return &RedisScriptLimiter{
client: client,
keyPrefix: keyPrefix,
windowSize: windowSize,
limit: limit,
luaScript: luaScript,
}
}
// Allow checks if a request from a specific client should be allowed
func (rsl *RedisScriptLimiter) Allow(ctx context.Context, clientID string) (bool, error) {
key := fmt.Sprintf("%s:%s", rsl.keyPrefix, clientID)
now := time.Now().UnixNano() / int64(time.Millisecond)
windowMs := int64(rsl.windowSize / time.Millisecond)
// Execute the Lua script
result, err := rsl.luaScript.Run(ctx, rsl.client, []string{key}, now, windowMs, rsl.limit).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func main() {
// This is a demonstration - in a real application, you would configure Redis properly
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
})
defer rdb.Close()
// Create a Redis-based rate limiter: 5 requests per minute per client
limiter := NewRedisScriptLimiter(rdb, "ratelimit", 5, time.Minute)
ctx := context.Background()
// Simulate requests from a client
clientID := "user123"
// Try 7 requests in quick succession
for i := 1; i <= 7; i++ {
allowed, err := limiter.Allow(ctx, clientID)
if err != nil {
fmt.Printf("Error: %v\n", err)
continue
}
fmt.Printf("Request %d: %v\n", i, allowed)
}
}
Using Lua scripts provides several benefits:
- Atomicity: The entire rate limiting logic executes in a single atomic operation
- Performance: Reduced network round-trips between application and Redis
- Consistency: Logic runs entirely on the Redis server, eliminating race conditions
Performance and Optimization
Adaptive and Dynamic Rate Limiting
In real-world systems, traffic patterns can vary significantly over time. Adaptive rate limiting adjusts limits dynamically based on system conditions, providing better resource utilization and protection.
Load-Based Adaptive Rate Limiter
This implementation adjusts rate limits based on system load:
package main
import (
"fmt"
"math"
"runtime"
"sync"
"time"
)
// AdaptiveRateLimiter implements a rate limiter that adjusts based on system load
type AdaptiveRateLimiter struct {
mu sync.Mutex
baseLimit int
currentLimit int
windowSize time.Duration
requestCount int
windowStart time.Time
loadCheckPeriod time.Duration
lastLoadCheck time.Time
minLimit int
maxLimit int
}
// NewAdaptiveRateLimiter creates a new adaptive rate limiter
func NewAdaptiveRateLimiter(baseLimit, minLimit, maxLimit int, windowSize time.Duration) *AdaptiveRateLimiter {
limiter := &AdaptiveRateLimiter{
baseLimit: baseLimit,
currentLimit: baseLimit,
windowSize: windowSize,
windowStart: time.Now(),
loadCheckPeriod: 5 * time.Second,
lastLoadCheck: time.Now(),
minLimit: minLimit,
maxLimit: maxLimit,
}
// Start a goroutine to periodically adjust the limit based on system load
go limiter.adaptToLoad()
return limiter
}
// adaptToLoad periodically checks system load and adjusts the rate limit
func (l *AdaptiveRateLimiter) adaptToLoad() {
ticker := time.NewTicker(l.loadCheckPeriod)
defer ticker.Stop()
for range ticker.C {
l.adjustLimit()
}
}
// adjustLimit modifies the current limit based on CPU utilization
func (l *AdaptiveRateLimiter) adjustLimit() {
l.mu.Lock()
defer l.mu.Unlock()
// Get current CPU utilization (simplified)
var m runtime.MemStats
runtime.ReadMemStats(&m)
numCPU := float64(runtime.NumCPU())
cpuUtilization := float64(runtime.NumGoroutine()) / numCPU / 10 // Simplified metric
// Adjust limit based on CPU utilization
// - High utilization: reduce limit
// - Low utilization: increase limit
adjustmentFactor := 1.0
if cpuUtilization > 0.7 { // High load
adjustmentFactor = 0.8 // Reduce by 20%
} else if cpuUtilization < 0.3 { // Low load
adjustmentFactor = 1.2 // Increase by 20%
}
// Apply adjustment with bounds
newLimit := int(float64(l.currentLimit) * adjustmentFactor)
l.currentLimit = int(math.Max(float64(l.minLimit), math.Min(float64(l.maxLimit), float64(newLimit))))
fmt.Printf("System load: %.2f, Adjusted limit: %d\n", cpuUtilization, l.currentLimit)
}
// Allow checks if a request should be allowed based on the current adaptive rate
func (l *AdaptiveRateLimiter) Allow() bool {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
// If the window has expired, reset the counter
if now.Sub(l.windowStart) >= l.windowSize {
l.requestCount = 0
l.windowStart = now
}
// Check if we've reached the current limit
if l.requestCount >= l.currentLimit {
return false
}
// Increment the counter and allow the request
l.requestCount++
return true
}
func main() {
// Create an adaptive rate limiter: base 100 RPS, min 50 RPS, max 200 RPS
limiter := NewAdaptiveRateLimiter(100, 50, 200, time.Second)
// Simulate requests with varying load
for i := 1; i <= 20; i++ {
// Create artificial load every 5 iterations
if i%5 == 0 {
fmt.Println("Creating artificial load...")
for j := 0; j < 1000; j++ {
go func() {
time.Sleep(2 * time.Second)
}()
}
}
allowed := limiter.Allow()
fmt.Printf("Request %d: %v (current limit: %d)\n", i, allowed, limiter.currentLimit)
time.Sleep(50 * time.Millisecond)
}
}
Integration with Circuit Breakers
Rate limiting can be combined with circuit breakers for enhanced resilience:
package main
import (
"errors"
"fmt"
"sync"
"time"
)
// CircuitState represents the state of a circuit breaker
type CircuitState int
const (
StateClosed CircuitState = iota // Normal operation, requests allowed
StateOpen // Circuit is open, requests are rejected
StateHalfOpen // Testing if the circuit can be closed again
)
// CircuitBreaker implements the circuit breaker pattern
type CircuitBreaker struct {
mu sync.Mutex
state CircuitState
failureThreshold int
failureCount int
resetTimeout time.Duration
lastStateChange time.Time
halfOpenMaxCalls int
halfOpenCallCount int
}
// NewCircuitBreaker creates a new circuit breaker
func NewCircuitBreaker(failureThreshold int, resetTimeout time.Duration) *CircuitBreaker {
return &CircuitBreaker{
state: StateClosed,
failureThreshold: failureThreshold,
resetTimeout: resetTimeout,
lastStateChange: time.Now(),
halfOpenMaxCalls: 3,
}
}
// Execute runs the given function if the circuit allows it
func (cb *CircuitBreaker) Execute(fn func() error) error {
cb.mu.Lock()
// Check if the circuit is open
if cb.state == StateOpen {
// Check if it's time to try half-open state
if time.Since(cb.lastStateChange) > cb.resetTimeout {
cb.toHalfOpen()
} else {
cb.mu.Unlock()
return errors.New("circuit breaker is open")
}
}
// If half-open, check if we've reached the call limit
if cb.state == StateHalfOpen && cb.halfOpenCallCount >= cb.halfOpenMaxCalls {
cb.mu.Unlock()
return errors.New("circuit breaker is half-open and at call limit")
}
// Increment call count for half-open state
if cb.state == StateHalfOpen {
cb.halfOpenCallCount++
}
cb.mu.Unlock()
// Execute the function
err := fn()
cb.mu.Lock()
defer cb.mu.Unlock()
// Handle the result
if err != nil {
// Record failure
cb.failureCount++
// Check if we need to open the circuit
if (cb.state == StateClosed && cb.failureCount >= cb.failureThreshold) ||
cb.state == StateHalfOpen {
cb.toOpen()
}
return err
}
// Success - if we're half-open, close the circuit
if cb.state == StateHalfOpen {
cb.toClosed()
}
// Reset failure count on success in closed state
if cb.state == StateClosed {
cb.failureCount = 0
}
return nil
}
// toOpen changes the circuit state to open
func (cb *CircuitBreaker) toOpen() {
cb.state = StateOpen
cb.lastStateChange = time.Now()
fmt.Println("Circuit breaker state changed to OPEN")
}
// toHalfOpen changes the circuit state to half-open
func (cb *CircuitBreaker) toHalfOpen() {
cb.state = StateHalfOpen
cb.lastStateChange = time.Now()
cb.halfOpenCallCount = 0
fmt.Println("Circuit breaker state changed to HALF-OPEN")
}
// toClosed changes the circuit state to closed
func (cb *CircuitBreaker) toClosed() {
cb.state = StateClosed
cb.lastStateChange = time.Now()
cb.failureCount = 0
fmt.Println("Circuit breaker state changed to CLOSED")
}
// RateLimitedCircuitBreaker combines rate limiting with circuit breaking
type RateLimitedCircuitBreaker struct {
rateLimiter *TokenBucket
circuitBreaker *CircuitBreaker
}
// NewRateLimitedCircuitBreaker creates a new rate-limited circuit breaker
func NewRateLimitedCircuitBreaker(
rps float64,
burst float64,
failureThreshold int,
resetTimeout time.Duration,
) *RateLimitedCircuitBreaker {
return &RateLimitedCircuitBreaker{
rateLimiter: NewTokenBucket(burst, rps),
circuitBreaker: NewCircuitBreaker(failureThreshold, resetTimeout),
}
}
// Execute runs the given function if both the rate limiter and circuit breaker allow it
func (rlcb *RateLimitedCircuitBreaker) Execute(fn func() error) error {
// First check rate limiter
if !rlcb.rateLimiter.Allow() {
return errors.New("rate limit exceeded")
}
// Then check circuit breaker
return rlcb.circuitBreaker.Execute(fn)
}
// TokenBucket implements the token bucket algorithm for rate limiting
type TokenBucket struct {
mu sync.Mutex
tokens float64
maxTokens float64
refillRate float64 // tokens per second
lastRefillTime time.Time
}
// NewTokenBucket creates a new token bucket rate limiter
func NewTokenBucket(maxTokens, refillRate float64) *TokenBucket {
return &TokenBucket{
tokens: maxTokens,
maxTokens: maxTokens,
refillRate: refillRate,
lastRefillTime: time.Now(),
}
}
// refill adds tokens to the bucket based on elapsed time
func (tb *TokenBucket) refill() {
now := time.Now()
elapsed := now.Sub(tb.lastRefillTime).Seconds()
// Calculate tokens to add based on elapsed time
newTokens := elapsed * tb.refillRate
// Update token count, capped at maxTokens
tb.tokens = min(tb.tokens+newTokens, tb.maxTokens)
tb.lastRefillTime = now
}
// Allow checks if a request can proceed and consumes a token if allowed
func (tb *TokenBucket) Allow() bool {
tb.mu.Lock()
defer tb.mu.Unlock()
tb.refill()
if tb.tokens >= 1.0 {
tb.tokens--
return true
}
return false
}
// min returns the minimum of two float64 values
func min(a, b float64) float64 {
if a < b {
return a
}
return b
}
func main() {
// Create a rate-limited circuit breaker:
// - 5 RPS
// - Burst of 10
// - Circuit opens after 3 failures
// - Circuit resets after 5 seconds
rlcb := NewRateLimitedCircuitBreaker(5, 10, 3, 5*time.Second)
// Simulate successful requests
fmt.Println("Simulating successful requests...")
for i := 1; i <= 5; i++ {
err := rlcb.Execute(func() error {
fmt.Printf("Request %d executed successfully\n", i)
return nil
})
if err != nil {
fmt.Printf("Request %d failed: %v\n", i, err)
}
time.Sleep(100 * time.Millisecond)
}
// Simulate failing requests
fmt.Println("\nSimulating failing requests...")
for i := 6; i <= 10; i++ {
err := rlcb.Execute(func() error {
fmt.Printf("Request %d would execute, but returning error\n", i)
return errors.New("simulated error")
})
if err != nil {
fmt.Printf("Request %d failed: %v\n", i, err)
}
time.Sleep(100 * time.Millisecond)
}
// Try more requests (should be rejected due to open circuit)
fmt.Println("\nTrying more requests (should be rejected)...")
for i := 11; i <= 15; i++ {
err := rlcb.Execute(func() error {
fmt.Printf("Request %d executed\n", i)
return nil
})
if err != nil {
fmt.Printf("Request %d failed: %v\n", i, err)
}
time.Sleep(100 * time.Millisecond)
}
// Wait for circuit to reset
fmt.Println("\nWaiting for circuit breaker timeout...")
time.Sleep(6 * time.Second)
// Try successful requests again
fmt.Println("\nTrying successful requests after timeout...")
for i := 16; i <= 20; i++ {
err := rlcb.Execute(func() error {
fmt.Printf("Request %d executed successfully\n", i)
return nil
})
if err != nil {
fmt.Printf("Request %d failed: %v\n", i, err)
}
time.Sleep(100 * time.Millisecond)
}
}
Production Best Practices
Monitoring and Observability
Effective rate limiting requires comprehensive monitoring to tune parameters and detect issues. Let’s implement a rate limiter with built-in metrics:
package main
import (
"fmt"
"sync"
"time"
)
// RateLimiterMetrics tracks statistics about rate limiter behavior
type RateLimiterMetrics struct {
mu sync.Mutex
totalRequests int64
allowedRequests int64
rejectedRequests int64
currentQPS float64
peakQPS float64
lastCalculated time.Time
requestCounts []int
requestCountsIndex int
requestCountsSize int
}
// NewRateLimiterMetrics creates a new metrics tracker
func NewRateLimiterMetrics() *RateLimiterMetrics {
metrics := &RateLimiterMetrics{
lastCalculated: time.Now(),
requestCountsSize: 60, // Track last 60 seconds
}
metrics.requestCounts = make([]int, metrics.requestCountsSize)
// Start a goroutine to periodically calculate QPS
go metrics.calculateQPS()
return metrics
}
// calculateQPS periodically calculates the current QPS
func (m *RateLimiterMetrics) calculateQPS() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for range ticker.C {
m.mu.Lock()
// Calculate QPS based on the sum of all request counts
totalRequests := 0
for _, count := range m.requestCounts {
totalRequests += count
}
m.currentQPS = float64(totalRequests) / float64(m.requestCountsSize)
if m.currentQPS > m.peakQPS {
m.peakQPS = m.currentQPS
}
// Reset the current second's count
m.requestCounts[m.requestCountsIndex] = 0
m.requestCountsIndex = (m.requestCountsIndex + 1) % m.requestCountsSize
m.mu.Unlock()
}
}
// RecordRequest records a request and its outcome
func (m *RateLimiterMetrics) RecordRequest(allowed bool) {
m.mu.Lock()
defer m.mu.Unlock()
m.totalRequests++
m.requestCounts[m.requestCountsIndex]++
if allowed {
m.allowedRequests++
} else {
m.rejectedRequests++
}
}
// GetMetrics returns the current metrics
func (m *RateLimiterMetrics) GetMetrics() map[string]interface{} {
m.mu.Lock()
defer m.mu.Unlock()
rejectionRate := float64(0)
if m.totalRequests > 0 {
rejectionRate = float64(m.rejectedRequests) / float64(m.totalRequests)
}
return map[string]interface{}{
"total_requests": m.totalRequests,
"allowed_requests": m.allowedRequests,
"rejected_requests": m.rejectedRequests,
"rejection_rate": rejectionRate,
"current_qps": m.currentQPS,
"peak_qps": m.peakQPS,
}
}
// InstrumentedTokenBucket is a token bucket with metrics
type InstrumentedTokenBucket struct {
tokenBucket *TokenBucket
metrics *RateLimiterMetrics
}
// NewInstrumentedTokenBucket creates a new instrumented token bucket
func NewInstrumentedTokenBucket(maxTokens, refillRate float64) *InstrumentedTokenBucket {
return &InstrumentedTokenBucket{
tokenBucket: NewTokenBucket(maxTokens, refillRate),
metrics: NewRateLimiterMetrics(),
}
}
// Allow checks if a request can proceed and records metrics
func (itb *InstrumentedTokenBucket) Allow() bool {
allowed := itb.tokenBucket.Allow()
itb.metrics.RecordRequest(allowed)
return allowed
}
// GetMetrics returns the current metrics
func (itb *InstrumentedTokenBucket) GetMetrics() map[string]interface{} {
return itb.metrics.GetMetrics()
}
// TokenBucket implements the token bucket algorithm for rate limiting
type TokenBucket struct {
mu sync.Mutex
tokens float64
maxTokens float64
refillRate float64 // tokens per second
lastRefillTime time.Time
}
// NewTokenBucket creates a new token bucket rate limiter
func NewTokenBucket(maxTokens, refillRate float64) *TokenBucket {
return &TokenBucket{
tokens: maxTokens,
maxTokens: maxTokens,
refillRate: refillRate,
lastRefillTime: time.Now(),
}
}
// refill adds tokens to the bucket based on elapsed time
func (tb *TokenBucket) refill() {
now := time.Now()
elapsed := now.Sub(tb.lastRefillTime).Seconds()
// Calculate tokens to add based on elapsed time
newTokens := elapsed * tb.refillRate
// Update token count, capped at maxTokens
tb.tokens = min(tb.tokens+newTokens, tb.maxTokens)
tb.lastRefillTime = now
}
// Allow checks if a request can proceed and consumes a token if allowed
func (tb *TokenBucket) Allow() bool {
tb.mu.Lock()
defer tb.mu.Unlock()
tb.refill()
if tb.tokens >= 1.0 {
tb.tokens--
return true
}
return false
}
// min returns the minimum of two float64 values
func min(a, b float64) float64 {
if a < b {
return a
}
return b
}
func main() {
// Create an instrumented token bucket: 10 max tokens, refill rate of 5 tokens per second
limiter := NewInstrumentedTokenBucket(10, 5)
// Start a goroutine to periodically print metrics
go func() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for range ticker.C {
metrics := limiter.GetMetrics()
fmt.Printf("\nCurrent Metrics:\n")
fmt.Printf(" Total Requests: %d\n", metrics["total_requests"])
fmt.Printf(" Allowed Requests: %d\n", metrics["allowed_requests"])
fmt.Printf(" Rejected Requests: %d\n", metrics["rejected_requests"])
fmt.Printf(" Rejection Rate: %.2f%%\n", metrics["rejection_rate"].(float64)*100)
fmt.Printf(" Current QPS: %.2f\n", metrics["current_qps"])
fmt.Printf(" Peak QPS: %.2f\n", metrics["peak_qps"])
}
}()
// Simulate normal traffic
fmt.Println("Simulating normal traffic...")
for i := 1; i <= 20; i++ {
allowed := limiter.Allow()
fmt.Printf("Request %d: %v\n", i, allowed)
time.Sleep(200 * time.Millisecond)
}
// Simulate burst traffic
fmt.Println("\nSimulating burst traffic...")
for i := 21; i <= 40; i++ {
allowed := limiter.Allow()
fmt.Printf("Request %d: %v\n", i, allowed)
time.Sleep(50 * time.Millisecond)
}
// Wait for metrics to be calculated
time.Sleep(4 * time.Second)
}
Prometheus Integration
For production systems, integrating with monitoring systems like Prometheus is essential:
package main
import (
"fmt"
"log"
"net/http"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
var (
// Define Prometheus metrics
requestsTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "rate_limiter_requests_total",
Help: "The total number of requests",
},
[]string{"status"},
)
currentTokens = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "rate_limiter_current_tokens",
Help: "The current number of tokens in the bucket",
},
[]string{"limiter_id"},
)
requestLatency = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "rate_limiter_request_duration_seconds",
Help: "The latency of rate limiter decisions",
Buckets: prometheus.DefBuckets,
},
[]string{"status"},
)
)
// PrometheusTokenBucket is a token bucket with Prometheus metrics
type PrometheusTokenBucket struct {
mu sync.Mutex
tokens float64
maxTokens float64
refillRate float64 // tokens per second
lastRefillTime time.Time
limiterID string
}
// NewPrometheusTokenBucket creates a new token bucket with Prometheus metrics
func NewPrometheusTokenBucket(limiterID string, maxTokens, refillRate float64) *PrometheusTokenBucket {
tb := &PrometheusTokenBucket{
tokens: maxTokens,
maxTokens: maxTokens,
refillRate: refillRate,
lastRefillTime: time.Now(),
limiterID: limiterID,
}
// Update token gauge periodically
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for range ticker.C {
tb.mu.Lock()
tb.refill()
currentTokens.WithLabelValues(tb.limiterID).Set(tb.tokens)
tb.mu.Unlock()
}
}()
return tb
}
// refill adds tokens to the bucket based on elapsed time
func (tb *PrometheusTokenBucket) refill() {
now := time.Now()
elapsed := now.Sub(tb.lastRefillTime).Seconds()
// Calculate tokens to add based on elapsed time
newTokens := elapsed * tb.refillRate
// Update token count, capped at maxTokens
tb.tokens = min(tb.tokens+newTokens, tb.maxTokens)
tb.lastRefillTime = now
}
// Allow checks if a request can proceed and records metrics
func (tb *PrometheusTokenBucket) Allow() bool {
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
// This function will be called with the observed duration
// We'll handle the status label in the Allow method
}))
defer timer.ObserveDuration()
tb.mu.Lock()
defer tb.mu.Unlock()
tb.refill()
currentTokens.WithLabelValues(tb.limiterID).Set(tb.tokens)
if tb.tokens >= 1.0 {
tb.tokens--
requestsTotal.WithLabelValues("allowed").Inc()
requestLatency.WithLabelValues("allowed").Observe(timer.ObserveDuration().Seconds())
return true
}
requestsTotal.WithLabelValues("rejected").Inc()
requestLatency.WithLabelValues("rejected").Observe(timer.ObserveDuration().Seconds())
return false
}
// min returns the minimum of two float64 values
func min(a, b float64) float64 {
if a < b {
return a
}
return b
}
func main() {
// Create a token bucket with Prometheus metrics
limiter := NewPrometheusTokenBucket("api_limiter", 10, 5)
// Create a simple handler with rate limiting
http.HandleFunc("/api", func(w http.ResponseWriter, r *http.Request) {
if !limiter.Allow() {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
fmt.Fprintln(w, "API request successful")
})
// Expose Prometheus metrics
http.Handle("/metrics", promhttp.Handler())
// Start the server
fmt.Println("Server starting on :8080...")
fmt.Println("API endpoint: http://localhost:8080/api")
fmt.Println("Metrics endpoint: http://localhost:8080/metrics")
log.Fatal(http.ListenAndServe(":8080", nil))
}
Multi-Tier Rate Limiting
In many applications, different types of requests have different resource costs. Multi-tier rate limiting allows for more granular control:
package main
import (
"fmt"
"sync"
"time"
)
// RequestTier represents the resource cost tier of a request
type RequestTier int
const (
TierLow RequestTier = iota // Low-cost operations (e.g., reads)
TierMedium // Medium-cost operations (e.g., simple writes)
TierHigh // High-cost operations (e.g., complex queries)
)
// TierCost defines the token cost for each tier
var TierCost = map[RequestTier]float64{
TierLow: 1.0,
TierMedium: 5.0,
TierHigh: 10.0,
}
// MultiTierTokenBucket implements a token bucket with different costs per request tier
type MultiTierTokenBucket struct {
mu sync.Mutex
tokens float64
maxTokens float64
refillRate float64 // tokens per second
lastRefillTime time.Time
}
// NewMultiTierTokenBucket creates a new multi-tier token bucket rate limiter
func NewMultiTierTokenBucket(maxTokens, refillRate float64) *MultiTierTokenBucket {
return &MultiTierTokenBucket{
tokens: maxTokens,
maxTokens: maxTokens,
refillRate: refillRate,
lastRefillTime: time.Now(),
}
}
// refill adds tokens to the bucket based on elapsed time
func (tb *MultiTierTokenBucket) refill() {
now := time.Now()
elapsed := now.Sub(tb.lastRefillTime).Seconds()
// Calculate tokens to add based on elapsed time
newTokens := elapsed * tb.refillRate
// Update token count, capped at maxTokens
tb.tokens = min(tb.tokens+newTokens, tb.maxTokens)
tb.lastRefillTime = now
}
// Allow checks if a request of the specified tier can proceed
func (tb *MultiTierTokenBucket) Allow(tier RequestTier) bool {
tb.mu.Lock()
defer tb.mu.Unlock()
tb.refill()
// Get the cost for this tier
cost := TierCost[tier]
if tb.tokens >= cost {
tb.tokens -= cost
return true
}
return false
}
// min returns the minimum of two float64 values
func min(a, b float64) float64 {
if a < b {
return a
}
return b
}
func main() {
// Create a multi-tier token bucket: 20 max tokens, refill rate of 5 tokens per second
limiter := NewMultiTierTokenBucket(20, 5)
// Simulate different types of requests
requests := []struct {
ID int
Tier RequestTier
}{
{1, TierLow}, // Cost: 1
{2, TierMedium}, // Cost: 5
{3, TierLow}, // Cost: 1
{4, TierHigh}, // Cost: 10
{5, TierMedium}, // Cost: 5
{6, TierLow}, // Cost: 1
}
for _, req := range requests {
allowed := limiter.Allow(req.Tier)
fmt.Printf("Request %d (Tier: %v, Cost: %.1f): %v\n",
req.ID, req.Tier, TierCost[req.Tier], allowed)
// If this was rejected, wait a bit and try again
if !allowed {
fmt.Println("Waiting for token refill...")
time.Sleep(time.Second)
allowed = limiter.Allow(req.Tier)
fmt.Printf("Retry Request %d: %v\n", req.ID, allowed)
}
}
}
Integration with Web Services
Rate limiting is most commonly applied in web services. Let’s explore how to integrate rate limiting into HTTP servers and middleware.
HTTP Middleware for Rate Limiting
Here’s a complete implementation of rate limiting middleware for standard Go HTTP servers:
package main
import (
"fmt"
"log"
"net/http"
"sync"
"time"
)
// RateLimiterMiddleware provides rate limiting for HTTP handlers
type RateLimiterMiddleware struct {
limiter *IPRateLimiter
}
// NewRateLimiterMiddleware creates a new rate limiter middleware
func NewRateLimiterMiddleware(rps int, burst int) *RateLimiterMiddleware {
return &RateLimiterMiddleware{
limiter: NewIPRateLimiter(rps, burst),
}
}
// Middleware is the HTTP middleware function that applies rate limiting
func (m *RateLimiterMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client IP
ip := getClientIP(r)
// Check if the request is allowed
if !m.limiter.Allow(ip) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
// Call the next handler
next.ServeHTTP(w, r)
})
}
// getClientIP extracts the client IP from the request
func getClientIP(r *http.Request) string {
// Try to get IP from X-Forwarded-For header
ip := r.Header.Get("X-Forwarded-For")
if ip != "" {
return ip
}
// Try to get IP from X-Real-IP header
ip = r.Header.Get("X-Real-IP")
if ip != "" {
return ip
}
// Fall back to RemoteAddr
return r.RemoteAddr
}
// IPRateLimiter manages rate limiters for different IP addresses
type IPRateLimiter struct {
mu sync.Mutex
limiters map[string]*TokenBucket
rps int // Requests per second
burst int // Maximum burst size
cleanupInterval time.Duration
lastCleanup time.Time
}
// NewIPRateLimiter creates a new IP-based rate limiter
func NewIPRateLimiter(rps, burst int) *IPRateLimiter {
limiter := &IPRateLimiter{
limiters: make(map[string]*TokenBucket),
rps: rps,
burst: burst,
cleanupInterval: 10 * time.Minute,
lastCleanup: time.Now(),
}
// Start cleanup goroutine
go limiter.cleanup()
return limiter
}
// cleanup periodically removes inactive limiters to prevent memory leaks
func (i *IPRateLimiter) cleanup() {
ticker := time.NewTicker(i.cleanupInterval)
defer ticker.Stop()
for range ticker.C {
i.mu.Lock()
// Remove limiters that haven't been used recently
now := time.Now()
for ip, limiter := range i.limiters {
if now.Sub(limiter.lastUsed) > 30*time.Minute {
delete(i.limiters, ip)
}
}
i.mu.Unlock()
}
}
// Allow checks if a request from the given IP should be allowed
func (i *IPRateLimiter) Allow(ip string) bool {
i.mu.Lock()
// Check if we need to clean up
now := time.Now()
if now.Sub(i.lastCleanup) > i.cleanupInterval {
// We'll do actual cleanup in a separate goroutine
i.lastCleanup = now
}
// Get or create limiter for this IP
limiter, exists := i.limiters[ip]
if !exists {
limiter = NewTokenBucket(float64(i.burst), float64(i.rps))
i.limiters[ip] = limiter
}
i.mu.Unlock()
return limiter.Allow()
}
// TokenBucket implements the token bucket algorithm for rate limiting
type TokenBucket struct {
mu sync.Mutex
tokens float64
maxTokens float64
refillRate float64 // tokens per second
lastRefillTime time.Time
lastUsed time.Time
}
// NewTokenBucket creates a new token bucket rate limiter
func NewTokenBucket(maxTokens, refillRate float64) *TokenBucket {
now := time.Now()
return &TokenBucket{
tokens: maxTokens,
maxTokens: maxTokens,
refillRate: refillRate,
lastRefillTime: now,
lastUsed: now,
}
}
// refill adds tokens to the bucket based on elapsed time
func (tb *TokenBucket) refill() {
now := time.Now()
elapsed := now.Sub(tb.lastRefillTime).Seconds()
// Calculate tokens to add based on elapsed time
newTokens := elapsed * tb.refillRate
// Update token count, capped at maxTokens
if newTokens > 0 {
tb.tokens = min(tb.tokens+newTokens, tb.maxTokens)
tb.lastRefillTime = now
}
}
// Allow checks if a request can proceed and consumes a token if allowed
func (tb *TokenBucket) Allow() bool {
tb.mu.Lock()
defer tb.mu.Unlock()
tb.refill()
tb.lastUsed = time.Now()
if tb.tokens >= 1.0 {
tb.tokens--
return true
}
return false
}
// min returns the minimum of two float64 values
func min(a, b float64) float64 {
if a < b {
return a
}
return b
}
func main() {
// Create a rate limiter middleware: 5 requests per second, burst of 10
rateLimiter := NewRateLimiterMiddleware(5, 10)
// Create a simple handler
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, World!")
})
// Apply the middleware
http.Handle("/", rateLimiter.Middleware(handler))
// Start the server
fmt.Println("Server starting on :8080...")
log.Fatal(http.ListenAndServe(":8080", nil))
}