Skip to content

Commit

Permalink
starting unit tests for aws_ec2
Browse files Browse the repository at this point in the history
  • Loading branch information
abikouo committed Nov 8, 2022
1 parent 30c14eb commit a5309df
Show file tree
Hide file tree
Showing 6 changed files with 705 additions and 495 deletions.
261 changes: 131 additions & 130 deletions plugins/inventory/aws_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,20 +269,16 @@
except ImportError:
pass # will be captured by imported HAS_BOTO3

from ansible.errors import AnsibleError
from ansible.module_utils._text import to_native
from ansible.module_utils._text import to_text
from ansible.module_utils.basic import missing_required_lib

from ansible.template import Templar
from ansible.plugins.inventory import BaseInventoryPlugin
from ansible.plugins.inventory import Cacheable
from ansible.plugins.inventory import Constructable
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import HAS_BOTO3
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import ansible_dict_to_boto3_filter_list
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_tag_list_to_ansible_dict
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import camel_dict_to_snake_dict
from ansible_collections.amazon.aws.plugins.module_utils.inventory import AnsibleAWSInventory
from ansible_collections.amazon.aws.plugins.plugin_utils.inventory import AWSInventoryBase


# The mappings give an array of keys to get from the filter name to the value
Expand Down Expand Up @@ -378,63 +374,113 @@
}


class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable, AnsibleAWSInventory):
def _get_tag_hostname(preference, instance):
tag_hostnames = preference.split('tag:', 1)[1]
if ',' in tag_hostnames:
tag_hostnames = tag_hostnames.split(',')
else:
tag_hostnames = [tag_hostnames]

tags = boto3_tag_list_to_ansible_dict(instance.get('Tags', []))
tag_values = []
for v in tag_hostnames:
if '=' in v:
tag_name, tag_value = v.split('=')
if tags.get(tag_name) == tag_value:
tag_values.append(to_text(tag_name) + "_" + to_text(tag_value))
else:
tag_value = tags.get(v)
if tag_value:
tag_values.append(to_text(tag_value))
return tag_values

NAME = 'amazon.aws.aws_ec2'

def __init__(self):
def _prepare_host_vars(original_host_vars, hostvars_prefix=None, hostvars_suffix=None,
use_contrib_script_compatible_ec2_tag_keys=False):
host_vars = camel_dict_to_snake_dict(original_host_vars, ignore_list=['Tags'])
host_vars['tags'] = boto3_tag_list_to_ansible_dict(original_host_vars.get('Tags', []))

super(InventoryModule, self).__init__()
# Allow easier grouping by region
host_vars['placement']['region'] = host_vars['placement']['availability_zone'][:-1]

self.group_prefix = 'aws_ec2_'
if use_contrib_script_compatible_ec2_tag_keys:
for k, v in host_vars['tags'].items():
host_vars["ec2_tag_%s" % k] = v

AnsibleAWSInventory.__init__(self)
if hostvars_prefix or hostvars_suffix:
for hostvar, hostval in host_vars.copy().items():
del host_vars[hostvar]
if hostvars_prefix:
hostvar = hostvars_prefix + hostvar
if hostvars_suffix:
hostvar = hostvar + hostvars_suffix
host_vars[hostvar] = hostval

def _compile_values(self, obj, attr):
'''
:param obj: A list or dict of instance attributes
:param attr: A key
:return The value(s) found via the attr
'''
if obj is None:
return
return host_vars

temp_obj = []

if isinstance(obj, list) or isinstance(obj, tuple):
for each in obj:
value = self._compile_values(each, attr)
if value:
temp_obj.append(value)
else:
temp_obj = obj.get(attr)
def _compile_values(obj, attr):
'''
:param obj: A list or dict of instance attributes
:param attr: A key
:return The value(s) found via the attr
'''
if obj is None:
return

has_indexes = any([isinstance(temp_obj, list), isinstance(temp_obj, tuple)])
if has_indexes and len(temp_obj) == 1:
return temp_obj[0]
temp_obj = []

return temp_obj
if isinstance(obj, list) or isinstance(obj, tuple):
for each in obj:
value = _compile_values(each, attr)
if value:
temp_obj.append(value)
else:
temp_obj = obj.get(attr)

def _get_boto_attr_chain(self, filter_name, instance):
'''
:param filter_name: The filter
:param instance: instance dict returned by boto3 ec2 describe_instances()
'''
allowed_filters = sorted(list(instance_data_filter_to_boto_attr.keys()) + list(instance_meta_filter_to_boto_attr.keys()))
has_indexes = any([isinstance(temp_obj, list), isinstance(temp_obj, tuple)])
if has_indexes and len(temp_obj) == 1:
return temp_obj[0]

# If filter not in allow_filters -> use it as a literal string
if filter_name not in allowed_filters:
return filter_name
return temp_obj


def _get_boto_attr_chain(filter_name, instance):
'''
:param filter_name: The filter
:param instance: instance dict returned by boto3 ec2 describe_instances()
'''
allowed_filters = sorted(list(instance_data_filter_to_boto_attr.keys()) + list(instance_meta_filter_to_boto_attr.keys()))

# If filter not in allow_filters -> use it as a literal string
if filter_name not in allowed_filters:
return filter_name

if filter_name in instance_data_filter_to_boto_attr:
boto_attr_list = instance_data_filter_to_boto_attr[filter_name]
else:
boto_attr_list = instance_meta_filter_to_boto_attr[filter_name]

instance_value = instance
for attribute in boto_attr_list:
instance_value = _compile_values(instance_value, attribute)
return instance_value


def _describe_ec2_instances(connection, filters):
paginator = connection.get_paginator('describe_instances')
return paginator.paginate(Filters=filters).build_full_result()

if filter_name in instance_data_filter_to_boto_attr:
boto_attr_list = instance_data_filter_to_boto_attr[filter_name]
else:
boto_attr_list = instance_meta_filter_to_boto_attr[filter_name]

instance_value = instance
for attribute in boto_attr_list:
instance_value = self._compile_values(instance_value, attribute)
return instance_value
class InventoryModule(AWSInventoryBase):

NAME = 'amazon.aws.aws_ec2'

def __init__(self):

super(InventoryModule, self).__init__()

self.group_prefix = 'aws_ec2_'

def _get_instances_by_region(self, regions, filters, strict_permissions):
'''
Expand All @@ -445,59 +491,36 @@ def _get_instances_by_region(self, regions, filters, strict_permissions):
'''
all_instances = []

# By default find non-terminated/terminating instances
if not any(f['Name'] == 'instance-state-name' for f in filters):
filters.append({'Name': 'instance-state-name', 'Values': ['running', 'pending', 'stopping', 'stopped']})

for connection, _region in self._boto3_conn(regions, "ec2"):
try:
# By default find non-terminated/terminating instances
if not any(f['Name'] == 'instance-state-name' for f in filters):
filters.append({'Name': 'instance-state-name', 'Values': ['running', 'pending', 'stopping', 'stopped']})
paginator = connection.get_paginator('describe_instances')
reservations = paginator.paginate(Filters=filters).build_full_result().get('Reservations')
reservations = _describe_ec2_instances(connection, filters).get('Reservations')
instances = []
for r in reservations:
new_instances = r['Instances']
reservation_details = {
'OwnerId': r['OwnerId'],
'RequesterId': r.get('RequesterId', ''),
'ReservationId': r['ReservationId']
}
for instance in new_instances:
instance.update(self._get_reservation_details(r))
instance.update(reservation_details)
instances.extend(new_instances)
except botocore.exceptions.ClientError as e:
if e.response['ResponseMetadata']['HTTPStatusCode'] == 403 and not strict_permissions:
instances = []
else:
raise AnsibleError("Failed to describe instances: %s" % to_native(e))
self.fail_aws("Failed to describe instances: %s" % to_native(e))
except botocore.exceptions.BotoCoreError as e:
raise AnsibleError("Failed to describe instances: %s" % to_native(e))
self.fail_aws("Failed to describe instances: %s" % to_native(e))

all_instances.extend(instances)

return all_instances

def _get_reservation_details(self, reservation):
return {
'OwnerId': reservation['OwnerId'],
'RequesterId': reservation.get('RequesterId', ''),
'ReservationId': reservation['ReservationId']
}

@classmethod
def _get_tag_hostname(cls, preference, instance):
tag_hostnames = preference.split('tag:', 1)[1]
if ',' in tag_hostnames:
tag_hostnames = tag_hostnames.split(',')
else:
tag_hostnames = [tag_hostnames]

tags = boto3_tag_list_to_ansible_dict(instance.get('Tags', []))
tag_values = []
for v in tag_hostnames:
if '=' in v:
tag_name, tag_value = v.split('=')
if tags.get(tag_name) == tag_value:
tag_values.append(to_text(tag_name) + "_" + to_text(tag_value))
else:
tag_value = tags.get(v)
if tag_value:
tag_values.append(to_text(tag_value))
return tag_values

def _sanitize_hostname(self, hostname):
if ':' in to_text(hostname):
return self._sanitize_group_name(to_text(hostname))
Expand All @@ -517,23 +540,25 @@ def _get_preferred_hostname(self, instance, hostnames):
for preference in hostnames:
if isinstance(preference, dict):
if 'name' not in preference:
raise AnsibleError("A 'name' key must be defined in a hostnames dictionary.")
self.fail_aws("A 'name' key must be defined in a hostnames dictionary.")
hostname = self._get_preferred_hostname(instance, [preference["name"]])
hostname_from_prefix = self._get_preferred_hostname(instance, [preference["prefix"]])
hostname_from_prefix = None
if "prefix" in preference:
hostname_from_prefix = self._get_preferred_hostname(instance, [preference["prefix"]])
separator = preference.get("separator", "_")
if hostname and hostname_from_prefix and 'prefix' in preference:
hostname = hostname_from_prefix + separator + hostname
elif preference.startswith('tag:'):
tags = self._get_tag_hostname(preference, instance)
tags = _get_tag_hostname(preference, instance)
hostname = tags[0] if tags else None
else:
hostname = self._get_boto_attr_chain(preference, instance)
hostname = _get_boto_attr_chain(preference, instance)
if hostname:
break
if hostname:
return self._sanitize_hostname(hostname)

def get_all_hostnames(self, instance, hostnames):
def _get_all_hostnames(self, instance, hostnames):
'''
:param instance: an instance dict returned by boto3 ec2 describe_instances()
:param hostnames: a list of hostname destination variables
Expand All @@ -547,16 +572,18 @@ def get_all_hostnames(self, instance, hostnames):
for preference in hostnames:
if isinstance(preference, dict):
if 'name' not in preference:
raise AnsibleError("A 'name' key must be defined in a hostnames dictionary.")
hostname = self.get_all_hostnames(instance, [preference["name"]])
hostname_from_prefix = self.get_all_hostnames(instance, [preference["prefix"]])
self.fail_aws("A 'name' key must be defined in a hostnames dictionary.")
hostname = self._get_all_hostnames(instance, [preference["name"]])
hostname_from_prefix = None
if 'prefix' in preference:
hostname_from_prefix = self._get_all_hostnames(instance, [preference["prefix"]])
separator = preference.get("separator", "_")
if hostname and hostname_from_prefix and 'prefix' in preference:
hostname = hostname_from_prefix[0] + separator + hostname[0]
elif preference.startswith('tag:'):
hostname = self._get_tag_hostname(preference, instance)
hostname = _get_tag_hostname(preference, instance)
else:
hostname = self._get_boto_attr_chain(preference, instance)
hostname = _get_boto_attr_chain(preference, instance)

if hostname:
if isinstance(hostname, list):
Expand Down Expand Up @@ -611,41 +638,17 @@ def _populate(self, groups, hostnames, allow_duplicated_hosts=False,
use_contrib_script_compatible_ec2_tag_keys=use_contrib_script_compatible_ec2_tag_keys)
self.inventory.add_child('all', group)

@classmethod
def prepare_host_vars(cls, original_host_vars, hostvars_prefix=None, hostvars_suffix=None,
use_contrib_script_compatible_ec2_tag_keys=False):
host_vars = camel_dict_to_snake_dict(original_host_vars, ignore_list=['Tags'])
host_vars['tags'] = boto3_tag_list_to_ansible_dict(original_host_vars.get('Tags', []))

# Allow easier grouping by region
host_vars['placement']['region'] = host_vars['placement']['availability_zone'][:-1]

if use_contrib_script_compatible_ec2_tag_keys:
for k, v in host_vars['tags'].items():
host_vars["ec2_tag_%s" % k] = v

if hostvars_prefix or hostvars_suffix:
for hostvar, hostval in host_vars.copy().items():
del host_vars[hostvar]
if hostvars_prefix:
hostvar = hostvars_prefix + hostvar
if hostvars_suffix:
hostvar = hostvar + hostvars_suffix
host_vars[hostvar] = hostval

return host_vars

def iter_entry(self, hosts, hostnames, allow_duplicated_hosts=False, hostvars_prefix=None,
hostvars_suffix=None, use_contrib_script_compatible_ec2_tag_keys=False):
for host in hosts:
if allow_duplicated_hosts:
hostname_list = self.get_all_hostnames(host, hostnames)
hostname_list = self._get_all_hostnames(host, hostnames)
else:
hostname_list = [self._get_preferred_hostname(host, hostnames)]
if not hostname_list or hostname_list[0] is None:
continue

host_vars = self.prepare_host_vars(
host_vars = _prepare_host_vars(
host,
hostvars_prefix,
hostvars_suffix,
Expand Down Expand Up @@ -694,34 +697,32 @@ def verify_file(self, path):
:param path: the path to the inventory config file
:return the contents of the config file
'''
inventory_file_suffix = ('aws_ec2.yml', 'aws_ec2.yaml')
if super(InventoryModule, self).verify_file(path):
if path.endswith(('aws_ec2.yml', 'aws_ec2.yaml')):
if path.endswith(inventory_file_suffix):
return True
self.display.debug("aws_ec2 inventory filename must end with 'aws_ec2.yml' or 'aws_ec2.yaml'")
self.display.debug(f"aws_ec2 inventory filename must end with {inventory_file_suffix}")
return False

def build_include_filters(self):
result = self.get_option('include_filters')
if self.get_option('filters'):
return [self.get_option('filters')] + self.get_option('include_filters')
elif self.get_option('include_filters'):
return self.get_option('include_filters')
else: # no filter
return [{}]
result = [self.get_option('filters')] + result
return result or [{}]

def parse(self, inventory, loader, path, cache=True):

super(InventoryModule, self).parse(inventory, loader, path)

if not HAS_BOTO3:
raise AnsibleError(missing_required_lib('botocore and boto3'))
self.fail_aws(missing_required_lib('botocore and boto3'))

self._read_config_data(path)

if self.get_option('use_contrib_script_compatible_sanitization'):
self._sanitize_group_name = self._legacy_script_compatible_group_sanitization

templar = Templar(loader=loader)
self._set_credentials(templar)
self._set_credentials(loader)

# get user specifications
regions = self.get_option('regions')
Expand Down
Loading

0 comments on commit a5309df

Please sign in to comment.