azurelinuxagent/common/protocol/util.py (207 lines of code) (raw):
# Microsoft Azure Linux Agent
#
# Copyright 2018 Microsoft Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Requires Python 2.6+ and Openssl 1.0+
#
import errno
import os
import re
import time
import threading
import azurelinuxagent.common.conf as conf
import azurelinuxagent.common.logger as logger
import azurelinuxagent.common.utils.fileutil as fileutil
from azurelinuxagent.common.singletonperthread import SingletonPerThread
from azurelinuxagent.common.exception import ProtocolError, OSUtilError, \
ProtocolNotFoundError, DhcpError
from azurelinuxagent.common.future import ustr
from azurelinuxagent.common.osutil import get_osutil
from azurelinuxagent.common.dhcp import get_dhcp_handler
from azurelinuxagent.common.protocol.metadata_server_migration_util import cleanup_metadata_server_artifacts, \
is_metadata_server_artifact_present
from azurelinuxagent.common.protocol.ovfenv import OvfEnv
from azurelinuxagent.common.protocol.wire import WireProtocol
from azurelinuxagent.common.utils.restutil import KNOWN_WIRESERVER_IP, \
IOErrorCounter
OVF_FILE_NAME = "ovf-env.xml"
PROTOCOL_FILE_NAME = "Protocol"
MAX_RETRY = 360
PROBE_INTERVAL = 10
ENDPOINT_FILE_NAME = "WireServerEndpoint"
PASSWORD_PATTERN = "<UserPassword>.*?<"
PASSWORD_REPLACEMENT = "<UserPassword>*<"
WIRE_PROTOCOL_NAME = "WireProtocol"
def get_protocol_util():
return ProtocolUtil()
class ProtocolUtil(SingletonPerThread):
"""
ProtocolUtil handles initialization for protocol instance. 2 protocol types
are invoked, wire protocol and metadata protocols.
Note: ProtocolUtil is a sub class of SingletonPerThread, this basically means that there would only be 1 single
instance of ProtocolUtil object per thread.
"""
def __init__(self):
self._lock = threading.RLock() # protects the files on disk created during protocol detection
self._protocol = None
self.endpoint = None
self.osutil = get_osutil()
self.dhcp_handler = get_dhcp_handler()
def copy_ovf_env(self):
"""
Copy ovf env file from dvd to hard disk.
Remove password before save it to the disk
"""
dvd_mount_point = conf.get_dvd_mount_point()
ovf_file_path_on_dvd = os.path.join(dvd_mount_point, OVF_FILE_NAME)
ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME)
try:
self.osutil.mount_dvd()
except OSUtilError as e:
raise ProtocolError("[CopyOvfEnv] Error mounting dvd: "
"{0}".format(ustr(e)))
try:
ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True)
ovfenv = OvfEnv(ovfxml)
except (IOError, OSError) as e:
raise ProtocolError("[CopyOvfEnv] Error reading file "
"{0}: {1}".format(ovf_file_path_on_dvd,
ustr(e)))
try:
ovfxml = re.sub(PASSWORD_PATTERN,
PASSWORD_REPLACEMENT,
ovfxml)
fileutil.write_file(ovf_file_path, ovfxml)
except (IOError, OSError) as e:
raise ProtocolError("[CopyOvfEnv] Error writing file "
"{0}: {1}".format(ovf_file_path,
ustr(e)))
self._cleanup_ovf_dvd()
return ovfenv
def _cleanup_ovf_dvd(self):
try:
self.osutil.umount_dvd()
self.osutil.eject_dvd()
except OSUtilError as e:
logger.warn(ustr(e))
def get_ovf_env(self):
"""
Load saved ovf-env.xml
"""
ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME)
if os.path.isfile(ovf_file_path):
xml_text = fileutil.read_file(ovf_file_path)
return OvfEnv(xml_text)
else:
raise ProtocolError(
"ovf-env.xml is missing from {0}".format(ovf_file_path))
def _get_protocol_file_path(self):
return os.path.join(
conf.get_lib_dir(),
PROTOCOL_FILE_NAME)
def _get_wireserver_endpoint_file_path(self):
return os.path.join(
conf.get_lib_dir(),
ENDPOINT_FILE_NAME)
def get_wireserver_endpoint(self):
self._lock.acquire()
try:
if self.endpoint:
return self.endpoint
file_path = self._get_wireserver_endpoint_file_path()
if os.path.isfile(file_path):
try:
self.endpoint = fileutil.read_file(file_path)
if self.endpoint:
logger.info("WireServer endpoint {0} read from file", self.endpoint)
return self.endpoint
logger.error("[GetWireserverEndpoint] Unexpected empty file {0}", file_path)
except (IOError, OSError) as e:
logger.error("[GetWireserverEndpoint] Error reading file {0}: {1}", file_path, str(e))
else:
logger.error("[GetWireserverEndpoint] Missing file {0}", file_path)
self.endpoint = KNOWN_WIRESERVER_IP
logger.info("Using hardcoded Wireserver endpoint {0}", self.endpoint)
return self.endpoint
finally:
self._lock.release()
def _set_wireserver_endpoint(self, endpoint):
try:
self.endpoint = endpoint
file_path = self._get_wireserver_endpoint_file_path()
fileutil.write_file(file_path, endpoint)
except (IOError, OSError) as e:
raise OSUtilError(ustr(e))
def _clear_wireserver_endpoint(self):
"""
Cleanup previous saved wireserver endpoint.
"""
self.endpoint = None
endpoint_file_path = self._get_wireserver_endpoint_file_path()
if not os.path.isfile(endpoint_file_path):
return
try:
os.remove(endpoint_file_path)
except (IOError, OSError) as e:
# Ignore file-not-found errors (since the file is being removed)
if e.errno == errno.ENOENT:
return
logger.error("Failed to clear wiresever endpoint: {0}", e)
def _detect_protocol(self, save_to_history, init_goal_state=True):
"""
Probe protocol endpoints in turn.
"""
self.clear_protocol()
for retry in range(0, MAX_RETRY):
try:
endpoint = self.dhcp_handler.endpoint
if endpoint is None:
# pylint: disable=W0105
'''
Check if DHCP can be used to get the wire protocol endpoint
'''
# pylint: enable=W0105
dhcp_available = self.osutil.is_dhcp_available()
if dhcp_available:
logger.info("WireServer endpoint is not found. Rerun dhcp handler")
try:
self.dhcp_handler.run()
except DhcpError as e:
raise ProtocolError(ustr(e))
endpoint = self.dhcp_handler.endpoint
else:
logger.info("_detect_protocol: DHCP not available")
endpoint = self.get_wireserver_endpoint()
try:
protocol = WireProtocol(endpoint)
protocol.detect(init_goal_state=init_goal_state, save_to_history=save_to_history)
self._set_wireserver_endpoint(endpoint)
return protocol
except ProtocolError as e:
logger.info("WireServer is not responding. Reset dhcp endpoint")
self.dhcp_handler.endpoint = None
self.dhcp_handler.skip_cache = True
raise e
except ProtocolError as e:
logger.info("Protocol endpoint not found: {0}", e)
if retry < MAX_RETRY - 1:
logger.info("Retry detect protocol: retry={0}", retry)
time.sleep(PROBE_INTERVAL)
raise ProtocolNotFoundError("No protocol found.")
def _save_protocol(self, protocol_name):
"""
Save protocol endpoint
"""
protocol_file_path = self._get_protocol_file_path()
try:
fileutil.write_file(protocol_file_path, protocol_name)
except (IOError, OSError) as e:
logger.error("Failed to save protocol endpoint: {0}", e)
def clear_protocol(self):
"""
Cleanup previous saved protocol endpoint.
"""
self._lock.acquire()
try:
logger.info("Clean protocol and wireserver endpoint")
self._clear_wireserver_endpoint()
self._protocol = None
protocol_file_path = self._get_protocol_file_path()
if not os.path.isfile(protocol_file_path):
return
try:
os.remove(protocol_file_path)
except (IOError, OSError) as e:
# Ignore file-not-found errors (since the file is being removed)
if e.errno == errno.ENOENT:
return
logger.error("Failed to clear protocol endpoint: {0}", e)
finally:
self._lock.release()
def get_protocol(self, init_goal_state=True, save_to_history=False):
"""
Detect protocol by endpoint.
:returns: protocol instance
"""
self._lock.acquire()
try:
if self._protocol is not None:
return self._protocol
# If the protocol file contains MetadataProtocol we need to fall through to
# _detect_protocol so that we can generate the WireServer transport certificates.
protocol_file_path = self._get_protocol_file_path()
if os.path.isfile(protocol_file_path) and fileutil.read_file(protocol_file_path) == WIRE_PROTOCOL_NAME:
endpoint = self.get_wireserver_endpoint()
self._protocol = WireProtocol(endpoint)
# If metadataserver certificates are present we clean certificates
# and remove MetadataServer firewall rule. It is possible
# there was a previous intermediate upgrade before 2.2.48 but metadata artifacts
# were not cleaned up (intermediate updated agent does not have cleanup
# logic but we transitioned from Metadata to Wire protocol)
if is_metadata_server_artifact_present():
cleanup_metadata_server_artifacts()
return self._protocol
logger.info("Detect protocol endpoint")
protocol = self._detect_protocol(save_to_history=save_to_history, init_goal_state=init_goal_state)
IOErrorCounter.set_protocol_endpoint(endpoint=protocol.get_endpoint())
self._save_protocol(WIRE_PROTOCOL_NAME)
self._protocol = protocol
# Need to clean up MDS artifacts only after _detect_protocol so that we don't
# delete MDS certificates if we can't reach WireServer and have to roll back
# the update
if is_metadata_server_artifact_present():
cleanup_metadata_server_artifacts()
return self._protocol
finally:
self._lock.release()