dave: return tool errors back to the ai

So that it can correct itself
This commit is contained in:
William Casarin
2025-04-22 16:04:54 -07:00
parent 9692b6b9fe
commit 56f5151739
4 changed files with 90 additions and 9 deletions

View File

@@ -9,6 +9,7 @@ use futures::StreamExt;
use nostrdb::Transaction; use nostrdb::Transaction;
use notedeck::AppContext; use notedeck::AppContext;
use std::collections::HashMap; use std::collections::HashMap;
use std::string::ToString;
use std::sync::mpsc::{self, Receiver}; use std::sync::mpsc::{self, Receiver};
use std::sync::Arc; use std::sync::Arc;
@@ -137,6 +138,15 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr
should_send = true; should_send = true;
} }
ToolCalls::Invalid(invalid) => {
should_send = true;
self.chat.push(Message::tool_error(
call.id().to_string(),
invalid.error.clone(),
));
}
ToolCalls::Query(search_call) => { ToolCalls::Query(search_call) => {
should_send = true; should_send = true;
@@ -270,12 +280,23 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr
parsed_tool_calls.push(tool_call); parsed_tool_calls.push(tool_call);
} }
Err(err) => { Err(err) => {
// TODO: we should be
tracing::error!( tracing::error!(
"failed to parse tool call {:?}: {:?}", "failed to parse tool call {:?}: {}",
unknown_tool_call, unknown_tool_call,
err, err,
); );
// TODO: return error to user
if let Some(id) = partial.id() {
// we have an id, so we can communicate the error
// back to the ai
parsed_tool_calls.push(ToolCall::invalid(
id.to_string(),
partial.name,
partial.arguments,
err.to_string(),
));
}
} }
}; };
} }

View File

@@ -19,6 +19,10 @@ pub enum DaveApiResponse {
} }
impl Message { impl Message {
pub fn tool_error(id: String, msg: String) -> Self {
Self::ToolResponse(ToolResponse::error(id, msg))
}
pub fn to_api_msg(&self, txn: &Transaction, ndb: &Ndb) -> ChatCompletionRequestMessage { pub fn to_api_msg(&self, txn: &Transaction, ndb: &Ndb) -> ChatCompletionRequestMessage {
match self { match self {
Message::User(msg) => { Message::User(msg) => {

View File

@@ -4,7 +4,7 @@ use enostr::{NoteId, Pubkey};
use nostrdb::{Ndb, Note, NoteKey, Transaction}; use nostrdb::{Ndb, Note, NoteKey, Transaction};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::collections::HashMap; use std::{collections::HashMap, fmt};
/// A tool /// A tool
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -18,6 +18,22 @@ impl ToolCall {
&self.id &self.id
} }
pub fn invalid(
id: String,
name: Option<String>,
arguments: Option<String>,
error: String,
) -> Self {
Self {
id,
typ: ToolCalls::Invalid(InvalidToolCall {
name,
arguments,
error,
}),
}
}
pub fn calls(&self) -> &ToolCalls { pub fn calls(&self) -> &ToolCalls {
&self.typ &self.typ
} }
@@ -34,11 +50,11 @@ impl ToolCall {
/// On streaming APIs, tool calls are incremental. We use this /// On streaming APIs, tool calls are incremental. We use this
/// to represent tool calls that are in the process of returning. /// to represent tool calls that are in the process of returning.
/// These eventually just become [`ToolCall`]'s /// These eventually just become [`ToolCall`]'s
#[derive(Default, Debug, Clone)] #[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct PartialToolCall { pub struct PartialToolCall {
id: Option<String>, pub id: Option<String>,
name: Option<String>, pub name: Option<String>,
arguments: Option<String>, pub arguments: Option<String>,
} }
impl PartialToolCall { impl PartialToolCall {
@@ -75,6 +91,7 @@ pub struct QueryResponse {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolResponses { pub enum ToolResponses {
Error(String),
Query(QueryResponse), Query(QueryResponse),
PresentNotes, PresentNotes,
} }
@@ -110,6 +127,13 @@ impl PartialToolCall {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvalidToolCall {
pub error: String,
pub name: Option<String>,
pub arguments: Option<String>,
}
/// An enumeration of the possible tool calls that /// An enumeration of the possible tool calls that
/// can be parsed from Dave responses. When adding /// can be parsed from Dave responses. When adding
/// new tools, this needs to be updated so that we can /// new tools, this needs to be updated so that we can
@@ -118,6 +142,12 @@ impl PartialToolCall {
pub enum ToolCalls { pub enum ToolCalls {
Query(QueryCall), Query(QueryCall),
PresentNotes(PresentNotesCall), PresentNotes(PresentNotesCall),
Invalid(InvalidToolCall),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ErrorCall {
error: String,
} }
impl ToolCalls { impl ToolCalls {
@@ -131,6 +161,7 @@ impl ToolCalls {
fn name(&self) -> &'static str { fn name(&self) -> &'static str {
match self { match self {
Self::Query(_) => "search", Self::Query(_) => "search",
Self::Invalid(_) => "error",
Self::PresentNotes(_) => "present", Self::PresentNotes(_) => "present",
} }
} }
@@ -138,6 +169,7 @@ impl ToolCalls {
fn arguments(&self) -> String { fn arguments(&self) -> String {
match self { match self {
Self::Query(search) => serde_json::to_string(search).unwrap(), Self::Query(search) => serde_json::to_string(search).unwrap(),
Self::Invalid(partial) => serde_json::to_string(partial).unwrap(),
Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(), Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(),
} }
} }
@@ -151,6 +183,19 @@ pub enum ToolCallError {
ArgParseFailure(String), ArgParseFailure(String),
} }
impl fmt::Display for ToolCallError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ToolCallError::EmptyName => write!(f, "the tool name was empty"),
ToolCallError::EmptyArgs => write!(f, "no arguments were provided"),
ToolCallError::NotFound(ref name) => write!(f, "tool '{}' not found", name),
ToolCallError::ArgParseFailure(ref msg) => {
write!(f, "failed to parse arguments: {}", msg)
}
}
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
enum ArgType { enum ArgType {
String, String,
@@ -276,6 +321,13 @@ impl ToolResponse {
Self { id, typ: responses } Self { id, typ: responses }
} }
pub fn error(id: String, msg: String) -> Self {
Self {
id,
typ: ToolResponses::Error(msg),
}
}
pub fn responses(&self) -> &ToolResponses { pub fn responses(&self) -> &ToolResponses {
&self.typ &self.typ
} }
@@ -323,7 +375,7 @@ impl PresentNotesCall {
Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids })) Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids }))
} }
Err(e) => Err(ToolCallError::ArgParseFailure(format!( Err(e) => Err(ToolCallError::ArgParseFailure(format!(
"Failed to parse args: '{}', error: {}", "{}, error: {}",
args, e args, e
))), ))),
} }
@@ -424,7 +476,7 @@ impl QueryCall {
match serde_json::from_str::<QueryCall>(args) { match serde_json::from_str::<QueryCall>(args) {
Ok(call) => Ok(ToolCalls::Query(call)), Ok(call) => Ok(ToolCalls::Query(call)),
Err(e) => Err(ToolCallError::ArgParseFailure(format!( Err(e) => Err(ToolCallError::ArgParseFailure(format!(
"Failed to parse args: '{}', error: {}", "{}, error: {}",
args, e args, e
))), ))),
} }
@@ -448,6 +500,7 @@ struct SimpleNote {
fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String { fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String {
match resp { match resp {
ToolResponses::PresentNotes => "".to_string(), ToolResponses::PresentNotes => "".to_string(),
ToolResponses::Error(s) => format!("error: {}", &s),
ToolResponses::Query(search_r) => { ToolResponses::Query(search_r) => {
let simple_notes: Vec<SimpleNote> = search_r let simple_notes: Vec<SimpleNote> = search_r

View File

@@ -204,6 +204,9 @@ impl<'a> DaveUi<'a> {
for call in toolcalls { for call in toolcalls {
match call.calls() { match call.calls() {
ToolCalls::PresentNotes(call) => Self::present_notes_ui(ctx, call, ui), ToolCalls::PresentNotes(call) => Self::present_notes_ui(ctx, call, ui),
ToolCalls::Invalid(err) => {
ui.label(format!("invalid tool call: {:?}", err));
}
ToolCalls::Query(search_call) => { ToolCalls::Query(search_call) => {
ui.allocate_ui_with_layout( ui.allocate_ui_with_layout(
egui::vec2(ui.available_size().x, 32.0), egui::vec2(ui.available_size().x, 32.0),