go/mqtt/env.go (180 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
package mqtt
import (
"context"
"crypto/tls"
"os"
"strconv"
"strings"
"github.com/Azure/iot-operations-sdks/go/mqtt/auth"
)
type (
// Env provides all session client parameters parsed from well-known
// environment variables.
Env struct {
ClientID string
ConnectionProvider ConnectionProvider
*SessionClientOptions
}
connectionProviderBuilder struct {
hostname string
port uint16
useTLS *bool
caFile string
certFile string
keyFile string
passFile string
}
)
// SessionClientConfigFromEnv parses a session client configuration from
// well-known environment variables. Note that this will only return an error if
// the environment variables parse incorrectly; it will not return an error if
// required parameters (e.g. for the connection provider) are missing, to allow
// optional parameters to be specified from environment independently.
func SessionClientConfigFromEnv() (*Env, error) {
opts := &Env{SessionClientOptions: &SessionClientOptions{}}
conn := connectionProviderBuilder{}
for _, env := range os.Environ() {
idx := strings.IndexByte(env, '=')
key := env[:idx]
val := env[idx+1:]
switch key {
case "AIO_BROKER_HOSTNAME":
conn.hostname = val
case "AIO_BROKER_TCP_PORT":
port, err := strconv.ParseUint(val, 10, 16)
if err != nil {
return nil, &InvalidArgumentError{
message: "could not parse broker TCP port",
wrapped: err,
}
}
conn.port = uint16(port)
case "AIO_MQTT_USE_TLS":
useTLS, err := strconv.ParseBool(val)
if err != nil {
return nil, &InvalidArgumentError{
message: "could not parse MQTT use TLS",
wrapped: err,
}
}
conn.useTLS = &useTLS
case "AIO_MQTT_CLEAN_START":
cleanStart, err := strconv.ParseBool(val)
if err != nil {
return nil, &InvalidArgumentError{
message: "could not parse MQTT clean start",
wrapped: err,
}
}
opts.CleanStart = cleanStart
case "AIO_MQTT_KEEP_ALIVE":
keepAlive, err := strconv.ParseUint(val, 10, 16)
if err != nil {
return nil, &InvalidArgumentError{
message: "could not parse MQTT keep-alive",
wrapped: err,
}
}
opts.KeepAlive = uint16(keepAlive)
case "AIO_MQTT_CLIENT_ID":
opts.ClientID = val
case "AIO_MQTT_SESSION_EXPIRY":
sessionExpiry, err := strconv.ParseUint(val, 10, 32)
if err != nil {
return nil, &InvalidArgumentError{
message: "could not parse MQTT session expiry",
wrapped: err,
}
}
opts.SessionExpiry = uint32(sessionExpiry)
case "AIO_MQTT_USERNAME":
opts.Username = ConstantUsername(val)
case "AIO_MQTT_PASSWORD_FILE":
opts.Password = FilePassword(val)
case "AIO_SAT_FILE":
satAuth, err := auth.NewAIOServiceAccountToken(val)
if err != nil {
return nil, &InvalidArgumentError{
message: "error setting up the AIO SAT auth provider",
wrapped: err,
}
}
opts.Auth = satAuth
case "AIO_TLS_CA_FILE":
conn.caFile = val
case "AIO_TLS_CERT_FILE":
conn.certFile = val
case "AIO_TLS_KEY_FILE":
conn.keyFile = val
case "AIO_TLS_KEY_PASSWORD_FILE":
conn.passFile = val
}
}
var err error
opts.ConnectionProvider, err = conn.build()
if err != nil {
return nil, err
}
return opts, nil
}
// NewSessionClientFromEnv is a shorthand for constructing a session client
// using SessionClientConfigFromEnv.
func NewSessionClientFromEnv(
opt ...SessionClientOption,
) (*SessionClient, error) {
opts, err := SessionClientConfigFromEnv()
if err != nil {
return nil, err
}
opts.Apply(opt)
return NewSessionClient(opts.ClientID, opts.ConnectionProvider, opts)
}
func (b *connectionProviderBuilder) build() (ConnectionProvider, error) {
if b.hostname == "" {
if b.port != 0 || b.useTLS != nil || b.hasTLS() {
return nil, &InvalidArgumentError{
message: "connection configuration provided without hostname",
}
}
return nil, nil
}
if b.port == 0 {
b.port = 8883
}
if b.useTLS != nil && !*b.useTLS {
if b.hasTLS() {
return nil, &InvalidArgumentError{
message: "TLS configuration provided but not using TLS",
}
}
return TCPConnection(b.hostname, b.port), nil
}
if (b.certFile != "") != (b.keyFile != "") {
return nil, &InvalidArgumentError{
message: "certificate file and key file must be provided together",
}
}
var tlsOpts []TLSOption
// Bypasses hostname check in TLS config when deliberately connecting to
// localhost.
if b.hostname == "localhost" {
tlsOpts = append(tlsOpts, func(
_ context.Context,
cfg *tls.Config,
) error {
cfg.InsecureSkipVerify = true // #nosec G402
return nil
})
}
if b.certFile != "" {
if b.passFile != "" {
tlsOpts = append(tlsOpts, WithEncryptedX509(
b.certFile,
b.keyFile,
b.passFile,
))
} else {
tlsOpts = append(tlsOpts, WithX509(
b.certFile,
b.keyFile,
))
}
}
if b.caFile != "" {
tlsOpts = append(tlsOpts, WithCA(b.caFile))
}
return TLSConnection(b.hostname, b.port, tlsOpts...), nil
}
func (b *connectionProviderBuilder) hasTLS() bool {
return b.caFile != "" || b.certFile != "" ||
b.keyFile != "" || b.passFile != ""
}