Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aws
GitHub Repository: aws/aws-cli
Path: blob/develop/tests/unit/customizations/cloudtrail/test_validation.py
1569 views
1
# Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the 'License'). You
4
# may not use this file except in compliance with the License. A copy of
5
# the License is located at
6
#
7
# http://aws.amazon.com/apache2.0/
8
#
9
# or in the 'license' file accompanying this file. This file is
10
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
# ANY KIND, either express or implied. See the License for the specific
12
# language governing permissions and limitations under the License.
13
import binascii
14
import base64
15
import hashlib
16
import json
17
import gzip
18
from datetime import datetime, timedelta
19
from dateutil import parser, tz
20
21
import rsa
22
from argparse import Namespace
23
24
from awscli.testutils import BaseAWSCommandParamsTest
25
from awscli.customizations.cloudtrail.validation import DigestError, \
26
extract_digest_key_date, normalize_date, format_date, DigestProvider, \
27
DigestTraverser, create_digest_traverser, PublicKeyProvider, \
28
Sha256RSADigestValidator, DATE_FORMAT, CloudTrailValidateLogs, \
29
parse_date, assert_cloudtrail_arn_is_valid, DigestSignatureError, \
30
InvalidDigestFormat, S3ClientProvider
31
from awscli.compat import BytesIO
32
from botocore.exceptions import ClientError
33
from awscli.testutils import mock, unittest
34
from awscli.schema import ParameterRequiredError
35
36
37
START_DATE = parser.parse('20140810T000000Z')
38
END_DATE = parser.parse('20150810T000000Z')
39
TEST_ACCOUNT_ID = '123456789012'
40
TEST_TRAIL_ARN = 'arn:aws:cloudtrail:us-east-1:%s:trail/foo' % TEST_ACCOUNT_ID
41
VALID_TEST_KEY = ('MIIBCgKCAQEAn11L2YZ9h7onug2ILi1MWyHiMRsTQjfWE+pHVRLk1QjfW'
42
'hirG+lpOa8NrwQ/r7Ah5bNL6HepznOU9XTDSfmmnP97mqyc7z/upfZdS/'
43
'AHhYcGaz7n6Wc/RRBU6VmiPCrAUojuSk6/GjvA8iOPFsYDuBtviXarvuL'
44
'PlrT9kAd4Lb+rFfR5peEgBEkhlzc5HuWO7S0y+KunqxX6jQBnXGMtxmPB'
45
'PP0FylgWGNdFtks/4YSKcgqwH0YDcawP9GGGDAeCIqPWIXDLG1jOjRRzW'
46
'fCmD0iJUkz8vTsn4hq/5ZxRFE7UBAUiVcGbdnDdvVfhF9C3dQiDq3k7ad'
47
'QIziLT0cShgQIDAQAB')
48
TEST_ORGANIZATION_ACCOUNT_ID = '987654321098'
49
TEST_ORGANIZATION_ID = 'o-12345'
50
51
52
def create_mock_key_provider(key_list):
53
"""Creates a mock key provider that yields keys for each in key_list"""
54
public_keys = {}
55
for k in key_list:
56
public_keys[k] = {'Fingerprint': k,
57
'Value': 'ffaa00'}
58
key_provider = mock.Mock()
59
key_provider.get_public_keys.return_value = public_keys
60
return key_provider
61
62
63
def create_scenario(actions, logs=None):
64
"""Creates a scenario for a stack of actions
65
66
Each action can be "gap" meaning there is no previous link, "invalid"
67
meaning we should simulate an invalid digest, "missing" meaning we
68
should simulate a digest is missing from S3, "bucket_change" meaning
69
it is a link but the bucket is different than the previous bucket.
70
Values are popped one by one off of the list until a terminal "gap"
71
action is found.
72
"""
73
keys = [str(i) for i in range(len(actions))]
74
key_provider = create_mock_key_provider(keys)
75
digest_provider = MockDigestProvider(actions, logs)
76
digest_validator = mock.Mock()
77
78
def validate(bucket, key, public_key, digest_data, digest_str):
79
if '_invalid' in digest_data:
80
raise DigestError('invalid error')
81
82
digest_validator.validate = validate
83
return key_provider, digest_provider, digest_validator
84
85
86
def collecting_callback():
87
"""Create and return a callback and a list populated with call args"""
88
calls = []
89
90
def cb(**kwargs):
91
calls.append(kwargs)
92
93
return cb, calls
94
95
96
class MockDigestProvider(object):
97
def __init__(self, actions, logs=None):
98
self.logs = logs or []
99
self.actions = actions
100
self.calls = {'fetch_digest': [], 'load_digest_keys_in_range': []}
101
self.digests = []
102
for i in range(len(self.actions)):
103
self.digests.append(self.get_key_at_position(i))
104
105
def get_key_at_position(self, position):
106
dt = START_DATE + timedelta(hours=position)
107
key = ('AWSLogs/{account}/CloudTrail-Digest/us-east-1/{ymd}/{account}_'
108
'CloudTrail-Digest_us-east-1_foo_us-east-1_{date}.json.gz')
109
return key.format(
110
account=TEST_ACCOUNT_ID,
111
ymd=dt.strftime('%Y/%m/%d'),
112
date=dt.strftime(DATE_FORMAT))
113
114
@staticmethod
115
def create_digest(fingerprint, start_date, key, bucket, next_key=None,
116
next_bucket=None, logs=None):
117
digest_end_date = start_date + timedelta(hours=1, minutes=30)
118
return {'digestPublicKeyFingerprint': fingerprint,
119
'digestEndTime': digest_end_date.strftime(DATE_FORMAT),
120
'digestStartTime': start_date.strftime(DATE_FORMAT),
121
'previousDigestS3Bucket': next_bucket,
122
'previousDigestS3Object': next_key,
123
'digestS3Bucket': bucket,
124
'digestS3Object': key,
125
'awsAccountId': TEST_ACCOUNT_ID,
126
'previousDigestSignature': 'abcd',
127
'logFiles': logs or []}
128
129
@staticmethod
130
def create_link(key, next_key, next_bucket, position, action, logs,
131
bucket):
132
"""Creates a link in a digest chain for testing."""
133
digest_logs = []
134
if len(logs) > position:
135
digest_logs = logs[position]
136
end_date = parse_date(extract_digest_key_date(key))
137
# gap actions have no previous link.
138
if action == 'gap':
139
digest = MockDigestProvider.create_digest(
140
key=key, bucket=bucket, fingerprint=str(position),
141
start_date=end_date, logs=digest_logs)
142
else:
143
digest = MockDigestProvider.create_digest(
144
key=key, bucket=bucket, fingerprint=str(position),
145
start_date=end_date, next_bucket=next_bucket, next_key=next_key,
146
logs=digest_logs)
147
# Mark the digest as invalid if specified in the action.
148
if action == 'invalid':
149
digest['_invalid'] = True
150
return digest, json.dumps(digest)
151
152
def load_digest_keys_in_range(self, bucket, prefix, start_date, end_date):
153
self.calls['load_digest_keys_in_range'].append(locals())
154
return list(self.digests)
155
156
def fetch_digest(self, bucket, key):
157
self.calls['fetch_digest'].append(key)
158
position = self.digests.index(key)
159
action = self.actions[position]
160
# Simulate a digest missing from S3
161
if action == 'missing':
162
raise ClientError(
163
{'Error': {'Code': 'NoSuchKey', 'Message': 'foo'}},
164
'GetObject')
165
next_key = self.get_key_at_position(position - 1)
166
next_bucket = int(bucket)
167
if action == 'bucket_change':
168
next_bucket += 1
169
return self.create_link(key, next_key, str(next_bucket), position,
170
action, self.logs, bucket)
171
172
173
class TestValidation(unittest.TestCase):
174
def test_formats_dates(self):
175
date = datetime(2015, 8, 21, tzinfo=tz.tzutc())
176
self.assertEqual('20150821T000000Z', format_date(date))
177
178
def test_parses_dates_with_better_error_message(self):
179
try:
180
parse_date('foo')
181
self.fail('Should have failed to parse')
182
except ValueError as e:
183
self.assertIn('Unable to parse date value: foo', str(e))
184
185
def test_parses_dates(self):
186
date = parse_date('August 25, 2015 00:00:00 UTC')
187
self.assertEqual(date, datetime(2015, 8, 25, tzinfo=tz.tzutc()))
188
189
def test_ensures_cloudtrail_arns_are_valid(self):
190
try:
191
assert_cloudtrail_arn_is_valid('foo:bar:baz')
192
self.fail('Should have failed')
193
except ValueError as e:
194
self.assertIn('Invalid trail ARN provided: foo:bar:baz', str(e))
195
196
def test_ensures_cloudtrail_arns_are_valid_when_missing_resource(self):
197
try:
198
assert_cloudtrail_arn_is_valid(
199
'arn:aws:cloudtrail:us-east-1:%s:foo' % TEST_ACCOUNT_ID)
200
self.fail('Should have failed')
201
except ValueError as e:
202
self.assertIn('Invalid trail ARN provided', str(e))
203
204
def test_allows_valid_arns(self):
205
assert_cloudtrail_arn_is_valid(
206
'arn:aws:cloudtrail:us-east-1:%s:trail/foo' % TEST_ACCOUNT_ID)
207
208
def test_normalizes_date_timezones(self):
209
date = datetime(2015, 8, 21, tzinfo=tz.tzlocal())
210
normalized = normalize_date(date)
211
self.assertEqual(tz.tzutc(), normalized.tzinfo)
212
213
def test_extracts_dates_from_digest_keys(self):
214
arn = ('AWSLogs/{account}/CloudTrail-Digest/us-east-1/2015/08/'
215
'16/{account}_CloudTrail-Digest_us-east-1_foo_us-east-1_'
216
'20150816T230550Z.json.gz').format(account=TEST_ACCOUNT_ID)
217
self.assertEqual('20150816T230550Z', extract_digest_key_date(arn))
218
219
def test_creates_traverser(self):
220
mock_s3_provider = mock.Mock()
221
traverser = create_digest_traverser(
222
trail_arn=TEST_TRAIL_ARN, cloudtrail_client=mock.Mock(),
223
organization_client=mock.Mock(),
224
trail_source_region='us-east-1',
225
s3_client_provider=mock_s3_provider,
226
bucket='bucket', prefix='prefix')
227
self.assertEqual('bucket', traverser.starting_bucket)
228
self.assertEqual('prefix', traverser.starting_prefix)
229
digest_provider = traverser.digest_provider
230
self.assertEqual('us-east-1', digest_provider.trail_home_region)
231
self.assertEqual('foo', digest_provider.trail_name)
232
233
def test_creates_traverser_account_id(self):
234
mock_s3_provider = mock.Mock()
235
traverser = create_digest_traverser(
236
trail_arn=TEST_TRAIL_ARN, cloudtrail_client=mock.Mock(),
237
organization_client=mock.Mock(),
238
trail_source_region='us-east-1',
239
s3_client_provider=mock_s3_provider,
240
bucket='bucket', prefix='prefix',
241
account_id=TEST_ORGANIZATION_ACCOUNT_ID)
242
self.assertEqual('bucket', traverser.starting_bucket)
243
self.assertEqual('prefix', traverser.starting_prefix)
244
digest_provider = traverser.digest_provider
245
self.assertEqual('us-east-1', digest_provider.trail_home_region)
246
self.assertEqual('foo', digest_provider.trail_name)
247
self.assertEqual(
248
TEST_ORGANIZATION_ACCOUNT_ID, digest_provider.account_id)
249
250
def test_creates_traverser_and_gets_trail_by_arn(self):
251
cloudtrail_client = mock.Mock()
252
cloudtrail_client.describe_trails.return_value = {'trailList': [
253
{'TrailARN': TEST_TRAIL_ARN,
254
'S3BucketName': 'bucket', 'S3KeyPrefix': 'prefix',
255
'IsOrganizationTrail': False}
256
]}
257
traverser = create_digest_traverser(
258
trail_arn=TEST_TRAIL_ARN, trail_source_region='us-east-1',
259
cloudtrail_client=cloudtrail_client,
260
organization_client=mock.Mock(),
261
s3_client_provider=mock.Mock())
262
self.assertEqual('bucket', traverser.starting_bucket)
263
self.assertEqual('prefix', traverser.starting_prefix)
264
digest_provider = traverser.digest_provider
265
self.assertEqual('us-east-1', digest_provider.trail_home_region)
266
self.assertEqual('foo', digest_provider.trail_name)
267
self.assertEqual(TEST_ACCOUNT_ID, digest_provider.account_id)
268
269
def test_create_traverser_organizational_trail_not_launched(self):
270
cloudtrail_client = mock.Mock()
271
cloudtrail_client.describe_trails.return_value = {'trailList': [
272
{'TrailARN': TEST_TRAIL_ARN,
273
'S3BucketName': 'bucket', 'S3KeyPrefix': 'prefix'}
274
]}
275
traverser = create_digest_traverser(
276
trail_arn=TEST_TRAIL_ARN, trail_source_region='us-east-1',
277
cloudtrail_client=cloudtrail_client,
278
organization_client=mock.Mock(),
279
s3_client_provider=mock.Mock())
280
self.assertEqual('bucket', traverser.starting_bucket)
281
self.assertEqual('prefix', traverser.starting_prefix)
282
digest_provider = traverser.digest_provider
283
self.assertEqual('us-east-1', digest_provider.trail_home_region)
284
self.assertEqual('foo', digest_provider.trail_name)
285
self.assertEqual(TEST_ACCOUNT_ID, digest_provider.account_id)
286
287
def test_creates_traverser_and_gets_trail_by_arn_s3_bucket_specified(self):
288
cloudtrail_client = mock.Mock()
289
traverser = create_digest_traverser(
290
trail_arn=TEST_TRAIL_ARN, trail_source_region='us-east-1',
291
cloudtrail_client=cloudtrail_client,
292
organization_client=mock.Mock(),
293
s3_client_provider=mock.Mock(),
294
bucket="bucket")
295
self.assertEqual('bucket', traverser.starting_bucket)
296
digest_provider = traverser.digest_provider
297
self.assertEqual('us-east-1', digest_provider.trail_home_region)
298
self.assertEqual('foo', digest_provider.trail_name)
299
self.assertEqual(TEST_ACCOUNT_ID, digest_provider.account_id)
300
301
def test_creates_traverser_and_gets_organization_id(self):
302
cloudtrail_client = mock.Mock()
303
cloudtrail_client.describe_trails.return_value = {'trailList': [
304
{'TrailARN': TEST_TRAIL_ARN,
305
'S3BucketName': 'bucket', 'S3KeyPrefix': 'prefix',
306
'IsOrganizationTrail': True}
307
]}
308
organization_client = mock.Mock()
309
organization_client.describe_organization.return_value = {
310
"Organization": {
311
"MasterAccountId": TEST_ACCOUNT_ID,
312
"Id": TEST_ORGANIZATION_ID,
313
}
314
}
315
traverser = create_digest_traverser(
316
trail_arn=TEST_TRAIL_ARN, trail_source_region='us-east-1',
317
cloudtrail_client=cloudtrail_client,
318
organization_client=organization_client,
319
s3_client_provider=mock.Mock(), account_id=TEST_ACCOUNT_ID)
320
self.assertEqual('bucket', traverser.starting_bucket)
321
self.assertEqual('prefix', traverser.starting_prefix)
322
digest_provider = traverser.digest_provider
323
self.assertEqual('us-east-1', digest_provider.trail_home_region)
324
self.assertEqual('foo', digest_provider.trail_name)
325
self.assertEqual(TEST_ORGANIZATION_ID, digest_provider.organization_id)
326
327
def test_creates_traverser_organization_trail_missing_account_id(self):
328
cloudtrail_client = mock.Mock()
329
cloudtrail_client.describe_trails.return_value = {'trailList': [
330
{'TrailARN': TEST_TRAIL_ARN,
331
'S3BucketName': 'bucket', 'S3KeyPrefix': 'prefix',
332
'IsOrganizationTrail': True}
333
]}
334
organization_client = mock.Mock()
335
organization_client.describe_organization.return_value = {
336
"Organization": {
337
"MasterAccountId": TEST_ACCOUNT_ID,
338
"Id": TEST_ORGANIZATION_ID,
339
}
340
}
341
with self.assertRaises(ParameterRequiredError):
342
create_digest_traverser(
343
trail_arn=TEST_TRAIL_ARN, trail_source_region='us-east-1',
344
cloudtrail_client=cloudtrail_client,
345
organization_client=organization_client,
346
s3_client_provider=mock.Mock())
347
348
349
class TestPublicKeyProvider(unittest.TestCase):
350
def test_returns_public_key_in_range(self):
351
cloudtrail_client = mock.Mock()
352
cloudtrail_client.list_public_keys.return_value = {'PublicKeyList': [
353
{'Fingerprint': 'a', 'OtherData': 'a', 'Value': 'a'},
354
{'Fingerprint': 'b', 'OtherData': 'b', 'Value': 'b'},
355
{'Fingerprint': 'c', 'OtherData': 'c', 'Value': 'c'},
356
]}
357
provider = PublicKeyProvider(cloudtrail_client)
358
start_date = START_DATE
359
end_date = start_date + timedelta(days=2)
360
keys = provider.get_public_keys(start_date, end_date)
361
self.assertEqual({
362
'a': {'Fingerprint': 'a', 'OtherData': 'a', 'Value': 'a'},
363
'b': {'Fingerprint': 'b', 'OtherData': 'b', 'Value': 'b'},
364
'c': {'Fingerprint': 'c', 'OtherData': 'c', 'Value': 'c'},
365
}, keys)
366
cloudtrail_client.list_public_keys.assert_has_calls(
367
[mock.call(EndTime=end_date, StartTime=start_date)])
368
369
370
class TestSha256RSADigestValidator(unittest.TestCase):
371
def setUp(self):
372
self._digest_data = {'digestStartTime': 'baz',
373
'digestEndTime': 'foo',
374
'awsAccountId': 'account',
375
'digestPublicKeyFingerprint': 'abc',
376
'digestS3Bucket': 'bucket',
377
'digestS3Object': 'object',
378
'previousDigestSignature': 'xyz'}
379
self._inflated_digest = json.dumps(self._digest_data).encode()
380
self._digest_data['_signature'] = 'aeff'
381
382
def test_validates_digests(self):
383
(public_key, private_key) = rsa.newkeys(512)
384
sha256_hash = hashlib.sha256(self._inflated_digest)
385
string_to_sign = "%s\n%s/%s\n%s\n%s" % (
386
self._digest_data['digestEndTime'],
387
self._digest_data['digestS3Bucket'],
388
self._digest_data['digestS3Object'],
389
sha256_hash.hexdigest(),
390
self._digest_data['previousDigestSignature'])
391
signature = rsa.sign(string_to_sign.encode(), private_key, 'SHA-256')
392
self._digest_data['_signature'] = binascii.hexlify(signature)
393
validator = Sha256RSADigestValidator()
394
public_key_b64 = base64.b64encode(public_key.save_pkcs1(format='DER'))
395
validator.validate('b', 'k', public_key_b64, self._digest_data,
396
self._inflated_digest)
397
398
def test_does_not_expose_underlying_key_decoding_error(self):
399
validator = Sha256RSADigestValidator()
400
try:
401
validator.validate(
402
'b', 'k', 'YQo=', self._digest_data, 'invalid'.encode())
403
self.fail('Should have failed')
404
except DigestError as e:
405
self.assertEqual(('Digest file\ts3://b/k\tINVALID: Unable to load '
406
'PKCS #1 key with fingerprint abc'), str(e))
407
408
def test_does_not_expose_underlying_validation_error(self):
409
validator = Sha256RSADigestValidator()
410
try:
411
validator.validate(
412
'b', 'k', VALID_TEST_KEY, self._digest_data,
413
'invalid'.encode())
414
self.fail('Should have failed')
415
except DigestSignatureError as e:
416
self.assertEqual(('Digest file\ts3://b/k\tINVALID: signature '
417
'verification failed'), str(e))
418
419
def test_properly_signs_when_no_previous_signature(self):
420
validator = Sha256RSADigestValidator()
421
digest_data = {
422
'digestEndTime': 'a',
423
'digestS3Bucket': 'b',
424
'digestS3Object': 'c',
425
'previousDigestSignature': None}
426
signed = validator._create_string_to_sign(digest_data, 'abc'.encode())
427
self.assertEqual(
428
('a\nb/c\nba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff6'
429
'1f20015ad\nnull').encode(), signed)
430
431
432
class TestDigestProvider(BaseAWSCommandParamsTest):
433
def _fake_key(self, date):
434
parsed = parser.parse(date)
435
return ('prefix/AWSLogs/{account}/CloudTrail-Digest/us-east-1/{year}/'
436
'{month}/{day}/{account}_CloudTrail-Digest_us-east-1_foo_'
437
'us-east-1_{date}.json.gz').format(date=date, year=parsed.year,
438
month=parsed.month,
439
account=TEST_ACCOUNT_ID,
440
day=parsed.day)
441
442
def _get_mock_provider(self, s3_client):
443
mock_s3_client_provider = mock.Mock()
444
mock_s3_client_provider.get_client.return_value = s3_client
445
return DigestProvider(
446
mock_s3_client_provider, TEST_ACCOUNT_ID, 'foo', 'us-east-1')
447
448
def test_initializes_public_properties(self):
449
client = mock.Mock()
450
provider = DigestProvider(client, TEST_ACCOUNT_ID, 'foo', 'us-east-1')
451
self.assertEqual(TEST_ACCOUNT_ID, provider.account_id)
452
self.assertEqual('foo', provider.trail_name)
453
self.assertEqual('us-east-1', provider.trail_home_region)
454
455
def test_returns_digests_in_range(self):
456
s3_client = self.driver.session.create_client('s3')
457
keys = [self._fake_key(format_date(START_DATE - timedelta(days=1))),
458
self._fake_key(format_date(START_DATE + timedelta(days=1))),
459
self._fake_key(format_date(START_DATE + timedelta(days=2))),
460
self._fake_key(format_date(START_DATE + timedelta(days=3))),
461
self._fake_key(format_date(END_DATE + timedelta(hours=1))),
462
self._fake_key(format_date(END_DATE + timedelta(days=1)))]
463
# Create a key that looks similar but for a different trail.
464
bad_name = keys[3].replace('foo', 'baz')
465
# Create a key that looks similar but is from a different trail source
466
# region (e.g., CloudTrail-Digest/us-west-2).
467
bad_region = keys[3].replace(
468
'CloudTrail-Digest/us-east-1', 'CloudTrail-Digest/us-west-2')
469
bad_region = bad_region.replace(
470
'CloudTrail-Digest_us-east-1', 'CloudTrail-Digest_us-west-2')
471
self.parsed_responses = [
472
{"Contents": [{"Key": keys[0]}, # skip (date <)
473
{"Key": keys[1]},
474
{"Key": keys[2]},
475
{"Key": 'foo/baz/bar'}, # skip (regex (bogus))
476
{"Key": bad_name}, # skip (regex (trail name))
477
{"Key": bad_region}, # skip (regex (source))
478
{"Key": keys[3]},
479
{"Key": keys[4]}, # hour is +1, but keep
480
{"Key": keys[5]}]}] # skip (date >)
481
self.patch_make_request()
482
provider = self._get_mock_provider(s3_client)
483
digests = provider.load_digest_keys_in_range(
484
'foo', 'prefix', START_DATE, END_DATE)
485
self.assertNotIn(bad_name, digests)
486
self.assertNotIn(bad_region, digests)
487
self.assertEqual(keys[1], digests[0])
488
self.assertEqual(keys[2], digests[1])
489
self.assertEqual(keys[3], digests[2])
490
self.assertEqual(keys[4], digests[3])
491
492
def test_calls_list_objects_correctly(self):
493
s3_client = mock.Mock()
494
mock_paginate = s3_client.get_paginator.return_value.paginate
495
mock_search = mock_paginate.return_value.search
496
mock_search.return_value = []
497
provider = self._get_mock_provider(s3_client)
498
provider.load_digest_keys_in_range(
499
'1', 'prefix', START_DATE, END_DATE)
500
marker = ('prefix/AWSLogs/{account}/CloudTrail-Digest/us-east-1/'
501
'2014/08/09/{account}_CloudTrail-Digest_us-east-1_foo_'
502
'us-east-1_20140809T235900Z.json.gz')
503
mock_paginate.assert_called_once_with(
504
Bucket='1',
505
Marker=marker.format(account=TEST_ACCOUNT_ID))
506
507
def test_calls_list_objects_correctly_org_trails(self):
508
s3_client = mock.Mock()
509
mock_s3_client_provider = mock.Mock()
510
mock_paginate = s3_client.get_paginator.return_value.paginate
511
mock_search = mock_paginate.return_value.search
512
mock_search.return_value = []
513
mock_s3_client_provider.get_client.return_value = s3_client
514
provider = DigestProvider(
515
mock_s3_client_provider, TEST_ORGANIZATION_ACCOUNT_ID,
516
'foo', 'us-east-1', 'us-east-1',
517
TEST_ORGANIZATION_ID)
518
provider.load_digest_keys_in_range(
519
'1', 'prefix', START_DATE, END_DATE)
520
marker = (
521
'prefix/AWSLogs/{organization_id}/{member_account}/'
522
'CloudTrail-Digest/us-east-1/'
523
'2014/08/09/{member_account}_CloudTrail-Digest_us-east-1_foo_'
524
'us-east-1_20140809T235900Z.json.gz'
525
)
526
mock_paginate.assert_called_once_with(
527
Bucket='1',
528
Marker=marker.format(
529
member_account=TEST_ORGANIZATION_ACCOUNT_ID,
530
organization_id=TEST_ORGANIZATION_ID
531
)
532
)
533
534
def test_ensures_digest_has_proper_metadata(self):
535
out = BytesIO()
536
f = gzip.GzipFile(fileobj=out, mode="wb")
537
f.write('{"foo":"bar"}'.encode())
538
f.close()
539
gzipped_data = out.getvalue()
540
s3_client = mock.Mock()
541
s3_client.get_object.return_value = {
542
'Body': BytesIO(gzipped_data),
543
'Metadata': {}}
544
provider = self._get_mock_provider(s3_client)
545
with self.assertRaises(DigestSignatureError):
546
provider.fetch_digest('bucket', 'key')
547
548
def test_ensures_digest_can_be_gzip_inflated(self):
549
s3_client = mock.Mock()
550
s3_client.get_object.return_value = {
551
'Body': BytesIO('foo'.encode()),
552
'Metadata': {}}
553
provider = self._get_mock_provider(s3_client)
554
with self.assertRaises(InvalidDigestFormat):
555
provider.fetch_digest('bucket', 'key')
556
557
def test_ensures_digests_can_be_json_parsed(self):
558
json_str = '{{{'
559
out = BytesIO()
560
f = gzip.GzipFile(fileobj=out, mode="wb")
561
f.write(json_str.encode())
562
f.close()
563
gzipped_data = out.getvalue()
564
s3_client = mock.Mock()
565
s3_client.get_object.return_value = {
566
'Body': BytesIO(gzipped_data),
567
'Metadata': {'signature': 'abc', 'signature-algorithm': 'SHA256'}}
568
provider = self._get_mock_provider(s3_client)
569
with self.assertRaises(InvalidDigestFormat):
570
provider.fetch_digest('bucket', 'key')
571
572
def test_fetches_digests(self):
573
json_str = '{"foo":"bar"}'
574
out = BytesIO()
575
f = gzip.GzipFile(fileobj=out, mode="wb")
576
f.write(json_str.encode())
577
f.close()
578
gzipped_data = out.getvalue()
579
s3_client = mock.Mock()
580
s3_client.get_object.return_value = {
581
'Body': BytesIO(gzipped_data),
582
'Metadata': {'signature': 'abc', 'signature-algorithm': 'SHA256'}}
583
provider = self._get_mock_provider(s3_client)
584
result = provider.fetch_digest('bucket', 'key')
585
self.assertEqual({'foo': 'bar', '_signature': 'abc',
586
'_signature_algorithm': 'SHA256'}, result[0])
587
self.assertEqual(json_str.encode(), result[1])
588
589
590
class TestDigestTraverser(unittest.TestCase):
591
def test_initializes_with_default_validator(self):
592
provider = mock.Mock()
593
traverser = DigestTraverser(
594
digest_provider=provider, starting_bucket='1',
595
starting_prefix='baz', public_key_provider=mock.Mock())
596
self.assertEqual('1', traverser.starting_bucket)
597
self.assertEqual('baz', traverser.starting_prefix)
598
self.assertEqual(provider, traverser.digest_provider)
599
600
def test_ensures_public_keys_are_loaded(self):
601
start_date = START_DATE
602
end_date = END_DATE
603
digest_provider = mock.Mock()
604
key_provider = mock.Mock()
605
key_provider.get_public_keys.return_value = []
606
traverser = DigestTraverser(
607
digest_provider=digest_provider, starting_bucket='1',
608
starting_prefix='baz', public_key_provider=key_provider)
609
digest_iter = traverser.traverse(start_date, end_date)
610
with self.assertRaises(RuntimeError):
611
next(digest_iter)
612
key_provider.get_public_keys.assert_called_with(
613
start_date, end_date)
614
615
def test_ensures_public_key_is_found(self):
616
start_date = START_DATE
617
end_date = END_DATE
618
key_name = end_date.strftime(DATE_FORMAT) + '.json.gz'
619
region = 'us-west-2'
620
digest_provider = mock.Mock()
621
digest_provider.trail_home_region = region
622
digest_provider.load_digest_keys_in_range.return_value = [key_name]
623
digest_provider.fetch_digest.return_value = (
624
{'digestEndTime': 'foo',
625
'digestStartTime': 'foo',
626
'awsAccountId': 'account',
627
'digestPublicKeyFingerprint': 'abc',
628
'digestS3Bucket': '1',
629
'digestS3Object': key_name,
630
'previousDigestSignature': 'xyz'},
631
'abc'
632
)
633
key_provider = mock.Mock()
634
key_provider.get_public_keys.return_value = [{'Fingerprint': 'a'}]
635
on_invalid, calls = collecting_callback()
636
traverser = DigestTraverser(
637
digest_provider=digest_provider, starting_bucket='1',
638
starting_prefix='baz', public_key_provider=key_provider,
639
on_invalid=on_invalid)
640
digest_iter = traverser.traverse(start_date, end_date)
641
with self.assertRaises(StopIteration):
642
next(digest_iter)
643
self.assertEqual(1, len(calls))
644
self.assertEqual(
645
('Digest file\ts3://1/%s\tINVALID: public key not '
646
'found in region %s for fingerprint abc' % (key_name, region)),
647
calls[0]['message'])
648
649
def test_invokes_digest_validator(self):
650
start_date = START_DATE
651
end_date = END_DATE
652
key_name = end_date.strftime(DATE_FORMAT) + '.json.gz'
653
digest = {'digestPublicKeyFingerprint': 'a',
654
'digestS3Bucket': '1',
655
'digestS3Object': key_name,
656
'previousDigestSignature': '...',
657
'digestStartTime': (end_date - timedelta(hours=1)).strftime(
658
DATE_FORMAT),
659
'digestEndTime': end_date.strftime(DATE_FORMAT)}
660
digest_provider = mock.Mock()
661
digest_provider.load_digest_keys_in_range.return_value = [
662
key_name]
663
digest_provider.fetch_digest.return_value = (digest, key_name)
664
key_provider = mock.Mock()
665
public_keys = {'a': {'Fingerprint': 'a', 'Value': 'a'}}
666
key_provider.get_public_keys.return_value = public_keys
667
digest_validator = mock.Mock()
668
traverser = DigestTraverser(
669
digest_provider=digest_provider, starting_bucket='1',
670
starting_prefix='baz', public_key_provider=key_provider,
671
digest_validator=digest_validator)
672
digest_iter = traverser.traverse(start_date, end_date)
673
self.assertEqual(digest, next(digest_iter))
674
digest_validator.validate.assert_called_with(
675
'1', key_name, public_keys['a']['Value'], digest, key_name)
676
677
def test_ensures_digest_from_same_location_as_json_contents(self):
678
start_date = START_DATE
679
end_date = END_DATE
680
callback, collected = collecting_callback()
681
key_name = end_date.strftime(DATE_FORMAT) + '.json.gz'
682
digest = {'digestPublicKeyFingerprint': 'a',
683
'digestS3Bucket': 'not_same',
684
'digestS3Object': key_name,
685
'digestEndTime': end_date.strftime(DATE_FORMAT)}
686
digest_provider = mock.Mock()
687
digest_provider.load_digest_keys_in_range.return_value = [key_name]
688
digest_provider.fetch_digest.return_value = (digest, key_name)
689
key_provider = mock.Mock()
690
digest_validator = mock.Mock()
691
traverser = DigestTraverser(
692
digest_provider=digest_provider, starting_bucket='1',
693
starting_prefix='baz', public_key_provider=key_provider,
694
digest_validator=digest_validator, on_invalid=callback)
695
digest_iter = traverser.traverse(start_date, end_date)
696
self.assertIsNone(next(digest_iter, None))
697
self.assertEqual(1, len(collected))
698
self.assertEqual(
699
'Digest file\ts3://1/%s\tINVALID: invalid format' % key_name,
700
collected[0]['message'])
701
702
def test_loads_digests_in_range(self):
703
start_date = START_DATE
704
end_date = START_DATE + timedelta(hours=5)
705
key_provider, digest_provider, validator = create_scenario(
706
['gap', 'link', 'link', 'link'])
707
traverser = DigestTraverser(
708
digest_provider=digest_provider, starting_bucket='1',
709
starting_prefix='baz', public_key_provider=key_provider,
710
digest_validator=validator)
711
collected = list(traverser.traverse(start_date, end_date))
712
self.assertEqual(1, key_provider.get_public_keys.call_count)
713
self.assertEqual(
714
1, len(digest_provider.calls['load_digest_keys_in_range']))
715
self.assertEqual(4, len(digest_provider.calls['fetch_digest']))
716
self.assertEqual(4, len(collected))
717
718
def test_invokes_cb_and_continues_when_missing(self):
719
start_date = START_DATE
720
end_date = END_DATE
721
key_provider, digest_provider, validator = create_scenario(
722
['gap', 'link', 'missing', 'link'])
723
on_missing, missing_calls = collecting_callback()
724
traverser = DigestTraverser(
725
digest_provider=digest_provider, starting_bucket='1',
726
starting_prefix='baz', public_key_provider=key_provider,
727
digest_validator=validator, on_missing=on_missing)
728
collected = list(traverser.traverse(start_date, end_date))
729
self.assertEqual(3, len(collected))
730
self.assertEqual(1, key_provider.get_public_keys.call_count)
731
self.assertEqual(1, len(missing_calls))
732
# Ensure the keys were provided in the correct order.
733
self.assertIn('bucket', missing_calls[0])
734
self.assertIn('next_end_date', missing_calls[0])
735
# Ensure the keys were provided in the correct order.
736
self.assertEqual(digest_provider.digests[1],
737
missing_calls[0]['next_key'])
738
self.assertEqual(digest_provider.digests[2],
739
missing_calls[0]['last_key'])
740
# Ensure the provider was called correctly
741
self.assertEqual(1, key_provider.get_public_keys.call_count)
742
self.assertEqual(
743
1, len(digest_provider.calls['load_digest_keys_in_range']))
744
self.assertEqual(4, len(digest_provider.calls['fetch_digest']))
745
746
def test_invokes_cb_and_continues_when_invalid(self):
747
start_date = START_DATE
748
end_date = END_DATE
749
key_provider, digest_provider, validator = create_scenario(
750
['gap', 'link', 'invalid', 'link', 'invalid'])
751
on_invalid, invalid_calls = collecting_callback()
752
traverser = DigestTraverser(
753
digest_provider=digest_provider, starting_bucket='1',
754
starting_prefix='baz', public_key_provider=key_provider,
755
digest_validator=validator, on_invalid=on_invalid)
756
collected = list(traverser.traverse(start_date, end_date))
757
self.assertEqual(3, len(collected))
758
self.assertEqual(1, key_provider.get_public_keys.call_count)
759
self.assertEqual(2, len(invalid_calls))
760
# Ensure it was invoked with all the kwargs we expected.
761
self.assertIn('bucket', invalid_calls[0])
762
self.assertIn('next_end_date', invalid_calls[0])
763
# Ensure the keys were provided in the correct order.
764
self.assertEqual(digest_provider.digests[4],
765
invalid_calls[0]['last_key'])
766
self.assertEqual(digest_provider.digests[3],
767
invalid_calls[0]['next_key'])
768
self.assertEqual(digest_provider.digests[2],
769
invalid_calls[1]['last_key'])
770
self.assertEqual(digest_provider.digests[1],
771
invalid_calls[1]['next_key'])
772
# Ensure the provider was called correctly
773
self.assertEqual(1, key_provider.get_public_keys.call_count)
774
self.assertEqual(
775
1, len(digest_provider.calls['load_digest_keys_in_range']))
776
self.assertEqual(5, len(digest_provider.calls['fetch_digest']))
777
778
def test_invokes_cb_and_continues_when_gap(self):
779
start_date = START_DATE
780
end_date = END_DATE
781
key_provider, digest_provider, validator = create_scenario(
782
['gap', 'link', 'gap', 'gap'])
783
on_gap, gap_calls = collecting_callback()
784
traverser = DigestTraverser(
785
digest_provider=digest_provider, starting_bucket='1',
786
starting_prefix='baz', public_key_provider=key_provider,
787
digest_validator=validator, on_gap=on_gap)
788
collected = list(traverser.traverse(start_date, end_date))
789
self.assertEqual(4, len(collected))
790
self.assertEqual(1, key_provider.get_public_keys.call_count)
791
self.assertEqual(2, len(gap_calls))
792
# Ensure it was invoked with all the kwargs we expected.
793
self.assertIn('bucket', gap_calls[0])
794
self.assertIn('next_key', gap_calls[0])
795
self.assertIn('next_end_date', gap_calls[0])
796
self.assertIn('last_key', gap_calls[0])
797
self.assertIn('last_start_date', gap_calls[0])
798
# Ensure the keys were provided in the correct order.
799
self.assertEqual(digest_provider.digests[3], gap_calls[0]['last_key'])
800
self.assertEqual(digest_provider.digests[2], gap_calls[0]['next_key'])
801
self.assertEqual(digest_provider.digests[2], gap_calls[1]['last_key'])
802
self.assertEqual(digest_provider.digests[1], gap_calls[1]['next_key'])
803
# Ensure the provider was called correctly
804
self.assertEqual(1, key_provider.get_public_keys.call_count)
805
self.assertEqual(
806
1, len(digest_provider.calls['load_digest_keys_in_range']))
807
self.assertEqual(4, len(digest_provider.calls['fetch_digest']))
808
809
def test_reloads_objects_on_bucket_change(self):
810
start_date = START_DATE
811
end_date = END_DATE
812
key_provider, digest_provider, validator = create_scenario(
813
['gap', 'link', 'bucket_change', 'link'])
814
traverser = DigestTraverser(
815
digest_provider=digest_provider, starting_bucket='1',
816
starting_prefix='baz', public_key_provider=key_provider,
817
digest_validator=validator)
818
collected = list(traverser.traverse(start_date, end_date))
819
self.assertEqual(4, len(collected))
820
self.assertEqual(1, key_provider.get_public_keys.call_count)
821
# Ensure the provider was called correctly
822
self.assertEqual(1, key_provider.get_public_keys.call_count)
823
self.assertEqual(
824
2, len(digest_provider.calls['load_digest_keys_in_range']))
825
self.assertEqual(['1', '1', '2', '2'],
826
[c['digestS3Bucket'] for c in collected])
827
828
def test_does_not_hard_fail_on_invalid_signature(self):
829
start_date = START_DATE
830
end_date = END_DATE
831
end_timestamp = end_date.strftime(DATE_FORMAT) + '.json.gz'
832
digest = {'digestPublicKeyFingerprint': 'a',
833
'digestS3Bucket': '1',
834
'digestS3Object': end_timestamp,
835
'previousDigestSignature': '...',
836
'digestStartTime': (end_date - timedelta(hours=1)).strftime(
837
DATE_FORMAT),
838
'digestEndTime': end_timestamp,
839
'_signature': '123'}
840
digest_provider = mock.Mock()
841
digest_provider.load_digest_keys_in_range.return_value = [
842
end_timestamp]
843
digest_provider.fetch_digest.return_value = (digest, end_timestamp)
844
key_provider = mock.Mock()
845
public_keys = {'a': {'Fingerprint': 'a', 'Value': 'a'}}
846
key_provider.get_public_keys.return_value = public_keys
847
digest_validator = Sha256RSADigestValidator()
848
on_invalid, calls = collecting_callback()
849
traverser = DigestTraverser(
850
digest_provider=digest_provider, starting_bucket='1',
851
starting_prefix='baz', public_key_provider=key_provider,
852
digest_validator=digest_validator, on_invalid=on_invalid)
853
digest_iter = traverser.traverse(start_date, end_date)
854
next(digest_iter, None)
855
self.assertIn(
856
'Digest file\ts3://1/%s\tINVALID: ' % end_timestamp,
857
calls[0]['message'])
858
859
860
class TestCloudTrailCommand(BaseAWSCommandParamsTest):
861
def test_s3_client_created_lazily(self):
862
session = mock.Mock()
863
command = CloudTrailValidateLogs(session)
864
parsed_globals = mock.Mock(region=None, verify_ssl=None, endpoint_url=None)
865
command.setup_services(parsed_globals)
866
create_client_calls = session.create_client.call_args_list
867
self.assertEqual(
868
create_client_calls,
869
[
870
mock.call('organizations', verify=None, region_name=None),
871
mock.call('cloudtrail', verify=None, region_name=None)
872
]
873
)
874
875
def test_endpoint_url_is_used_for_cloudtrail(self):
876
endpoint_url = 'https://mycloudtrail.aws.amazon.com/'
877
session = mock.Mock()
878
command = CloudTrailValidateLogs(session)
879
parsed_globals = mock.Mock(region='foo', verify_ssl=None,
880
endpoint_url=endpoint_url)
881
command.setup_services(parsed_globals)
882
create_client_calls = session.create_client.call_args_list
883
self.assertEqual(
884
create_client_calls,
885
[
886
mock.call('organizations', verify=None, region_name='foo'),
887
# Here we should inject the endpoint_url only for cloudtrail.
888
mock.call('cloudtrail', verify=None, region_name='foo',
889
endpoint_url=endpoint_url)
890
]
891
)
892
893
def test_initializes_args(self):
894
session = mock.Mock()
895
command = CloudTrailValidateLogs(session)
896
start_date = START_DATE.strftime(DATE_FORMAT)
897
args = Namespace(trail_arn='abc', verbose=True,
898
start_time=start_date, s3_bucket='bucket',
899
s3_prefix='prefix', end_time=None, account_id=None)
900
command.handle_args(args)
901
self.assertEqual('abc', command.trail_arn)
902
self.assertEqual(True, command.is_verbose)
903
self.assertEqual('bucket', command.s3_bucket)
904
self.assertEqual('prefix', command.s3_prefix)
905
self.assertEqual(start_date, command.start_time.strftime(DATE_FORMAT))
906
self.assertIsNotNone(command.end_time)
907
self.assertGreater(command.end_time, command.start_time)
908
909
910
class TestS3ClientProvider(BaseAWSCommandParamsTest):
911
def test_creates_clients_for_buckets_in_us_east_1(self):
912
session = mock.Mock()
913
s3_client = mock.Mock()
914
session.create_client.return_value = s3_client
915
s3_client.get_bucket_location.return_value = {'LocationConstraint': ''}
916
provider = S3ClientProvider(session)
917
created_client = provider.get_client('foo')
918
self.assertEqual(s3_client, created_client)
919
create_client_calls = session.create_client.call_args_list
920
self.assertEqual(create_client_calls, [mock.call('s3', region_name='us-east-1')])
921
self.assertEqual(1, s3_client.get_bucket_location.call_count)
922
923
def test_creates_clients_for_buckets_outside_us_east_1(self):
924
session = mock.Mock()
925
s3_client = mock.Mock()
926
session.create_client.return_value = s3_client
927
s3_client.get_bucket_location.return_value = {
928
'LocationConstraint': 'us-west-2'}
929
provider = S3ClientProvider(session, 'us-west-1')
930
created_client = provider.get_client('foo')
931
self.assertEqual(s3_client, created_client)
932
create_client_calls = session.create_client.call_args_list
933
self.assertEqual(create_client_calls, [
934
mock.call('s3', region_name='us-west-1'),
935
mock.call('s3', region_name='us-west-2')
936
])
937
self.assertEqual(1, s3_client.get_bucket_location.call_count)
938
939
def test_caches_previously_loaded_bucket_regions(self):
940
session = mock.Mock()
941
s3_client = mock.Mock()
942
session.create_client.return_value = s3_client
943
s3_client.get_bucket_location.return_value = {'LocationConstraint': ''}
944
provider = S3ClientProvider(session)
945
provider.get_client('foo')
946
self.assertEqual(1, s3_client.get_bucket_location.call_count)
947
provider.get_client('foo')
948
self.assertEqual(1, s3_client.get_bucket_location.call_count)
949
provider.get_client('bar')
950
self.assertEqual(2, s3_client.get_bucket_location.call_count)
951
provider.get_client('bar')
952
self.assertEqual(2, s3_client.get_bucket_location.call_count)
953
954
def test_caches_previously_loaded_clients(self):
955
session = mock.Mock()
956
s3_client = mock.Mock()
957
session.create_client.return_value = s3_client
958
s3_client.get_bucket_location.return_value = {'LocationConstraint': ''}
959
provider = S3ClientProvider(session)
960
client = provider.get_client('foo')
961
self.assertEqual(1, session.create_client.call_count)
962
self.assertEqual(client, provider.get_client('foo'))
963
self.assertEqual(1, session.create_client.call_count)
964
965
def test_removes_cli_error_events(self):
966
# We should also remove the error handler for S3.
967
# This can be removed once the client switchover is done.
968
session = mock.Mock()
969
s3_client = mock.Mock()
970
session.create_client.return_value = s3_client
971
s3_client.get_bucket_location.return_value = {'LocationConstraint': ''}
972
provider = S3ClientProvider(session)
973
client = provider.get_client('foo')
974
975