Custom Synchronization Primitives
While Go’s standard library provides a solid foundation, some scenarios require specialized synchronization primitives. Let’s explore how to build custom primitives that address specific needs.
Countdown Latch
A countdown latch allows one or more goroutines to wait until a set of operations completes:
package main
import (
"fmt"
"sync"
"time"
)
// CountdownLatch is a synchronization aid that allows one or more goroutines
// to wait until a set of operations being performed in other goroutines completes.
type CountdownLatch struct {
count int
mu sync.Mutex
cond *sync.Cond
}
// NewCountdownLatch creates a new countdown latch initialized with the given count.
func NewCountdownLatch(count int) *CountdownLatch {
latch := &CountdownLatch{
count: count,
}
latch.cond = sync.NewCond(&latch.mu)
return latch
}
// CountDown decrements the count of the latch, releasing all waiting goroutines
// when the count reaches zero.
func (l *CountdownLatch) CountDown() {
l.mu.Lock()
defer l.mu.Unlock()
if l.count <= 0 {
return
}
l.count--
if l.count == 0 {
l.cond.Broadcast()
}
}
// Await causes the current goroutine to wait until the latch has counted down to zero.
func (l *CountdownLatch) Await() {
l.mu.Lock()
defer l.mu.Unlock()
for l.count > 0 {
l.cond.Wait()
}
}
// TryAwait causes the current goroutine to wait until the latch has counted down to zero
// or the specified timeout elapses.
func (l *CountdownLatch) TryAwait(timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
l.mu.Lock()
defer l.mu.Unlock()
for l.count > 0 {
if time.Now().After(deadline) {
return false
}
// Set a short timeout for the condition variable wait
waitTimer := time.NewTimer(50 * time.Millisecond)
waitCh := make(chan struct{})
go func() {
l.cond.Wait()
close(waitCh)
}()
l.mu.Unlock()
select {
case <-waitCh:
// Condition was signaled
waitTimer.Stop()
l.mu.Lock()
case <-waitTimer.C:
// Timer expired, reacquire lock and check condition again
l.mu.Lock()
}
}
return true
}
// GetCount returns the current count.
func (l *CountdownLatch) GetCount() int {
l.mu.Lock()
defer l.mu.Unlock()
return l.count
}
func demonstrateCountdownLatch() {
fmt.Println("\n=== Countdown Latch ===")
// Create a countdown latch with count 3
latch := NewCountdownLatch(3)
// Start a goroutine that waits on the latch
go func() {
fmt.Println("Waiter: Waiting for latch to reach zero")
latch.Await()
fmt.Println("Waiter: Latch reached zero, proceeding")
}()
// Start another goroutine that waits with timeout
go func() {
fmt.Println("Timeout Waiter: Waiting with 2 second timeout")
if latch.TryAwait(2 * time.Second) {
fmt.Println("Timeout Waiter: Latch reached zero within timeout")
} else {
fmt.Println("Timeout Waiter: Timeout expired before latch reached zero")
}
}()
// Simulate work being completed
for i := 1; i <= 3; i++ {
time.Sleep(500 * time.Millisecond)
fmt.Printf("Main: Counting down latch (%d remaining)\n", latch.GetCount()-1)
latch.CountDown()
}
// Give waiters time to print their messages
time.Sleep(100 * time.Millisecond)
}
func main() {
demonstrateCountdownLatch()
}
Cyclic Barrier
A cyclic barrier enables a group of goroutines to wait for each other to reach a common execution point:
package main
import (
"fmt"
"sync"
"time"
)
// CyclicBarrier is a synchronization aid that allows a set of goroutines to all
// wait for each other to reach a common barrier point.
type CyclicBarrier struct {
parties int
count int
generation int
barrierAction func()
mu sync.Mutex
cond *sync.Cond
}
// NewCyclicBarrier creates a new cyclic barrier that will trip when the given
// number of parties are waiting upon it.
func NewCyclicBarrier(parties int, barrierAction func()) *CyclicBarrier {
barrier := &CyclicBarrier{
parties: parties,
count: parties,
barrierAction: barrierAction,
}
barrier.cond = sync.NewCond(&barrier.mu)
return barrier
}
// Await causes the current goroutine to wait until all parties have invoked await
// on this barrier. If the current goroutine is the last to arrive, the barrier action
// is executed and the barrier is reset.
func (b *CyclicBarrier) Await() int {
b.mu.Lock()
defer b.mu.Unlock()
generation := b.generation
// Decrement count and check if we're the last to arrive
b.count--
index := b.parties - b.count - 1
if b.count == 0 {
// We're the last to arrive
if b.barrierAction != nil {
// Execute the barrier action
b.barrierAction()
}
// Reset the barrier
b.count = b.parties
b.generation++
// Wake up all waiting goroutines
b.cond.Broadcast()
return index
}
// Wait until the barrier is tripped or reset
for generation == b.generation {
b.cond.Wait()
}
return index
}
// Reset resets the barrier to its initial state.
func (b *CyclicBarrier) Reset() {
b.mu.Lock()
defer b.mu.Unlock()
b.count = b.parties
b.generation++
b.cond.Broadcast()
}
// GetNumberWaiting returns the number of parties currently waiting at the barrier.
func (b *CyclicBarrier) GetNumberWaiting() int {
b.mu.Lock()
defer b.mu.Unlock()
return b.parties - b.count
}
func demonstrateCyclicBarrier() {
fmt.Println("\n=== Cyclic Barrier ===")
// Number of workers
numWorkers := 3
// Create a barrier action
barrierAction := func() {
fmt.Println("Barrier Action: All workers reached the barrier, executing barrier action")
}
// Create a cyclic barrier
barrier := NewCyclicBarrier(numWorkers, barrierAction)
// Create a wait group to wait for all workers to complete
var wg sync.WaitGroup
// Launch workers
for i := 1; i <= numWorkers; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for phase := 1; phase <= 3; phase++ {
// Simulate work before reaching the barrier
workTime := time.Duration(id*100) * time.Millisecond
fmt.Printf("Worker %d: Working on phase %d for %v\n", id, phase, workTime)
time.Sleep(workTime)
fmt.Printf("Worker %d: Reached barrier for phase %d\n", id, phase)
index := barrier.Await()
fmt.Printf("Worker %d: Crossed barrier for phase %d (arrival index: %d)\n", id, phase, index)
// Small pause between phases
time.Sleep(50 * time.Millisecond)
}
}(i)
}
wg.Wait()
fmt.Println("All workers completed all phases")
}
func main() {
demonstrateCyclicBarrier()
}
These custom synchronization primitives demonstrate how Go’s basic primitives can be combined to create more specialized tools. Each addresses specific coordination patterns that aren’t directly provided by the standard library.