linux/handler.go (117 lines of code) (raw):

// +build linux // Copyright (c) Facebook, Inc. and its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package linux import ( "fmt" "sort" "sync" "github.com/google/go-tpm/tpmutil" ) const ( minValidHandle tpmutil.Handle = TPMOrgPersistentMin + 0x0a maxValidHandle tpmutil.Handle = TPMOrgPersistentMin + 0xff ) // KeyHandler stores mapping between key label and key handle. type KeyHandler struct { lock sync.Mutex handles map[string]tpmutil.Handle inFlightLocks map[string]*sync.Mutex inFlightHandles map[string]tpmutil.Handle } // NewKeyHandler returns an instance of KeyHandler. func NewKeyHandler(handles map[string]tpmutil.Handle) KeyHandler { if handles == nil { handles = make(map[string]tpmutil.Handle) } return KeyHandler{ handles: handles, inFlightLocks: make(map[string]*sync.Mutex), inFlightHandles: make(map[string]tpmutil.Handle), } } // Get returns handle for given keyID if present, otherwise // return next available handle and callback which should be called // after tpm key initialization. success indicates whether // tpm key initialization was successful or not. func (h *KeyHandler) Get(keyID string) (tpmutil.Handle, func(success bool), error) { lock := h.lockKey(keyID) h.lock.Lock() defer h.lock.Unlock() handle, ok := h.handles[keyID] if ok { return handle, func(bool) { lock.Unlock() }, nil } next, err := h.nextAvailable() if err != nil { return 0, nil, err } h.inFlightHandles[keyID] = next flush := func(success bool) { h.lock.Lock() if success { h.handles[keyID] = next } delete(h.inFlightLocks, keyID) delete(h.inFlightHandles, keyID) lock.Unlock() h.lock.Unlock() } return next, flush, nil } func (h *KeyHandler) lockKey(keyID string) *sync.Mutex { var ret *sync.Mutex for stop := false; !stop; { h.lock.Lock() lock := h.inFlightLocks[keyID] if lock == nil { lock = new(sync.Mutex) h.inFlightLocks[keyID] = lock } h.lock.Unlock() ret = lock ret.Lock() h.lock.Lock() lock = h.inFlightLocks[keyID] if ret == lock { stop = true } else { ret.Unlock() } h.lock.Unlock() } return ret } // Remove deletes handle with given keyID from KeyHandler if present. func (h *KeyHandler) Remove(keyID string) func(success bool) { lock := h.lockKey(keyID) return func(success bool) { h.lock.Lock() if success { delete(h.handles, keyID) } delete(h.inFlightHandles, keyID) lock.Unlock() h.lock.Unlock() } } func handlesToSlice(m map[string]tpmutil.Handle, dst []tpmutil.Handle) []tpmutil.Handle { for _, key := range m { if key >= minValidHandle { dst = append(dst, key) } } return dst } // nextAvailable returns next key handle available in KeyHandler, // tries to fill gaps if possible. func (h *KeyHandler) nextAvailable() (tpmutil.Handle, error) { l := len(h.handles) + len(h.inFlightHandles) if l == 0 { return minValidHandle, nil } if tpmutil.Handle(l) >= maxValidHandle-minValidHandle+1 { return 0, fmt.Errorf("no more key handles available") } keys := handlesToSlice(h.handles, nil) keys = handlesToSlice(h.inFlightHandles, keys) sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) ret := keys[0] for _, key := range keys { if ret != key { return ret, nil } ret++ } return ret, nil }