dave: return tool errors back to the ai
So that it can correct itself
This commit is contained in:
@@ -9,6 +9,7 @@ use futures::StreamExt;
|
||||
use nostrdb::Transaction;
|
||||
use notedeck::AppContext;
|
||||
use std::collections::HashMap;
|
||||
use std::string::ToString;
|
||||
use std::sync::mpsc::{self, Receiver};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -137,6 +138,15 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr
|
||||
should_send = true;
|
||||
}
|
||||
|
||||
ToolCalls::Invalid(invalid) => {
|
||||
should_send = true;
|
||||
|
||||
self.chat.push(Message::tool_error(
|
||||
call.id().to_string(),
|
||||
invalid.error.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
ToolCalls::Query(search_call) => {
|
||||
should_send = true;
|
||||
|
||||
@@ -270,12 +280,23 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr
|
||||
parsed_tool_calls.push(tool_call);
|
||||
}
|
||||
Err(err) => {
|
||||
// TODO: we should be
|
||||
tracing::error!(
|
||||
"failed to parse tool call {:?}: {:?}",
|
||||
"failed to parse tool call {:?}: {}",
|
||||
unknown_tool_call,
|
||||
err,
|
||||
);
|
||||
// TODO: return error to user
|
||||
|
||||
if let Some(id) = partial.id() {
|
||||
// we have an id, so we can communicate the error
|
||||
// back to the ai
|
||||
parsed_tool_calls.push(ToolCall::invalid(
|
||||
id.to_string(),
|
||||
partial.name,
|
||||
partial.arguments,
|
||||
err.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -19,6 +19,10 @@ pub enum DaveApiResponse {
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn tool_error(id: String, msg: String) -> Self {
|
||||
Self::ToolResponse(ToolResponse::error(id, msg))
|
||||
}
|
||||
|
||||
pub fn to_api_msg(&self, txn: &Transaction, ndb: &Ndb) -> ChatCompletionRequestMessage {
|
||||
match self {
|
||||
Message::User(msg) => {
|
||||
|
||||
@@ -4,7 +4,7 @@ use enostr::{NoteId, Pubkey};
|
||||
use nostrdb::{Ndb, Note, NoteKey, Transaction};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, fmt};
|
||||
|
||||
/// A tool
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -18,6 +18,22 @@ impl ToolCall {
|
||||
&self.id
|
||||
}
|
||||
|
||||
pub fn invalid(
|
||||
id: String,
|
||||
name: Option<String>,
|
||||
arguments: Option<String>,
|
||||
error: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
typ: ToolCalls::Invalid(InvalidToolCall {
|
||||
name,
|
||||
arguments,
|
||||
error,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn calls(&self) -> &ToolCalls {
|
||||
&self.typ
|
||||
}
|
||||
@@ -34,11 +50,11 @@ impl ToolCall {
|
||||
/// On streaming APIs, tool calls are incremental. We use this
|
||||
/// to represent tool calls that are in the process of returning.
|
||||
/// These eventually just become [`ToolCall`]'s
|
||||
#[derive(Default, Debug, Clone)]
|
||||
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PartialToolCall {
|
||||
id: Option<String>,
|
||||
name: Option<String>,
|
||||
arguments: Option<String>,
|
||||
pub id: Option<String>,
|
||||
pub name: Option<String>,
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
impl PartialToolCall {
|
||||
@@ -75,6 +91,7 @@ pub struct QueryResponse {
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ToolResponses {
|
||||
Error(String),
|
||||
Query(QueryResponse),
|
||||
PresentNotes,
|
||||
}
|
||||
@@ -110,6 +127,13 @@ impl PartialToolCall {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InvalidToolCall {
|
||||
pub error: String,
|
||||
pub name: Option<String>,
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
/// An enumeration of the possible tool calls that
|
||||
/// can be parsed from Dave responses. When adding
|
||||
/// new tools, this needs to be updated so that we can
|
||||
@@ -118,6 +142,12 @@ impl PartialToolCall {
|
||||
pub enum ToolCalls {
|
||||
Query(QueryCall),
|
||||
PresentNotes(PresentNotesCall),
|
||||
Invalid(InvalidToolCall),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ErrorCall {
|
||||
error: String,
|
||||
}
|
||||
|
||||
impl ToolCalls {
|
||||
@@ -131,6 +161,7 @@ impl ToolCalls {
|
||||
fn name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Query(_) => "search",
|
||||
Self::Invalid(_) => "error",
|
||||
Self::PresentNotes(_) => "present",
|
||||
}
|
||||
}
|
||||
@@ -138,6 +169,7 @@ impl ToolCalls {
|
||||
fn arguments(&self) -> String {
|
||||
match self {
|
||||
Self::Query(search) => serde_json::to_string(search).unwrap(),
|
||||
Self::Invalid(partial) => serde_json::to_string(partial).unwrap(),
|
||||
Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(),
|
||||
}
|
||||
}
|
||||
@@ -151,6 +183,19 @@ pub enum ToolCallError {
|
||||
ArgParseFailure(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for ToolCallError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ToolCallError::EmptyName => write!(f, "the tool name was empty"),
|
||||
ToolCallError::EmptyArgs => write!(f, "no arguments were provided"),
|
||||
ToolCallError::NotFound(ref name) => write!(f, "tool '{}' not found", name),
|
||||
ToolCallError::ArgParseFailure(ref msg) => {
|
||||
write!(f, "failed to parse arguments: {}", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum ArgType {
|
||||
String,
|
||||
@@ -276,6 +321,13 @@ impl ToolResponse {
|
||||
Self { id, typ: responses }
|
||||
}
|
||||
|
||||
pub fn error(id: String, msg: String) -> Self {
|
||||
Self {
|
||||
id,
|
||||
typ: ToolResponses::Error(msg),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn responses(&self) -> &ToolResponses {
|
||||
&self.typ
|
||||
}
|
||||
@@ -323,7 +375,7 @@ impl PresentNotesCall {
|
||||
Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids }))
|
||||
}
|
||||
Err(e) => Err(ToolCallError::ArgParseFailure(format!(
|
||||
"Failed to parse args: '{}', error: {}",
|
||||
"{}, error: {}",
|
||||
args, e
|
||||
))),
|
||||
}
|
||||
@@ -424,7 +476,7 @@ impl QueryCall {
|
||||
match serde_json::from_str::<QueryCall>(args) {
|
||||
Ok(call) => Ok(ToolCalls::Query(call)),
|
||||
Err(e) => Err(ToolCallError::ArgParseFailure(format!(
|
||||
"Failed to parse args: '{}', error: {}",
|
||||
"{}, error: {}",
|
||||
args, e
|
||||
))),
|
||||
}
|
||||
@@ -448,6 +500,7 @@ struct SimpleNote {
|
||||
fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String {
|
||||
match resp {
|
||||
ToolResponses::PresentNotes => "".to_string(),
|
||||
ToolResponses::Error(s) => format!("error: {}", &s),
|
||||
|
||||
ToolResponses::Query(search_r) => {
|
||||
let simple_notes: Vec<SimpleNote> = search_r
|
||||
|
||||
@@ -204,6 +204,9 @@ impl<'a> DaveUi<'a> {
|
||||
for call in toolcalls {
|
||||
match call.calls() {
|
||||
ToolCalls::PresentNotes(call) => Self::present_notes_ui(ctx, call, ui),
|
||||
ToolCalls::Invalid(err) => {
|
||||
ui.label(format!("invalid tool call: {:?}", err));
|
||||
}
|
||||
ToolCalls::Query(search_call) => {
|
||||
ui.allocate_ui_with_layout(
|
||||
egui::vec2(ui.available_size().x, 32.0),
|
||||
|
||||
Reference in New Issue
Block a user