Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aws
GitHub Repository: aws/aws-cli
Path: blob/develop/awscli/testutils.py
2634 views
1
# Copyright 2012-2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License"). You
4
# may not use this file except in compliance with the License. A copy of
5
# the License is located at
6
#
7
# http://aws.amazon.com/apache2.0/
8
#
9
# or in the "license" file accompanying this file. This file is
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
# ANY KIND, either express or implied. See the License for the specific
12
# language governing permissions and limitations under the License.
13
"""Test utilities for the AWS CLI.
14
15
This module includes various classes/functions that help in writing
16
CLI unit/integration tests. This module should not be imported by
17
any module **except** for test code. This is included in the CLI
18
package so that code that is not part of the CLI can still take
19
advantage of all the testing utilities we provide.
20
21
"""
22
23
import binascii
24
import contextlib
25
import copy
26
import json
27
import logging
28
import os
29
import platform
30
import random
31
import shutil
32
import string
33
import sys
34
import tempfile
35
import time
36
import unittest
37
import uuid
38
from pprint import pformat
39
from subprocess import PIPE, Popen
40
from unittest import mock
41
from pathlib import Path
42
43
import botocore.loaders
44
from botocore.awsrequest import AWSResponse
45
from botocore.exceptions import ClientError, WaiterError
46
47
import awscli.clidriver
48
from awscli.compat import BytesIO, StringIO
49
from awscli.utils import create_nested_client
50
51
_LOADER = botocore.loaders.Loader()
52
INTEG_LOG = logging.getLogger('awscli.tests.integration')
53
AWS_CMD = None
54
55
with tempfile.TemporaryDirectory() as tmpdir:
56
with open(Path(tmpdir) / 'aws-cli-tmp-file', 'w') as f:
57
pass
58
CASE_INSENSITIVE = (Path(tmpdir) / 'AWS-CLI-TMP-FILE').exists()
59
60
61
def skip_if_windows(reason):
62
"""Decorator to skip tests that should not be run on windows.
63
64
Example usage:
65
66
@skip_if_windows("Not valid")
67
def test_some_non_windows_stuff(self):
68
self.assertEqual(...)
69
70
"""
71
72
def decorator(func):
73
return unittest.skipIf(
74
platform.system() not in ['Darwin', 'Linux'], reason
75
)(func)
76
77
return decorator
78
79
80
def skip_if_case_sensitive():
81
def decorator(func):
82
return unittest.skipIf(
83
not CASE_INSENSITIVE,
84
"This test requires a case-insensitive filesystem."
85
)(func)
86
return decorator
87
88
89
def create_clidriver():
90
driver = awscli.clidriver.create_clidriver()
91
session = driver.session
92
data_path = session.get_config_variable('data_path').split(os.pathsep)
93
if not data_path:
94
data_path = []
95
_LOADER.search_paths.extend(data_path)
96
session.register_component('data_loader', _LOADER)
97
return driver
98
99
100
def get_aws_cmd():
101
global AWS_CMD
102
import awscli
103
104
if AWS_CMD is None:
105
# Try <repo>/bin/aws
106
repo_root = os.path.dirname(os.path.abspath(awscli.__file__))
107
aws_cmd = os.path.join(repo_root, 'bin', 'aws')
108
if not os.path.isfile(aws_cmd):
109
aws_cmd = _search_path_for_cmd('aws')
110
if aws_cmd is None:
111
raise ValueError(
112
'Could not find "aws" executable. Either '
113
'make sure it is on your PATH, or you can '
114
'explicitly set this value using '
115
'"set_aws_cmd()"'
116
)
117
AWS_CMD = aws_cmd
118
return AWS_CMD
119
120
121
def _search_path_for_cmd(cmd_name):
122
for path in os.environ.get('PATH', '').split(os.pathsep):
123
full_cmd_path = os.path.join(path, cmd_name)
124
if os.path.isfile(full_cmd_path):
125
return full_cmd_path
126
return None
127
128
129
def set_aws_cmd(aws_cmd):
130
global AWS_CMD
131
AWS_CMD = aws_cmd
132
133
134
@contextlib.contextmanager
135
def temporary_file(mode):
136
"""This is a cross platform temporary file creation.
137
138
tempfile.NamedTemporary file on windows creates a secure temp file
139
that can't be read by other processes and can't be opened a second time.
140
141
For tests, we generally *want* them to be read multiple times.
142
The test fixture writes the temp file contents, the test reads the
143
temp file.
144
145
"""
146
temporary_directory = tempfile.mkdtemp()
147
basename = 'tmpfile-%s' % str(random_chars(8))
148
full_filename = os.path.join(temporary_directory, basename)
149
open(full_filename, 'w').close()
150
try:
151
with open(full_filename, mode) as f:
152
yield f
153
finally:
154
shutil.rmtree(temporary_directory)
155
156
157
def create_bucket(session, name=None, region=None):
158
"""
159
Creates a bucket
160
:returns: the name of the bucket created
161
"""
162
if not region:
163
region = 'us-west-2'
164
client = create_nested_client(session, 's3', region_name=region)
165
if name:
166
bucket_name = name
167
else:
168
bucket_name = random_bucket_name()
169
params = {'Bucket': bucket_name, 'ObjectOwnership': 'ObjectWriter'}
170
if region != 'us-east-1':
171
params['CreateBucketConfiguration'] = {'LocationConstraint': region}
172
try:
173
client.create_bucket(**params)
174
except ClientError as e:
175
if e.response['Error'].get('Code') == 'BucketAlreadyOwnedByYou':
176
# This can happen in the retried request, when the first one
177
# succeeded on S3 but somehow the response never comes back.
178
# We still got a bucket ready for test anyway.
179
pass
180
else:
181
raise
182
return bucket_name
183
184
185
def create_dir_bucket(session, name=None, location=None):
186
"""
187
Creates a S3 directory bucket
188
:returns: the name of the bucket created
189
"""
190
if not location:
191
location = ('us-west-2', 'usw2-az1')
192
region, az = location
193
client = create_nested_client(session, 's3', region_name=region)
194
if name:
195
bucket_name = name
196
else:
197
bucket_name = f"{random_bucket_name()}--{az}--x-s3"
198
params = {
199
'Bucket': bucket_name,
200
'CreateBucketConfiguration': {
201
'Location': {'Type': 'AvailabilityZone', 'Name': az},
202
'Bucket': {
203
'Type': 'Directory',
204
'DataRedundancy': 'SingleAvailabilityZone',
205
},
206
},
207
}
208
try:
209
client.create_bucket(**params)
210
except ClientError as e:
211
if e.response['Error'].get('Code') == 'BucketAlreadyOwnedByYou':
212
# This can happen in the retried request, when the first one
213
# succeeded on S3 but somehow the response never comes back.
214
# We still got a bucket ready for test anyway.
215
pass
216
else:
217
raise
218
return bucket_name
219
220
221
def random_chars(num_chars):
222
"""Returns random hex characters.
223
224
Useful for creating resources with random names.
225
226
"""
227
return binascii.hexlify(os.urandom(int(num_chars / 2))).decode('ascii')
228
229
230
def random_bucket_name(prefix='awscli-s3integ', num_random=15):
231
"""Generate a random S3 bucket name.
232
233
:param prefix: A prefix to use in the bucket name. Useful
234
for tracking resources. This default value makes it easy
235
to see which buckets were created from CLI integ tests.
236
:param num_random: Number of random chars to include in the bucket name.
237
238
:returns: The name of a randomly generated bucket name as a string.
239
240
"""
241
return f"{prefix}-{random_chars(num_random)}-{int(time.time())}"
242
243
244
class BaseCLIDriverTest(unittest.TestCase):
245
"""Base unittest that use clidriver.
246
247
This will load all the default plugins as well so it
248
will simulate the behavior the user will see.
249
"""
250
251
def setUp(self):
252
self.environ = {
253
'AWS_DATA_PATH': os.environ['AWS_DATA_PATH'],
254
'AWS_DEFAULT_REGION': 'us-east-1',
255
'AWS_ACCESS_KEY_ID': 'access_key',
256
'AWS_SECRET_ACCESS_KEY': 'secret_key',
257
'AWS_CONFIG_FILE': '',
258
}
259
self.environ_patch = mock.patch('os.environ', self.environ)
260
self.environ_patch.start()
261
self.driver = create_clidriver()
262
self.session = self.driver.session
263
264
def tearDown(self):
265
self.environ_patch.stop()
266
267
268
class BaseAWSHelpOutputTest(BaseCLIDriverTest):
269
def setUp(self):
270
super().setUp()
271
self.renderer_patch = mock.patch('awscli.help.get_renderer')
272
self.renderer_mock = self.renderer_patch.start()
273
self.renderer = CapturedRenderer()
274
self.renderer_mock.return_value = self.renderer
275
276
def tearDown(self):
277
super().tearDown()
278
self.renderer_patch.stop()
279
280
def assert_contains(self, contains):
281
if contains not in self.renderer.rendered_contents:
282
self.fail(
283
"The expected contents:\n%s\nwere not in the "
284
"actual rendered contents:\n%s"
285
% (contains, self.renderer.rendered_contents)
286
)
287
288
def assert_contains_with_count(self, contains, count):
289
r_count = self.renderer.rendered_contents.count(contains)
290
if r_count != count:
291
self.fail(
292
"The expected contents:\n%s\n, with the "
293
"count:\n%d\nwere not in the actual rendered "
294
" contents:\n%s\nwith count:\n%d"
295
% (contains, count, self.renderer.rendered_contents, r_count)
296
)
297
298
def assert_not_contains(self, contents):
299
if contents in self.renderer.rendered_contents:
300
self.fail(
301
"The contents:\n%s\nwere not suppose to be in the "
302
"actual rendered contents:\n%s"
303
% (contents, self.renderer.rendered_contents)
304
)
305
306
def assert_text_order(self, *args, **kwargs):
307
# First we need to find where the SYNOPSIS section starts.
308
starting_from = kwargs.pop('starting_from')
309
args = list(args)
310
contents = self.renderer.rendered_contents
311
self.assertIn(starting_from, contents)
312
start_index = contents.find(starting_from)
313
arg_indices = [contents.find(arg, start_index) for arg in args]
314
previous = arg_indices[0]
315
for i, index in enumerate(arg_indices[1:], 1):
316
if index == -1:
317
self.fail(
318
'The string %r was not found in the contents: %s'
319
% (args[index], contents)
320
)
321
if index < previous:
322
self.fail(
323
'The string %r came before %r, but was suppose to come '
324
'after it.\n%s' % (args[i], args[i - 1], contents)
325
)
326
previous = index
327
328
329
class CapturedRenderer:
330
def __init__(self):
331
self.rendered_contents = ''
332
333
def render(self, contents):
334
self.rendered_contents = contents.decode('utf-8')
335
336
337
class CapturedOutput:
338
def __init__(self, stdout, stderr):
339
self.stdout = stdout
340
self.stderr = stderr
341
342
343
@contextlib.contextmanager
344
def capture_output():
345
stderr = StringIO()
346
stdout = StringIO()
347
with mock.patch('sys.stderr', stderr):
348
with mock.patch('sys.stdout', stdout):
349
yield CapturedOutput(stdout, stderr)
350
351
352
@contextlib.contextmanager
353
def capture_input(input_bytes=b''):
354
input_data = BytesIO(input_bytes)
355
mock_object = mock.Mock()
356
mock_object.buffer = input_data
357
358
with mock.patch('sys.stdin', mock_object):
359
yield input_data
360
361
362
class BaseAWSCommandParamsTest(unittest.TestCase):
363
maxDiff = None
364
365
def setUp(self):
366
self.last_params = {}
367
self.last_kwargs = None
368
# awscli/__init__.py injects AWS_DATA_PATH at import time
369
# so that we can find cli.json. This might be fixed in the
370
# future, but for now we just grab that value out of the real
371
# os.environ so the patched os.environ has this data and
372
# the CLI works.
373
self.environ = {
374
'AWS_DATA_PATH': os.environ['AWS_DATA_PATH'],
375
'AWS_DEFAULT_REGION': 'us-east-1',
376
'AWS_ACCESS_KEY_ID': 'access_key',
377
'AWS_SECRET_ACCESS_KEY': 'secret_key',
378
'AWS_CONFIG_FILE': '',
379
'AWS_SHARED_CREDENTIALS_FILE': '',
380
}
381
if os.environ.get('ComSpec'):
382
self.environ['ComSpec'] = os.environ['ComSpec']
383
self.environ_patch = mock.patch('os.environ', self.environ)
384
self.environ_patch.start()
385
self.http_response = AWSResponse(None, 200, {}, None)
386
self.parsed_response = {}
387
self.make_request_patch = mock.patch(
388
'botocore.endpoint.Endpoint.make_request'
389
)
390
self.make_request_is_patched = False
391
self.operations_called = []
392
self.parsed_responses = None
393
self.driver = create_clidriver()
394
395
def tearDown(self):
396
# This clears all the previous registrations.
397
self.environ_patch.stop()
398
if self.make_request_is_patched:
399
self.make_request_patch.stop()
400
self.make_request_is_patched = False
401
402
def before_call(self, params, **kwargs):
403
self._store_params(params)
404
405
def _store_params(self, params):
406
self.last_request_dict = params
407
self.last_params = params['body']
408
409
def patch_make_request(self):
410
# If you do not stop a previously started patch,
411
# it can never be stopped if you call start() again on the same
412
# patch again...
413
# So stop the current patch before calling start() on it again.
414
if self.make_request_is_patched:
415
self.make_request_patch.stop()
416
self.make_request_is_patched = False
417
make_request_patch = self.make_request_patch.start()
418
if self.parsed_responses is not None:
419
make_request_patch.side_effect = lambda *args, **kwargs: (
420
self.http_response,
421
self.parsed_responses.pop(0),
422
)
423
else:
424
make_request_patch.return_value = (
425
self.http_response,
426
self.parsed_response,
427
)
428
self.make_request_is_patched = True
429
430
def assert_params_for_cmd(
431
self,
432
cmd,
433
params=None,
434
expected_rc=0,
435
stderr_contains=None,
436
ignore_params=None,
437
):
438
stdout, stderr, rc = self.run_cmd(cmd, expected_rc)
439
if stderr_contains is not None:
440
self.assertIn(stderr_contains, stderr)
441
if params is not None:
442
# The last kwargs of Operation.call() in botocore.
443
last_kwargs = copy.copy(self.last_kwargs)
444
if ignore_params is not None:
445
for key in ignore_params:
446
try:
447
del last_kwargs[key]
448
except KeyError:
449
pass
450
if params != last_kwargs:
451
self.fail(
452
"Actual params did not match expected params.\n"
453
"Expected:\n\n"
454
"%s\n"
455
"Actual:\n\n%s\n" % (pformat(params), pformat(last_kwargs))
456
)
457
return stdout, stderr, rc
458
459
def before_parameter_build(self, params, model, **kwargs):
460
self.last_kwargs = params
461
self.operations_called.append((model, params.copy()))
462
463
def run_cmd(self, cmd, expected_rc=0):
464
logging.debug("Calling cmd: %s", cmd)
465
self.patch_make_request()
466
event_emitter = self.driver.session.get_component('event_emitter')
467
event_emitter.register('before-call', self.before_call)
468
event_emitter.register_first(
469
'before-parameter-build.*.*', self.before_parameter_build
470
)
471
if not isinstance(cmd, list):
472
cmdlist = cmd.split()
473
else:
474
cmdlist = cmd
475
476
with capture_output() as captured:
477
try:
478
rc = self.driver.main(cmdlist)
479
except SystemExit as e:
480
# We need to catch SystemExit so that we
481
# can get a proper rc and still present the
482
# stdout/stderr to the test runner so we can
483
# figure out what went wrong.
484
rc = e.code
485
stderr = captured.stderr.getvalue()
486
stdout = captured.stdout.getvalue()
487
self.assertEqual(
488
rc,
489
expected_rc,
490
"Unexpected rc (expected: %s, actual: %s) for command: %s\n"
491
"stdout:\n%sstderr:\n%s" % (expected_rc, rc, cmd, stdout, stderr),
492
)
493
return stdout, stderr, rc
494
495
496
class BaseAWSPreviewCommandParamsTest(BaseAWSCommandParamsTest):
497
def setUp(self):
498
self.preview_patch = mock.patch(
499
'awscli.customizations.preview.mark_as_preview'
500
)
501
self.preview_patch.start()
502
super().setUp()
503
504
def tearDown(self):
505
self.preview_patch.stop()
506
super().tearDown()
507
508
509
class BaseCLIWireResponseTest(unittest.TestCase):
510
def setUp(self):
511
self.environ = {
512
'AWS_DATA_PATH': os.environ['AWS_DATA_PATH'],
513
'AWS_DEFAULT_REGION': 'us-east-1',
514
'AWS_ACCESS_KEY_ID': 'access_key',
515
'AWS_SECRET_ACCESS_KEY': 'secret_key',
516
'AWS_CONFIG_FILE': '',
517
}
518
self.environ_patch = mock.patch('os.environ', self.environ)
519
self.environ_patch.start()
520
# TODO: fix this patch when we have a better way to stub out responses
521
self.send_patch = mock.patch('botocore.endpoint.Endpoint._send')
522
self.send_is_patched = False
523
self.driver = create_clidriver()
524
525
def tearDown(self):
526
self.environ_patch.stop()
527
if self.send_is_patched:
528
self.send_patch.stop()
529
self.send_is_patched = False
530
531
def patch_send(self, status_code=200, headers={}, content=b''):
532
if self.send_is_patched:
533
self.send_patch.stop()
534
self.send_is_patched = False
535
send_patch = self.send_patch.start()
536
send_patch.return_value = mock.Mock(
537
status_code=status_code, headers=headers, content=content
538
)
539
self.send_is_patched = True
540
541
def run_cmd(self, cmd, expected_rc=0):
542
if not isinstance(cmd, list):
543
cmdlist = cmd.split()
544
else:
545
cmdlist = cmd
546
with capture_output() as captured:
547
try:
548
rc = self.driver.main(cmdlist)
549
except SystemExit as e:
550
rc = e.code
551
stderr = captured.stderr.getvalue()
552
stdout = captured.stdout.getvalue()
553
self.assertEqual(
554
rc,
555
expected_rc,
556
"Unexpected rc (expected: %s, actual: %s) for command: %s\n"
557
"stdout:\n%sstderr:\n%s" % (expected_rc, rc, cmd, stdout, stderr),
558
)
559
return stdout, stderr, rc
560
561
562
class FileCreator:
563
def __init__(self):
564
self.rootdir = tempfile.mkdtemp()
565
566
def remove_all(self):
567
if os.path.exists(self.rootdir):
568
shutil.rmtree(self.rootdir)
569
570
def create_file(self, filename, contents, mtime=None, mode='w'):
571
"""Creates a file in a tmpdir
572
573
``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
574
It will be translated into a full path in a tmp dir.
575
576
If the ``mtime`` argument is provided, then the file's
577
mtime will be set to the provided value (must be an epoch time).
578
Otherwise the mtime is left untouched.
579
580
``mode`` is the mode the file should be opened either as ``w`` or
581
`wb``.
582
583
Returns the full path to the file.
584
585
"""
586
full_path = os.path.join(self.rootdir, filename)
587
if not os.path.isdir(os.path.dirname(full_path)):
588
os.makedirs(os.path.dirname(full_path))
589
with open(full_path, mode) as f:
590
f.write(contents)
591
current_time = os.path.getmtime(full_path)
592
# Subtract a few years off the last modification date.
593
os.utime(full_path, (current_time, current_time - 100000000))
594
if mtime is not None:
595
os.utime(full_path, (mtime, mtime))
596
return full_path
597
598
def append_file(self, filename, contents):
599
"""Append contents to a file
600
601
``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
602
It will be translated into a full path in a tmp dir.
603
604
Returns the full path to the file.
605
"""
606
full_path = os.path.join(self.rootdir, filename)
607
if not os.path.isdir(os.path.dirname(full_path)):
608
os.makedirs(os.path.dirname(full_path))
609
with open(full_path, 'a') as f:
610
f.write(contents)
611
return full_path
612
613
def full_path(self, filename):
614
"""Translate relative path to full path in temp dir.
615
616
f.full_path('foo/bar.txt') -> /tmp/asdfasd/foo/bar.txt
617
"""
618
return os.path.join(self.rootdir, filename)
619
620
621
class ProcessTerminatedError(Exception):
622
pass
623
624
625
class Result:
626
def __init__(self, rc, stdout, stderr, memory_usage=None):
627
self.rc = rc
628
self.stdout = stdout
629
self.stderr = stderr
630
INTEG_LOG.debug("rc: %s", rc)
631
INTEG_LOG.debug("stdout: %s", stdout)
632
INTEG_LOG.debug("stderr: %s", stderr)
633
if memory_usage is None:
634
memory_usage = []
635
self.memory_usage = memory_usage
636
637
@property
638
def json(self):
639
return json.loads(self.stdout)
640
641
642
def _escape_quotes(command):
643
# For windows we have different rules for escaping.
644
# First, double quotes must be escaped.
645
command = command.replace('"', '\\"')
646
# Second, single quotes do nothing, to quote a value we need
647
# to use double quotes.
648
command = command.replace("'", '"')
649
return command
650
651
652
def aws(
653
command,
654
collect_memory=False,
655
env_vars=None,
656
wait_for_finish=True,
657
input_data=None,
658
input_file=None,
659
):
660
"""Run an aws command.
661
662
This help function abstracts the differences of running the "aws"
663
command on different platforms.
664
665
If collect_memory is ``True`` the the Result object will have a list
666
of memory usage taken at 2 second intervals. The memory usage
667
will be in bytes.
668
669
If env_vars is None, this will set the environment variables
670
to be used by the aws process.
671
672
If wait_for_finish is False, then the Process object is returned
673
to the caller. It is then the caller's responsibility to ensure
674
proper cleanup. This can be useful if you want to test timeout's
675
or how the CLI responds to various signals.
676
677
:type input_data: string
678
:param input_data: This string will be communicated to the process through
679
the stdin of the process. It essentially allows the user to
680
avoid having to use a file handle to pass information to the process.
681
Note that this string is not passed on creation of the process, but
682
rather communicated to the process.
683
684
:type input_file: a file handle
685
:param input_file: This is a file handle that will act as the
686
the stdin of the process immediately on creation. Essentially
687
any data written to the file will be read from stdin of the
688
process. This is needed if you plan to stream data into stdin while
689
collecting memory.
690
"""
691
if platform.system() == 'Windows':
692
command = _escape_quotes(command)
693
if 'AWS_TEST_COMMAND' in os.environ:
694
aws_command = os.environ['AWS_TEST_COMMAND']
695
else:
696
aws_command = 'python %s' % get_aws_cmd()
697
full_command = '%s %s' % (aws_command, command)
698
stdout_encoding = get_stdout_encoding()
699
INTEG_LOG.debug("Running command: %s", full_command)
700
env = os.environ.copy()
701
if 'AWS_DEFAULT_REGION' not in env:
702
env['AWS_DEFAULT_REGION'] = "us-east-1"
703
if env_vars is not None:
704
env = env_vars
705
if input_file is None:
706
input_file = PIPE
707
process = Popen(
708
full_command,
709
stdout=PIPE,
710
stderr=PIPE,
711
stdin=input_file,
712
shell=True,
713
env=env,
714
)
715
if not wait_for_finish:
716
return process
717
memory = None
718
if not collect_memory:
719
kwargs = {}
720
if input_data:
721
kwargs = {'input': input_data}
722
stdout, stderr = process.communicate(**kwargs)
723
else:
724
stdout, stderr, memory = _wait_and_collect_mem(process)
725
return Result(
726
process.returncode,
727
stdout.decode(stdout_encoding),
728
stderr.decode(stdout_encoding),
729
memory,
730
)
731
732
733
def get_stdout_encoding():
734
encoding = getattr(sys.__stdout__, 'encoding', None)
735
if encoding is None:
736
encoding = 'utf-8'
737
return encoding
738
739
740
def _wait_and_collect_mem(process):
741
# We only know how to collect memory on mac/linux.
742
if platform.system() == 'Darwin':
743
get_memory = _get_memory_with_ps
744
elif platform.system() == 'Linux':
745
get_memory = _get_memory_with_ps
746
else:
747
raise ValueError(
748
f"Can't collect memory for process on platform {platform.system()}."
749
)
750
memory = []
751
while process.poll() is None:
752
try:
753
current = get_memory(process.pid)
754
except ProcessTerminatedError:
755
# It's possible the process terminated between .poll()
756
# and get_memory().
757
break
758
memory.append(current)
759
stdout, stderr = process.communicate()
760
return stdout, stderr, memory
761
762
763
def _get_memory_with_ps(pid):
764
# It's probably possible to do with proc_pidinfo and ctypes on a Mac,
765
# but we'll do it the easy way with parsing ps output.
766
command_list = 'ps u -p'.split()
767
command_list.append(str(pid))
768
p = Popen(command_list, stdout=PIPE)
769
stdout = p.communicate()[0]
770
if not p.returncode == 0:
771
raise ProcessTerminatedError(str(pid))
772
else:
773
# Get the RSS from output that looks like this:
774
# USER PID %CPU %MEM VSZ RSS TT STAT STARTED TIME COMMAND
775
# user 47102 0.0 0.1 2437000 4496 s002 S+ 7:04PM 0:00.12 python2.6
776
return int(stdout.splitlines()[1].split()[5]) * 1024
777
778
779
class BaseS3CLICommand(unittest.TestCase):
780
"""Base class for aws s3 command.
781
782
This contains convenience functions to make writing these tests easier
783
and more streamlined.
784
785
"""
786
787
_PUT_HEAD_SHARED_EXTRAS = [
788
'SSECustomerAlgorithm',
789
'SSECustomerKey',
790
'SSECustomerKeyMD5',
791
'RequestPayer',
792
]
793
794
def setUp(self):
795
self.files = FileCreator()
796
self.session = botocore.session.get_session()
797
self.regions = {}
798
self.region = 'us-west-2'
799
self.client = create_nested_client(self.session, 's3', region_name=self.region)
800
self.extra_setup()
801
802
def extra_setup(self):
803
# Subclasses can use this to define extra setup steps.
804
pass
805
806
def tearDown(self):
807
self.files.remove_all()
808
self.extra_teardown()
809
810
def extra_teardown(self):
811
# Subclasses can use this to define extra teardown steps.
812
pass
813
814
def override_parser(self, **kwargs):
815
factory = self.session.get_component('response_parser_factory')
816
factory.set_parser_defaults(**kwargs)
817
818
def create_client_for_bucket(self, bucket_name):
819
region = self.regions.get(bucket_name, self.region)
820
client = create_nested_client(self.session, 's3', region_name=region)
821
return client
822
823
def assert_key_contents_equal(self, bucket, key, expected_contents):
824
self.wait_until_key_exists(bucket, key)
825
if isinstance(expected_contents, BytesIO):
826
expected_contents = expected_contents.getvalue().decode('utf-8')
827
actual_contents = self.get_key_contents(bucket, key)
828
# The contents can be huge so we try to give helpful error messages
829
# without necessarily printing the actual contents.
830
self.assertEqual(len(actual_contents), len(expected_contents))
831
if actual_contents != expected_contents:
832
self.fail(
833
f"Contents for {bucket}/{key} do not match (but they "
834
"have the same length)"
835
)
836
837
def delete_public_access_block(self, bucket_name):
838
client = self.create_client_for_bucket(bucket_name)
839
client.delete_public_access_block(Bucket=bucket_name)
840
841
def create_bucket(self, name=None, region=None):
842
if not region:
843
region = self.region
844
bucket_name = create_bucket(self.session, name, region)
845
self.regions[bucket_name] = region
846
self.addCleanup(self.delete_bucket, bucket_name)
847
848
# Wait for the bucket to exist before letting it be used.
849
self.wait_bucket_exists(bucket_name)
850
self.delete_public_access_block(bucket_name)
851
return bucket_name
852
853
def create_dir_bucket(self, name=None, location=None):
854
if location:
855
region, _ = location
856
else:
857
region = self.region
858
bucket_name = create_dir_bucket(self.session, name, location)
859
self.regions[bucket_name] = region
860
self.addCleanup(self.delete_bucket, bucket_name)
861
862
# Wait for the bucket to exist before letting it be used.
863
self.wait_bucket_exists(bucket_name)
864
return bucket_name
865
866
def put_object(self, bucket_name, key_name, contents='', extra_args=None):
867
client = self.create_client_for_bucket(bucket_name)
868
call_args = {'Bucket': bucket_name, 'Key': key_name, 'Body': contents}
869
if extra_args is not None:
870
call_args.update(extra_args)
871
response = client.put_object(**call_args)
872
self.addCleanup(self.delete_key, bucket_name, key_name)
873
extra_head_params = {}
874
if extra_args:
875
extra_head_params = dict(
876
(k, v)
877
for (k, v) in extra_args.items()
878
if k in self._PUT_HEAD_SHARED_EXTRAS
879
)
880
self.wait_until_key_exists(
881
bucket_name,
882
key_name,
883
extra_params=extra_head_params,
884
)
885
return response
886
887
def delete_bucket(self, bucket_name, attempts=5, delay=5):
888
self.remove_all_objects(bucket_name)
889
client = self.create_client_for_bucket(bucket_name)
890
891
# There's a chance that, even though the bucket has been used
892
# several times, the delete will fail due to eventual consistency
893
# issues.
894
attempts_remaining = attempts
895
while True:
896
attempts_remaining -= 1
897
try:
898
client.delete_bucket(Bucket=bucket_name)
899
break
900
except client.exceptions.NoSuchBucket:
901
if self.bucket_not_exists(bucket_name):
902
# Fast fail when the NoSuchBucket error is real.
903
break
904
if attempts_remaining <= 0:
905
raise
906
time.sleep(delay)
907
908
self.regions.pop(bucket_name, None)
909
910
def remove_all_objects(self, bucket_name):
911
client = self.create_client_for_bucket(bucket_name)
912
paginator = client.get_paginator('list_objects_v2')
913
pages = paginator.paginate(Bucket=bucket_name)
914
key_names = []
915
for page in pages:
916
key_names += [obj['Key'] for obj in page.get('Contents', [])]
917
for key_name in key_names:
918
self.delete_key(bucket_name, key_name)
919
920
def delete_key(self, bucket_name, key_name):
921
client = self.create_client_for_bucket(bucket_name)
922
response = client.delete_object(Bucket=bucket_name, Key=key_name)
923
924
def get_key_contents(self, bucket_name, key_name):
925
self.wait_until_key_exists(bucket_name, key_name)
926
client = self.create_client_for_bucket(bucket_name)
927
response = client.get_object(Bucket=bucket_name, Key=key_name)
928
return response['Body'].read().decode('utf-8')
929
930
def wait_bucket_exists(self, bucket_name, min_successes=3):
931
client = self.create_client_for_bucket(bucket_name)
932
waiter = client.get_waiter('bucket_exists')
933
consistency_waiter = ConsistencyWaiter(
934
min_successes=min_successes, delay_initial_poll=True
935
)
936
consistency_waiter.wait(
937
lambda: waiter.wait(Bucket=bucket_name) is None
938
)
939
940
def bucket_not_exists(self, bucket_name):
941
client = self.create_client_for_bucket(bucket_name)
942
try:
943
client.head_bucket(Bucket=bucket_name)
944
return True
945
except ClientError as error:
946
if error.response.get('Code') == '404':
947
return False
948
raise
949
950
def key_exists(self, bucket_name, key_name, min_successes=3):
951
try:
952
self.wait_until_key_exists(
953
bucket_name, key_name, min_successes=min_successes
954
)
955
return True
956
except (ClientError, WaiterError):
957
return False
958
959
def key_not_exists(self, bucket_name, key_name, min_successes=3):
960
try:
961
self.wait_until_key_not_exists(
962
bucket_name, key_name, min_successes=min_successes
963
)
964
return True
965
except (ClientError, WaiterError):
966
return False
967
968
def list_buckets(self):
969
response = self.client.list_buckets()
970
return response['Buckets']
971
972
def content_type_for_key(self, bucket_name, key_name):
973
parsed = self.head_object(bucket_name, key_name)
974
return parsed['ContentType']
975
976
def head_object(self, bucket_name, key_name):
977
client = self.create_client_for_bucket(bucket_name)
978
response = client.head_object(Bucket=bucket_name, Key=key_name)
979
return response
980
981
def wait_until_key_exists(
982
self, bucket_name, key_name, extra_params=None, min_successes=3
983
):
984
self._wait_for_key(
985
bucket_name, key_name, extra_params, min_successes, exists=True
986
)
987
988
def wait_until_key_not_exists(
989
self, bucket_name, key_name, extra_params=None, min_successes=3
990
):
991
self._wait_for_key(
992
bucket_name, key_name, extra_params, min_successes, exists=False
993
)
994
995
def _wait_for_key(
996
self,
997
bucket_name,
998
key_name,
999
extra_params=None,
1000
min_successes=3,
1001
exists=True,
1002
):
1003
client = self.create_client_for_bucket(bucket_name)
1004
if exists:
1005
waiter = client.get_waiter('object_exists')
1006
else:
1007
waiter = client.get_waiter('object_not_exists')
1008
params = {'Bucket': bucket_name, 'Key': key_name}
1009
if extra_params is not None:
1010
params.update(extra_params)
1011
for _ in range(min_successes):
1012
waiter.wait(**params)
1013
1014
def assert_no_errors(self, p):
1015
self.assertEqual(
1016
p.rc,
1017
0,
1018
"Non zero rc (%s) received: %s" % (p.rc, p.stdout + p.stderr),
1019
)
1020
self.assertNotIn("Error:", p.stderr)
1021
self.assertNotIn("failed:", p.stderr)
1022
self.assertNotIn("client error", p.stderr)
1023
self.assertNotIn("server error", p.stderr)
1024
1025
1026
class StringIOWithFileNo(StringIO):
1027
def fileno(self):
1028
return 0
1029
1030
1031
class TestEventHandler:
1032
def __init__(self, handler=None):
1033
self._handler = handler
1034
self._called = False
1035
self.__test__ = False
1036
1037
@property
1038
def called(self):
1039
return self._called
1040
1041
def handler(self, **kwargs):
1042
self._called = True
1043
if self._handler is not None:
1044
self._handler(**kwargs)
1045
1046
1047
class ConsistencyWaiterException(Exception):
1048
pass
1049
1050
1051
class ConsistencyWaiter:
1052
"""
1053
A waiter class for some check to reach a consistent state.
1054
1055
:type min_successes: int
1056
:param min_successes: The minimum number of successful check calls to
1057
treat the check as stable. Default of 1 success.
1058
1059
:type max_attempts: int
1060
:param min_successes: The maximum number of times to attempt calling
1061
the check. Default of 20 attempts.
1062
1063
:type delay: int
1064
:param delay: The number of seconds to delay the next API call after a
1065
failed check call. Default of 5 seconds.
1066
"""
1067
1068
def __init__(
1069
self,
1070
min_successes=1,
1071
max_attempts=20,
1072
delay=5,
1073
delay_initial_poll=False,
1074
):
1075
self.min_successes = min_successes
1076
self.max_attempts = max_attempts
1077
self.delay = delay
1078
self.delay_initial_poll = delay_initial_poll
1079
1080
def wait(self, check, *args, **kwargs):
1081
"""
1082
Wait until the check succeeds the configured number of times
1083
1084
:type check: callable
1085
:param check: A callable that returns True or False to indicate
1086
if the check succeeded or failed.
1087
1088
:type args: list
1089
:param args: Any ordered arguments to be passed to the check.
1090
1091
:type kwargs: dict
1092
:param kwargs: Any keyword arguments to be passed to the check.
1093
"""
1094
attempts = 0
1095
successes = 0
1096
if self.delay_initial_poll:
1097
time.sleep(self.delay)
1098
while attempts < self.max_attempts:
1099
attempts += 1
1100
if check(*args, **kwargs):
1101
successes += 1
1102
if successes >= self.min_successes:
1103
return
1104
else:
1105
time.sleep(self.delay)
1106
fail_msg = self._fail_message(attempts, successes)
1107
raise ConsistencyWaiterException(fail_msg)
1108
1109
def _fail_message(self, attempts, successes):
1110
format_args = (attempts, successes)
1111
return 'Failed after %s attempts, only had %s successes' % format_args
1112
1113
1114
@contextlib.contextmanager
1115
def cd(path):
1116
try:
1117
original_dir = os.getcwd()
1118
os.chdir(path)
1119
yield
1120
finally:
1121
os.chdir(original_dir)
1122
1123