@@ -11,6 +11,8 @@ eframe = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
egui-wgpu = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
bytemuck = "1.22.0"
|
||||
futures = "0.3.31"
|
||||
reqwest = "0.12.15"
|
||||
|
||||
@@ -4,13 +4,18 @@ use async_openai::{
|
||||
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent,
|
||||
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
|
||||
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
|
||||
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
|
||||
ChatCompletionRequestUserMessageContent, ChatCompletionTool, ChatCompletionToolType,
|
||||
CreateChatCompletionRequest, FunctionObject,
|
||||
},
|
||||
Client,
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use notedeck::AppContext;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::mpsc::{self, Receiver};
|
||||
use std::sync::Arc;
|
||||
|
||||
use avatar::DaveAvatar;
|
||||
use egui::{Rect, Vec2};
|
||||
@@ -54,19 +59,56 @@ impl Message {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum SearchContext {
|
||||
Home,
|
||||
Profile,
|
||||
Any,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SearchCall {
|
||||
context: SearchContext,
|
||||
query: String,
|
||||
}
|
||||
|
||||
impl SearchCall {
|
||||
pub fn parse(args: &str) -> Result<ToolCall, ToolCallError> {
|
||||
match serde_json::from_str::<SearchCall>(args) {
|
||||
Ok(call) => Ok(ToolCall::Search(call)),
|
||||
Err(e) => Err(ToolCallError::ArgParseFailure(format!(
|
||||
"Failed to parse args: '{}', error: {}",
|
||||
args, e
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ToolCall {
|
||||
Search(SearchCall),
|
||||
}
|
||||
|
||||
pub enum DaveResponse {
|
||||
ToolCall(ToolCall),
|
||||
Token(String),
|
||||
}
|
||||
|
||||
pub struct Dave {
|
||||
chat: Vec<Message>,
|
||||
/// A 3d representation of dave.
|
||||
avatar: Option<DaveAvatar>,
|
||||
input: String,
|
||||
pubkey: String,
|
||||
tools: Arc<HashMap<String, Tool>>,
|
||||
client: async_openai::Client<OpenAIConfig>,
|
||||
incoming_tokens: Option<Receiver<String>>,
|
||||
incoming_tokens: Option<Receiver<DaveResponse>>,
|
||||
}
|
||||
|
||||
impl Dave {
|
||||
pub fn new(render_state: Option<&RenderState>) -> Self {
|
||||
let mut config = OpenAIConfig::new().with_api_base("http://ollama.jb55.com/v1");
|
||||
let mut config = OpenAIConfig::new(); //.with_api_base("http://ollama.jb55.com/v1");
|
||||
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
@@ -76,12 +118,17 @@ impl Dave {
|
||||
let input = "".to_string();
|
||||
let pubkey = "test_pubkey".to_string();
|
||||
let avatar = render_state.map(DaveAvatar::new);
|
||||
let mut tools: HashMap<String, Tool> = HashMap::new();
|
||||
for tool in dave_tools() {
|
||||
tools.insert(tool.name.to_string(), tool);
|
||||
}
|
||||
|
||||
Dave {
|
||||
client,
|
||||
pubkey,
|
||||
avatar,
|
||||
incoming_tokens: None,
|
||||
tools: Arc::new(tools),
|
||||
input,
|
||||
chat: vec![
|
||||
Message::System("You are an ai agent for the nostr protocol. You have access to tools that can query the network, so you can help find content for users (TODO: actually implement this)".to_string()),
|
||||
@@ -91,11 +138,17 @@ impl Dave {
|
||||
|
||||
fn render(&mut self, ui: &mut egui::Ui) {
|
||||
if let Some(recvr) = &self.incoming_tokens {
|
||||
if let Ok(token) = recvr.try_recv() {
|
||||
match self.chat.last_mut() {
|
||||
Some(Message::Assistant(msg)) => *msg = msg.clone() + &token,
|
||||
Some(_) => self.chat.push(Message::Assistant(token)),
|
||||
None => {}
|
||||
while let Ok(res) = recvr.try_recv() {
|
||||
match res {
|
||||
DaveResponse::Token(token) => match self.chat.last_mut() {
|
||||
Some(Message::Assistant(msg)) => *msg = msg.clone() + &token,
|
||||
Some(_) => self.chat.push(Message::Assistant(token)),
|
||||
None => {}
|
||||
},
|
||||
|
||||
DaveResponse::ToolCall(tool) => {
|
||||
tracing::info!("got tool call: {:?}", tool);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -171,14 +224,16 @@ impl Dave {
|
||||
self.incoming_tokens = Some(rx);
|
||||
let ctx = ctx.clone();
|
||||
let client = self.client.clone();
|
||||
let tools = self.tools.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut token_stream = match client
|
||||
.chat()
|
||||
.create_stream(CreateChatCompletionRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
//model: "gpt-4o".to_string(),
|
||||
model: "llama3.1:latest".to_string(),
|
||||
stream: Some(true),
|
||||
messages,
|
||||
tools: Some(dave_tools().iter().map(|t| t.to_api()).collect()),
|
||||
user: Some(pubkey),
|
||||
..Default::default()
|
||||
})
|
||||
@@ -192,7 +247,8 @@ impl Dave {
|
||||
Ok(stream) => stream,
|
||||
};
|
||||
|
||||
tracing::info!("got stream!");
|
||||
let mut tool_call_name: Option<String> = None;
|
||||
let mut tool_call_chunks: Vec<String> = vec![];
|
||||
|
||||
while let Some(token) = token_stream.next().await {
|
||||
let token = match token {
|
||||
@@ -202,16 +258,61 @@ impl Dave {
|
||||
return;
|
||||
}
|
||||
};
|
||||
let Some(choice) = token.choices.first() else {
|
||||
return;
|
||||
};
|
||||
let Some(content) = &choice.delta.content else {
|
||||
return;
|
||||
};
|
||||
|
||||
tx.send(content.to_owned()).unwrap();
|
||||
ctx.request_repaint();
|
||||
for choice in &token.choices {
|
||||
let resp = &choice.delta;
|
||||
|
||||
// if we have tool call arg chunks, collect them here
|
||||
if let Some(tool_calls) = &resp.tool_calls {
|
||||
for tool in tool_calls {
|
||||
let Some(fcall) = &tool.function else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Some(name) = &fcall.name {
|
||||
tool_call_name = Some(name.clone());
|
||||
}
|
||||
|
||||
let Some(argchunk) = &fcall.arguments else {
|
||||
continue;
|
||||
};
|
||||
|
||||
tool_call_chunks.push(argchunk.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(content) = &resp.content {
|
||||
tx.send(DaveResponse::Token(content.to_owned())).unwrap();
|
||||
ctx.request_repaint();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tool_name) = tool_call_name {
|
||||
if !tool_call_chunks.is_empty() {
|
||||
let args = tool_call_chunks.join("");
|
||||
match parse_tool_call(&tools, &tool_name, &args) {
|
||||
Ok(tool_call) => {
|
||||
tx.send(DaveResponse::ToolCall(tool_call)).unwrap();
|
||||
ctx.request_repaint();
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!(
|
||||
"failed to parse tool call err({:?}): name({:?}) args({:?})",
|
||||
err,
|
||||
tool_name,
|
||||
args,
|
||||
);
|
||||
// TODO: return error to user
|
||||
}
|
||||
};
|
||||
} else {
|
||||
// TODO: return error to user
|
||||
tracing::error!("got tool call '{}' with no arguments?", tool_name);
|
||||
}
|
||||
}
|
||||
|
||||
tracing::debug!("stream closed");
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -228,3 +329,147 @@ impl notedeck::App for Dave {
|
||||
self.render(ui);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum ArgType {
|
||||
String,
|
||||
Number,
|
||||
Enum(Vec<&'static str>),
|
||||
}
|
||||
|
||||
impl ArgType {
|
||||
pub fn type_string(&self) -> &'static str {
|
||||
match self {
|
||||
Self::String => "string",
|
||||
Self::Number => "number",
|
||||
Self::Enum(_) => "string",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ToolArg {
|
||||
typ: ArgType,
|
||||
name: &'static str,
|
||||
required: bool,
|
||||
description: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Tool {
|
||||
parse_call: fn(&str) -> Result<ToolCall, ToolCallError>,
|
||||
name: &'static str,
|
||||
description: &'static str,
|
||||
arguments: Vec<ToolArg>,
|
||||
}
|
||||
|
||||
impl Tool {
|
||||
pub fn to_function_object(&self) -> FunctionObject {
|
||||
let required_args = self
|
||||
.arguments
|
||||
.iter()
|
||||
.filter_map(|arg| {
|
||||
if arg.required {
|
||||
Some(Value::String(arg.name.to_owned()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new();
|
||||
parameters.insert("type".to_string(), Value::String("object".to_string()));
|
||||
parameters.insert("required".to_string(), Value::Array(required_args));
|
||||
parameters.insert("additionalProperties".to_string(), Value::Bool(false));
|
||||
|
||||
let mut properties: serde_json::Map<String, Value> = serde_json::Map::new();
|
||||
|
||||
for arg in &self.arguments {
|
||||
let mut props: serde_json::Map<String, Value> = serde_json::Map::new();
|
||||
props.insert(
|
||||
"type".to_string(),
|
||||
Value::String(arg.typ.type_string().to_string()),
|
||||
);
|
||||
props.insert(
|
||||
"description".to_string(),
|
||||
Value::String(arg.description.to_owned()),
|
||||
);
|
||||
if let ArgType::Enum(enums) = &arg.typ {
|
||||
props.insert(
|
||||
"enum".to_string(),
|
||||
Value::Array(
|
||||
enums
|
||||
.into_iter()
|
||||
.map(|s| Value::String((*s).to_owned()))
|
||||
.collect(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
properties.insert(arg.name.to_owned(), Value::Object(props));
|
||||
}
|
||||
|
||||
parameters.insert("properties".to_string(), Value::Object(properties));
|
||||
|
||||
FunctionObject {
|
||||
name: self.name.to_owned(),
|
||||
description: Some(self.description.to_owned()),
|
||||
strict: Some(true),
|
||||
parameters: Some(Value::Object(parameters)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_api(&self) -> ChatCompletionTool {
|
||||
ChatCompletionTool {
|
||||
r#type: ChatCompletionToolType::Function,
|
||||
function: self.to_function_object(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn search_tool() -> Tool {
|
||||
Tool {
|
||||
name: "search",
|
||||
parse_call: SearchCall::parse,
|
||||
description: "Full-text search functionality. Used for finding individual notes with specific terms. Queries with multiple words will only return results with notes that have all of those words.",
|
||||
arguments: vec![
|
||||
ToolArg {
|
||||
name: "query",
|
||||
typ: ArgType::String,
|
||||
required: true,
|
||||
description: "The search query",
|
||||
},
|
||||
|
||||
ToolArg {
|
||||
name: "context",
|
||||
typ: ArgType::Enum(vec!["home", "profile", "any"]),
|
||||
required: true,
|
||||
description: "The context in which the search is occuring. valid options are 'home', 'profile', 'any'",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ToolCallError {
|
||||
EmptyName,
|
||||
EmptyArgs,
|
||||
NotFound(String),
|
||||
ArgParseFailure(String),
|
||||
}
|
||||
|
||||
fn parse_tool_call(
|
||||
tools: &HashMap<String, Tool>,
|
||||
name: &str,
|
||||
args: &str,
|
||||
) -> Result<ToolCall, ToolCallError> {
|
||||
let Some(tool) = tools.get(name) else {
|
||||
return Err(ToolCallError::NotFound(name.to_owned()));
|
||||
};
|
||||
|
||||
(tool.parse_call)(&args)
|
||||
}
|
||||
|
||||
fn dave_tools() -> Vec<Tool> {
|
||||
vec![search_tool()]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user