oracle/pkg/database/lib/lro/server.go (247 lines of code) (raw):
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package lro contains an implementation of
// https://pkg.go.dev/google.golang.org/genproto/googleapis/longrunning#OperationsServer
package lro
import (
"context"
"errors"
"fmt"
"sort"
"sync"
"time"
opspb "google.golang.org/genproto/googleapis/longrunning"
"google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/emptypb"
log "k8s.io/klog/v2"
)
const (
defaultPageSize int = 10
// DefaultWaitOperationTimeOut is the timeout for WaitOperation.
DefaultWaitOperationTimeOut = 1 * time.Hour
ttlAfterDelete = 10 * time.Minute
ttlAfterComplete = 12 * time.Hour
jobCleanupInterval = time.Minute
)
type job interface {
// Cancel errors if the job is not cancelable.
Cancel() error
// Delete is called on job deletion to clean up resources held by the job.
Delete() error
// done, result, error. This is done in one call to be thread safe.
Status() (bool, *anypb.Any, error)
// Waits until the task is done: result, error. This should use wait groups or something else to do an async wait.
Wait(timeout time.Duration) error
// IsDone returns if the job has completed.
IsDone() bool
// Name returns the job name for metrics/logging purposes.
Name() string
}
type ttlJob struct {
job job
startTime time.Time
completeTime time.Time
mu sync.Mutex
deleteTime time.Time
}
// Server is a gRPC based operation server which
// implements google/longrunning/operations.proto .
type Server struct {
mu sync.Mutex
jobs map[string]*ttlJob
}
// GetOperation gets the status of the LRO operation.
// It is the implementation of GetOperation in
// google/longrunning/operations.proto.
func (s *Server) GetOperation(_ context.Context, request *opspb.GetOperationRequest) (*opspb.Operation, error) {
job, err := s.validateAndGetOperation(request.GetName())
if err != nil {
return nil, err
}
jobID := request.GetName()
resp := GetOperationData(jobID, job.job)
return resp, nil
}
// CancelOperation cancels a long running operation.
// It is the implementation of CancelOperation
// in google/longrunning/operations.proto.
func (s *Server) CancelOperation(_ context.Context, request *opspb.CancelOperationRequest) (*emptypb.Empty, error) {
job, err := s.validateAndGetOperation(request.GetName())
if err != nil {
return nil, err
}
return &emptypb.Empty{}, job.job.Cancel()
}
// ListOperations is part of google/longrunning/operations.proto.
// It is not implemented fully yet.
func (s *Server) ListOperations(_ context.Context, request *opspb.ListOperationsRequest) (*opspb.ListOperationsResponse, error) {
pageSize := int(request.GetPageSize())
if pageSize == 0 {
pageSize = defaultPageSize
}
s.mu.Lock()
defer s.mu.Unlock()
// Zip through the jobs
var operations []*opspb.Operation
var nextID string
for _, id := range sortedMapKeys(s.jobs) {
// Skip until the index is past the next page token id.
if request.GetPageToken() == "" || request.GetPageToken() <= id {
if len(operations) >= pageSize {
nextID = id
break
}
job := s.jobs[id]
operations = append(operations, GetOperationData(id, job.job))
}
}
return &opspb.ListOperationsResponse{Operations: operations, NextPageToken: nextID}, nil
}
// DeleteOperation is part of google/longrunning/operations.proto.
func (s *Server) DeleteOperation(_ context.Context, request *opspb.DeleteOperationRequest) (*emptypb.Empty, error) {
job, err := s.validateAndGetOperation(request.GetName())
if err != nil {
return nil, err
}
job.mu.Lock()
defer job.mu.Unlock()
job.deleteTime = time.Now()
return &emptypb.Empty{}, nil
}
// WaitOperation is part of google/longrunning/operations.proto.
func (s *Server) WaitOperation(_ context.Context, request *opspb.WaitOperationRequest) (*opspb.Operation, error) {
job, err := s.validateAndGetOperation(request.GetName())
if err != nil {
return nil, err
}
duration := DefaultWaitOperationTimeOut
if timeout := request.GetTimeout(); timeout != nil {
err = timeout.CheckValid()
if err != nil {
return nil, grpcstatus.Errorf(codes.InvalidArgument, "Invalid timeout %v for WaitOperation", timeout)
}
duration = timeout.AsDuration()
}
j := job.job
// Wait for the operation to finish and then return the result.
if err := j.Wait(duration); err != nil {
// Error on the wait itself.
log.Infof("WaitOperation: failed to wait for job %v error=%v", request.GetName(), err)
return nil, err
}
return GetOperationData(request.GetName(), j), nil
}
// DeleteExpiredJobs deletes the jobs that are considered as expired.
func (s *Server) DeleteExpiredJobs(ttlAfterDelete time.Duration, ttlAfterComplete time.Duration) {
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
for id, j := range s.jobs {
shouldDelete := false
// Check if Delete has been explicitly called for this job
if isDeletedJobExpired(j, now, ttlAfterDelete) {
shouldDelete = true
}
// Check if the jobs has completed for some time
if j.completeTime.IsZero() && j.job.IsDone() {
j.completeTime = now
}
if !j.completeTime.IsZero() && now.Sub(j.completeTime) > ttlAfterComplete {
shouldDelete = true
}
if shouldDelete {
delete(s.jobs, id)
if err := j.job.Delete(); err != nil {
log.Warning("Job %v deletion returned an error: %v", id, err)
} else {
log.Infof("Job %v has been deleted.", id)
}
}
}
}
func (s *Server) validateAndGetOperation(operationID string) (*ttlJob, error) {
if operationID == "" {
return nil, grpcstatus.Error(codes.InvalidArgument, "bad request: empty operation ID")
}
job, ok := s.getJob(operationID)
if !ok {
return nil, grpcstatus.Errorf(codes.NotFound, "LRO with ID %q NOT found", operationID)
}
return job, nil
}
// AddJob adds a job into the server to be tracked.
func (s *Server) AddJob(id string, job job) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.jobs[id]; ok {
log.Warningf("Job %v already exists", id)
return grpcstatus.Errorf(codes.AlreadyExists, "LRO with ID %q already exists", id)
}
// Start the operation if we know it doesn't exist.
s.startOperation(job.Name())
s.jobs[id] = &ttlJob{job: job, startTime: time.Now()}
return nil
}
func (s *Server) getJob(id string) (*ttlJob, bool) {
s.mu.Lock()
defer s.mu.Unlock()
j, ok := s.jobs[id]
return j, ok
}
func (s *Server) deleteJob(id string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.jobs, id)
log.Infof("Job %v has been deleted.", id)
}
func cleanup(ctx context.Context, lro *Server) {
log.Info("Starting cleanup goroutine.")
tick := time.NewTicker(jobCleanupInterval)
defer tick.Stop()
for {
select {
case <-ctx.Done():
return
case <-tick.C:
lro.DeleteExpiredJobs(ttlAfterDelete, ttlAfterComplete)
}
}
}
// NewServer returns Long running operation server.
func NewServer(ctx context.Context) *Server {
lro := &Server{
jobs: make(map[string]*ttlJob),
}
go cleanup(ctx, lro)
return lro
}
// EndOperation records the result of the operation.
func (s *Server) EndOperation(id string, status string) {
if job, ok := s.getJob(id); ok {
log.Infof("EndOperation: job %v status %v", job.job.Name(), status)
}
}
// WaitAndUnmarshalResult waits until the operation with the opName finishes,
// and either populates the result or the error.
func (s *Server) WaitAndUnmarshalResult(ctx context.Context, opName string, targetProto proto.Message) error {
op, err := s.WaitOperation(ctx, &opspb.WaitOperationRequest{Name: opName})
if err != nil {
return fmt.Errorf("WaitOperation returns error: %v", err)
}
if op.GetError() != nil {
return errors.New(op.GetError().GetMessage())
}
if op.GetResponse() == nil || targetProto == nil {
return nil
}
return op.GetResponse().UnmarshalTo(targetProto)
}
func (s *Server) startOperation(name string) {
log.Infof("startOperation: job %v", name)
}
// GetOperationData fills in the operation data for this specific job.
func GetOperationData(id string, j job) *opspb.Operation {
done, result, e := j.Status()
return BuildOperation(id, done, result, e)
}
// BuildOperation builds the operation response for this specific grpcstatus.
func BuildOperation(id string, done bool, result *anypb.Any, e error) *opspb.Operation {
// Nothing to return at all.
if result == nil && e == nil {
return &opspb.Operation{Done: done, Name: id}
}
// Can return partial results
if e != nil {
if st, ok := grpcstatus.FromError(e); ok {
return &opspb.Operation{Done: done, Name: id, Result: &opspb.Operation_Error{
Error: st.Proto(),
}}
}
return &opspb.Operation{Done: done, Name: id, Result: &opspb.Operation_Error{
Error: &status.Status{
Code: int32(codes.Unknown),
Message: e.Error(),
},
}}
}
return &opspb.Operation{Done: done, Name: id, Result: &opspb.Operation_Response{
Response: result,
}}
}
// sortedMapKeys is used in ListOperation to make sure everything is in order.
func sortedMapKeys(m map[string]*ttlJob) []string {
keys := make([]string, 0, len(m))
for key := range m {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}
func isDeletedJobExpired(job *ttlJob, now time.Time, ttl time.Duration) bool {
job.mu.Lock()
defer job.mu.Unlock()
return !job.deleteTime.IsZero() && now.Sub(job.deleteTime) > ttl
}