Skip to main content

streamocracy/polls/
mod.rs

1//! Poll functionality for the Streamocracy bot
2
3use serenity::all::{CommandInteraction, Context, CreateEmbed, ReactionType, UserId};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::sync::LazyLock;
7use std::time::Duration;
8use tokio::sync::Mutex;
9use tokio::time::sleep;
10use tracing::{error, info, warn};
11
12pub mod votekick;
13
14/// Poll metadata stored while poll is active.
15#[derive(Clone)]
16pub struct PollInfo {
17    /// Channel where poll was created
18    pub channel_id: u64,
19}
20
21/// Thread-safe storage for active polls
22type ActivePolls = Arc<Mutex<HashMap<u64, PollInfo>>>;
23
24static ACTIVE_POLLS: LazyLock<ActivePolls> = LazyLock::new(|| Arc::new(Mutex::new(HashMap::new())));
25
26/// Trait for reaction-based polls.
27#[serenity::async_trait]
28pub trait Poll: Send + Sync {
29    /// The poll title displayed in the embed.
30    fn title(&self) -> String;
31
32    /// The poll description/question.
33    fn description(&self) -> String;
34
35    /// Duration of the poll in seconds.
36    fn duration(&self) -> u64;
37
38    /// The yes/positive reaction emoji.
39    fn yes_reaction(&self) -> ReactionType {
40        ReactionType::Unicode("✅".to_string())
41    }
42
43    /// The no/negative reaction emoji.
44    fn no_reaction(&self) -> ReactionType {
45        ReactionType::Unicode("❌".to_string())
46    }
47
48    /// Build the embed shown for the poll.
49    fn build_embed(&self) -> CreateEmbed {
50        CreateEmbed::default()
51            .title(self.title())
52            .description(self.description())
53            .field("Duration", format!("{} seconds", self.duration()), false)
54    }
55
56    /// Called when the poll ends with results.
57    /// yes_votes and no_votes are counts excluding the bot.
58    async fn on_complete(&self, ctx: &Context, message_id: u64, yes_votes: u32, no_votes: u32);
59
60    /// Start the poll by sending the embed and adding reactions.
61    /// Returns the message ID of the created poll.
62    async fn start(&self, ctx: &Context, command: &CommandInteraction) -> anyhow::Result<u64> {
63        let embed = self.build_embed();
64
65        command
66            .create_response(
67                &ctx.http,
68                serenity::all::CreateInteractionResponse::Message(
69                    serenity::all::CreateInteractionResponseMessage::new().embed(embed),
70                ),
71            )
72            .await?;
73
74        let message = command.get_response(&ctx.http).await?;
75        let yes = self.yes_reaction();
76        let no = self.no_reaction();
77
78        if let Err(e) = message.react(&ctx.http, yes).await {
79            error!("Failed to add yes reaction: {}", e);
80        }
81        if let Err(e) = message.react(&ctx.http, no).await {
82            error!("Failed to add no reaction: {}", e);
83        }
84
85        let message_id = message.id.get();
86
87        {
88            let mut active = ACTIVE_POLLS.lock().await;
89            active.insert(
90                message_id,
91                PollInfo {
92                    channel_id: message.channel_id.get(),
93                },
94            );
95        }
96
97        info!("Poll started (message_id: {})", message_id);
98        Ok(message_id)
99    }
100}
101
102/// Schedule a poll to complete after its duration.
103pub async fn schedule_poll_completion<P: Poll + 'static>(
104    poll: P,
105    ctx: Context,
106    message_id: u64,
107    duration_secs: u64,
108) {
109    let ctx_clone = ctx.clone();
110    tokio::spawn(async move {
111        sleep(Duration::from_secs(duration_secs)).await;
112        complete_poll(&poll, &ctx_clone, message_id).await;
113    });
114}
115
116/// Complete a poll by counting votes and calling on_complete.
117async fn complete_poll<P: Poll>(poll: &P, ctx: &Context, message_id: u64) {
118    let poll_info = {
119        let mut active = ACTIVE_POLLS.lock().await;
120        match active.remove(&message_id) {
121            Some(info) => info,
122            None => {
123                warn!("No active poll found for message {}", message_id);
124                return;
125            }
126        }
127    };
128
129    let channel_id = serenity::all::ChannelId::new(poll_info.channel_id);
130    let message = match channel_id.message(&ctx.http, message_id).await {
131        Ok(m) => m,
132        Err(e) => {
133            error!("Failed to fetch poll message: {}", e);
134            return;
135        }
136    };
137
138    let yes_reaction = poll.yes_reaction();
139    let no_reaction = poll.no_reaction();
140    let yes_votes = get_reaction_count(&ctx.http, &message, &yes_reaction).await;
141    let no_votes = get_reaction_count(&ctx.http, &message, &no_reaction).await;
142
143    info!(
144        "Poll results for message {}: Yes={}, No={}",
145        message_id, yes_votes, no_votes
146    );
147
148    if let Err(e) = channel_id.delete_message(&ctx.http, message_id).await {
149        warn!("Failed to delete poll message: {}", e);
150    }
151
152    poll.on_complete(ctx, message_id, yes_votes, no_votes).await;
153}
154
155/// Count users who reacted with a specific emoji, excluding the bot.
156async fn get_reaction_count(
157    http: &serenity::all::Http,
158    message: &serenity::all::Message,
159    reaction_type: &ReactionType,
160) -> u32 {
161    let mut count = 0u32;
162    let mut after: Option<UserId> = None;
163
164    loop {
165        let users = match message
166            .reaction_users(http, reaction_type.clone(), Some(100u8), after)
167            .await
168        {
169            Ok(u) => u,
170            Err(e) => {
171                error!("Failed to get reaction users: {}", e);
172                break;
173            }
174        };
175
176        if users.is_empty() {
177            break;
178        }
179
180        for user in &users {
181            if user.id != message.author.id {
182                count += 1;
183            }
184        }
185
186        if users.len() < 100 {
187            break;
188        }
189
190        after = users.last().map(|u| u.id);
191    }
192
193    count
194}
195
196/// Send a message that auto-deletes after a specified number of seconds.
197pub async fn send_temporary_message(
198    ctx: &Context,
199    channel_id: serenity::all::ChannelId,
200    content: impl Into<String>,
201    delete_after_secs: u64,
202) {
203    let content = content.into();
204    let http = ctx.http.clone();
205
206    match channel_id.say(&http, content).await {
207        Ok(message) => {
208            let message_id = message.id;
209            tokio::spawn(async move {
210                sleep(Duration::from_secs(delete_after_secs)).await;
211                let _ = channel_id.delete_message(&http, message_id).await;
212            });
213        }
214        Err(e) => {
215            error!("Failed to send temporary message: {}", e);
216        }
217    }
218}