functions.go (152 lines of code) (raw):
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rest
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strconv"
"sync"
"github.com/GoogleCloudPlatform/functions-framework-go/functions"
"github.com/goccy/go-json"
"github.com/googlecloudplatform/pi-delivery/gen/index"
"github.com/googlecloudplatform/pi-delivery/pkg/service"
"go.ajitem.com/zapdriver"
"go.uber.org/zap"
)
var _serv *service.Service
var _servOnce sync.Once
var maxDigitsPerRequest = 1000
var bucketName = index.BucketName
const (
envMaxDigitsPerRequest = "PI_MAX_DIGITS_PER_REQUEST"
envBucketName = "PI_BUCKET_NAME"
)
func init() {
functions.HTTP("Get", Get)
functions.HTTP("NotFound", NotFound)
if logger, err := zapdriver.NewProduction(); err != nil {
zap.S().Fatalw("zapdriver.NewProduction() failed", "error", err)
} else {
zap.ReplaceGlobals(logger)
}
defer zap.S().Sync()
// Read configurations from env.
if s := os.Getenv(envMaxDigitsPerRequest); s != "" {
if i, err := strconv.Atoi(s); err != nil {
zap.S().Error("invalid env value", "name", envMaxDigitsPerRequest, "value", s)
} else {
maxDigitsPerRequest = i
}
}
if s := os.Getenv(envBucketName); s != "" {
bucketName = s
}
zap.S().Info("Config",
"maxDigitsPerRequest", maxDigitsPerRequest,
"bucketName", bucketName,
)
}
func getService(ctx context.Context) *service.Service {
_servOnce.Do(func() {
_serv = service.NewService(ctx, zap.S(), bucketName)
})
return _serv
}
func namedLogger(l *zap.SugaredLogger, name string, req *http.Request) *zap.SugaredLogger {
return l.Named(name).
With(
zapdriver.HTTP(zapdriver.NewHTTP(req, nil)),
)
}
func writeError(l *zap.SugaredLogger, res http.ResponseWriter, code int, s string) {
l.Errorw(s, "code", code)
res.Header().Add("Content-Type", "text/plain")
res.WriteHeader(code)
_, err := io.WriteString(res, s)
if err != nil {
l.Errorw("WriteString failed", "error", err)
}
}
func getIntQueryParam(l *zap.SugaredLogger, q url.Values, name string, def int64) (int64, error) {
// TODO(yuryu): Use Has() when go 1.17 is available on Functions.
p := q.Get(name)
if p == "" {
return def, nil
}
i, err := strconv.ParseInt(p, 10, 64)
if err != nil {
l.Errorw("ParseInt failed", "error", err, "param", name, "value", p)
return 0, fmt.Errorf("invalid request: %s", name)
}
return i, nil
}
// GetResponse is the JSON response for Get.
type GetResponse struct {
// Content is a string representation of Pi digits.
// ex. "31415926535897932384626433832795028841971693993"
Content string `json:"content"`
}
// Get is the entrypoint for the API.
// It takes three parameters in the query string:
// - start (int64): the digit position to read from.
// - numberOfDigits(int64): number of digits to read.
// - radix (int): the radix of pi to read. 10 or 16. default 10.
// It returns a JSON response as GetResponse.
func Get(res http.ResponseWriter, req *http.Request) {
l := namedLogger(zap.S(), "Get", req)
defer l.Sync()
l.Info("Get start")
res.Header().Set("Access-Control-Allow-Origin", "*")
q := req.URL.Query()
radix, err := getIntQueryParam(l, q, "radix", 10)
if err != nil {
writeError(l, res, http.StatusBadRequest, err.Error())
return
}
if radix != 10 && radix != 16 {
writeError(l, res, http.StatusBadRequest, "radix must be either 10 or 16")
return
}
set := index.Decimal
if radix == 16 {
set = index.Hexadecimal
}
start, err := getIntQueryParam(l, q, "start", 0)
if err != nil {
writeError(l, res, http.StatusBadRequest, err.Error())
return
}
if start < 0 {
writeError(l, res, http.StatusBadRequest, "start is negative")
return
}
if start > set.TotalDigits() {
writeError(l, res, http.StatusBadRequest, "start out of range")
return
}
numberOfDigits, err := getIntQueryParam(l, q, "numberOfDigits", 100)
if err != nil {
writeError(l, res, http.StatusBadRequest, err.Error())
return
}
if numberOfDigits < 0 {
writeError(l, res, http.StatusBadRequest, "numberOfDigits is negative")
return
}
if numberOfDigits > int64(maxDigitsPerRequest) {
writeError(l, res, http.StatusBadRequest, "numberOfDigits is too big")
return
}
unpacked, err := getService(req.Context()).
Get(req.Context(), l, set, start, numberOfDigits)
if err != nil {
writeError(l, res, http.StatusInternalServerError, "Internal Server Error")
return
}
res.Header().Set("Content-Type", "application/json")
res.WriteHeader(http.StatusOK)
err = json.NewEncoder(res).EncodeWithOption(
&GetResponse{Content: string(unpacked)},
json.DisableHTMLEscape(),
)
if err != nil {
l.Errorw("json encode failed",
"error", err)
}
}
// NotFound returns 404 for all requests.
// This is necessary because LB can't return 404 by itself.
// https://issuetracker.google.com/160192483
func NotFound(res http.ResponseWriter, req *http.Request) {
res.Header().Set("Content-Type", "text/plain; charset=utf-8")
res.WriteHeader(http.StatusNotFound)
io.WriteString(res, fmt.Sprintf("The requested url %s was not found.\n", req.URL.Path))
}