plugins/teststeps/waitport/waitport.go (223 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
// Package waitport provides an ability of waiting for a port on remote host to be opened for listening
package waitport
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"strconv"
"strings"
"time"
"github.com/facebookincubator/contest/pkg/event"
"github.com/facebookincubator/contest/pkg/event/testevent"
"github.com/facebookincubator/contest/pkg/target"
"github.com/facebookincubator/contest/pkg/test"
"github.com/facebookincubator/contest/pkg/xcontext"
"github.com/facebookincubator/contest/plugins/teststeps"
)
// Name is the name used to look this plugin up.
const Name = "waitport"
// event names for this plugin.
const (
EventCmdStart = event.Name(Name + "Start")
EventCmdEnd = event.Name(Name + "End")
)
// Events defines the events that a TestStep is allow to emit
var Events = []event.Name{
EventCmdStart,
EventCmdEnd,
}
// WaitPort provides an ability of waiting for a port on remote host to be opened for listening
type WaitPort struct {
}
// Name returns the plugin name.
func (ts *WaitPort) Name() string {
return Name
}
// Run executes the cmd step.
func (ts *WaitPort) Run(ctx xcontext.Context, ch test.TestStepChannels, inputParams test.TestStepParameters, ev testevent.Emitter, resumeState json.RawMessage) (json.RawMessage, error) {
params, err := parseParameters(inputParams)
if err != nil {
return nil, err
}
f := func(ctx xcontext.Context, targetWithData *teststeps.TargetWithData) error {
target := targetWithData.Target
targetParams, err := expandParameters(target, params)
if err != nil {
return err
}
// Emit EventCmdStart
// Can emit duplicate events on server restart / job resumption
payload, err := json.Marshal(targetParams)
if err != nil {
ctx.Warnf("Cannot encode payload for %T: %v", params, err)
} else {
rm := json.RawMessage(payload)
evData := testevent.Data{
EventName: EventCmdStart,
Target: target,
Payload: &rm,
}
if err := ev.Emit(ctx, evData); err != nil {
ctx.Warnf("Cannot emit event EventCmdStart: %v", err)
}
}
var resultAddresses []string
portStr := strconv.Itoa(targetParams.Port)
if len(targetParams.Address) > 0 {
resultAddresses = append(resultAddresses, net.JoinHostPort(targetParams.Address, portStr))
} else {
if len(target.FQDN) > 0 {
resultAddresses = append(resultAddresses, net.JoinHostPort(target.FQDN, portStr))
}
if len(target.PrimaryIPv4) > 0 {
resultAddresses = append(resultAddresses, net.JoinHostPort(target.PrimaryIPv4.String(), portStr))
}
if len(target.PrimaryIPv6) > 0 {
resultAddresses = append(resultAddresses, net.JoinHostPort(target.PrimaryIPv6.String(), portStr))
}
}
// The timeout restarts after a server restart/resume
finishedContext, cancel := context.WithTimeout(ctx, targetParams.Timeout)
defer cancel()
resultErr := func() error {
for {
for _, addr := range resultAddresses {
var d net.Dialer
conn, err := d.DialContext(finishedContext, targetParams.Protocol, addr)
if err == nil {
ctx.Warnf("successfully connected via %s", addr)
if err := conn.Close(); err != nil {
ctx.Warnf("failed to close opened connection: %v", err)
}
return nil
}
ctx.Warnf("failed to connect to '%s', err: '%v'", addr, err)
if finishedContext.Err() != nil {
return finishedContext.Err()
}
}
ctx.Infof("wait for the next iteration")
select {
case <-finishedContext.Done():
return finishedContext.Err()
case <-time.After(targetParams.CheckInterval):
}
}
}()
// Emit EventCmdEnd
evData := testevent.Data{
EventName: EventCmdEnd,
Target: target,
Payload: nil,
}
if err := ev.Emit(ctx, evData); err != nil {
ctx.Warnf("Cannot emit event EventCmdEnd: %v", err)
}
ctx.Infof("wait port plugin finished, err: '%v'", resultErr)
return resultErr
}
return teststeps.ForEachTargetWithResume(ctx, ch, resumeState, 0, f)
}
// ValidateParameters validates the parameters associated to the TestStep
func (ts *WaitPort) ValidateParameters(ctx xcontext.Context, params test.TestStepParameters) error {
_, err := parseParameters(params)
return err
}
// New initializes and returns a new Cmd test step.
func New() test.TestStep {
return &WaitPort{}
}
// Load returns the name, factory and events which are needed to register the step.
func Load() (string, test.TestStepFactory, []event.Name) {
return Name, New, Events
}
var protocolOptions = []string{
"tcp",
"tcp4",
"tcp6",
"udp",
"udp4",
"udp6",
}
type parameters struct {
target *test.Param
port int
protocol string
checkInterval time.Duration
timeout time.Duration
}
func parseParameters(params test.TestStepParameters) (*parameters, error) {
const maxPortValue = 65535
var target *test.Param
targets := params.Get("target")
switch len(targets) {
case 0:
case 1:
target = &targets[0]
default:
return nil, errors.New("0 or 1 'target' parameters should be provided")
}
port, err := params.GetInt("port")
if err != nil {
return nil, fmt.Errorf("failed to parse 'port' parameter: %w", err)
}
if port <= 0 || port > maxPortValue {
return nil, fmt.Errorf("'port' parameter should be a positive integer that is less than %d, got: '%d'", maxPortValue, port)
}
protocols := params.Get("protocol")
if len(protocols) != 1 {
return nil, errors.New("a single 'protocol' should be provided")
}
protocol := strings.ToLower(protocols[0].String())
isValidProtocol := func() bool {
for _, opt := range protocolOptions {
if opt == protocol {
return true
}
}
return false
}()
if !isValidProtocol {
return nil, fmt.Errorf("'protocol' should be one of [%s]", strings.Join(protocolOptions, ", "))
}
timeout, err := parseDurationParam(params, "timeout")
if err != nil {
return nil, err
}
checkInterval, err := parseDurationParam(params, "check_interval")
if err != nil {
return nil, err
}
return ¶meters{
target: target,
port: int(port),
protocol: protocol,
checkInterval: checkInterval,
timeout: timeout,
}, nil
}
type targetParameters struct {
Address string
Port int
Protocol string
CheckInterval time.Duration
Timeout time.Duration
}
func expandParameters(t *target.Target, params *parameters) (*targetParameters, error) {
var address string
if params.target != nil {
var err error
address, err = params.target.Expand(t)
if err != nil {
return nil, fmt.Errorf("cannot expand target parameter '%s': '%v'", params.target.String(), err)
}
}
return &targetParameters{
Address: address,
Port: params.port,
Protocol: params.protocol,
Timeout: params.timeout,
}, nil
}
func parseDurationParam(teststepParams test.TestStepParameters, parameterName string) (time.Duration, error) {
var result time.Duration
params := teststepParams.Get(parameterName)
if len(params) != 1 {
return result, fmt.Errorf("a single '%s' should be provided", parameterName)
}
var err error
paramStr := params[0].String()
result, err = time.ParseDuration(paramStr)
if err != nil {
return result, fmt.Errorf("failed to convert '%s' duration parameter, err: %v", paramStr, err)
}
return result, nil
}