ad-joining/register-computer/ad/domain.py (356 lines of code) (raw):

# # Copyright 2019 Google LLC # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # import ssl import time import ldap3 import ldap3.utils.conv from ldap3 import Tls from ldap3.core.exceptions import LDAPException, LDAPStrongerAuthRequiredResult import ldap3.core.exceptions import logging import dns.resolver import json class LdapException(Exception): pass class NoSuchObjectException(LdapException): pass class AlreadyExistsException(LdapException): pass class DomainControllerLookupException(LdapException): pass class ActiveDirectoryConnection(object): LDAP_ATTRIBUTE_PROJECT_ID = "msDS-cloudExtensionAttribute1" LDAP_ATTRIBUTE_ZONE = "msDS-cloudExtensionAttribute2" LDAP_ATTRIBUTE_INSTANCE_NAME = "msDS-cloudExtensionAttribute3" LDAP_ATTRIBUTE_GROUP_DATA = "msDS-AzApplicationData" ACTIVE_DIRECTORY_GROUP_TYPE_DOMAIN_LOCAL = 4 ACTIVE_DIRECTORY_GROUP_TYPE_SECURITY = -2147483648 LDAP_OPERATION_RETRIES = 5 def __init__(self, domain_controller, connection, base_dn): assert isinstance(connection, ldap3.Connection) self.__connection = connection self.__base_dn = base_dn self.__domain_controller = domain_controller def __to_scalar(self, value): if not value or len(value) == 0: return None else: return str(value) def __find(self, converter, search_filter, search_base_dn, search_scope, attributes): # Initial paged search will yield search cookie self.__connection.search( search_filter=search_filter, search_base=search_base_dn, search_scope=search_scope, attributes=attributes, paged_size=100) # Retrieve page cookie cookie = self.__connection.result['controls']['1.2.840.113556.1.4.319']['value']['cookie'] results = [] for entry in self.__connection.entries: results.append(converter(entry)) while cookie: self.__connection.search( search_filter=search_filter, search_base=search_base_dn, search_scope=search_scope, attributes=attributes, paged_size=100, paged_cookie=cookie) # Update page cookie cookie = self.__connection.result['controls']['1.2.840.113556.1.4.319']['value']['cookie'] for entry in self.__connection.entries: results.append(converter(entry)) return results def __to_ou(self, entry): return OrganizationalUnit( entry.entry_dn, self.__to_scalar(entry["name"]) ) def __to_computer(self, entry): return Computer( entry.entry_dn, self.__to_scalar(entry["name"]), self.__to_scalar(entry[self.LDAP_ATTRIBUTE_PROJECT_ID]), self.__to_scalar(entry[self.LDAP_ATTRIBUTE_ZONE]), self.__to_scalar(entry[self.LDAP_ATTRIBUTE_INSTANCE_NAME]), self.__to_scalar(entry["dNSHostName"]) ) def __to_group(self, entry): return Group( entry.entry_dn, self.__to_scalar(entry["name"]), self.__to_scalar(entry[self.LDAP_ATTRIBUTE_GROUP_DATA]) ) @staticmethod def locate_domain_controllers(domain_name, site_name): query = "_ldap._tcp" # Use site-awareness if site was provided if not site_name is None and len(site_name) > 0: query += f".{site_name}._sites" logging.info(f"Using site-awareness to select closest DC for site '{site_name}'") query += f".dc._msdcs.{domain_name}" records = dns.resolver.query(query, "SRV") if len(records) == 0: raise DomainControllerLookupException("No SRV records found for %s" % domain_name) records_sorted = sorted(records, key=lambda r: (-r.priority, r.weight, r.target)) return [str(record.target)[:-1] if str(record.target).endswith(".") else str(record.target) for record in records_sorted] @staticmethod def connect(domain_controller, base_dn, user, password, use_ldaps=False, certificate_data=None): logging.info("Connecting to LDAP endpoint of '%s' as '%s'" % (domain_controller, user)) if use_ldaps: logging.info("Using LDAP over SSL/TLS") tls_configuration = Tls(ssl.create_default_context(ssl.Purpose.SERVER_AUTH), validate=ssl.CERT_REQUIRED) if certificate_data is not None: logging.debug("Using CA certificate data from Secret Manager") tls_configuration.ca_certs_data = certificate_data server = ldap3.Server(domain_controller, port=636, connect_timeout=5, use_ssl=True, tls=tls_configuration) else: server = ldap3.Server(domain_controller, port=389, connect_timeout=5, use_ssl=False) connection = ldap3.Connection(server, user=user, password=password, authentication=ldap3.NTLM, raise_exceptions=True, receive_timeout=20) try: if connection.bind(): return ActiveDirectoryConnection(domain_controller, connection, base_dn) except LDAPStrongerAuthRequiredResult: logging.exception("Failed to connect to LDAP endpoint: Active Directory requires LDAPS for NTLM binds") except LDAPException as e: logging.warn("Failed to connect to LDAP endpoint: %s" % e) # LDAP connection could not be established, raise exception raise LdapException("Connecting to LDAP endpoint of '%s' as '%s' failed" % (domain_controller, user)) def get_domain_controller(self): return self.__domain_controller def find_ou(self, search_base_dn, name=None, includeDescendants=True): if name: search_filter = f"(&(objectClass=organizationalUnit)(name={ldap3.utils.conv.escape_filter_chars(name)}))" else: search_filter = "(objectClass=organizationalUnit)" if includeDescendants: search_scope = ldap3.SUBTREE else: search_scope = ldap3.BASE try: return self.__find( converter=self.__to_ou, search_filter=search_filter, search_base_dn=search_base_dn, search_scope=search_scope, attributes=[ "distinguishedName", "name" ] ) except ldap3.core.exceptions.LDAPNoSuchObjectResult: # In case OU was not found, return an empty array instead of raising an exception return [] def find_computer(self, search_base_dn): # Search either for the specific group or in the base DN (but not its descendants) if search_base_dn.startswith("CN="): search_scope = ldap3.BASE else: search_scope = ldap3.LEVEL try: return self.__find( converter=self.__to_computer, search_filter="(objectClass=computer)", search_base_dn=search_base_dn, search_scope=search_scope, attributes=[ "distinguishedName", "name", ActiveDirectoryConnection.LDAP_ATTRIBUTE_PROJECT_ID, ActiveDirectoryConnection.LDAP_ATTRIBUTE_ZONE, ActiveDirectoryConnection.LDAP_ATTRIBUTE_INSTANCE_NAME, "dNSHostName" ] ) except ldap3.core.exceptions.LDAPNoSuchObjectResult as e: raise NoSuchObjectException(e) def add_computer(self, ou, computer_name, upn, project_id, zone, instance_name): WORKSTATION_TRUST_ACCOUNT = 0x1000 PASSWD_NOTREQD = 0x20 dn = "CN=%s,%s" % (computer_name, ou) try: self.__connection.add( dn, [ # objectClass "computer" , "organizationalPerson", "person", "user", "top" ], { # Mandatory attributes for a computer object. "objectClass": "computer", "sAMAccountName": computer_name + "$", "userPrincipalName": upn, "userAccountControl": WORKSTATION_TRUST_ACCOUNT | PASSWD_NOTREQD, ActiveDirectoryConnection.LDAP_ATTRIBUTE_PROJECT_ID: project_id, ActiveDirectoryConnection.LDAP_ATTRIBUTE_ZONE: zone, ActiveDirectoryConnection.LDAP_ATTRIBUTE_INSTANCE_NAME: instance_name }) return dn except ldap3.core.exceptions.LDAPEntryAlreadyExistsResult as e: raise AlreadyExistsException(e) def remove_computer_upn(self, ou, computer_name): try: self.__connection.modify( "CN=%s,%s" % (computer_name, ou), { "userPrincipalName": [(ldap3.MODIFY_DELETE, [])] }) except ldap3.core.exceptions.LDAPAttributeOrValueExistsResult as e: raise AlreadyExistsException(e) def set_computer_upn(self, ou, computer_name, upn): try: self.__connection.modify( "CN=%s,%s" % (computer_name, ou), { "userPrincipalName": [(ldap3.MODIFY_REPLACE, [upn])] }) except ldap3.core.exceptions.LDAPAttributeOrValueExistsResult as e: raise AlreadyExistsException(e) def set_computer_zone(self, ou, computer_name, zone): try: self.__connection.modify( "CN=%s,%s" % (computer_name, ou), { ActiveDirectoryConnection.LDAP_ATTRIBUTE_ZONE: [(ldap3.MODIFY_REPLACE, [zone])] }) except ldap3.core.exceptions.LDAPAttributeOrValueExistsResult as e: raise AlreadyExistsException(e) def delete_computer(self, computer_dn): try: # Computer accounts can have children. Use LDAP_SERVER_TREE_DELETE_OID # to perform a recursive delete operation (with criticality = True). recursive_delete = ('1.2.840.113556.1.4.805', True, None) self.__connection.delete(computer_dn, controls=[recursive_delete]) except ldap3.core.exceptions.LDAPNoSuchObjectResult as e: raise NoSuchObjectException(e) def delete_dns_record(self, dns_record_dn): try: self.__connection.delete(dns_record_dn) except ldap3.core.exceptions.LDAPNoSuchObjectResult as e: raise NoSuchObjectException(e) def get_netbios_name(self): self.__connection.search( search_filter="(nETBIOSNAME=*)", search_base="CN=Partitions,CN=Configuration," + self.__base_dn, attributes=["nETBIOSNAME"]) if len(self.__connection.entries) == 0: raise LdapException("Partitions information not found in directory") else: return self.__to_scalar(self.__connection.entries[0]["nETBIOSNAME"]) def get_upn_by_samaccountname(self, samaccountname): if "\\" in samaccountname: samaccountname = samaccountname.split("\\")[1] self.__connection.search( search_filter="(&(objectClass=user)(sAMAccountName=%s))" % ldap3.utils.conv.escape_filter_chars(samaccountname), search_base=self.__base_dn, attributes=["userPrincipalName"]) if len(self.__connection.entries) == 0: raise LdapException("User '%s' not found in directory" % samaccountname) else: return self.__to_scalar(self.__connection.entries[0]["userPrincipalName"]) def find_group(self, search_base_dn): # Search either for the specific group or in the base DN (but not its descendants) if search_base_dn.startswith("CN="): search_scope = ldap3.BASE else: search_scope = ldap3.LEVEL try: return self.__find( converter=self.__to_group, search_filter="(&(objectClass=group))", search_base_dn=search_base_dn, search_scope=search_scope, attributes=[ "distinguishedName", "name", ActiveDirectoryConnection.LDAP_ATTRIBUTE_GROUP_DATA ] ) except ldap3.core.exceptions.LDAPNoSuchObjectResult as e: raise NoSuchObjectException(e) def add_group(self, ou, group_name, project_id, zone, region): try: metadata = { "project_id" : project_id, "zone" : zone, "region" : region } group_metadata = json.dumps(metadata) dn = "CN=%s,%s" % (group_name, ou) self.__connection.add( dn, [ "group", "top" ], { # Mandatory attributes for a computer object. "groupType": self.ACTIVE_DIRECTORY_GROUP_TYPE_DOMAIN_LOCAL + self.ACTIVE_DIRECTORY_GROUP_TYPE_SECURITY, "objectClass": "group", "name": group_name, "description" : "Group for computers of MIG '%s'" % (group_name), ActiveDirectoryConnection.LDAP_ATTRIBUTE_GROUP_DATA: group_metadata }) return dn except ldap3.core.exceptions.LDAPEntryAlreadyExistsResult as e: raise AlreadyExistsException(e) def add_member_to_group(self, ou, group_name, computer_dn): retries = 0 while retries < self.LDAP_OPERATION_RETRIES: try: self.__connection.modify( "CN=%s,%s" % (group_name, ou), { 'member': [(ldap3.MODIFY_ADD, [computer_dn])] }) break except ldap3.core.exceptions.LDAPBusyResult: logging.warn(f"LDAP endpoint is busy, retrying operation 'add_member_to_group' for '{computer_dn}'") retries += 1 time.sleep(1) except ldap3.core.exceptions.LDAPEntryAlreadyExistsResult as e: logging.info(f"'{computer_dn}' already part of '{group_name}'") pass def delete_group(self, group_dn): self.__connection.delete(group_dn) def get_user(self): return self.__connection.user class NamedObject(object): def __init__(self, dn, name): self.__dn = dn self.__name = name def get_dn(self): return self.__dn def get_name(self): return self.__name class OrganizationalUnit(NamedObject): pass class Computer(NamedObject): def __init__(self, dn, name, project_id, zone, instance_name, dns_hostname): super(Computer, self).__init__(dn, name) self.__project_id = project_id self.__zone = zone self.__instance_name = instance_name self.__dns_hostname = dns_hostname def get_instance_name(self): return self.__instance_name def get_zone(self): return self.__zone def get_project_id(self): return self.__project_id def get_dns_record_dn(self): """ DN of corresponding DNS record, for example: DC=host,DC=domain.tld,CN=MicrosoftDNS,DC=DomainDnsZones,DC=domain,DC=tld """ if not self.__dns_hostname: # Some computer objects might not have a DNS hostname return None dns_hostname_parts = self.__dns_hostname.lower().split('.') hostname = dns_hostname_parts[0] domain = dns_hostname_parts[1:] return "DC=%s,DC=%s,CN=MicrosoftDNS,DC=DomainDnsZones,%s" % ( hostname, '.'.join(domain), ','.join( ["DC=" + dc for dc in dns_hostname_parts[1:]])) class Group(NamedObject): def __init__(self, dn, name, group_metadata): super(Group, self).__init__(dn, name) if group_metadata: metadata = json.loads(group_metadata) else: metadata = {} self.__region = metadata.get("region") self.__zone = metadata.get("zone") self.__project_id = metadata.get("project_id") def get_project_id(self): return self.__project_id def get_region(self): return self.__region def get_zone(self): return self.__zone