sdk/data/azcosmos/async_cache.go (123 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package azcosmos
import (
"context"
"reflect"
"sync"
)
type invalidCacheValue struct{}
func (i invalidCacheValue) Error() string { return "Invalid cache value" }
type asyncCache struct {
values sync.Map
}
type cacheValue struct {
value interface{}
obsoleteValue interface{}
complete bool
fn cacheValueTask
ch <-chan *cacheTaskResult
err error
}
type cacheValueTask func() *cacheTaskResult
type cacheTaskResult struct {
value interface{}
err error
}
func newAsyncCache() *asyncCache {
return &asyncCache{}
}
func (ac *asyncCache) setValue(key interface{}, value interface{}) {
ac.values.Store(key, cacheValue{value: value})
}
func (ac *asyncCache) set(key interface{}, singleValueInit cacheValueTask, ctx context.Context) error {
ch := ac.execCacheValueTask(singleValueInit)
cachedValue := cacheValue{complete: false, fn: singleValueInit, ch: ch}
ac.values.Store(key, cachedValue)
_, err := ac.awaitCacheValue(key, ctx)
if err != nil {
return err
}
return nil
}
func (ac *asyncCache) getValue(key interface{}) (interface{}, bool) {
var cachedValue cacheValue
value, ok := ac.values.Load(key)
if !ok {
return nil, false
}
cachedValue, ok = value.(cacheValue)
if ok {
return cachedValue.value, ok
}
return nil, false
}
func (ac *asyncCache) getAsync(key interface{}, obsoleteValue interface{}, singleValueInit cacheValueTask) error {
var cachedValue cacheValue
value, valueExists := ac.values.Load(key)
if !valueExists {
return nil
}
cachedValue, converted := value.(cacheValue)
if !converted {
return invalidCacheValue{}
}
if cachedValue.complete {
ch := ac.execCacheValueTask(singleValueInit)
cachedValue.obsoleteValue = obsoleteValue
cachedValue.complete = false
cachedValue.fn = singleValueInit
cachedValue.ch = ch
ac.values.Store(key, cachedValue)
} else {
cachedValue.fn = singleValueInit
cachedValue.obsoleteValue = obsoleteValue
ac.values.Store(key, cachedValue)
}
return nil
}
func (ac *asyncCache) remove(key interface{}) {
ac.values.Delete(key)
}
func (ac *asyncCache) clear() {
ac.values.Range(func(key interface{}, value interface{}) bool {
ac.values.Delete(key)
return true
})
}
func (ac *asyncCache) execCacheValueTask(t cacheValueTask) <-chan *cacheTaskResult {
ch := make(chan *cacheTaskResult)
go func() {
defer close(ch)
ch <- t()
}()
return ch
}
func (ac *asyncCache) awaitCacheValue(key interface{}, ctx context.Context) (interface{}, error) {
value, exists := ac.values.Load(key)
if exists {
cachedValue, converted := value.(cacheValue)
if !converted {
return nil, invalidCacheValue{}
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case result := <-cachedValue.ch:
if result == nil {
return cachedValue.value, cachedValue.err
}
if !reflect.DeepEqual(cachedValue.obsoleteValue, result.value) {
cachedValue.value = result.value
cachedValue.err = result.err
cachedValue.complete = true
ac.values.Store(key, cachedValue)
} else {
newch := ac.execCacheValueTask(cachedValue.fn)
cachedValue.ch = newch
ac.values.Store(key, cachedValue)
return ac.awaitCacheValue(key, ctx)
}
}
return cachedValue.value, cachedValue.err
}
return nil, nil
}