use crate::{
ui_transform::UiGlobalTransform, ComputedNode, ComputedUiTargetCamera, Node, OverrideClip,
UiStack,
};
use bevy_camera::{visibility::InheritedVisibility, Camera, NormalizedRenderTarget};
use bevy_ecs::{
change_detection::DetectChangesMut,
entity::{ContainsEntity, Entity},
hierarchy::ChildOf,
prelude::{Component, With},
query::{QueryData, Without},
reflect::ReflectComponent,
system::{Local, Query, Res},
};
use bevy_input::{mouse::MouseButton, touch::Touches, ButtonInput};
use bevy_math::Vec2;
use bevy_platform::collections::HashMap;
use bevy_reflect::{std_traits::ReflectDefault, Reflect};
use bevy_window::{PrimaryWindow, Window};
use smallvec::SmallVec;
#[cfg(feature = "serialize")]
use bevy_reflect::{ReflectDeserialize, ReflectSerialize};
#[derive(Component, Copy, Clone, Eq, PartialEq, Debug, Reflect)]
#[reflect(Component, Default, PartialEq, Debug, Clone)]
#[cfg_attr(
feature = "serialize",
derive(serde::Serialize, serde::Deserialize),
reflect(Serialize, Deserialize)
)]
pub enum Interaction {
Pressed,
Hovered,
None,
}
impl Interaction {
const DEFAULT: Self = Self::None;
}
impl Default for Interaction {
fn default() -> Self {
Self::DEFAULT
}
}
#[derive(Component, Copy, Clone, Default, PartialEq, Debug, Reflect)]
#[reflect(Component, Default, PartialEq, Debug, Clone)]
#[cfg_attr(
feature = "serialize",
derive(serde::Serialize, serde::Deserialize),
reflect(Serialize, Deserialize)
)]
pub struct RelativeCursorPosition {
pub cursor_over: bool,
pub normalized: Option<Vec2>,
}
impl RelativeCursorPosition {
pub fn cursor_over(&self) -> bool {
self.cursor_over
}
}
#[derive(Component, Copy, Clone, Eq, PartialEq, Debug, Reflect)]
#[reflect(Component, Default, PartialEq, Debug, Clone)]
#[cfg_attr(
feature = "serialize",
derive(serde::Serialize, serde::Deserialize),
reflect(Serialize, Deserialize)
)]
pub enum FocusPolicy {
Block,
Pass,
}
impl FocusPolicy {
const DEFAULT: Self = Self::Pass;
}
impl Default for FocusPolicy {
fn default() -> Self {
Self::DEFAULT
}
}
#[derive(Default)]
pub struct State {
entities_to_reset: SmallVec<[Entity; 1]>,
}
#[derive(QueryData)]
#[query_data(mutable)]
pub struct NodeQuery {
entity: Entity,
node: &'static ComputedNode,
transform: &'static UiGlobalTransform,
interaction: Option<&'static mut Interaction>,
relative_cursor_position: Option<&'static mut RelativeCursorPosition>,
focus_policy: Option<&'static FocusPolicy>,
inherited_visibility: Option<&'static InheritedVisibility>,
target_camera: &'static ComputedUiTargetCamera,
}
pub fn ui_focus_system(
mut state: Local<State>,
camera_query: Query<(Entity, &Camera)>,
primary_window: Query<Entity, With<PrimaryWindow>>,
windows: Query<&Window>,
mouse_button_input: Res<ButtonInput<MouseButton>>,
touches_input: Res<Touches>,
ui_stack: Res<UiStack>,
mut node_query: Query<NodeQuery>,
clipping_query: Query<(&ComputedNode, &UiGlobalTransform, &Node)>,
child_of_query: Query<&ChildOf, Without<OverrideClip>>,
) {
let primary_window = primary_window.iter().next();
for entity in state.entities_to_reset.drain(..) {
if let Ok(NodeQueryItem {
interaction: Some(mut interaction),
..
}) = node_query.get_mut(entity)
{
*interaction = Interaction::None;
}
}
let mouse_released =
mouse_button_input.just_released(MouseButton::Left) || touches_input.any_just_released();
if mouse_released {
for node in &mut node_query {
if let Some(mut interaction) = node.interaction
&& *interaction == Interaction::Pressed
{
*interaction = Interaction::None;
}
}
}
let mouse_clicked =
mouse_button_input.just_pressed(MouseButton::Left) || touches_input.any_just_pressed();
let camera_cursor_positions: HashMap<Entity, Vec2> = camera_query
.iter()
.filter_map(|(entity, camera)| {
let Some(NormalizedRenderTarget::Window(window_ref)) =
camera.target.normalize(primary_window)
else {
return None;
};
let window = windows.get(window_ref.entity()).ok()?;
let viewport_position = camera
.physical_viewport_rect()
.map(|rect| rect.min.as_vec2())
.unwrap_or_default();
window
.physical_cursor_position()
.or_else(|| {
touches_input
.first_pressed_position()
.map(|pos| pos * window.scale_factor())
})
.map(|cursor_position| (entity, cursor_position - viewport_position))
})
.collect();
let mut hovered_nodes = ui_stack
.uinodes
.iter()
.rev()
.filter_map(|entity| {
let Ok(node) = node_query.get_mut(*entity) else {
return None;
};
let inherited_visibility = node.inherited_visibility?;
if !inherited_visibility.get() {
if let Some(mut interaction) = node.interaction {
interaction.set_if_neq(Interaction::None);
}
return None;
}
let camera_entity = node.target_camera.get()?;
let cursor_position = camera_cursor_positions.get(&camera_entity);
let contains_cursor = cursor_position.is_some_and(|point| {
node.node.contains_point(*node.transform, *point)
&& clip_check_recursive(*point, *entity, &clipping_query, &child_of_query)
});
let normalized_cursor_position = cursor_position.and_then(|cursor_position| {
node.node.normalize_point(*node.transform, *cursor_position)
});
let relative_cursor_position_component = RelativeCursorPosition {
cursor_over: contains_cursor,
normalized: normalized_cursor_position,
};
if let Some(mut node_relative_cursor_position_component) = node.relative_cursor_position
{
node_relative_cursor_position_component
.set_if_neq(relative_cursor_position_component);
}
if contains_cursor {
Some(*entity)
} else {
if let Some(mut interaction) = node.interaction
&& (*interaction == Interaction::Hovered
|| (normalized_cursor_position.is_none()))
{
interaction.set_if_neq(Interaction::None);
}
None
}
})
.collect::<Vec<Entity>>()
.into_iter();
let mut iter = node_query.iter_many_mut(hovered_nodes.by_ref());
while let Some(node) = iter.fetch_next() {
if let Some(mut interaction) = node.interaction {
if mouse_clicked {
if *interaction != Interaction::Pressed {
*interaction = Interaction::Pressed;
if mouse_released {
state.entities_to_reset.push(node.entity);
}
}
} else if *interaction == Interaction::None {
*interaction = Interaction::Hovered;
}
}
match node.focus_policy.unwrap_or(&FocusPolicy::Block) {
FocusPolicy::Block => {
break;
}
FocusPolicy::Pass => { }
}
}
let mut iter = node_query.iter_many_mut(hovered_nodes);
while let Some(node) = iter.fetch_next() {
if let Some(mut interaction) = node.interaction {
if *interaction != Interaction::Pressed {
interaction.set_if_neq(Interaction::None);
}
}
}
}
pub fn clip_check_recursive(
point: Vec2,
entity: Entity,
clipping_query: &Query<'_, '_, (&ComputedNode, &UiGlobalTransform, &Node)>,
child_of_query: &Query<&ChildOf, Without<OverrideClip>>,
) -> bool {
if let Ok(child_of) = child_of_query.get(entity) {
let parent = child_of.0;
if let Ok((computed_node, transform, node)) = clipping_query.get(parent)
&& !computed_node
.resolve_clip_rect(node.overflow, node.overflow_clip_margin)
.contains(transform.inverse().transform_point2(point))
{
return false;
}
return clip_check_recursive(point, parent, clipping_query, child_of_query);
}
true
}