Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aws
GitHub Repository: aws/aws-cli
Path: blob/develop/tests/__init__.py
1566 views
1
import collections
2
import copy
3
import os
4
import sys
5
import unittest
6
7
# Both nose and py.test will add the first parent directory it
8
# encounters that does not have a __init__.py to the sys.path. In
9
# our case, this is the root of the repository. This means that Python
10
# will import the awscli package from source instead of any installed
11
# distribution. This environment variable provides the option to remove the
12
# repository root from sys.path to be able to rely on the installed
13
# distribution when running the tests.
14
if os.environ.get('TESTS_REMOVE_REPO_ROOT_FROM_PATH'):
15
rootdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
16
sys.path = [
17
path for path in sys.path
18
if not os.path.isdir(path) or not os.path.samefile(path, rootdir)
19
]
20
21
import awscli
22
from awscli.clidriver import create_clidriver
23
from awscli.compat import collections_abc
24
from awscli.testutils import mock, capture_output
25
26
import botocore.awsrequest
27
import botocore.loaders
28
import botocore.model
29
import botocore.serialize
30
import botocore.validate
31
from botocore.compat import HAS_CRT
32
33
# A shared loader to use for classes in this module. This allows us to
34
# load models outside of influence of a session and take advantage of
35
# caching to speed up tests.
36
_LOADER = botocore.loaders.Loader()
37
38
39
class CLIRunner(object):
40
"""Runs CLI commands in a stubbed environment"""
41
def __init__(self, env=None, session_stubber=None):
42
if env is None:
43
env = self._get_default_env()
44
self.env = env
45
if session_stubber is None:
46
session_stubber = SessionStubber()
47
self._session_stubber = session_stubber
48
49
def run(self, cmdline):
50
with mock.patch('os.environ', self.env):
51
with capture_output() as output:
52
runner_result = self._do_run(cmdline)
53
runner_result.stdout = output.stdout.getvalue()
54
runner_result.stderr = output.stderr.getvalue()
55
return runner_result
56
57
def add_response(self, response):
58
self._session_stubber.add_response(response)
59
60
def _get_default_env(self):
61
# awscli/__init__.py injects AWS_DATA_PATH at import time
62
# so that we can find cli.json. This might be fixed in the
63
# future, but for now we are just replicating the logic in
64
# this abstraction.
65
cli_data_dir = os.path.join(
66
os.path.dirname(os.path.abspath(awscli.__file__)),
67
'data'
68
)
69
return {
70
'AWS_DATA_PATH': cli_data_dir,
71
'AWS_DEFAULT_REGION': 'us-west-2',
72
'AWS_ACCESS_KEY_ID': 'access_key',
73
'AWS_SECRET_ACCESS_KEY': 'secret_key',
74
'AWS_CONFIG_FILE': '',
75
'AWS_SHARED_CREDENTIALS_FILE': '',
76
}
77
78
def _do_run(self, cmdline):
79
driver = create_clidriver()
80
self._session_stubber.register(driver.session)
81
rc = driver.main(cmdline)
82
self._session_stubber.assert_no_remaining_responses()
83
runner_result = CLIRunnerResult(rc)
84
runner_result.aws_requests = copy.copy(
85
self._session_stubber.received_aws_requests
86
)
87
return runner_result
88
89
90
class SessionStubber(object):
91
def __init__(self):
92
self.received_aws_requests = []
93
self._responses = collections.deque()
94
95
def register(self, session):
96
events = session.get_component('event_emitter')
97
events.register_first(
98
'before-parameter-build.*.*', self._capture_aws_request,
99
)
100
events.register_last(
101
'request-created', self._capture_http_request
102
)
103
events.register_first(
104
'before-send.*.*', self._return_queued_http_response,
105
)
106
107
def add_response(self, response):
108
self._responses.append(response)
109
110
def assert_no_remaining_responses(self):
111
if len(self._responses) != 0:
112
raise AssertionError(
113
"The following queued responses are remaining: %s" %
114
self._responses
115
)
116
117
def _capture_aws_request(self, params, model, context, **kwargs):
118
aws_request = AWSRequest(
119
service_name=model.service_model.service_name,
120
operation_name=model.name,
121
params=params,
122
)
123
self.received_aws_requests.append(aws_request)
124
context['current_aws_request'] = aws_request
125
126
def _capture_http_request(self, request, **kwargs):
127
request.context['current_aws_request'].http_requests.append(
128
HTTPRequest(
129
method=request.method,
130
url=request.url,
131
headers=request.headers,
132
body=request.body,
133
)
134
)
135
136
def _return_queued_http_response(self, request, **kwargs):
137
response = self._responses.popleft()
138
return response.on_http_request_sent(request)
139
140
141
class BaseResponse(object):
142
def on_http_request_sent(self, request):
143
raise NotImplementedError('on_http_request_sent')
144
145
146
class AWSResponse(BaseResponse):
147
def __init__(self, service_name, operation_name, parsed_response,
148
validate=True):
149
self._service_name = service_name
150
self._operation_name = operation_name
151
self._parsed_response = parsed_response
152
self._service_model = self._get_service_model()
153
self._operation_model = self._service_model.operation_model(
154
self._operation_name)
155
if validate:
156
self._validate_parsed_response()
157
158
def on_http_request_sent(self, request):
159
return self._generate_http_response()
160
161
def __repr__(self):
162
return (
163
'AWSResponse(service_name=%r, operation_name=%r, '
164
'parsed_response=%r)' %
165
(self._service_name, self._operation_name, self._parsed_response)
166
)
167
168
def _get_service_model(self):
169
loaded_service_model = _LOADER.load_service_model(
170
service_name=self._service_name, type_name='service-2'
171
)
172
return botocore.model.ServiceModel(
173
loaded_service_model, service_name=self._service_name)
174
175
def _validate_parsed_response(self):
176
if self._operation_model.output_shape:
177
botocore.validate.validate_parameters(
178
self._parsed_response, self._operation_model.output_shape)
179
180
def _generate_http_response(self):
181
serialized = self._reverse_serialize_parsed_response()
182
return HTTPResponse(
183
headers=serialized['headers'],
184
body=serialized['body']
185
)
186
187
def _reverse_serialize_parsed_response(self):
188
# NOTE: This is fairly hacky, but it gets us a reasonable,
189
# serialized response with a fairly low amount of effort. Basically,
190
# we swap the operation model so that its input shape points to its
191
# output shape so that we can use the serializer to reverse the
192
# parsing logic and generate a raw HTTP response instead of a raw HTTP
193
# request.
194
#
195
# Theoretically this should work for many use cases (e.g. JSON
196
# protocols), but there are definitely edge cases that are not
197
# being handled (e.g. query protocol). Going forward as more tests
198
# adopt this, we **will** have to build up the logic around this.
199
serializer = botocore.serialize.create_serializer(
200
protocol_name=self._service_model.metadata['protocol'],
201
include_validation=False,
202
)
203
self._operation_model.input_shape = self._operation_model.output_shape
204
return serializer.serialize_to_request(
205
self._parsed_response, self._operation_model)
206
207
208
class HTTPResponse(BaseResponse):
209
def __init__(self, status_code=200, headers=None, body=b''):
210
self.status_code = status_code
211
if headers is None:
212
headers = {}
213
self.headers = headers
214
self.body = body
215
# Botocore's interface uses content instead of body so just
216
# making the content an alias to the body.
217
self.content = body
218
219
def on_http_request_sent(self, request):
220
return self
221
222
223
class CLIRunnerResult(object):
224
def __init__(self, rc, stdout=None, stderr=None):
225
self.rc = rc
226
self.stdout = stdout
227
self.stderr = stderr
228
self.aws_requests = []
229
230
231
class AWSRequest(object):
232
def __init__(self, service_name, operation_name, params):
233
self.service_name = service_name
234
self.operation_name = operation_name
235
self.params = params
236
self.http_requests = []
237
238
def __repr__(self):
239
return (
240
'AWSRequest(service_name=%r, operation_name=%r, params=%r)' %
241
(self.service_name, self.operation_name, self.params)
242
)
243
244
def __eq__(self, other):
245
return (
246
self.service_name == other.service_name and
247
self.operation_name == other.operation_name and
248
self.params == other.params
249
)
250
251
def __ne__(self, other):
252
return not self.__eq__(other)
253
254
255
class HTTPRequest(object):
256
def __init__(self, method, url, headers, body):
257
self.method = method
258
self.url = url
259
self.headers = headers
260
self.body = body
261
262
def __repr__(self):
263
return (
264
'HTTPRequest(method=%r, url=%r, headers=%r, body=%r)' %
265
(self.method, self.url, self.headers, self.body)
266
)
267
268
def __eq__(self, other):
269
return (
270
self.method == other.method and
271
self.url == other.url and
272
self.headers == other.headers and
273
self.body == other.body
274
)
275
276
def __ne__(self, other):
277
return not self.__eq__(other)
278
279
280
# CaseInsensitiveDict from requests that must be serializble.
281
class CaseInsensitiveDict(collections_abc.MutableMapping):
282
def __init__(self, data=None, **kwargs):
283
self._store = dict()
284
if data is None:
285
data = {}
286
self.update(data, **kwargs)
287
288
def __setitem__(self, key, value):
289
# Use the lowercased key for lookups, but store the actual
290
# key alongside the value.
291
self._store[key.lower()] = (key, value)
292
293
def __getitem__(self, key):
294
return self._store[key.lower()][1]
295
296
def __delitem__(self, key):
297
del self._store[key.lower()]
298
299
def __iter__(self):
300
return (casedkey for casedkey, mappedvalue in self._store.values())
301
302
def __len__(self):
303
return len(self._store)
304
305
def lower_items(self):
306
"""Like iteritems(), but with all lowercase keys."""
307
return (
308
(lowerkey, keyval[1])
309
for (lowerkey, keyval)
310
in self._store.items()
311
)
312
313
def __eq__(self, other):
314
if isinstance(other, collections_abc.Mapping):
315
other = CaseInsensitiveDict(other)
316
else:
317
return NotImplemented
318
# Compare insensitively
319
return dict(self.lower_items()) == dict(other.lower_items())
320
321
# Copy is required
322
def copy(self):
323
return CaseInsensitiveDict(self._store.values())
324
325
def __repr__(self):
326
return str(dict(self.items()))
327
328
329
def requires_crt(reason=None):
330
if reason is None:
331
reason = "Test requires awscrt to be installed"
332
333
def decorator(func):
334
return unittest.skipIf(not HAS_CRT, reason)(func)
335
336
return decorator
337
338