Skip to content

Commit

Permalink
fix(query_row_optional): do not treat rows with NULL as missing rows
Browse files Browse the repository at this point in the history
Instead of treating NULL type error
as absence of the row,
handle NULL values with SQL.
Previously we sometimes
accidentally treated a single column
being NULL as the lack of the whole row.
  • Loading branch information
link2xt committed Oct 4, 2024
1 parent 46922d4 commit 5af3c26
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 19 deletions.
25 changes: 20 additions & 5 deletions src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,13 @@ impl ChatId {
pub(crate) async fn get_timestamp(self, context: &Context) -> Result<Option<i64>> {
let timestamp = context
.sql
.query_get_value("SELECT MAX(timestamp) FROM msgs WHERE chat_id=?", (self,))
.query_get_value(
"SELECT MAX(timestamp)
FROM msgs
WHERE chat_id=?
HAVING COUNT(*) > 0",
(self,),
)
.await?;
Ok(timestamp)
}
Expand Down Expand Up @@ -1251,7 +1257,7 @@ impl ChatId {
) -> Result<Option<(String, String, String)>> {
self.parent_query(
context,
"rfc724_mid, mime_in_reply_to, mime_references",
"rfc724_mid, mime_in_reply_to, IFNULL(mime_references, '')",
state_out_min,
|row: &rusqlite::Row| {
let rfc724_mid: String = row.get(0)?;
Expand Down Expand Up @@ -1405,7 +1411,10 @@ impl ChatId {
context
.sql
.query_get_value(
"SELECT MAX(timestamp) FROM msgs WHERE chat_id=? AND state!=?",
"SELECT MAX(timestamp)
FROM msgs
WHERE chat_id=? AND state!=?
HAVING COUNT(*) > 0",
(self, MessageState::OutDraft),
)
.await?
Expand All @@ -1421,7 +1430,10 @@ impl ChatId {
context
.sql
.query_get_value(
"SELECT MAX(timestamp) FROM msgs WHERE chat_id=? AND hidden=0 AND state>?",
"SELECT MAX(timestamp)
FROM msgs
WHERE chat_id=? AND hidden=0 AND state>?
HAVING COUNT(*) > 0",
(self, MessageState::InFresh),
)
.await?
Expand Down Expand Up @@ -4345,7 +4357,10 @@ pub async fn add_device_msg_with_importance(
if let Some(last_msg_time) = context
.sql
.query_get_value(
"SELECT MAX(timestamp) FROM msgs WHERE chat_id=?",
"SELECT MAX(timestamp)
FROM msgs
WHERE chat_id=?
HAVING COUNT(*) > 0",
(chat_id,),
)
.await?
Expand Down
27 changes: 23 additions & 4 deletions src/ephemeral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,16 @@ impl ChatId {
pub async fn get_ephemeral_timer(self, context: &Context) -> Result<Timer> {
let timer = context
.sql
.query_get_value("SELECT ephemeral_timer FROM chats WHERE id=?;", (self,))
.query_row(
"SELECT IFNULL(ephemeral_timer, 0) FROM chats WHERE id=?",
(self,),
|row| {
let timer: Timer = row.get(0)?;
Ok(timer)
},
)
.await?;
Ok(timer.unwrap_or_default())
Ok(timer)
}

/// Set ephemeral timer value without sending a message.
Expand Down Expand Up @@ -509,7 +516,8 @@ async fn next_delete_device_after_timestamp(context: &Context) -> Result<Option<
FROM msgs
WHERE chat_id > ?
AND chat_id != ?
AND chat_id != ?;
AND chat_id != ?
HAVING count(*) > 0
"#,
(DC_CHAT_ID_TRASH, self_chat_id, device_chat_id),
)
Expand All @@ -533,7 +541,8 @@ async fn next_expiration_timestamp(context: &Context) -> Option<i64> {
SELECT min(ephemeral_timestamp)
FROM msgs
WHERE ephemeral_timestamp != 0
AND chat_id != ?;
AND chat_id != ?
HAVING count(*) > 0
"#,
(DC_CHAT_ID_TRASH,), // Trash contains already deleted messages, skip them
)
Expand Down Expand Up @@ -1410,4 +1419,14 @@ mod tests {

Ok(())
}

/// Tests that `.get_ephemeral_timer()` returns an error for invalid chat ID.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_get_ephemeral_timer_wrong_chat_id() -> Result<()> {
let context = TestContext::new().await;
let chat_id = ChatId::new(12345);
assert!(chat_id.get_ephemeral_timer(&context).await.is_err());

Ok(())
}
}
26 changes: 20 additions & 6 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,19 @@ impl MsgId {
}

/// Returns information about hops of a message, used for message info
pub async fn hop_info(self, context: &Context) -> Result<Option<String>> {
context
pub async fn hop_info(self, context: &Context) -> Result<String> {
let hop_info = context
.sql
.query_get_value("SELECT hop_info FROM msgs WHERE id=?", (self,))
.await
.query_row(
"SELECT IFNULL(hop_info, '') FROM msgs WHERE id=?",
(self,),
|row| {
let hop_info: String = row.get(0)?;
Ok(hop_info)
},
)
.await?;
Ok(hop_info)
}

/// Returns detailed message information in a multi-line text form.
Expand Down Expand Up @@ -366,7 +374,11 @@ impl MsgId {
let hop_info = self.hop_info(context).await?;

ret += "\n\n";
ret += &hop_info.unwrap_or_else(|| "No Hop Info".to_owned());
if hop_info.is_empty() {
ret += "No Hop Info";
} else {
ret += &hop_info;
}

Ok(ret)
}
Expand Down Expand Up @@ -1998,7 +2010,9 @@ pub(crate) async fn rfc724_mid_exists_ex(
.query_row_optional(
&("SELECT id, timestamp_sent, MIN(".to_string()
+ expr
+ ") FROM msgs WHERE rfc724_mid=? ORDER BY timestamp_sent DESC"),
+ ") FROM msgs WHERE rfc724_mid=?
HAVING COUNT(*) > 0 -- Prevent MIN(expr) from returning NULL when there are no rows.
ORDER BY timestamp_sent DESC"),
(rfc724_mid,),
|row| {
let msg_id: MsgId = row.get(0)?;
Expand Down
3 changes: 2 additions & 1 deletion src/mimefactory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ impl MimeFactory {
let (in_reply_to, references) = context
.sql
.query_row(
"SELECT mime_in_reply_to, mime_references FROM msgs WHERE id=?",
"SELECT mime_in_reply_to, IFNULL(mime_references, '')
FROM msgs WHERE id=?",
(msg.id,),
|row| {
let in_reply_to: String = row.get(0)?;
Expand Down
4 changes: 1 addition & 3 deletions src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,15 +558,13 @@ impl Sql {
self.call(move |conn| match conn.query_row(sql.as_ref(), params, f) {
Ok(res) => Ok(Some(res)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(rusqlite::Error::InvalidColumnType(_, _, rusqlite::types::Type::Null)) => Ok(None),
Err(err) => Err(err.into()),
})
.await
}

/// Executes a query which is expected to return one row and one
/// column. If the query does not return a value or returns SQL
/// `NULL`, returns `Ok(None)`.
/// column. If the query does not return any rows, returns `Ok(None)`.
pub async fn query_get_value<T>(
&self,
query: &str,
Expand Down

0 comments on commit 5af3c26

Please sign in to comment.