Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/cli/src/commands/agent.rs
13383 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
use std::fs;
7
8
use ahp::{Client, Transport, TransportError, TransportMessage};
9
use ahp_types::commands::{AuthenticateParams, AuthenticateResult};
10
use ahp_types::errors::ahp_error_codes;
11
use ahp_types::state::ProtectedResourceMetadata;
12
use ahp_types::PROTOCOL_VERSION;
13
use futures::{SinkExt, StreamExt};
14
use tokio_tungstenite::tungstenite::Message;
15
16
use crate::auth::{Auth, AuthProvider};
17
use crate::constants::AGENT_HOST_PORT;
18
use crate::log;
19
use crate::tunnels::dev_tunnels::DevTunnels;
20
use crate::util::errors::{wrap, AnyError, CodeError};
21
use crate::util::machine::process_exists;
22
23
use super::agent_host::AgentHostLockData;
24
use super::CommandContext;
25
26
/// Connects to an agent host, initializes the AHP session, and returns
27
/// the ready-to-use client. If an explicit `address` is given it is used
28
/// directly; if `tunnel_name` is given, the tunnel is looked up via the
29
/// dev tunnels API; otherwise the lockfile written by `code agent host`
30
/// is read to discover the local instance.
31
///
32
/// The returned client has been initialized but **not** authenticated.
33
/// Use [`request_with_auth`] to issue commands that may require auth.
34
pub async fn connect(
35
ctx: &CommandContext,
36
address: Option<&str>,
37
tunnel_name: Option<&str>,
38
) -> Result<Client, AnyError> {
39
let client = match (address, tunnel_name) {
40
(Some(addr), _) => connect_ws(addr).await?,
41
(None, Some(name)) => connect_via_tunnel(ctx, name).await?,
42
(None, None) => {
43
let addr = resolve_address_from_lockfile(ctx)?;
44
connect_ws(&addr).await?
45
}
46
};
47
48
client
49
.initialize("code-cli".into(), PROTOCOL_VERSION as i64, vec![])
50
.await
51
.map_err(|e| wrap(e, "AHP initialize failed"))?;
52
53
Ok(client)
54
}
55
56
/// Opens a WebSocket connection and creates an AHP client.
57
async fn connect_ws(address: &str) -> Result<Client, AnyError> {
58
let transport = ahp_ws::WebSocketTransport::connect(address)
59
.await
60
.map_err(|e| wrap(e, format!("Failed to connect to agent host at {address}")))?;
61
62
Client::connect(transport, ahp::ClientConfig::default())
63
.await
64
.map_err(|e| wrap(e, "Failed to establish AHP session").into())
65
}
66
67
/// Connects to an agent host over a dev tunnel relay. Looks up the tunnel
68
/// by name, opens a direct-tcpip channel to the agent host port, performs
69
/// a WebSocket handshake over the raw stream, then creates an AHP client.
70
async fn connect_via_tunnel(ctx: &CommandContext, name: &str) -> Result<Client, AnyError> {
71
let auth = Auth::new(&ctx.paths, ctx.log.clone());
72
let mut dt = DevTunnels::new_remote_tunnel(&ctx.log, auth, &ctx.paths);
73
74
let (port_conn, _relay_handle) = dt.connect_to_tunnel_port(name, AGENT_HOST_PORT).await?;
75
76
let rw = port_conn.into_rw();
77
let (ws_stream, _) = tokio_tungstenite::client_async("ws://localhost/", rw)
78
.await
79
.map_err(|e| wrap(e, "WebSocket handshake over tunnel failed"))?;
80
81
let transport = TunnelWsTransport {
82
inner: ws_stream,
83
// Keep the relay handle alive so the SSH session isn't dropped.
84
_relay_handle,
85
};
86
87
Client::connect(transport, ahp::ClientConfig::default())
88
.await
89
.map_err(|e| wrap(e, "Failed to establish AHP session over tunnel").into())
90
}
91
92
/// A [`Transport`] backed by a WebSocket stream running over a tunnel
93
/// relay channel (via `PortConnectionRW`).
94
struct TunnelWsTransport {
95
inner: tokio_tungstenite::WebSocketStream<tunnels::connections::PortConnectionRW>,
96
/// Prevent the relay handle from being dropped, which would close the
97
/// underlying SSH session.
98
_relay_handle: tunnels::connections::ClientRelayHandle,
99
}
100
101
impl Transport for TunnelWsTransport {
102
async fn send(&mut self, msg: TransportMessage) -> Result<(), TransportError> {
103
let frame = match msg {
104
TransportMessage::Parsed(m) => {
105
let s = serde_json::to_string(&m)
106
.map_err(|e| TransportError::Protocol(e.to_string()))?;
107
Message::Text(s.into())
108
}
109
TransportMessage::Text(s) => Message::Text(s.into()),
110
TransportMessage::Binary(b) => Message::Binary(b.into()),
111
};
112
self.inner
113
.send(frame)
114
.await
115
.map_err(|e| TransportError::Io(e.to_string()))
116
}
117
118
async fn recv(&mut self) -> Result<Option<TransportMessage>, TransportError> {
119
loop {
120
match self.inner.next().await {
121
None => return Ok(None),
122
Some(Err(e)) => return Err(TransportError::Io(e.to_string())),
123
Some(Ok(Message::Text(s))) => {
124
return Ok(Some(TransportMessage::Text(s.to_string())))
125
}
126
Some(Ok(Message::Binary(b))) => {
127
return Ok(Some(TransportMessage::Binary(b.to_vec())))
128
}
129
Some(Ok(Message::Close(_))) => return Ok(None),
130
Some(Ok(_)) => continue,
131
}
132
}
133
}
134
135
async fn close(&mut self) -> Result<(), TransportError> {
136
self.inner
137
.close(None)
138
.await
139
.map_err(|e| TransportError::Io(e.to_string()))
140
}
141
}
142
143
/// Issues a JSON-RPC request, automatically handling `-32007` auth errors
144
/// by running the device-flow login and retrying once.
145
pub async fn request_with_auth<P, R>(
146
ctx: &CommandContext,
147
client: &Client,
148
method: &str,
149
params: P,
150
) -> Result<R, AnyError>
151
where
152
P: serde::Serialize + Clone,
153
R: serde::de::DeserializeOwned,
154
{
155
match client.request::<P, R>(method, params.clone()).await {
156
Ok(r) => Ok(r),
157
Err(ref e) if is_auth_required(e) => {
158
debug!(
159
ctx.log,
160
"Server requires authentication, starting login flow..."
161
);
162
authenticate_from_error(ctx, client, e).await?;
163
client
164
.request::<P, R>(method, params)
165
.await
166
.map_err(|e| wrap(e, format!("Failed after authentication: {method}")).into())
167
}
168
Err(e) => Err(wrap(e, format!("Request failed: {method}")).into()),
169
}
170
}
171
172
fn is_auth_required(err: &ahp::ClientError) -> bool {
173
matches!(err, ahp::ClientError::Rpc(e) if e.code == ahp_error_codes::AUTH_REQUIRED)
174
}
175
176
fn parse_protected_resources(err: &ahp::ClientError) -> Vec<ProtectedResourceMetadata> {
177
if let ahp::ClientError::Rpc(e) = err {
178
if let Some(data) = &e.data {
179
if let Ok(resources) =
180
serde_json::from_value::<Vec<ProtectedResourceMetadata>>(data.clone())
181
{
182
return resources;
183
}
184
}
185
}
186
Vec::new()
187
}
188
189
fn provider_for_resource(resource: &ProtectedResourceMetadata) -> Option<AuthProvider> {
190
for server in resource
191
.authorization_servers
192
.as_deref()
193
.unwrap_or_default()
194
{
195
if server.contains("github.com") {
196
return Some(AuthProvider::Github);
197
}
198
if server.contains("microsoftonline.com") || server.contains("login.microsoft.com") {
199
return Some(AuthProvider::Microsoft);
200
}
201
}
202
None
203
}
204
205
async fn authenticate_from_error(
206
ctx: &CommandContext,
207
client: &Client,
208
err: &ahp::ClientError,
209
) -> Result<(), AnyError> {
210
let resources = parse_protected_resources(err);
211
if resources.is_empty() {
212
return Err(wrap(
213
"Server returned AuthRequired but did not include protected resource metadata",
214
"Cannot determine authentication provider",
215
)
216
.into());
217
}
218
219
let auth = Auth::with_namespace(&ctx.paths, ctx.log.clone(), Some("agent-host".into()));
220
221
for resource in &resources {
222
let provider = provider_for_resource(resource);
223
let scopes = resource.scopes_supported.as_ref().map(|s| s.join("+"));
224
225
// Reuse a stored credential from the namespace if one exists; only
226
// start a device-flow login when there is nothing cached.
227
let credential = match auth.get_current_credential() {
228
Ok(Some(existing)) => existing,
229
_ => match provider {
230
Some(p) => auth.login_with_scopes(p, scopes).await?,
231
None => auth.get_credential().await?,
232
},
233
};
234
235
let _: AuthenticateResult = client
236
.request(
237
"authenticate",
238
AuthenticateParams {
239
resource: resource.resource.clone(),
240
token: credential.access_token().to_string(),
241
},
242
)
243
.await
244
.map_err(|e| {
245
wrap(
246
e,
247
format!("AHP authenticate failed for {}", resource.resource),
248
)
249
})?;
250
}
251
252
Ok(())
253
}
254
255
fn resolve_address_from_lockfile(ctx: &CommandContext) -> Result<String, AnyError> {
256
let lockfile_path = ctx.paths.agent_host_lockfile();
257
258
let data = fs::read_to_string(&lockfile_path).map_err(|e| {
259
wrap(
260
e,
261
"No running agent host found. Start one with `code agent host` or specify --address",
262
)
263
})?;
264
265
let lock: AgentHostLockData = serde_json::from_str(&data).map_err(|e| {
266
wrap(
267
e,
268
format!("Corrupt agent host lockfile at {}", lockfile_path.display()),
269
)
270
})?;
271
272
if !process_exists(lock.pid) {
273
let _ = fs::remove_file(&lockfile_path);
274
return Err(CodeError::NoRunningAgentHost.into());
275
}
276
277
let mut url = lock.address;
278
if let Some(token) = &lock.connection_token {
279
url.push_str(&format!("?tkn={token}"));
280
}
281
Ok(url)
282
}
283
284