pkg/partmap/partmap.go (137 lines of code) (raw):

package partmap import ( "fmt" "sync" "github.com/cespare/xxhash" ) // Map is a concurrent map with string keys and generic values, partitioned for better concurrency. type Map[V any] struct { partitions []*syncMap[V] } // syncMap is a thread-safe map with string keys and generic values. type syncMap[V any] struct { mu sync.RWMutex m map[string]V } // apply applies a function to each key-value pair in the map. func (s *syncMap[V]) apply(fn func(key string, value V) error) error { s.mu.RLock() defer s.mu.RUnlock() for key, value := range s.m { if err := fn(key, value); err != nil { return err } } return nil } // get retrieves the value for a given key. func (s *syncMap[V]) get(key string) (V, bool) { s.mu.RLock() defer s.mu.RUnlock() v, ok := s.m[key] return v, ok } // set sets the value for a given key. func (s *syncMap[V]) set(key string, value V) V { s.mu.Lock() defer s.mu.Unlock() v, ok := s.m[key] if ok { return v } s.m[key] = value return value } // delete removes the value for a given key. func (s *syncMap[V]) delete(key string) (V, bool) { s.mu.Lock() defer s.mu.Unlock() v, ok := s.m[key] if !ok { var zero V return zero, false } delete(s.m, key) return v, true } // getOrCreate retrieves the value for a given key or creates it using the provided function. func (s *syncMap[V]) getOrCreate(key string, fn func() (V, error)) (V, error) { s.mu.RLock() v, ok := s.m[key] s.mu.RUnlock() if ok { return v, nil } s.mu.Lock() defer s.mu.Unlock() v, ok = s.m[key] if ok { return v, nil } var err error v, err = fn() if err != nil { var zero V return zero, err } s.m[key] = v return v, nil } func (s *syncMap[V]) mutate(key string, fn func(value V) (V, error)) error { if fn == nil { return fmt.Errorf("fn is nil") } s.mu.Lock() defer s.mu.Unlock() // Use the zero value of V if the key does not exist. v, _ := s.m[key] newValue, err := fn(v) if err != nil { return err } s.m[key] = newValue return nil } // NewMap creates a new Map with the specified number of partitions. func NewMap[V any](partitions int) *Map[V] { parts := make([]*syncMap[V], partitions) for i := range parts { parts[i] = &syncMap[V]{ m: make(map[string]V), } } return &Map[V]{ partitions: parts, } } // GetOrCreate retrieves the value for a given key or creates it using the provided function. func (m *Map[V]) GetOrCreate(key string, fn func() (V, error)) (V, error) { idx := m.partition(key) return m.partitions[idx].getOrCreate(key, fn) } // Get retrieves the value for a given key. func (m *Map[V]) Get(key string) (V, bool) { idx := m.partition(key) return m.partitions[idx].get(key) } // Set sets the value for a given key. func (m *Map[V]) Set(key string, value V) V { idx := m.partition(key) return m.partitions[idx].set(key, value) } // Delete removes the value for a given key. func (m *Map[V]) Delete(key string) (V, bool) { idx := m.partition(key) return m.partitions[idx].delete(key) } // Mutate applies a function to the value for a given key and sets the returned value atomically. func (m *Map[V]) Mutate(key string, fn func(value V) (V, error)) error { idx := m.partition(key) return m.partitions[idx].mutate(key, fn) } // partition calculates the partition index for a given key. func (m *Map[V]) partition(key string) uint64 { return xxhash.Sum64String(key) % uint64(len(m.partitions)) } // Each applies a function to each key-value pair in the map. func (m *Map[V]) Each(fn func(key string, value V) error) error { for _, partition := range m.partitions { if err := partition.apply(fn); err != nil { return err } } return nil } // Count returns the number of key-value pairs in the map. func (m *Map[V]) Count() int { var count int for _, partition := range m.partitions { partition.mu.RLock() count += len(partition.m) partition.mu.RUnlock() } return count }