2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -3304,6 +3304,8 @@ dependencies = [
|
|||||||
"futures",
|
"futures",
|
||||||
"notedeck",
|
"notedeck",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ eframe = { workspace = true }
|
|||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
egui-wgpu = { workspace = true }
|
egui-wgpu = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
bytemuck = "1.22.0"
|
bytemuck = "1.22.0"
|
||||||
futures = "0.3.31"
|
futures = "0.3.31"
|
||||||
reqwest = "0.12.15"
|
reqwest = "0.12.15"
|
||||||
|
|||||||
@@ -4,13 +4,18 @@ use async_openai::{
|
|||||||
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent,
|
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent,
|
||||||
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
|
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
|
||||||
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
|
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
|
||||||
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
|
ChatCompletionRequestUserMessageContent, ChatCompletionTool, ChatCompletionToolType,
|
||||||
|
CreateChatCompletionRequest, FunctionObject,
|
||||||
},
|
},
|
||||||
Client,
|
Client,
|
||||||
};
|
};
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use notedeck::AppContext;
|
use notedeck::AppContext;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::mpsc::{self, Receiver};
|
use std::sync::mpsc::{self, Receiver};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use avatar::DaveAvatar;
|
use avatar::DaveAvatar;
|
||||||
use egui::{Rect, Vec2};
|
use egui::{Rect, Vec2};
|
||||||
@@ -54,19 +59,56 @@ impl Message {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum SearchContext {
|
||||||
|
Home,
|
||||||
|
Profile,
|
||||||
|
Any,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct SearchCall {
|
||||||
|
context: SearchContext,
|
||||||
|
query: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SearchCall {
|
||||||
|
pub fn parse(args: &str) -> Result<ToolCall, ToolCallError> {
|
||||||
|
match serde_json::from_str::<SearchCall>(args) {
|
||||||
|
Ok(call) => Ok(ToolCall::Search(call)),
|
||||||
|
Err(e) => Err(ToolCallError::ArgParseFailure(format!(
|
||||||
|
"Failed to parse args: '{}', error: {}",
|
||||||
|
args, e
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ToolCall {
|
||||||
|
Search(SearchCall),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum DaveResponse {
|
||||||
|
ToolCall(ToolCall),
|
||||||
|
Token(String),
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Dave {
|
pub struct Dave {
|
||||||
chat: Vec<Message>,
|
chat: Vec<Message>,
|
||||||
/// A 3d representation of dave.
|
/// A 3d representation of dave.
|
||||||
avatar: Option<DaveAvatar>,
|
avatar: Option<DaveAvatar>,
|
||||||
input: String,
|
input: String,
|
||||||
pubkey: String,
|
pubkey: String,
|
||||||
|
tools: Arc<HashMap<String, Tool>>,
|
||||||
client: async_openai::Client<OpenAIConfig>,
|
client: async_openai::Client<OpenAIConfig>,
|
||||||
incoming_tokens: Option<Receiver<String>>,
|
incoming_tokens: Option<Receiver<DaveResponse>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Dave {
|
impl Dave {
|
||||||
pub fn new(render_state: Option<&RenderState>) -> Self {
|
pub fn new(render_state: Option<&RenderState>) -> Self {
|
||||||
let mut config = OpenAIConfig::new().with_api_base("http://ollama.jb55.com/v1");
|
let mut config = OpenAIConfig::new(); //.with_api_base("http://ollama.jb55.com/v1");
|
||||||
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
|
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
|
||||||
config = config.with_api_key(api_key);
|
config = config.with_api_key(api_key);
|
||||||
}
|
}
|
||||||
@@ -76,12 +118,17 @@ impl Dave {
|
|||||||
let input = "".to_string();
|
let input = "".to_string();
|
||||||
let pubkey = "test_pubkey".to_string();
|
let pubkey = "test_pubkey".to_string();
|
||||||
let avatar = render_state.map(DaveAvatar::new);
|
let avatar = render_state.map(DaveAvatar::new);
|
||||||
|
let mut tools: HashMap<String, Tool> = HashMap::new();
|
||||||
|
for tool in dave_tools() {
|
||||||
|
tools.insert(tool.name.to_string(), tool);
|
||||||
|
}
|
||||||
|
|
||||||
Dave {
|
Dave {
|
||||||
client,
|
client,
|
||||||
pubkey,
|
pubkey,
|
||||||
avatar,
|
avatar,
|
||||||
incoming_tokens: None,
|
incoming_tokens: None,
|
||||||
|
tools: Arc::new(tools),
|
||||||
input,
|
input,
|
||||||
chat: vec![
|
chat: vec![
|
||||||
Message::System("You are an ai agent for the nostr protocol. You have access to tools that can query the network, so you can help find content for users (TODO: actually implement this)".to_string()),
|
Message::System("You are an ai agent for the nostr protocol. You have access to tools that can query the network, so you can help find content for users (TODO: actually implement this)".to_string()),
|
||||||
@@ -91,11 +138,17 @@ impl Dave {
|
|||||||
|
|
||||||
fn render(&mut self, ui: &mut egui::Ui) {
|
fn render(&mut self, ui: &mut egui::Ui) {
|
||||||
if let Some(recvr) = &self.incoming_tokens {
|
if let Some(recvr) = &self.incoming_tokens {
|
||||||
if let Ok(token) = recvr.try_recv() {
|
while let Ok(res) = recvr.try_recv() {
|
||||||
match self.chat.last_mut() {
|
match res {
|
||||||
Some(Message::Assistant(msg)) => *msg = msg.clone() + &token,
|
DaveResponse::Token(token) => match self.chat.last_mut() {
|
||||||
Some(_) => self.chat.push(Message::Assistant(token)),
|
Some(Message::Assistant(msg)) => *msg = msg.clone() + &token,
|
||||||
None => {}
|
Some(_) => self.chat.push(Message::Assistant(token)),
|
||||||
|
None => {}
|
||||||
|
},
|
||||||
|
|
||||||
|
DaveResponse::ToolCall(tool) => {
|
||||||
|
tracing::info!("got tool call: {:?}", tool);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -171,14 +224,16 @@ impl Dave {
|
|||||||
self.incoming_tokens = Some(rx);
|
self.incoming_tokens = Some(rx);
|
||||||
let ctx = ctx.clone();
|
let ctx = ctx.clone();
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
|
let tools = self.tools.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut token_stream = match client
|
let mut token_stream = match client
|
||||||
.chat()
|
.chat()
|
||||||
.create_stream(CreateChatCompletionRequest {
|
.create_stream(CreateChatCompletionRequest {
|
||||||
|
model: "gpt-4o".to_string(),
|
||||||
//model: "gpt-4o".to_string(),
|
//model: "gpt-4o".to_string(),
|
||||||
model: "llama3.1:latest".to_string(),
|
|
||||||
stream: Some(true),
|
stream: Some(true),
|
||||||
messages,
|
messages,
|
||||||
|
tools: Some(dave_tools().iter().map(|t| t.to_api()).collect()),
|
||||||
user: Some(pubkey),
|
user: Some(pubkey),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
})
|
})
|
||||||
@@ -192,7 +247,8 @@ impl Dave {
|
|||||||
Ok(stream) => stream,
|
Ok(stream) => stream,
|
||||||
};
|
};
|
||||||
|
|
||||||
tracing::info!("got stream!");
|
let mut tool_call_name: Option<String> = None;
|
||||||
|
let mut tool_call_chunks: Vec<String> = vec![];
|
||||||
|
|
||||||
while let Some(token) = token_stream.next().await {
|
while let Some(token) = token_stream.next().await {
|
||||||
let token = match token {
|
let token = match token {
|
||||||
@@ -202,16 +258,61 @@ impl Dave {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let Some(choice) = token.choices.first() else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
let Some(content) = &choice.delta.content else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
tx.send(content.to_owned()).unwrap();
|
for choice in &token.choices {
|
||||||
ctx.request_repaint();
|
let resp = &choice.delta;
|
||||||
|
|
||||||
|
// if we have tool call arg chunks, collect them here
|
||||||
|
if let Some(tool_calls) = &resp.tool_calls {
|
||||||
|
for tool in tool_calls {
|
||||||
|
let Some(fcall) = &tool.function else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(name) = &fcall.name {
|
||||||
|
tool_call_name = Some(name.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(argchunk) = &fcall.arguments else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
tool_call_chunks.push(argchunk.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(content) = &resp.content {
|
||||||
|
tx.send(DaveResponse::Token(content.to_owned())).unwrap();
|
||||||
|
ctx.request_repaint();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(tool_name) = tool_call_name {
|
||||||
|
if !tool_call_chunks.is_empty() {
|
||||||
|
let args = tool_call_chunks.join("");
|
||||||
|
match parse_tool_call(&tools, &tool_name, &args) {
|
||||||
|
Ok(tool_call) => {
|
||||||
|
tx.send(DaveResponse::ToolCall(tool_call)).unwrap();
|
||||||
|
ctx.request_repaint();
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
tracing::error!(
|
||||||
|
"failed to parse tool call err({:?}): name({:?}) args({:?})",
|
||||||
|
err,
|
||||||
|
tool_name,
|
||||||
|
args,
|
||||||
|
);
|
||||||
|
// TODO: return error to user
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
// TODO: return error to user
|
||||||
|
tracing::error!("got tool call '{}' with no arguments?", tool_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::debug!("stream closed");
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -228,3 +329,147 @@ impl notedeck::App for Dave {
|
|||||||
self.render(ui);
|
self.render(ui);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum ArgType {
|
||||||
|
String,
|
||||||
|
Number,
|
||||||
|
Enum(Vec<&'static str>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ArgType {
|
||||||
|
pub fn type_string(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::String => "string",
|
||||||
|
Self::Number => "number",
|
||||||
|
Self::Enum(_) => "string",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct ToolArg {
|
||||||
|
typ: ArgType,
|
||||||
|
name: &'static str,
|
||||||
|
required: bool,
|
||||||
|
description: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Tool {
|
||||||
|
parse_call: fn(&str) -> Result<ToolCall, ToolCallError>,
|
||||||
|
name: &'static str,
|
||||||
|
description: &'static str,
|
||||||
|
arguments: Vec<ToolArg>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tool {
|
||||||
|
pub fn to_function_object(&self) -> FunctionObject {
|
||||||
|
let required_args = self
|
||||||
|
.arguments
|
||||||
|
.iter()
|
||||||
|
.filter_map(|arg| {
|
||||||
|
if arg.required {
|
||||||
|
Some(Value::String(arg.name.to_owned()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new();
|
||||||
|
parameters.insert("type".to_string(), Value::String("object".to_string()));
|
||||||
|
parameters.insert("required".to_string(), Value::Array(required_args));
|
||||||
|
parameters.insert("additionalProperties".to_string(), Value::Bool(false));
|
||||||
|
|
||||||
|
let mut properties: serde_json::Map<String, Value> = serde_json::Map::new();
|
||||||
|
|
||||||
|
for arg in &self.arguments {
|
||||||
|
let mut props: serde_json::Map<String, Value> = serde_json::Map::new();
|
||||||
|
props.insert(
|
||||||
|
"type".to_string(),
|
||||||
|
Value::String(arg.typ.type_string().to_string()),
|
||||||
|
);
|
||||||
|
props.insert(
|
||||||
|
"description".to_string(),
|
||||||
|
Value::String(arg.description.to_owned()),
|
||||||
|
);
|
||||||
|
if let ArgType::Enum(enums) = &arg.typ {
|
||||||
|
props.insert(
|
||||||
|
"enum".to_string(),
|
||||||
|
Value::Array(
|
||||||
|
enums
|
||||||
|
.into_iter()
|
||||||
|
.map(|s| Value::String((*s).to_owned()))
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
properties.insert(arg.name.to_owned(), Value::Object(props));
|
||||||
|
}
|
||||||
|
|
||||||
|
parameters.insert("properties".to_string(), Value::Object(properties));
|
||||||
|
|
||||||
|
FunctionObject {
|
||||||
|
name: self.name.to_owned(),
|
||||||
|
description: Some(self.description.to_owned()),
|
||||||
|
strict: Some(true),
|
||||||
|
parameters: Some(Value::Object(parameters)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_api(&self) -> ChatCompletionTool {
|
||||||
|
ChatCompletionTool {
|
||||||
|
r#type: ChatCompletionToolType::Function,
|
||||||
|
function: self.to_function_object(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn search_tool() -> Tool {
|
||||||
|
Tool {
|
||||||
|
name: "search",
|
||||||
|
parse_call: SearchCall::parse,
|
||||||
|
description: "Full-text search functionality. Used for finding individual notes with specific terms. Queries with multiple words will only return results with notes that have all of those words.",
|
||||||
|
arguments: vec![
|
||||||
|
ToolArg {
|
||||||
|
name: "query",
|
||||||
|
typ: ArgType::String,
|
||||||
|
required: true,
|
||||||
|
description: "The search query",
|
||||||
|
},
|
||||||
|
|
||||||
|
ToolArg {
|
||||||
|
name: "context",
|
||||||
|
typ: ArgType::Enum(vec!["home", "profile", "any"]),
|
||||||
|
required: true,
|
||||||
|
description: "The context in which the search is occuring. valid options are 'home', 'profile', 'any'",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ToolCallError {
|
||||||
|
EmptyName,
|
||||||
|
EmptyArgs,
|
||||||
|
NotFound(String),
|
||||||
|
ArgParseFailure(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_tool_call(
|
||||||
|
tools: &HashMap<String, Tool>,
|
||||||
|
name: &str,
|
||||||
|
args: &str,
|
||||||
|
) -> Result<ToolCall, ToolCallError> {
|
||||||
|
let Some(tool) = tools.get(name) else {
|
||||||
|
return Err(ToolCallError::NotFound(name.to_owned()));
|
||||||
|
};
|
||||||
|
|
||||||
|
(tool.parse_call)(&args)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dave_tools() -> Vec<Tool> {
|
||||||
|
vec![search_tool()]
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user