src/sessionmanagerplugin/session/portsession/portsession.go (102 lines of code) (raw):
// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may not
// use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
// either express or implied. See the License for the specific language governing
// permissions and limitations under the License.
// Package portsession starts port session.
package portsession
import (
"github.com/aws/session-manager-plugin/src/config"
"github.com/aws/session-manager-plugin/src/jsonutil"
"github.com/aws/session-manager-plugin/src/log"
"github.com/aws/session-manager-plugin/src/message"
"github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session"
"github.com/aws/session-manager-plugin/src/version"
)
const (
LocalPortForwardingType = "LocalPortForwarding"
)
type PortSession struct {
session.Session
portParameters PortParameters
portSessionType IPortSession
}
type IPortSession interface {
IsStreamNotSet() (status bool)
InitializeStreams(log log.T, agentVersion string) (err error)
ReadStream(log log.T) (err error)
WriteStream(outputMessage message.ClientMessage) (err error)
Stop()
}
type PortParameters struct {
PortNumber string `json:"portNumber"`
LocalPortNumber string `json:"localPortNumber"`
LocalUnixSocket string `json:"localUnixSocket"`
LocalConnectionType string `json:"localConnectionType"`
Type string `json:"type"`
}
func init() {
session.Register(&PortSession{})
}
// Name is the session name used inputStream the plugin
func (PortSession) Name() string {
return config.PortPluginName
}
func (s *PortSession) Initialize(log log.T, sessionVar *session.Session) {
s.Session = *sessionVar
if err := jsonutil.Remarshal(s.SessionProperties, &s.portParameters); err != nil {
log.Errorf("Invalid format: %v", err)
}
if s.portParameters.Type == LocalPortForwardingType {
if version.DoesAgentSupportTCPMultiplexing(log, s.DataChannel.GetAgentVersion()) {
s.portSessionType = &MuxPortForwarding{
sessionId: s.SessionId,
portParameters: s.portParameters,
session: s.Session,
}
} else {
s.portSessionType = &BasicPortForwarding{
sessionId: s.SessionId,
portParameters: s.portParameters,
session: s.Session,
}
}
} else {
s.portSessionType = &StandardStreamForwarding{
portParameters: s.portParameters,
session: s.Session,
}
}
s.DataChannel.RegisterOutputStreamHandler(s.ProcessStreamMessagePayload, true)
s.DataChannel.GetWsChannel().SetOnMessage(func(input []byte) {
if s.portSessionType.IsStreamNotSet() {
outputMessage := &message.ClientMessage{}
if err := outputMessage.DeserializeClientMessage(log, input); err != nil {
log.Debugf("Ignore message deserialize error while stream connection had not set.")
return
}
if outputMessage.MessageType == message.OutputStreamMessage {
log.Debugf("Waiting for user to establish connection before processing incoming messages.")
return
} else {
log.Infof("Received %s message while establishing connection", outputMessage.MessageType)
}
}
s.DataChannel.OutputMessageHandler(log, s.Stop, s.SessionId, input)
})
log.Infof("Connected to instance[%s] on port: %s", sessionVar.TargetId, s.portParameters.PortNumber)
}
func (s *PortSession) Stop() {
s.portSessionType.Stop()
}
// StartSession redirects inputStream/outputStream data to datachannel.
func (s *PortSession) SetSessionHandlers(log log.T) (err error) {
if err = s.portSessionType.InitializeStreams(log, s.DataChannel.GetAgentVersion()); err != nil {
return err
}
if err = s.portSessionType.ReadStream(log); err != nil {
return err
}
return
}
// ProcessStreamMessagePayload writes messages received on datachannel to stdout
func (s *PortSession) ProcessStreamMessagePayload(log log.T, outputMessage message.ClientMessage) (isHandlerReady bool, err error) {
if s.portSessionType.IsStreamNotSet() {
log.Debugf("Waiting for streams to be established before processing incoming messages.")
return false, nil
}
log.Tracef("Received payload of size %d from datachannel.", outputMessage.PayloadLength)
err = s.portSessionType.WriteStream(outputMessage)
return true, err
}