Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aws
GitHub Repository: aws/aws-cli
Path: blob/develop/awscli/testutils.py
1566 views
1
# Copyright 2012-2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License"). You
4
# may not use this file except in compliance with the License. A copy of
5
# the License is located at
6
#
7
# http://aws.amazon.com/apache2.0/
8
#
9
# or in the "license" file accompanying this file. This file is
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
# ANY KIND, either express or implied. See the License for the specific
12
# language governing permissions and limitations under the License.
13
"""Test utilities for the AWS CLI.
14
15
This module includes various classes/functions that help in writing
16
CLI unit/integration tests. This module should not be imported by
17
any module **except** for test code. This is included in the CLI
18
package so that code that is not part of the CLI can still take
19
advantage of all the testing utilities we provide.
20
21
"""
22
23
import binascii
24
import contextlib
25
import copy
26
import json
27
import logging
28
import os
29
import platform
30
import shutil
31
import sys
32
import tempfile
33
import time
34
import unittest
35
import uuid
36
from pprint import pformat
37
from subprocess import PIPE, Popen
38
from unittest import mock
39
40
import botocore.loaders
41
from botocore.awsrequest import AWSResponse
42
from botocore.exceptions import ClientError, WaiterError
43
44
import awscli.clidriver
45
from awscli.compat import BytesIO, StringIO
46
from awscli.utils import create_nested_client
47
48
_LOADER = botocore.loaders.Loader()
49
INTEG_LOG = logging.getLogger('awscli.tests.integration')
50
AWS_CMD = None
51
52
53
def skip_if_windows(reason):
54
"""Decorator to skip tests that should not be run on windows.
55
56
Example usage:
57
58
@skip_if_windows("Not valid")
59
def test_some_non_windows_stuff(self):
60
self.assertEqual(...)
61
62
"""
63
64
def decorator(func):
65
return unittest.skipIf(
66
platform.system() not in ['Darwin', 'Linux'], reason
67
)(func)
68
69
return decorator
70
71
72
def create_clidriver():
73
driver = awscli.clidriver.create_clidriver()
74
session = driver.session
75
data_path = session.get_config_variable('data_path').split(os.pathsep)
76
if not data_path:
77
data_path = []
78
_LOADER.search_paths.extend(data_path)
79
session.register_component('data_loader', _LOADER)
80
return driver
81
82
83
def get_aws_cmd():
84
global AWS_CMD
85
import awscli
86
87
if AWS_CMD is None:
88
# Try <repo>/bin/aws
89
repo_root = os.path.dirname(os.path.abspath(awscli.__file__))
90
aws_cmd = os.path.join(repo_root, 'bin', 'aws')
91
if not os.path.isfile(aws_cmd):
92
aws_cmd = _search_path_for_cmd('aws')
93
if aws_cmd is None:
94
raise ValueError(
95
'Could not find "aws" executable. Either '
96
'make sure it is on your PATH, or you can '
97
'explicitly set this value using '
98
'"set_aws_cmd()"'
99
)
100
AWS_CMD = aws_cmd
101
return AWS_CMD
102
103
104
def _search_path_for_cmd(cmd_name):
105
for path in os.environ.get('PATH', '').split(os.pathsep):
106
full_cmd_path = os.path.join(path, cmd_name)
107
if os.path.isfile(full_cmd_path):
108
return full_cmd_path
109
return None
110
111
112
def set_aws_cmd(aws_cmd):
113
global AWS_CMD
114
AWS_CMD = aws_cmd
115
116
117
@contextlib.contextmanager
118
def temporary_file(mode):
119
"""This is a cross platform temporary file creation.
120
121
tempfile.NamedTemporary file on windows creates a secure temp file
122
that can't be read by other processes and can't be opened a second time.
123
124
For tests, we generally *want* them to be read multiple times.
125
The test fixture writes the temp file contents, the test reads the
126
temp file.
127
128
"""
129
temporary_directory = tempfile.mkdtemp()
130
basename = 'tmpfile-%s' % str(random_chars(8))
131
full_filename = os.path.join(temporary_directory, basename)
132
open(full_filename, 'w').close()
133
try:
134
with open(full_filename, mode) as f:
135
yield f
136
finally:
137
shutil.rmtree(temporary_directory)
138
139
140
def create_bucket(session, name=None, region=None):
141
"""
142
Creates a bucket
143
:returns: the name of the bucket created
144
"""
145
if not region:
146
region = 'us-west-2'
147
client = create_nested_client(session, 's3', region_name=region)
148
if name:
149
bucket_name = name
150
else:
151
bucket_name = random_bucket_name()
152
params = {'Bucket': bucket_name, 'ObjectOwnership': 'ObjectWriter'}
153
if region != 'us-east-1':
154
params['CreateBucketConfiguration'] = {'LocationConstraint': region}
155
try:
156
client.create_bucket(**params)
157
except ClientError as e:
158
if e.response['Error'].get('Code') == 'BucketAlreadyOwnedByYou':
159
# This can happen in the retried request, when the first one
160
# succeeded on S3 but somehow the response never comes back.
161
# We still got a bucket ready for test anyway.
162
pass
163
else:
164
raise
165
return bucket_name
166
167
168
def create_dir_bucket(session, name=None, location=None):
169
"""
170
Creates a S3 directory bucket
171
:returns: the name of the bucket created
172
"""
173
if not location:
174
location = ('us-west-2', 'usw2-az1')
175
region, az = location
176
client = create_nested_client(session, 's3', region_name=region)
177
if name:
178
bucket_name = name
179
else:
180
bucket_name = f"{random_bucket_name()}--{az}--x-s3"
181
params = {
182
'Bucket': bucket_name,
183
'CreateBucketConfiguration': {
184
'Location': {'Type': 'AvailabilityZone', 'Name': az},
185
'Bucket': {
186
'Type': 'Directory',
187
'DataRedundancy': 'SingleAvailabilityZone',
188
},
189
},
190
}
191
try:
192
client.create_bucket(**params)
193
except ClientError as e:
194
if e.response['Error'].get('Code') == 'BucketAlreadyOwnedByYou':
195
# This can happen in the retried request, when the first one
196
# succeeded on S3 but somehow the response never comes back.
197
# We still got a bucket ready for test anyway.
198
pass
199
else:
200
raise
201
return bucket_name
202
203
204
def random_chars(num_chars):
205
"""Returns random hex characters.
206
207
Useful for creating resources with random names.
208
209
"""
210
return binascii.hexlify(os.urandom(int(num_chars / 2))).decode('ascii')
211
212
213
def random_bucket_name(prefix='awscli-s3integ', num_random=15):
214
"""Generate a random S3 bucket name.
215
216
:param prefix: A prefix to use in the bucket name. Useful
217
for tracking resources. This default value makes it easy
218
to see which buckets were created from CLI integ tests.
219
:param num_random: Number of random chars to include in the bucket name.
220
221
:returns: The name of a randomly generated bucket name as a string.
222
223
"""
224
return f"{prefix}-{random_chars(num_random)}-{int(time.time())}"
225
226
227
class BaseCLIDriverTest(unittest.TestCase):
228
"""Base unittest that use clidriver.
229
230
This will load all the default plugins as well so it
231
will simulate the behavior the user will see.
232
"""
233
234
def setUp(self):
235
self.environ = {
236
'AWS_DATA_PATH': os.environ['AWS_DATA_PATH'],
237
'AWS_DEFAULT_REGION': 'us-east-1',
238
'AWS_ACCESS_KEY_ID': 'access_key',
239
'AWS_SECRET_ACCESS_KEY': 'secret_key',
240
'AWS_CONFIG_FILE': '',
241
}
242
self.environ_patch = mock.patch('os.environ', self.environ)
243
self.environ_patch.start()
244
self.driver = create_clidriver()
245
self.session = self.driver.session
246
247
def tearDown(self):
248
self.environ_patch.stop()
249
250
251
class BaseAWSHelpOutputTest(BaseCLIDriverTest):
252
def setUp(self):
253
super().setUp()
254
self.renderer_patch = mock.patch('awscli.help.get_renderer')
255
self.renderer_mock = self.renderer_patch.start()
256
self.renderer = CapturedRenderer()
257
self.renderer_mock.return_value = self.renderer
258
259
def tearDown(self):
260
super().tearDown()
261
self.renderer_patch.stop()
262
263
def assert_contains(self, contains):
264
if contains not in self.renderer.rendered_contents:
265
self.fail(
266
"The expected contents:\n%s\nwere not in the "
267
"actual rendered contents:\n%s"
268
% (contains, self.renderer.rendered_contents)
269
)
270
271
def assert_contains_with_count(self, contains, count):
272
r_count = self.renderer.rendered_contents.count(contains)
273
if r_count != count:
274
self.fail(
275
"The expected contents:\n%s\n, with the "
276
"count:\n%d\nwere not in the actual rendered "
277
" contents:\n%s\nwith count:\n%d"
278
% (contains, count, self.renderer.rendered_contents, r_count)
279
)
280
281
def assert_not_contains(self, contents):
282
if contents in self.renderer.rendered_contents:
283
self.fail(
284
"The contents:\n%s\nwere not suppose to be in the "
285
"actual rendered contents:\n%s"
286
% (contents, self.renderer.rendered_contents)
287
)
288
289
def assert_text_order(self, *args, **kwargs):
290
# First we need to find where the SYNOPSIS section starts.
291
starting_from = kwargs.pop('starting_from')
292
args = list(args)
293
contents = self.renderer.rendered_contents
294
self.assertIn(starting_from, contents)
295
start_index = contents.find(starting_from)
296
arg_indices = [contents.find(arg, start_index) for arg in args]
297
previous = arg_indices[0]
298
for i, index in enumerate(arg_indices[1:], 1):
299
if index == -1:
300
self.fail(
301
'The string %r was not found in the contents: %s'
302
% (args[index], contents)
303
)
304
if index < previous:
305
self.fail(
306
'The string %r came before %r, but was suppose to come '
307
'after it.\n%s' % (args[i], args[i - 1], contents)
308
)
309
previous = index
310
311
312
class CapturedRenderer:
313
def __init__(self):
314
self.rendered_contents = ''
315
316
def render(self, contents):
317
self.rendered_contents = contents.decode('utf-8')
318
319
320
class CapturedOutput:
321
def __init__(self, stdout, stderr):
322
self.stdout = stdout
323
self.stderr = stderr
324
325
326
@contextlib.contextmanager
327
def capture_output():
328
stderr = StringIO()
329
stdout = StringIO()
330
with mock.patch('sys.stderr', stderr):
331
with mock.patch('sys.stdout', stdout):
332
yield CapturedOutput(stdout, stderr)
333
334
335
@contextlib.contextmanager
336
def capture_input(input_bytes=b''):
337
input_data = BytesIO(input_bytes)
338
mock_object = mock.Mock()
339
mock_object.buffer = input_data
340
341
with mock.patch('sys.stdin', mock_object):
342
yield input_data
343
344
345
class BaseAWSCommandParamsTest(unittest.TestCase):
346
maxDiff = None
347
348
def setUp(self):
349
self.last_params = {}
350
self.last_kwargs = None
351
# awscli/__init__.py injects AWS_DATA_PATH at import time
352
# so that we can find cli.json. This might be fixed in the
353
# future, but for now we just grab that value out of the real
354
# os.environ so the patched os.environ has this data and
355
# the CLI works.
356
self.environ = {
357
'AWS_DATA_PATH': os.environ['AWS_DATA_PATH'],
358
'AWS_DEFAULT_REGION': 'us-east-1',
359
'AWS_ACCESS_KEY_ID': 'access_key',
360
'AWS_SECRET_ACCESS_KEY': 'secret_key',
361
'AWS_CONFIG_FILE': '',
362
'AWS_SHARED_CREDENTIALS_FILE': '',
363
}
364
if os.environ.get('ComSpec'):
365
self.environ['ComSpec'] = os.environ['ComSpec']
366
self.environ_patch = mock.patch('os.environ', self.environ)
367
self.environ_patch.start()
368
self.http_response = AWSResponse(None, 200, {}, None)
369
self.parsed_response = {}
370
self.make_request_patch = mock.patch(
371
'botocore.endpoint.Endpoint.make_request'
372
)
373
self.make_request_is_patched = False
374
self.operations_called = []
375
self.parsed_responses = None
376
self.driver = create_clidriver()
377
378
def tearDown(self):
379
# This clears all the previous registrations.
380
self.environ_patch.stop()
381
if self.make_request_is_patched:
382
self.make_request_patch.stop()
383
self.make_request_is_patched = False
384
385
def before_call(self, params, **kwargs):
386
self._store_params(params)
387
388
def _store_params(self, params):
389
self.last_request_dict = params
390
self.last_params = params['body']
391
392
def patch_make_request(self):
393
# If you do not stop a previously started patch,
394
# it can never be stopped if you call start() again on the same
395
# patch again...
396
# So stop the current patch before calling start() on it again.
397
if self.make_request_is_patched:
398
self.make_request_patch.stop()
399
self.make_request_is_patched = False
400
make_request_patch = self.make_request_patch.start()
401
if self.parsed_responses is not None:
402
make_request_patch.side_effect = lambda *args, **kwargs: (
403
self.http_response,
404
self.parsed_responses.pop(0),
405
)
406
else:
407
make_request_patch.return_value = (
408
self.http_response,
409
self.parsed_response,
410
)
411
self.make_request_is_patched = True
412
413
def assert_params_for_cmd(
414
self,
415
cmd,
416
params=None,
417
expected_rc=0,
418
stderr_contains=None,
419
ignore_params=None,
420
):
421
stdout, stderr, rc = self.run_cmd(cmd, expected_rc)
422
if stderr_contains is not None:
423
self.assertIn(stderr_contains, stderr)
424
if params is not None:
425
# The last kwargs of Operation.call() in botocore.
426
last_kwargs = copy.copy(self.last_kwargs)
427
if ignore_params is not None:
428
for key in ignore_params:
429
try:
430
del last_kwargs[key]
431
except KeyError:
432
pass
433
if params != last_kwargs:
434
self.fail(
435
"Actual params did not match expected params.\n"
436
"Expected:\n\n"
437
"%s\n"
438
"Actual:\n\n%s\n" % (pformat(params), pformat(last_kwargs))
439
)
440
return stdout, stderr, rc
441
442
def before_parameter_build(self, params, model, **kwargs):
443
self.last_kwargs = params
444
self.operations_called.append((model, params.copy()))
445
446
def run_cmd(self, cmd, expected_rc=0):
447
logging.debug("Calling cmd: %s", cmd)
448
self.patch_make_request()
449
event_emitter = self.driver.session.get_component('event_emitter')
450
event_emitter.register('before-call', self.before_call)
451
event_emitter.register_first(
452
'before-parameter-build.*.*', self.before_parameter_build
453
)
454
if not isinstance(cmd, list):
455
cmdlist = cmd.split()
456
else:
457
cmdlist = cmd
458
459
with capture_output() as captured:
460
try:
461
rc = self.driver.main(cmdlist)
462
except SystemExit as e:
463
# We need to catch SystemExit so that we
464
# can get a proper rc and still present the
465
# stdout/stderr to the test runner so we can
466
# figure out what went wrong.
467
rc = e.code
468
stderr = captured.stderr.getvalue()
469
stdout = captured.stdout.getvalue()
470
self.assertEqual(
471
rc,
472
expected_rc,
473
"Unexpected rc (expected: %s, actual: %s) for command: %s\n"
474
"stdout:\n%sstderr:\n%s" % (expected_rc, rc, cmd, stdout, stderr),
475
)
476
return stdout, stderr, rc
477
478
479
class BaseAWSPreviewCommandParamsTest(BaseAWSCommandParamsTest):
480
def setUp(self):
481
self.preview_patch = mock.patch(
482
'awscli.customizations.preview.mark_as_preview'
483
)
484
self.preview_patch.start()
485
super().setUp()
486
487
def tearDown(self):
488
self.preview_patch.stop()
489
super().tearDown()
490
491
492
class BaseCLIWireResponseTest(unittest.TestCase):
493
def setUp(self):
494
self.environ = {
495
'AWS_DATA_PATH': os.environ['AWS_DATA_PATH'],
496
'AWS_DEFAULT_REGION': 'us-east-1',
497
'AWS_ACCESS_KEY_ID': 'access_key',
498
'AWS_SECRET_ACCESS_KEY': 'secret_key',
499
'AWS_CONFIG_FILE': '',
500
}
501
self.environ_patch = mock.patch('os.environ', self.environ)
502
self.environ_patch.start()
503
# TODO: fix this patch when we have a better way to stub out responses
504
self.send_patch = mock.patch('botocore.endpoint.Endpoint._send')
505
self.send_is_patched = False
506
self.driver = create_clidriver()
507
508
def tearDown(self):
509
self.environ_patch.stop()
510
if self.send_is_patched:
511
self.send_patch.stop()
512
self.send_is_patched = False
513
514
def patch_send(self, status_code=200, headers={}, content=b''):
515
if self.send_is_patched:
516
self.send_patch.stop()
517
self.send_is_patched = False
518
send_patch = self.send_patch.start()
519
send_patch.return_value = mock.Mock(
520
status_code=status_code, headers=headers, content=content
521
)
522
self.send_is_patched = True
523
524
def run_cmd(self, cmd, expected_rc=0):
525
if not isinstance(cmd, list):
526
cmdlist = cmd.split()
527
else:
528
cmdlist = cmd
529
with capture_output() as captured:
530
try:
531
rc = self.driver.main(cmdlist)
532
except SystemExit as e:
533
rc = e.code
534
stderr = captured.stderr.getvalue()
535
stdout = captured.stdout.getvalue()
536
self.assertEqual(
537
rc,
538
expected_rc,
539
"Unexpected rc (expected: %s, actual: %s) for command: %s\n"
540
"stdout:\n%sstderr:\n%s" % (expected_rc, rc, cmd, stdout, stderr),
541
)
542
return stdout, stderr, rc
543
544
545
class FileCreator:
546
def __init__(self):
547
self.rootdir = tempfile.mkdtemp()
548
549
def remove_all(self):
550
if os.path.exists(self.rootdir):
551
shutil.rmtree(self.rootdir)
552
553
def create_file(self, filename, contents, mtime=None, mode='w'):
554
"""Creates a file in a tmpdir
555
556
``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
557
It will be translated into a full path in a tmp dir.
558
559
If the ``mtime`` argument is provided, then the file's
560
mtime will be set to the provided value (must be an epoch time).
561
Otherwise the mtime is left untouched.
562
563
``mode`` is the mode the file should be opened either as ``w`` or
564
`wb``.
565
566
Returns the full path to the file.
567
568
"""
569
full_path = os.path.join(self.rootdir, filename)
570
if not os.path.isdir(os.path.dirname(full_path)):
571
os.makedirs(os.path.dirname(full_path))
572
with open(full_path, mode) as f:
573
f.write(contents)
574
current_time = os.path.getmtime(full_path)
575
# Subtract a few years off the last modification date.
576
os.utime(full_path, (current_time, current_time - 100000000))
577
if mtime is not None:
578
os.utime(full_path, (mtime, mtime))
579
return full_path
580
581
def append_file(self, filename, contents):
582
"""Append contents to a file
583
584
``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
585
It will be translated into a full path in a tmp dir.
586
587
Returns the full path to the file.
588
"""
589
full_path = os.path.join(self.rootdir, filename)
590
if not os.path.isdir(os.path.dirname(full_path)):
591
os.makedirs(os.path.dirname(full_path))
592
with open(full_path, 'a') as f:
593
f.write(contents)
594
return full_path
595
596
def full_path(self, filename):
597
"""Translate relative path to full path in temp dir.
598
599
f.full_path('foo/bar.txt') -> /tmp/asdfasd/foo/bar.txt
600
"""
601
return os.path.join(self.rootdir, filename)
602
603
604
class ProcessTerminatedError(Exception):
605
pass
606
607
608
class Result:
609
def __init__(self, rc, stdout, stderr, memory_usage=None):
610
self.rc = rc
611
self.stdout = stdout
612
self.stderr = stderr
613
INTEG_LOG.debug("rc: %s", rc)
614
INTEG_LOG.debug("stdout: %s", stdout)
615
INTEG_LOG.debug("stderr: %s", stderr)
616
if memory_usage is None:
617
memory_usage = []
618
self.memory_usage = memory_usage
619
620
@property
621
def json(self):
622
return json.loads(self.stdout)
623
624
625
def _escape_quotes(command):
626
# For windows we have different rules for escaping.
627
# First, double quotes must be escaped.
628
command = command.replace('"', '\\"')
629
# Second, single quotes do nothing, to quote a value we need
630
# to use double quotes.
631
command = command.replace("'", '"')
632
return command
633
634
635
def aws(
636
command,
637
collect_memory=False,
638
env_vars=None,
639
wait_for_finish=True,
640
input_data=None,
641
input_file=None,
642
):
643
"""Run an aws command.
644
645
This help function abstracts the differences of running the "aws"
646
command on different platforms.
647
648
If collect_memory is ``True`` the the Result object will have a list
649
of memory usage taken at 2 second intervals. The memory usage
650
will be in bytes.
651
652
If env_vars is None, this will set the environment variables
653
to be used by the aws process.
654
655
If wait_for_finish is False, then the Process object is returned
656
to the caller. It is then the caller's responsibility to ensure
657
proper cleanup. This can be useful if you want to test timeout's
658
or how the CLI responds to various signals.
659
660
:type input_data: string
661
:param input_data: This string will be communicated to the process through
662
the stdin of the process. It essentially allows the user to
663
avoid having to use a file handle to pass information to the process.
664
Note that this string is not passed on creation of the process, but
665
rather communicated to the process.
666
667
:type input_file: a file handle
668
:param input_file: This is a file handle that will act as the
669
the stdin of the process immediately on creation. Essentially
670
any data written to the file will be read from stdin of the
671
process. This is needed if you plan to stream data into stdin while
672
collecting memory.
673
"""
674
if platform.system() == 'Windows':
675
command = _escape_quotes(command)
676
if 'AWS_TEST_COMMAND' in os.environ:
677
aws_command = os.environ['AWS_TEST_COMMAND']
678
else:
679
aws_command = 'python %s' % get_aws_cmd()
680
full_command = '%s %s' % (aws_command, command)
681
stdout_encoding = get_stdout_encoding()
682
INTEG_LOG.debug("Running command: %s", full_command)
683
env = os.environ.copy()
684
if 'AWS_DEFAULT_REGION' not in env:
685
env['AWS_DEFAULT_REGION'] = "us-east-1"
686
if env_vars is not None:
687
env = env_vars
688
if input_file is None:
689
input_file = PIPE
690
process = Popen(
691
full_command,
692
stdout=PIPE,
693
stderr=PIPE,
694
stdin=input_file,
695
shell=True,
696
env=env,
697
)
698
if not wait_for_finish:
699
return process
700
memory = None
701
if not collect_memory:
702
kwargs = {}
703
if input_data:
704
kwargs = {'input': input_data}
705
stdout, stderr = process.communicate(**kwargs)
706
else:
707
stdout, stderr, memory = _wait_and_collect_mem(process)
708
return Result(
709
process.returncode,
710
stdout.decode(stdout_encoding),
711
stderr.decode(stdout_encoding),
712
memory,
713
)
714
715
716
def get_stdout_encoding():
717
encoding = getattr(sys.__stdout__, 'encoding', None)
718
if encoding is None:
719
encoding = 'utf-8'
720
return encoding
721
722
723
def _wait_and_collect_mem(process):
724
# We only know how to collect memory on mac/linux.
725
if platform.system() == 'Darwin':
726
get_memory = _get_memory_with_ps
727
elif platform.system() == 'Linux':
728
get_memory = _get_memory_with_ps
729
else:
730
raise ValueError(
731
f"Can't collect memory for process on platform {platform.system()}."
732
)
733
memory = []
734
while process.poll() is None:
735
try:
736
current = get_memory(process.pid)
737
except ProcessTerminatedError:
738
# It's possible the process terminated between .poll()
739
# and get_memory().
740
break
741
memory.append(current)
742
stdout, stderr = process.communicate()
743
return stdout, stderr, memory
744
745
746
def _get_memory_with_ps(pid):
747
# It's probably possible to do with proc_pidinfo and ctypes on a Mac,
748
# but we'll do it the easy way with parsing ps output.
749
command_list = 'ps u -p'.split()
750
command_list.append(str(pid))
751
p = Popen(command_list, stdout=PIPE)
752
stdout = p.communicate()[0]
753
if not p.returncode == 0:
754
raise ProcessTerminatedError(str(pid))
755
else:
756
# Get the RSS from output that looks like this:
757
# USER PID %CPU %MEM VSZ RSS TT STAT STARTED TIME COMMAND
758
# user 47102 0.0 0.1 2437000 4496 s002 S+ 7:04PM 0:00.12 python2.6
759
return int(stdout.splitlines()[1].split()[5]) * 1024
760
761
762
class BaseS3CLICommand(unittest.TestCase):
763
"""Base class for aws s3 command.
764
765
This contains convenience functions to make writing these tests easier
766
and more streamlined.
767
768
"""
769
770
_PUT_HEAD_SHARED_EXTRAS = [
771
'SSECustomerAlgorithm',
772
'SSECustomerKey',
773
'SSECustomerKeyMD5',
774
'RequestPayer',
775
]
776
777
def setUp(self):
778
self.files = FileCreator()
779
self.session = botocore.session.get_session()
780
self.regions = {}
781
self.region = 'us-west-2'
782
self.client = create_nested_client(self.session, 's3', region_name=self.region)
783
self.extra_setup()
784
785
def extra_setup(self):
786
# Subclasses can use this to define extra setup steps.
787
pass
788
789
def tearDown(self):
790
self.files.remove_all()
791
self.extra_teardown()
792
793
def extra_teardown(self):
794
# Subclasses can use this to define extra teardown steps.
795
pass
796
797
def override_parser(self, **kwargs):
798
factory = self.session.get_component('response_parser_factory')
799
factory.set_parser_defaults(**kwargs)
800
801
def create_client_for_bucket(self, bucket_name):
802
region = self.regions.get(bucket_name, self.region)
803
client = create_nested_client(self.session, 's3', region_name=region)
804
return client
805
806
def assert_key_contents_equal(self, bucket, key, expected_contents):
807
self.wait_until_key_exists(bucket, key)
808
if isinstance(expected_contents, BytesIO):
809
expected_contents = expected_contents.getvalue().decode('utf-8')
810
actual_contents = self.get_key_contents(bucket, key)
811
# The contents can be huge so we try to give helpful error messages
812
# without necessarily printing the actual contents.
813
self.assertEqual(len(actual_contents), len(expected_contents))
814
if actual_contents != expected_contents:
815
self.fail(
816
f"Contents for {bucket}/{key} do not match (but they "
817
"have the same length)"
818
)
819
820
def delete_public_access_block(self, bucket_name):
821
client = self.create_client_for_bucket(bucket_name)
822
client.delete_public_access_block(Bucket=bucket_name)
823
824
def create_bucket(self, name=None, region=None):
825
if not region:
826
region = self.region
827
bucket_name = create_bucket(self.session, name, region)
828
self.regions[bucket_name] = region
829
self.addCleanup(self.delete_bucket, bucket_name)
830
831
# Wait for the bucket to exist before letting it be used.
832
self.wait_bucket_exists(bucket_name)
833
self.delete_public_access_block(bucket_name)
834
return bucket_name
835
836
def create_dir_bucket(self, name=None, location=None):
837
if location:
838
region, _ = location
839
else:
840
region = self.region
841
bucket_name = create_dir_bucket(self.session, name, location)
842
self.regions[bucket_name] = region
843
self.addCleanup(self.delete_bucket, bucket_name)
844
845
# Wait for the bucket to exist before letting it be used.
846
self.wait_bucket_exists(bucket_name)
847
return bucket_name
848
849
def put_object(self, bucket_name, key_name, contents='', extra_args=None):
850
client = self.create_client_for_bucket(bucket_name)
851
call_args = {'Bucket': bucket_name, 'Key': key_name, 'Body': contents}
852
if extra_args is not None:
853
call_args.update(extra_args)
854
response = client.put_object(**call_args)
855
self.addCleanup(self.delete_key, bucket_name, key_name)
856
extra_head_params = {}
857
if extra_args:
858
extra_head_params = dict(
859
(k, v)
860
for (k, v) in extra_args.items()
861
if k in self._PUT_HEAD_SHARED_EXTRAS
862
)
863
self.wait_until_key_exists(
864
bucket_name,
865
key_name,
866
extra_params=extra_head_params,
867
)
868
return response
869
870
def delete_bucket(self, bucket_name, attempts=5, delay=5):
871
self.remove_all_objects(bucket_name)
872
client = self.create_client_for_bucket(bucket_name)
873
874
# There's a chance that, even though the bucket has been used
875
# several times, the delete will fail due to eventual consistency
876
# issues.
877
attempts_remaining = attempts
878
while True:
879
attempts_remaining -= 1
880
try:
881
client.delete_bucket(Bucket=bucket_name)
882
break
883
except client.exceptions.NoSuchBucket:
884
if self.bucket_not_exists(bucket_name):
885
# Fast fail when the NoSuchBucket error is real.
886
break
887
if attempts_remaining <= 0:
888
raise
889
time.sleep(delay)
890
891
self.regions.pop(bucket_name, None)
892
893
def remove_all_objects(self, bucket_name):
894
client = self.create_client_for_bucket(bucket_name)
895
paginator = client.get_paginator('list_objects_v2')
896
pages = paginator.paginate(Bucket=bucket_name)
897
key_names = []
898
for page in pages:
899
key_names += [obj['Key'] for obj in page.get('Contents', [])]
900
for key_name in key_names:
901
self.delete_key(bucket_name, key_name)
902
903
def delete_key(self, bucket_name, key_name):
904
client = self.create_client_for_bucket(bucket_name)
905
response = client.delete_object(Bucket=bucket_name, Key=key_name)
906
907
def get_key_contents(self, bucket_name, key_name):
908
self.wait_until_key_exists(bucket_name, key_name)
909
client = self.create_client_for_bucket(bucket_name)
910
response = client.get_object(Bucket=bucket_name, Key=key_name)
911
return response['Body'].read().decode('utf-8')
912
913
def wait_bucket_exists(self, bucket_name, min_successes=3):
914
client = self.create_client_for_bucket(bucket_name)
915
waiter = client.get_waiter('bucket_exists')
916
consistency_waiter = ConsistencyWaiter(
917
min_successes=min_successes, delay_initial_poll=True
918
)
919
consistency_waiter.wait(
920
lambda: waiter.wait(Bucket=bucket_name) is None
921
)
922
923
def bucket_not_exists(self, bucket_name):
924
client = self.create_client_for_bucket(bucket_name)
925
try:
926
client.head_bucket(Bucket=bucket_name)
927
return True
928
except ClientError as error:
929
if error.response.get('Code') == '404':
930
return False
931
raise
932
933
def key_exists(self, bucket_name, key_name, min_successes=3):
934
try:
935
self.wait_until_key_exists(
936
bucket_name, key_name, min_successes=min_successes
937
)
938
return True
939
except (ClientError, WaiterError):
940
return False
941
942
def key_not_exists(self, bucket_name, key_name, min_successes=3):
943
try:
944
self.wait_until_key_not_exists(
945
bucket_name, key_name, min_successes=min_successes
946
)
947
return True
948
except (ClientError, WaiterError):
949
return False
950
951
def list_buckets(self):
952
response = self.client.list_buckets()
953
return response['Buckets']
954
955
def content_type_for_key(self, bucket_name, key_name):
956
parsed = self.head_object(bucket_name, key_name)
957
return parsed['ContentType']
958
959
def head_object(self, bucket_name, key_name):
960
client = self.create_client_for_bucket(bucket_name)
961
response = client.head_object(Bucket=bucket_name, Key=key_name)
962
return response
963
964
def wait_until_key_exists(
965
self, bucket_name, key_name, extra_params=None, min_successes=3
966
):
967
self._wait_for_key(
968
bucket_name, key_name, extra_params, min_successes, exists=True
969
)
970
971
def wait_until_key_not_exists(
972
self, bucket_name, key_name, extra_params=None, min_successes=3
973
):
974
self._wait_for_key(
975
bucket_name, key_name, extra_params, min_successes, exists=False
976
)
977
978
def _wait_for_key(
979
self,
980
bucket_name,
981
key_name,
982
extra_params=None,
983
min_successes=3,
984
exists=True,
985
):
986
client = self.create_client_for_bucket(bucket_name)
987
if exists:
988
waiter = client.get_waiter('object_exists')
989
else:
990
waiter = client.get_waiter('object_not_exists')
991
params = {'Bucket': bucket_name, 'Key': key_name}
992
if extra_params is not None:
993
params.update(extra_params)
994
for _ in range(min_successes):
995
waiter.wait(**params)
996
997
def assert_no_errors(self, p):
998
self.assertEqual(
999
p.rc,
1000
0,
1001
"Non zero rc (%s) received: %s" % (p.rc, p.stdout + p.stderr),
1002
)
1003
self.assertNotIn("Error:", p.stderr)
1004
self.assertNotIn("failed:", p.stderr)
1005
self.assertNotIn("client error", p.stderr)
1006
self.assertNotIn("server error", p.stderr)
1007
1008
1009
class StringIOWithFileNo(StringIO):
1010
def fileno(self):
1011
return 0
1012
1013
1014
class TestEventHandler:
1015
def __init__(self, handler=None):
1016
self._handler = handler
1017
self._called = False
1018
self.__test__ = False
1019
1020
@property
1021
def called(self):
1022
return self._called
1023
1024
def handler(self, **kwargs):
1025
self._called = True
1026
if self._handler is not None:
1027
self._handler(**kwargs)
1028
1029
1030
class ConsistencyWaiterException(Exception):
1031
pass
1032
1033
1034
class ConsistencyWaiter:
1035
"""
1036
A waiter class for some check to reach a consistent state.
1037
1038
:type min_successes: int
1039
:param min_successes: The minimum number of successful check calls to
1040
treat the check as stable. Default of 1 success.
1041
1042
:type max_attempts: int
1043
:param min_successes: The maximum number of times to attempt calling
1044
the check. Default of 20 attempts.
1045
1046
:type delay: int
1047
:param delay: The number of seconds to delay the next API call after a
1048
failed check call. Default of 5 seconds.
1049
"""
1050
1051
def __init__(
1052
self,
1053
min_successes=1,
1054
max_attempts=20,
1055
delay=5,
1056
delay_initial_poll=False,
1057
):
1058
self.min_successes = min_successes
1059
self.max_attempts = max_attempts
1060
self.delay = delay
1061
self.delay_initial_poll = delay_initial_poll
1062
1063
def wait(self, check, *args, **kwargs):
1064
"""
1065
Wait until the check succeeds the configured number of times
1066
1067
:type check: callable
1068
:param check: A callable that returns True or False to indicate
1069
if the check succeeded or failed.
1070
1071
:type args: list
1072
:param args: Any ordered arguments to be passed to the check.
1073
1074
:type kwargs: dict
1075
:param kwargs: Any keyword arguments to be passed to the check.
1076
"""
1077
attempts = 0
1078
successes = 0
1079
if self.delay_initial_poll:
1080
time.sleep(self.delay)
1081
while attempts < self.max_attempts:
1082
attempts += 1
1083
if check(*args, **kwargs):
1084
successes += 1
1085
if successes >= self.min_successes:
1086
return
1087
else:
1088
time.sleep(self.delay)
1089
fail_msg = self._fail_message(attempts, successes)
1090
raise ConsistencyWaiterException(fail_msg)
1091
1092
def _fail_message(self, attempts, successes):
1093
format_args = (attempts, successes)
1094
return 'Failed after %s attempts, only had %s successes' % format_args
1095
1096
1097
@contextlib.contextmanager
1098
def cd(path):
1099
try:
1100
original_dir = os.getcwd()
1101
os.chdir(path)
1102
yield
1103
finally:
1104
os.chdir(original_dir)
1105
1106