Custom Synchronization Primitives

While Go’s standard library provides a solid foundation, some scenarios require specialized synchronization primitives. Let’s explore how to build custom primitives that address specific needs.

Countdown Latch

A countdown latch allows one or more goroutines to wait until a set of operations completes:

package main

import (
	"fmt"
	"sync"
	"time"
)

// CountdownLatch is a synchronization aid that allows one or more goroutines
// to wait until a set of operations being performed in other goroutines completes.
type CountdownLatch struct {
	count int
	mu    sync.Mutex
	cond  *sync.Cond
}

// NewCountdownLatch creates a new countdown latch initialized with the given count.
func NewCountdownLatch(count int) *CountdownLatch {
	latch := &CountdownLatch{
		count: count,
	}
	latch.cond = sync.NewCond(&latch.mu)
	return latch
}

// CountDown decrements the count of the latch, releasing all waiting goroutines
// when the count reaches zero.
func (l *CountdownLatch) CountDown() {
	l.mu.Lock()
	defer l.mu.Unlock()
	
	if l.count <= 0 {
		return
	}
	
	l.count--
	if l.count == 0 {
		l.cond.Broadcast()
	}
}

// Await causes the current goroutine to wait until the latch has counted down to zero.
func (l *CountdownLatch) Await() {
	l.mu.Lock()
	defer l.mu.Unlock()
	
	for l.count > 0 {
		l.cond.Wait()
	}
}

// TryAwait causes the current goroutine to wait until the latch has counted down to zero
// or the specified timeout elapses.
func (l *CountdownLatch) TryAwait(timeout time.Duration) bool {
	deadline := time.Now().Add(timeout)
	
	l.mu.Lock()
	defer l.mu.Unlock()
	
	for l.count > 0 {
		if time.Now().After(deadline) {
			return false
		}
		
		// Set a short timeout for the condition variable wait
		waitTimer := time.NewTimer(50 * time.Millisecond)
		waitCh := make(chan struct{})
		
		go func() {
			l.cond.Wait()
			close(waitCh)
		}()
		
		l.mu.Unlock()
		select {
		case <-waitCh:
			// Condition was signaled
			waitTimer.Stop()
			l.mu.Lock()
		case <-waitTimer.C:
			// Timer expired, reacquire lock and check condition again
			l.mu.Lock()
		}
	}
	
	return true
}

// GetCount returns the current count.
func (l *CountdownLatch) GetCount() int {
	l.mu.Lock()
	defer l.mu.Unlock()
	return l.count
}

func demonstrateCountdownLatch() {
	fmt.Println("\n=== Countdown Latch ===")
	
	// Create a countdown latch with count 3
	latch := NewCountdownLatch(3)
	
	// Start a goroutine that waits on the latch
	go func() {
		fmt.Println("Waiter: Waiting for latch to reach zero")
		latch.Await()
		fmt.Println("Waiter: Latch reached zero, proceeding")
	}()
	
	// Start another goroutine that waits with timeout
	go func() {
		fmt.Println("Timeout Waiter: Waiting with 2 second timeout")
		if latch.TryAwait(2 * time.Second) {
			fmt.Println("Timeout Waiter: Latch reached zero within timeout")
		} else {
			fmt.Println("Timeout Waiter: Timeout expired before latch reached zero")
		}
	}()
	
	// Simulate work being completed
	for i := 1; i <= 3; i++ {
		time.Sleep(500 * time.Millisecond)
		fmt.Printf("Main: Counting down latch (%d remaining)\n", latch.GetCount()-1)
		latch.CountDown()
	}
	
	// Give waiters time to print their messages
	time.Sleep(100 * time.Millisecond)
}

func main() {
	demonstrateCountdownLatch()
}

Cyclic Barrier

A cyclic barrier enables a group of goroutines to wait for each other to reach a common execution point:

package main

import (
	"fmt"
	"sync"
	"time"
)

// CyclicBarrier is a synchronization aid that allows a set of goroutines to all
// wait for each other to reach a common barrier point.
type CyclicBarrier struct {
	parties       int
	count         int
	generation    int
	barrierAction func()
	mu            sync.Mutex
	cond          *sync.Cond
}

// NewCyclicBarrier creates a new cyclic barrier that will trip when the given
// number of parties are waiting upon it.
func NewCyclicBarrier(parties int, barrierAction func()) *CyclicBarrier {
	barrier := &CyclicBarrier{
		parties:       parties,
		count:         parties,
		barrierAction: barrierAction,
	}
	barrier.cond = sync.NewCond(&barrier.mu)
	return barrier
}

// Await causes the current goroutine to wait until all parties have invoked await
// on this barrier. If the current goroutine is the last to arrive, the barrier action
// is executed and the barrier is reset.
func (b *CyclicBarrier) Await() int {
	b.mu.Lock()
	defer b.mu.Unlock()
	
	generation := b.generation
	
	// Decrement count and check if we're the last to arrive
	b.count--
	index := b.parties - b.count - 1
	
	if b.count == 0 {
		// We're the last to arrive
		if b.barrierAction != nil {
			// Execute the barrier action
			b.barrierAction()
		}
		
		// Reset the barrier
		b.count = b.parties
		b.generation++
		
		// Wake up all waiting goroutines
		b.cond.Broadcast()
		return index
	}
	
	// Wait until the barrier is tripped or reset
	for generation == b.generation {
		b.cond.Wait()
	}
	
	return index
}

// Reset resets the barrier to its initial state.
func (b *CyclicBarrier) Reset() {
	b.mu.Lock()
	defer b.mu.Unlock()
	
	b.count = b.parties
	b.generation++
	b.cond.Broadcast()
}

// GetNumberWaiting returns the number of parties currently waiting at the barrier.
func (b *CyclicBarrier) GetNumberWaiting() int {
	b.mu.Lock()
	defer b.mu.Unlock()
	return b.parties - b.count
}

func demonstrateCyclicBarrier() {
	fmt.Println("\n=== Cyclic Barrier ===")
	
	// Number of workers
	numWorkers := 3
	
	// Create a barrier action
	barrierAction := func() {
		fmt.Println("Barrier Action: All workers reached the barrier, executing barrier action")
	}
	
	// Create a cyclic barrier
	barrier := NewCyclicBarrier(numWorkers, barrierAction)
	
	// Create a wait group to wait for all workers to complete
	var wg sync.WaitGroup
	
	// Launch workers
	for i := 1; i <= numWorkers; i++ {
		wg.Add(1)
		go func(id int) {
			defer wg.Done()
			
			for phase := 1; phase <= 3; phase++ {
				// Simulate work before reaching the barrier
				workTime := time.Duration(id*100) * time.Millisecond
				fmt.Printf("Worker %d: Working on phase %d for %v\n", id, phase, workTime)
				time.Sleep(workTime)
				
				fmt.Printf("Worker %d: Reached barrier for phase %d\n", id, phase)
				index := barrier.Await()
				fmt.Printf("Worker %d: Crossed barrier for phase %d (arrival index: %d)\n", id, phase, index)
				
				// Small pause between phases
				time.Sleep(50 * time.Millisecond)
			}
		}(i)
	}
	
	wg.Wait()
	fmt.Println("All workers completed all phases")
}

func main() {
	demonstrateCyclicBarrier()
}

These custom synchronization primitives demonstrate how Go’s basic primitives can be combined to create more specialized tools. Each addresses specific coordination patterns that aren’t directly provided by the standard library.