Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aws
GitHub Repository: aws/aws-cli
Path: blob/develop/awscli/customizations/emr/emrutils.py
2634 views
1
# Copyright 2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License"). You
4
# may not use this file except in compliance with the License. A copy of
5
# the License is located at
6
#
7
# http://aws.amazon.com/apache2.0/
8
#
9
# or in the "license" file accompanying this file. This file is
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
# ANY KIND, either express or implied. See the License for the specific
12
# language governing permissions and limitations under the License.
13
14
import json
15
import logging
16
import os
17
18
from botocore.exceptions import NoCredentialsError
19
20
from awscli.clidriver import CLIOperationCaller
21
from awscli.customizations.emr import constants, exceptions
22
23
LOG = logging.getLogger(__name__)
24
25
26
def parse_tags(raw_tags_list):
27
tags_dict_list = []
28
if raw_tags_list:
29
for tag in raw_tags_list:
30
if tag.find('=') == -1:
31
key, value = tag, ''
32
else:
33
key, value = tag.split('=', 1)
34
tags_dict_list.append({'Key': key, 'Value': value})
35
36
return tags_dict_list
37
38
39
def parse_key_value_string(key_value_string):
40
# raw_key_value_string is a list of key value pairs separated by comma.
41
# Examples: "k1=v1,k2='v 2',k3,k4"
42
key_value_list = []
43
if key_value_string is not None:
44
raw_key_value_list = key_value_string.split(',')
45
for kv in raw_key_value_list:
46
if kv.find('=') == -1:
47
key, value = kv, ''
48
else:
49
key, value = kv.split('=', 1)
50
key_value_list.append({'Key': key, 'Value': value})
51
return key_value_list
52
else:
53
return None
54
55
56
def apply_boolean_options(
57
true_option, true_option_name, false_option, false_option_name
58
):
59
if true_option and false_option:
60
error_message = (
61
'aws: error: cannot use both '
62
+ true_option_name
63
+ ' and '
64
+ false_option_name
65
+ ' options together.'
66
)
67
raise ValueError(error_message)
68
elif true_option:
69
return True
70
else:
71
return False
72
73
74
# Deprecate. Rename to apply_dict
75
def apply(params, key, value):
76
if value:
77
params[key] = value
78
79
return params
80
81
82
def apply_dict(params, key, value):
83
if value:
84
params[key] = value
85
86
return params
87
88
89
def apply_params(src_params, src_key, dest_params, dest_key):
90
if src_key in src_params.keys() and src_params[src_key]:
91
dest_params[dest_key] = src_params[src_key]
92
93
return dest_params
94
95
96
def build_step(
97
jar,
98
name='Step',
99
action_on_failure=constants.DEFAULT_FAILURE_ACTION,
100
args=None,
101
main_class=None,
102
properties=None,
103
log_uri=None,
104
encryption_key_arn=None,
105
):
106
check_required_field(structure='HadoopJarStep', name='Jar', value=jar)
107
108
step = {}
109
apply_dict(step, 'Name', name)
110
apply_dict(step, 'ActionOnFailure', action_on_failure)
111
jar_config = {}
112
jar_config['Jar'] = jar
113
apply_dict(jar_config, 'Args', args)
114
apply_dict(jar_config, 'MainClass', main_class)
115
apply_dict(jar_config, 'Properties', properties)
116
step['HadoopJarStep'] = jar_config
117
step_monitoring_config = {}
118
s3_monitoring_configuration = {}
119
apply_dict(s3_monitoring_configuration, 'LogUri', log_uri)
120
apply_dict(
121
s3_monitoring_configuration, 'EncryptionKeyArn', encryption_key_arn
122
)
123
if s3_monitoring_configuration:
124
step_monitoring_config['S3MonitoringConfiguration'] = (
125
s3_monitoring_configuration
126
)
127
step['StepMonitoringConfiguration'] = step_monitoring_config
128
129
return step
130
131
132
def build_bootstrap_action(path, name='Bootstrap Action', args=None):
133
if path is None:
134
raise exceptions.MissingParametersError(
135
object_name='ScriptBootstrapActionConfig', missing='Path'
136
)
137
ba_config = {}
138
apply_dict(ba_config, 'Name', name)
139
script_config = {}
140
apply_dict(script_config, 'Args', args)
141
script_config['Path'] = path
142
apply_dict(ba_config, 'ScriptBootstrapAction', script_config)
143
144
return ba_config
145
146
147
def build_s3_link(relative_path='', region='us-east-1'):
148
if region is None:
149
region = 'us-east-1'
150
return f's3://{region}.elasticmapreduce{relative_path}'
151
152
153
def get_script_runner(region='us-east-1'):
154
if region is None:
155
region = 'us-east-1'
156
return build_s3_link(
157
relative_path=constants.SCRIPT_RUNNER_PATH, region=region
158
)
159
160
161
def check_required_field(structure, name, value):
162
if not value:
163
raise exceptions.MissingParametersError(
164
object_name=structure, missing=name
165
)
166
167
168
def check_empty_string_list(name, value):
169
if not value or (len(value) == 1 and value[0].strip() == ""):
170
raise exceptions.EmptyListError(param=name)
171
172
173
def call(
174
session,
175
operation_name,
176
parameters,
177
region_name=None,
178
endpoint_url=None,
179
verify=None,
180
):
181
# We could get an error from get_endpoint() about not having
182
# a region configured. Before this happens we want to check
183
# for credentials so we can give a good error message.
184
if session.get_credentials() is None:
185
raise NoCredentialsError()
186
187
client = session.create_client(
188
'emr',
189
region_name=region_name,
190
endpoint_url=endpoint_url,
191
verify=verify,
192
)
193
LOG.debug('Calling ' + str(operation_name))
194
return getattr(client, operation_name)(**parameters)
195
196
197
def get_example_file(command):
198
return open('awscli/examples/emr/' + command + '.rst')
199
200
201
def dict_to_string(dict, indent=2):
202
return json.dumps(dict, indent=indent)
203
204
205
def get_client(session, parsed_globals):
206
return session.create_client(
207
'emr',
208
region_name=get_region(session, parsed_globals),
209
endpoint_url=parsed_globals.endpoint_url,
210
verify=parsed_globals.verify_ssl,
211
)
212
213
214
def get_cluster_state(session, parsed_globals, cluster_id):
215
client = get_client(session, parsed_globals)
216
data = client.describe_cluster(ClusterId=cluster_id)
217
return data['Cluster']['Status']['State']
218
219
220
def find_master_dns(session, parsed_globals, cluster_id):
221
"""
222
Returns the master_instance's 'PublicDnsName'.
223
"""
224
client = get_client(session, parsed_globals)
225
data = client.describe_cluster(ClusterId=cluster_id)
226
return data['Cluster']['MasterPublicDnsName']
227
228
229
def which(program):
230
for path in os.environ["PATH"].split(os.pathsep):
231
path = path.strip('"')
232
exe_file = os.path.join(path, program)
233
if os.path.isfile(exe_file) and os.access(exe_file, os.X_OK):
234
return exe_file
235
236
return None
237
238
239
def call_and_display_response(
240
session, operation_name, parameters, parsed_globals
241
):
242
cli_operation_caller = CLIOperationCaller(session)
243
cli_operation_caller.invoke(
244
'emr', operation_name, parameters, parsed_globals
245
)
246
247
248
def display_response(session, operation_name, result, parsed_globals):
249
cli_operation_caller = CLIOperationCaller(session)
250
# Calling a private method. Should be changed after the functionality
251
# is moved outside CliOperationCaller.
252
cli_operation_caller._display_response(
253
operation_name, result, parsed_globals
254
)
255
256
257
def get_region(session, parsed_globals):
258
region = parsed_globals.region
259
if region is None:
260
region = session.get_config_variable('region')
261
return region
262
263
264
def join(values, separator=',', lastSeparator='and'):
265
"""
266
Helper method to print a list of values
267
[1,2,3] -> '1, 2 and 3'
268
"""
269
values = [str(x) for x in values]
270
if len(values) < 1:
271
return ""
272
elif len(values) == 1:
273
return values[0]
274
else:
275
separator = '%s ' % separator
276
return ' '.join(
277
[separator.join(values[:-1]), lastSeparator, values[-1]]
278
)
279
280
281
def split_to_key_value(string):
282
if string.find('=') == -1:
283
return string, ''
284
else:
285
return string.split('=', 1)
286
287
288
def get_cluster(cluster_id, session, region, endpoint_url, verify_ssl):
289
describe_cluster_params = {'ClusterId': cluster_id}
290
describe_cluster_response = call(
291
session,
292
'describe_cluster',
293
describe_cluster_params,
294
region,
295
endpoint_url,
296
verify_ssl,
297
)
298
299
if describe_cluster_response is not None:
300
return describe_cluster_response.get('Cluster')
301
302
303
def get_release_label(cluster_id, session, region, endpoint_url, verify_ssl):
304
cluster = get_cluster(
305
cluster_id, session, region, endpoint_url, verify_ssl
306
)
307
if cluster is not None:
308
return cluster.get('ReleaseLabel')
309
310