The ABA Problem and Solutions

One of the challenges in lock-free programming is the ABA problem, where a value changes from A to B and back to A, potentially causing incorrect behavior in compare-and-swap operations:

package main

import (
	"fmt"
	"sync"
	"sync/atomic"
	"time"
)

func demonstrateABAProblem() {
	// Simple stack implementation using a linked list
	type node struct {
		value int
		next  *node
	}
	
	// Head pointer
	var head atomic.Pointer[node]
	
	// Initialize with some values
	nodeC := &node{value: 3}
	nodeB := &node{value: 2, next: nodeC}
	nodeA := &node{value: 1, next: nodeB}
	head.Store(nodeA)
	
	// Function to pop a node
	pop := func() *node {
		for {
			oldHead := head.Load()
			if oldHead == nil {
				return nil
			}
			newHead := oldHead.next
			if head.CompareAndSwap(oldHead, newHead) {
				return oldHead
			}
		}
	}
	
	// Function to push a node
	push := func(n *node) {
		for {
			oldHead := head.Load()
			n.next = oldHead
			if head.CompareAndSwap(oldHead, n) {
				return
			}
		}
	}
	
	// Simulate the ABA problem
	var wg sync.WaitGroup
	wg.Add(2)
	
	// Goroutine 1: Pop two nodes, then push the first one back
	go func() {
		defer wg.Done()
		
		// Pop nodeA
		nodeA := pop()
		fmt.Printf("Goroutine 1: Popped %d\n", nodeA.value)
		
		// Simulate some delay
		time.Sleep(100 * time.Millisecond)
		
		// Push nodeA back
		push(nodeA)
		fmt.Printf("Goroutine 1: Pushed %d back\n", nodeA.value)
	}()
	
	// Goroutine 2: Pop a node, then push two new nodes, then pop again
	go func() {
		defer wg.Done()
		
		// Wait for Goroutine 1 to pop nodeA
		time.Sleep(10 * time.Millisecond)
		
		// Pop nodeB
		nodeB := pop()
		fmt.Printf("Goroutine 2: Popped %d\n", nodeB.value)
		
		// Pop nodeC
		nodeC := pop()
		fmt.Printf("Goroutine 2: Popped %d\n", nodeC.value)
		
		// Push new nodes
		push(&node{value: 4})
		push(&node{value: 5})
		fmt.Printf("Goroutine 2: Pushed 4 and 5\n")
	}()
	
	wg.Wait()
	
	// Print the final stack
	fmt.Println("Final stack:")
	for curr := head.Load(); curr != nil; curr = curr.next {
		fmt.Printf("%d ", curr.value)
	}
	fmt.Println()
}

// Solution to ABA: Use version counters (tagged pointers)
func abaWithVersionCounters() {
	// Node with version counter
	type versionedNode struct {
		node    *node
		version uint64
	}
	
	type node struct {
		value int
		next  *node
	}
	
	// Head pointer with version
	var head atomic.Uint64
	var nodes []*node
	
	// Initialize with some values
	nodes = append(nodes, &node{value: 1})
	nodes = append(nodes, &node{value: 2})
	nodes = append(nodes, &node{value: 3})
	
	// Link the nodes
	nodes[0].next = nodes[1]
	nodes[1].next = nodes[2]
	
	// Store the initial head (index 0, version 1)
	head.Store(1) // Version 1, index 0
	
	// Function to pop a node
	pop := func() *node {
		for {
			oldHead := head.Load()
			oldVersion := oldHead >> 32
			oldIndex := oldHead & 0xFFFFFFFF
			
			if oldIndex >= uint64(len(nodes)) {
				return nil
			}
			
			oldNode := nodes[oldIndex]
			var newIndex uint64
			if oldNode.next != nil {
				// Find the index of the next node
				for i, n := range nodes {
					if n == oldNode.next {
						newIndex = uint64(i)
						break
					}
				}
			} else {
				newIndex = 0xFFFFFFFF // Special value for nil
			}
			
			// Increment version
			newVersion := oldVersion + 1
			newHead := (newVersion << 32) | newIndex
			
			if head.CompareAndSwap(oldHead, newHead) {
				return oldNode
			}
		}
	}
	
	// Function to push a node
	push := func(value int) {
		for {
			oldHead := head.Load()
			oldVersion := oldHead >> 32
			oldIndex := oldHead & 0xFFFFFFFF
			
			var oldNode *node
			if oldIndex < uint64(len(nodes)) {
				oldNode = nodes[oldIndex]
			}
			
			// Create new node
			newNode := &node{value: value, next: oldNode}
			nodes = append(nodes, newNode)
			newIndex := uint64(len(nodes) - 1)
			
			// Increment version
			newVersion := oldVersion + 1
			newHead := (newVersion << 32) | newIndex
			
			if head.CompareAndSwap(oldHead, newHead) {
				return
			}
		}
	}
	
	// Test the versioned stack
	fmt.Println("Testing versioned stack (ABA solution):")
	
	// Pop the first node
	node1 := pop()
	fmt.Printf("Popped: %d\n", node1.value)
	
	// Pop the second node
	node2 := pop()
	fmt.Printf("Popped: %d\n", node2.value)
	
	// Push the first node back (would cause ABA in a naive implementation)
	push(node1.value)
	fmt.Printf("Pushed: %d\n", node1.value)
	
	// Push a new node
	push(4)
	fmt.Printf("Pushed: %d\n", 4)
	
	// Print the final stack
	fmt.Println("Final stack:")
	for curr := head.Load(); curr != 0xFFFFFFFF; {
		version := curr >> 32
		index := curr & 0xFFFFFFFF
		
		if index >= uint64(len(nodes)) {
			break
		}
		
		node := nodes[index]
		fmt.Printf("%d (v%d) ", node.value, version)
		
		// Find the next node's index
		var nextIndex uint64 = 0xFFFFFFFF
		if node.next != nil {
			for i, n := range nodes {
				if n == node.next {
					nextIndex = uint64(i)
					break
				}
			}
		}
		
		curr = (version << 32) | nextIndex
	}
	fmt.Println()
}

func main() {
	fmt.Println("Demonstrating the ABA problem:")
	demonstrateABAProblem()
	
	fmt.Println("\nDemonstrating ABA solution with version counters:")
	abaWithVersionCounters()
}

This example demonstrates the ABA problem and a solution using version counters. The ABA problem occurs when a thread reads a value A, gets preempted, and then another thread changes the value to B and back to A. When the first thread resumes, it cannot detect that the value has changed. The solution is to use version counters or “tagged pointers” that increment with each modification, ensuring that even if the pointer value is the same, the version will be different.

Compare-and-Swap Patterns

Compare-and-swap (CAS) operations are the foundation of lock-free programming. Let’s explore common patterns and techniques for using CAS effectively.

Basic CAS Loop Pattern

The most common pattern in lock-free programming is the CAS loop:

package main

import (
	"fmt"
	"sync"
	"sync/atomic"
)

func casLoopPattern() {
	var value atomic.Int64
	value.Store(10)
	
	// Basic CAS loop pattern
	updateValue := func(delta int64) {
		for {
			// Read the current value
			current := value.Load()
			
			// Compute the new value
			new := current + delta
			
			// Try to update the value
			if value.CompareAndSwap(current, new) {
				// Success! Value was updated
				return
			}
			// If CAS failed, another goroutine modified the value
			// Loop and try again
		}
	}
	
	// Test the CAS loop
	fmt.Printf("Initial value: %d\n", value.Load())
	
	updateValue(5)
	fmt.Printf("After +5: %d\n", value.Load())
	
	updateValue(-3)
	fmt.Printf("After -3: %d\n", value.Load())
	
	// Test concurrent updates
	var wg sync.WaitGroup
	for i := 0; i < 100; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			updateValue(1)
		}()
	}
	
	wg.Wait()
	fmt.Printf("After 100 concurrent +1 operations: %d\n", value.Load())
}

func main() {
	casLoopPattern()
	
	// Demonstrate more advanced patterns
	backoffPattern()
	conditionalUpdate()
	multipleFieldUpdate()
}

This example demonstrates the basic CAS loop pattern, which is the foundation of most lock-free algorithms. The pattern involves reading the current value, computing a new value based on the current one, and then attempting to update the value using CAS. If the CAS fails, the loop retries with the updated current value.

Exponential Backoff Pattern

In high-contention scenarios, adding a backoff strategy can improve performance:

func backoffPattern() {
	var value atomic.Int64
	
	// CAS loop with exponential backoff
	updateWithBackoff := func(delta int64) {
		backoff := 1 // Start with minimal backoff
		
		for {
			// Read the current value
			current := value.Load()
			
			// Compute the new value
			new := current + delta
			
			// Try to update the value
			if value.CompareAndSwap(current, new) {
				// Success! Value was updated
				return
			}
			
			// If CAS failed, apply backoff
			for i := 0; i < backoff; i++ {
				// Simple busy-wait backoff
				// In real code, consider runtime.Gosched() or time.Sleep
			}
			
			// Increase backoff exponentially, up to a maximum
			if backoff < 1000 {
				backoff *= 2
			}
		}
	}
	
	// Test concurrent updates with backoff
	var wg sync.WaitGroup
	start := time.Now()
	
	for i := 0; i < 1000; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			updateWithBackoff(1)
		}()
	}
	
	wg.Wait()
	duration := time.Since(start)
	
	fmt.Printf("\nBackoff pattern:\n")
	fmt.Printf("  1000 concurrent updates with backoff: %d\n", value.Load())
	fmt.Printf("  Duration: %v\n", duration)
}

This example demonstrates a CAS loop with exponential backoff. When contention is high, adding a backoff strategy can reduce the number of failed CAS operations and improve overall throughput.