dave: toolcall parsing

Signed-off-by: William Casarin <jb55@jb55.com>
This commit is contained in:
William Casarin
2025-03-25 16:45:22 -07:00
parent 6e2c4cb695
commit 4dfb013d6a
3 changed files with 267 additions and 18 deletions

2
Cargo.lock generated
View File

@@ -3304,6 +3304,8 @@ dependencies = [
"futures",
"notedeck",
"reqwest",
"serde",
"serde_json",
"tokio",
"tracing",
]

View File

@@ -11,6 +11,8 @@ eframe = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
egui-wgpu = { workspace = true }
serde_json = { workspace = true }
serde = { workspace = true }
bytemuck = "1.22.0"
futures = "0.3.31"
reqwest = "0.12.15"

View File

@@ -4,13 +4,18 @@ use async_openai::{
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent,
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
ChatCompletionRequestUserMessageContent, ChatCompletionTool, ChatCompletionToolType,
CreateChatCompletionRequest, FunctionObject,
},
Client,
};
use futures::StreamExt;
use notedeck::AppContext;
use serde::Deserialize;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::mpsc::{self, Receiver};
use std::sync::Arc;
use avatar::DaveAvatar;
use egui::{Rect, Vec2};
@@ -54,19 +59,56 @@ impl Message {
}
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SearchContext {
Home,
Profile,
Any,
}
#[derive(Debug, Deserialize)]
pub struct SearchCall {
context: SearchContext,
query: String,
}
impl SearchCall {
pub fn parse(args: &str) -> Result<ToolCall, ToolCallError> {
match serde_json::from_str::<SearchCall>(args) {
Ok(call) => Ok(ToolCall::Search(call)),
Err(e) => Err(ToolCallError::ArgParseFailure(format!(
"Failed to parse args: '{}', error: {}",
args, e
))),
}
}
}
#[derive(Debug)]
pub enum ToolCall {
Search(SearchCall),
}
pub enum DaveResponse {
ToolCall(ToolCall),
Token(String),
}
pub struct Dave {
chat: Vec<Message>,
/// A 3d representation of dave.
avatar: Option<DaveAvatar>,
input: String,
pubkey: String,
tools: Arc<HashMap<String, Tool>>,
client: async_openai::Client<OpenAIConfig>,
incoming_tokens: Option<Receiver<String>>,
incoming_tokens: Option<Receiver<DaveResponse>>,
}
impl Dave {
pub fn new(render_state: Option<&RenderState>) -> Self {
let mut config = OpenAIConfig::new().with_api_base("http://ollama.jb55.com/v1");
let mut config = OpenAIConfig::new(); //.with_api_base("http://ollama.jb55.com/v1");
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
config = config.with_api_key(api_key);
}
@@ -76,12 +118,17 @@ impl Dave {
let input = "".to_string();
let pubkey = "test_pubkey".to_string();
let avatar = render_state.map(DaveAvatar::new);
let mut tools: HashMap<String, Tool> = HashMap::new();
for tool in dave_tools() {
tools.insert(tool.name.to_string(), tool);
}
Dave {
client,
pubkey,
avatar,
incoming_tokens: None,
tools: Arc::new(tools),
input,
chat: vec![
Message::System("You are an ai agent for the nostr protocol. You have access to tools that can query the network, so you can help find content for users (TODO: actually implement this)".to_string()),
@@ -91,11 +138,17 @@ impl Dave {
fn render(&mut self, ui: &mut egui::Ui) {
if let Some(recvr) = &self.incoming_tokens {
if let Ok(token) = recvr.try_recv() {
match self.chat.last_mut() {
Some(Message::Assistant(msg)) => *msg = msg.clone() + &token,
Some(_) => self.chat.push(Message::Assistant(token)),
None => {}
while let Ok(res) = recvr.try_recv() {
match res {
DaveResponse::Token(token) => match self.chat.last_mut() {
Some(Message::Assistant(msg)) => *msg = msg.clone() + &token,
Some(_) => self.chat.push(Message::Assistant(token)),
None => {}
},
DaveResponse::ToolCall(tool) => {
tracing::info!("got tool call: {:?}", tool);
}
}
}
}
@@ -171,14 +224,16 @@ impl Dave {
self.incoming_tokens = Some(rx);
let ctx = ctx.clone();
let client = self.client.clone();
let tools = self.tools.clone();
tokio::spawn(async move {
let mut token_stream = match client
.chat()
.create_stream(CreateChatCompletionRequest {
model: "gpt-4o".to_string(),
//model: "gpt-4o".to_string(),
model: "llama3.1:latest".to_string(),
stream: Some(true),
messages,
tools: Some(dave_tools().iter().map(|t| t.to_api()).collect()),
user: Some(pubkey),
..Default::default()
})
@@ -192,7 +247,8 @@ impl Dave {
Ok(stream) => stream,
};
tracing::info!("got stream!");
let mut tool_call_name: Option<String> = None;
let mut tool_call_chunks: Vec<String> = vec![];
while let Some(token) = token_stream.next().await {
let token = match token {
@@ -202,16 +258,61 @@ impl Dave {
return;
}
};
let Some(choice) = token.choices.first() else {
return;
};
let Some(content) = &choice.delta.content else {
return;
};
tx.send(content.to_owned()).unwrap();
ctx.request_repaint();
for choice in &token.choices {
let resp = &choice.delta;
// if we have tool call arg chunks, collect them here
if let Some(tool_calls) = &resp.tool_calls {
for tool in tool_calls {
let Some(fcall) = &tool.function else {
continue;
};
if let Some(name) = &fcall.name {
tool_call_name = Some(name.clone());
}
let Some(argchunk) = &fcall.arguments else {
continue;
};
tool_call_chunks.push(argchunk.clone());
}
}
if let Some(content) = &resp.content {
tx.send(DaveResponse::Token(content.to_owned())).unwrap();
ctx.request_repaint();
}
}
}
if let Some(tool_name) = tool_call_name {
if !tool_call_chunks.is_empty() {
let args = tool_call_chunks.join("");
match parse_tool_call(&tools, &tool_name, &args) {
Ok(tool_call) => {
tx.send(DaveResponse::ToolCall(tool_call)).unwrap();
ctx.request_repaint();
}
Err(err) => {
tracing::error!(
"failed to parse tool call err({:?}): name({:?}) args({:?})",
err,
tool_name,
args,
);
// TODO: return error to user
}
};
} else {
// TODO: return error to user
tracing::error!("got tool call '{}' with no arguments?", tool_name);
}
}
tracing::debug!("stream closed");
});
}
}
@@ -228,3 +329,147 @@ impl notedeck::App for Dave {
self.render(ui);
}
}
#[derive(Debug, Clone)]
enum ArgType {
String,
Number,
Enum(Vec<&'static str>),
}
impl ArgType {
pub fn type_string(&self) -> &'static str {
match self {
Self::String => "string",
Self::Number => "number",
Self::Enum(_) => "string",
}
}
}
#[derive(Debug, Clone)]
struct ToolArg {
typ: ArgType,
name: &'static str,
required: bool,
description: &'static str,
}
#[derive(Debug, Clone)]
pub struct Tool {
parse_call: fn(&str) -> Result<ToolCall, ToolCallError>,
name: &'static str,
description: &'static str,
arguments: Vec<ToolArg>,
}
impl Tool {
pub fn to_function_object(&self) -> FunctionObject {
let required_args = self
.arguments
.iter()
.filter_map(|arg| {
if arg.required {
Some(Value::String(arg.name.to_owned()))
} else {
None
}
})
.collect();
let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new();
parameters.insert("type".to_string(), Value::String("object".to_string()));
parameters.insert("required".to_string(), Value::Array(required_args));
parameters.insert("additionalProperties".to_string(), Value::Bool(false));
let mut properties: serde_json::Map<String, Value> = serde_json::Map::new();
for arg in &self.arguments {
let mut props: serde_json::Map<String, Value> = serde_json::Map::new();
props.insert(
"type".to_string(),
Value::String(arg.typ.type_string().to_string()),
);
props.insert(
"description".to_string(),
Value::String(arg.description.to_owned()),
);
if let ArgType::Enum(enums) = &arg.typ {
props.insert(
"enum".to_string(),
Value::Array(
enums
.into_iter()
.map(|s| Value::String((*s).to_owned()))
.collect(),
),
);
}
properties.insert(arg.name.to_owned(), Value::Object(props));
}
parameters.insert("properties".to_string(), Value::Object(properties));
FunctionObject {
name: self.name.to_owned(),
description: Some(self.description.to_owned()),
strict: Some(true),
parameters: Some(Value::Object(parameters)),
}
}
pub fn to_api(&self) -> ChatCompletionTool {
ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: self.to_function_object(),
}
}
}
fn search_tool() -> Tool {
Tool {
name: "search",
parse_call: SearchCall::parse,
description: "Full-text search functionality. Used for finding individual notes with specific terms. Queries with multiple words will only return results with notes that have all of those words.",
arguments: vec![
ToolArg {
name: "query",
typ: ArgType::String,
required: true,
description: "The search query",
},
ToolArg {
name: "context",
typ: ArgType::Enum(vec!["home", "profile", "any"]),
required: true,
description: "The context in which the search is occuring. valid options are 'home', 'profile', 'any'",
}
]
}
}
#[derive(Debug)]
pub enum ToolCallError {
EmptyName,
EmptyArgs,
NotFound(String),
ArgParseFailure(String),
}
fn parse_tool_call(
tools: &HashMap<String, Tool>,
name: &str,
args: &str,
) -> Result<ToolCall, ToolCallError> {
let Some(tool) = tools.get(name) else {
return Err(ToolCallError::NotFound(name.to_owned()));
};
(tool.parse_call)(&args)
}
fn dave_tools() -> Vec<Tool> {
vec![search_tool()]
}