feat: add multi-provider image generation
This commit is contained in:
@@ -26,6 +26,9 @@ reqwest = { version = "0.12", default-features = false, features = ["json", "mul
|
||||
thiserror = "2"
|
||||
async-trait = "0.1"
|
||||
base64 = "0.22"
|
||||
hex = "0.4"
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
|
||||
[features]
|
||||
default = ["custom-protocol"]
|
||||
|
||||
354
src-tauri/src/ai/dashscope.rs
Normal file
354
src-tauri/src/ai/dashscope.rs
Normal file
@@ -0,0 +1,354 @@
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use reqwest::Client;
|
||||
use serde_json::{json, Value};
|
||||
use std::{fs, path::Path, time::Duration};
|
||||
use tokio::time::sleep;
|
||||
|
||||
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
|
||||
use crate::AppError;
|
||||
|
||||
pub struct DashScopeProvider {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl DashScopeProvider {
|
||||
pub fn new(base_url: String, api_key: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_qwen_image_model(model: &str) -> bool {
|
||||
model.starts_with("qwen-image")
|
||||
}
|
||||
|
||||
fn is_sync_multimodal_model(model: &str) -> bool {
|
||||
is_qwen_image_model(model) || model.starts_with("z-image")
|
||||
}
|
||||
|
||||
fn dashscope_size(size: Option<&str>) -> Option<String> {
|
||||
size.map(|value| value.replace('x', "*"))
|
||||
}
|
||||
|
||||
fn mime_for_path(path: &str) -> &'static str {
|
||||
match Path::new(path)
|
||||
.extension()
|
||||
.and_then(|extension| extension.to_str())
|
||||
.map(|extension| extension.to_ascii_lowercase())
|
||||
.as_deref()
|
||||
{
|
||||
Some("jpg" | "jpeg") => "image/jpeg",
|
||||
Some("webp") => "image/webp",
|
||||
_ => "image/png",
|
||||
}
|
||||
}
|
||||
|
||||
fn path_to_data_url(path: &str) -> Result<String, AppError> {
|
||||
let bytes = fs::read(path)?;
|
||||
Ok(format!(
|
||||
"data:{};base64,{}",
|
||||
mime_for_path(path),
|
||||
general_purpose::STANDARD.encode(bytes)
|
||||
))
|
||||
}
|
||||
|
||||
fn parse_dashscope_error(response_text: &str, fallback: &str) -> String {
|
||||
serde_json::from_str::<Value>(response_text)
|
||||
.ok()
|
||||
.and_then(|value| {
|
||||
let code = value
|
||||
.get("code")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
let message = value
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
if code.is_empty() && message.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(format!("{code}: {message}"))
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| fallback.to_string())
|
||||
}
|
||||
|
||||
fn image_url_from_choices(value: &Value) -> Option<String> {
|
||||
value
|
||||
.get("output")?
|
||||
.get("choices")?
|
||||
.as_array()?
|
||||
.iter()
|
||||
.flat_map(|choice| {
|
||||
choice
|
||||
.get("message")
|
||||
.and_then(|message| message.get("content"))
|
||||
.and_then(Value::as_array)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
})
|
||||
.find_map(|content| {
|
||||
content
|
||||
.get("image")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
}
|
||||
|
||||
fn image_url_from_results(value: &Value) -> Option<String> {
|
||||
value
|
||||
.get("output")?
|
||||
.get("results")?
|
||||
.as_array()?
|
||||
.iter()
|
||||
.find_map(|item| item.get("url").and_then(Value::as_str).map(str::to_string))
|
||||
}
|
||||
|
||||
fn parse_image_url(value: &Value) -> Option<String> {
|
||||
image_url_from_choices(value).or_else(|| image_url_from_results(value))
|
||||
}
|
||||
|
||||
fn build_messages(prompt: &str, image_paths: &[String]) -> Result<Vec<Value>, AppError> {
|
||||
let mut content = image_paths
|
||||
.iter()
|
||||
.map(|path| path_to_data_url(path).map(|image| json!({ "image": image })))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
content.push(json!({ "text": prompt }));
|
||||
|
||||
Ok(vec![json!({
|
||||
"role": "user",
|
||||
"content": content,
|
||||
})])
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AiProvider for DashScopeProvider {
|
||||
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
|
||||
if is_sync_multimodal_model(&request.model) {
|
||||
return self
|
||||
.run_qwen_image(
|
||||
&request.prompt,
|
||||
&request.model,
|
||||
request.size.as_deref(),
|
||||
&[],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
self.run_async_image_generation(
|
||||
&request.prompt,
|
||||
&request.model,
|
||||
request.size.as_deref(),
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
|
||||
if is_sync_multimodal_model(&request.model) {
|
||||
return self
|
||||
.run_qwen_image(
|
||||
&request.prompt,
|
||||
&request.model,
|
||||
request.size.as_deref(),
|
||||
&request.image_paths,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
self.run_async_image_generation(
|
||||
&request.prompt,
|
||||
&request.model,
|
||||
request.size.as_deref(),
|
||||
&request.image_paths,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl DashScopeProvider {
|
||||
async fn run_qwen_image(
|
||||
&self,
|
||||
prompt: &str,
|
||||
model: &str,
|
||||
size: Option<&str>,
|
||||
image_paths: &[String],
|
||||
) -> Result<ImageResult, AppError> {
|
||||
let mut parameters = json!({
|
||||
"n": 1,
|
||||
"watermark": false,
|
||||
"prompt_extend": true,
|
||||
});
|
||||
if is_qwen_image_model(model) {
|
||||
parameters["negative_prompt"] = json!(" ");
|
||||
}
|
||||
if let Some(size) = dashscope_size(size) {
|
||||
parameters["size"] = json!(size);
|
||||
}
|
||||
|
||||
let body = json!({
|
||||
"model": model,
|
||||
"input": {
|
||||
"messages": build_messages(prompt, image_paths)?,
|
||||
},
|
||||
"parameters": parameters,
|
||||
});
|
||||
let response = self
|
||||
.client
|
||||
.post(format!(
|
||||
"{}/services/aigc/multimodal-generation/generation",
|
||||
self.base_url
|
||||
))
|
||||
.bearer_auth(&self.api_key)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let response_text = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(AppError::Provider(format!(
|
||||
"DashScope image generation failed ({status}): {}",
|
||||
parse_dashscope_error(&response_text, &response_text)
|
||||
)));
|
||||
}
|
||||
|
||||
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode DashScope response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let image_url = parse_image_url(&response_body).ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"DashScope response did not include image url: {response_text}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: ImageData::Url(image_url),
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_async_image_generation(
|
||||
&self,
|
||||
prompt: &str,
|
||||
model: &str,
|
||||
size: Option<&str>,
|
||||
image_paths: &[String],
|
||||
) -> Result<ImageResult, AppError> {
|
||||
let mut parameters = json!({
|
||||
"n": 1,
|
||||
"watermark": false,
|
||||
});
|
||||
if let Some(size) = dashscope_size(size) {
|
||||
parameters["size"] = json!(size);
|
||||
}
|
||||
if model.starts_with("wan2.7-image") {
|
||||
parameters["thinking_mode"] = json!(true);
|
||||
} else {
|
||||
parameters["prompt_extend"] = json!(true);
|
||||
}
|
||||
|
||||
let body = json!({
|
||||
"model": model,
|
||||
"input": {
|
||||
"messages": build_messages(prompt, image_paths)?,
|
||||
},
|
||||
"parameters": parameters,
|
||||
});
|
||||
let response = self
|
||||
.client
|
||||
.post(format!(
|
||||
"{}/services/aigc/image-generation/generation",
|
||||
self.base_url
|
||||
))
|
||||
.bearer_auth(&self.api_key)
|
||||
.header("X-DashScope-Async", "enable")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let response_text = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(AppError::Provider(format!(
|
||||
"DashScope async image task failed ({status}): {}",
|
||||
parse_dashscope_error(&response_text, &response_text)
|
||||
)));
|
||||
}
|
||||
|
||||
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode DashScope task response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let task_id = response_body
|
||||
.get("output")
|
||||
.and_then(|output| output.get("task_id"))
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"DashScope task response did not include task_id: {response_text}"
|
||||
))
|
||||
})?;
|
||||
|
||||
self.wait_async_task(task_id).await
|
||||
}
|
||||
|
||||
async fn wait_async_task(&self, task_id: &str) -> Result<ImageResult, AppError> {
|
||||
for _ in 0..40 {
|
||||
sleep(Duration::from_secs(3)).await;
|
||||
let response = self
|
||||
.client
|
||||
.get(format!("{}/tasks/{task_id}", self.base_url))
|
||||
.bearer_auth(&self.api_key)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let response_text = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(AppError::Provider(format!(
|
||||
"DashScope task polling failed ({status}): {}",
|
||||
parse_dashscope_error(&response_text, &response_text)
|
||||
)));
|
||||
}
|
||||
|
||||
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode DashScope task result: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let task_status = response_body
|
||||
.get("output")
|
||||
.and_then(|output| output.get("task_status"))
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
if task_status == "SUCCEEDED" {
|
||||
let image_url = parse_image_url(&response_body).ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"DashScope task succeeded but did not include image url: {response_text}"
|
||||
))
|
||||
})?;
|
||||
return Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: ImageData::Url(image_url),
|
||||
});
|
||||
}
|
||||
if matches!(task_status, "FAILED" | "CANCELED" | "UNKNOWN") {
|
||||
return Err(AppError::Provider(format!(
|
||||
"DashScope task ended with status {task_status}: {}",
|
||||
parse_dashscope_error(&response_text, &response_text)
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Err(AppError::Provider(
|
||||
"DashScope image task polling timed out".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
189
src-tauri/src/ai/google_gemini.rs
Normal file
189
src-tauri/src/ai/google_gemini.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use reqwest::Client;
|
||||
use serde_json::{json, Value};
|
||||
use std::{fs, path::Path};
|
||||
|
||||
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
|
||||
use crate::AppError;
|
||||
|
||||
pub struct GoogleGeminiProvider {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl GoogleGeminiProvider {
|
||||
pub fn new(base_url: String, api_key: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn mime_for_path(path: &str) -> &'static str {
|
||||
match Path::new(path)
|
||||
.extension()
|
||||
.and_then(|extension| extension.to_str())
|
||||
.map(|extension| extension.to_ascii_lowercase())
|
||||
.as_deref()
|
||||
{
|
||||
Some("jpg" | "jpeg") => "image/jpeg",
|
||||
Some("webp") => "image/webp",
|
||||
_ => "image/png",
|
||||
}
|
||||
}
|
||||
|
||||
fn size_to_aspect_ratio(size: Option<&str>) -> Option<&'static str> {
|
||||
let (width, height) = size?.split_once('x')?;
|
||||
let width = width.parse::<f32>().ok()?;
|
||||
let height = height.parse::<f32>().ok()?;
|
||||
if width <= 0.0 || height <= 0.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ratio = width / height;
|
||||
if (ratio - 1.0).abs() < 0.08 {
|
||||
Some("1:1")
|
||||
} else if (ratio - 16.0 / 9.0).abs() < 0.12 {
|
||||
Some("16:9")
|
||||
} else if (ratio - 9.0 / 16.0).abs() < 0.12 {
|
||||
Some("9:16")
|
||||
} else if (ratio - 4.0 / 3.0).abs() < 0.12 {
|
||||
Some("4:3")
|
||||
} else if (ratio - 3.0 / 4.0).abs() < 0.12 {
|
||||
Some("3:4")
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn image_part(path: &str) -> Result<Value, AppError> {
|
||||
let bytes = fs::read(path)?;
|
||||
Ok(json!({
|
||||
"inline_data": {
|
||||
"mime_type": mime_for_path(path),
|
||||
"data": general_purpose::STANDARD.encode(bytes),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn inline_data_from_part(part: &Value) -> Option<(&str, &str)> {
|
||||
let inline_data = part.get("inlineData").or_else(|| part.get("inline_data"))?;
|
||||
let data = inline_data.get("data").and_then(Value::as_str)?;
|
||||
let mime_type = inline_data
|
||||
.get("mimeType")
|
||||
.or_else(|| inline_data.get("mime_type"))
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("image/png");
|
||||
Some((mime_type, data))
|
||||
}
|
||||
|
||||
fn parse_gemini_response(response_text: &str) -> Result<ImageResult, AppError> {
|
||||
let response_body = serde_json::from_str::<Value>(response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode Gemini image response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let parts = response_body
|
||||
.get("candidates")
|
||||
.and_then(Value::as_array)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.flat_map(|candidate| {
|
||||
candidate
|
||||
.get("content")
|
||||
.and_then(|content| content.get("parts"))
|
||||
.and_then(Value::as_array)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
});
|
||||
|
||||
for part in parts {
|
||||
if let Some((mime_type, data)) = inline_data_from_part(part) {
|
||||
return Ok(ImageResult {
|
||||
mime_type: mime_type.to_string(),
|
||||
data: ImageData::Base64(data.to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Err(AppError::Provider(format!(
|
||||
"Gemini response did not include image data: {response_text}"
|
||||
)))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AiProvider for GoogleGeminiProvider {
|
||||
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
|
||||
self.generate_with_parts(
|
||||
&request.prompt,
|
||||
&request.model,
|
||||
request.size.as_deref(),
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
|
||||
self.generate_with_parts(
|
||||
&request.prompt,
|
||||
&request.model,
|
||||
request.size.as_deref(),
|
||||
&request.image_paths,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl GoogleGeminiProvider {
|
||||
async fn generate_with_parts(
|
||||
&self,
|
||||
prompt: &str,
|
||||
model: &str,
|
||||
size: Option<&str>,
|
||||
image_paths: &[String],
|
||||
) -> Result<ImageResult, AppError> {
|
||||
let mut parts = vec![json!({ "text": prompt })];
|
||||
parts.extend(
|
||||
image_paths
|
||||
.iter()
|
||||
.map(|path| image_part(path))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
);
|
||||
let mut generation_config = json!({
|
||||
"responseModalities": ["TEXT", "IMAGE"],
|
||||
});
|
||||
if let Some(aspect_ratio) = size_to_aspect_ratio(size) {
|
||||
generation_config["imageConfig"] = json!({
|
||||
"aspectRatio": aspect_ratio,
|
||||
});
|
||||
}
|
||||
|
||||
let body = json!({
|
||||
"contents": [{
|
||||
"role": "user",
|
||||
"parts": parts,
|
||||
}],
|
||||
"generationConfig": generation_config,
|
||||
});
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/models/{model}:generateContent", self.base_url))
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let response_text = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(AppError::Provider(format!(
|
||||
"Gemini image generation failed ({status}): {response_text}"
|
||||
)));
|
||||
}
|
||||
|
||||
parse_gemini_response(&response_text)
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,6 @@
|
||||
pub mod dashscope;
|
||||
pub mod google_gemini;
|
||||
pub mod openai_compatible;
|
||||
pub mod provider;
|
||||
pub mod seedream;
|
||||
pub mod tencent_hunyuan;
|
||||
|
||||
@@ -32,11 +32,12 @@ fn parse_image_response(response_text: &str) -> Result<ImageResult, AppError> {
|
||||
));
|
||||
}
|
||||
|
||||
let response_body: ImageResponseBody = serde_json::from_str(response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode image response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let response_body: ImageResponseBody =
|
||||
serde_json::from_str(response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode image response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let image_data = response_body
|
||||
.data
|
||||
.into_iter()
|
||||
@@ -45,7 +46,9 @@ fn parse_image_response(response_text: &str) -> Result<ImageResult, AppError> {
|
||||
.map(ImageData::Base64)
|
||||
.or_else(|| item.url.map(ImageData::Url))
|
||||
})
|
||||
.ok_or_else(|| AppError::Provider("image response did not include b64_json or url".to_string()))?;
|
||||
.ok_or_else(|| {
|
||||
AppError::Provider("image response did not include b64_json or url".to_string())
|
||||
})?;
|
||||
|
||||
Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
@@ -96,8 +99,13 @@ impl AiProvider for OpenAiCompatibleProvider {
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let message = response.text().await.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!("image generation failed ({status}): {message}")));
|
||||
let message = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!(
|
||||
"image generation failed ({status}): {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
@@ -134,7 +142,9 @@ impl AiProvider for OpenAiCompatibleProvider {
|
||||
Some("webp") => "image/webp",
|
||||
_ => "image/png",
|
||||
};
|
||||
let part = multipart::Part::bytes(bytes).file_name(file_name).mime_str(mime)?;
|
||||
let part = multipart::Part::bytes(bytes)
|
||||
.file_name(file_name)
|
||||
.mime_str(mime)?;
|
||||
form = form.part("image[]", part);
|
||||
}
|
||||
|
||||
@@ -148,8 +158,13 @@ impl AiProvider for OpenAiCompatibleProvider {
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let message = response.text().await.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!("image edit failed ({status}): {message}")));
|
||||
let message = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!(
|
||||
"image edit failed ({status}): {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
|
||||
180
src-tauri/src/ai/seedream.rs
Normal file
180
src-tauri/src/ai/seedream.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fs, path::Path};
|
||||
|
||||
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
|
||||
use crate::AppError;
|
||||
|
||||
pub struct SeedreamProvider {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl SeedreamProvider {
|
||||
pub fn new(base_url: String, api_key: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct SeedreamRequestBody<'a> {
|
||||
model: &'a str,
|
||||
prompt: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
size: Option<&'a str>,
|
||||
response_format: &'a str,
|
||||
stream: bool,
|
||||
watermark: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
image: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SeedreamResponseBody {
|
||||
data: Vec<SeedreamResponseItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SeedreamResponseItem {
|
||||
b64_json: Option<String>,
|
||||
url: Option<String>,
|
||||
}
|
||||
|
||||
fn mime_for_path(path: &str) -> &'static str {
|
||||
match Path::new(path)
|
||||
.extension()
|
||||
.and_then(|extension| extension.to_str())
|
||||
.map(|extension| extension.to_ascii_lowercase())
|
||||
.as_deref()
|
||||
{
|
||||
Some("jpg" | "jpeg") => "image/jpeg",
|
||||
Some("webp") => "image/webp",
|
||||
_ => "image/png",
|
||||
}
|
||||
}
|
||||
|
||||
fn path_to_data_url(path: &str) -> Result<String, AppError> {
|
||||
let bytes = fs::read(path)?;
|
||||
Ok(format!(
|
||||
"data:{};base64,{}",
|
||||
mime_for_path(path),
|
||||
general_purpose::STANDARD.encode(bytes)
|
||||
))
|
||||
}
|
||||
|
||||
fn parse_seedream_response(response_text: &str) -> Result<ImageResult, AppError> {
|
||||
if response_text.trim_start().starts_with("<!doctype html")
|
||||
|| response_text.trim_start().starts_with("<html")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"火山方舟返回了 HTML 页面,不是 API JSON 响应。请检查 Base URL 是否为 API 地址。"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let response_body: SeedreamResponseBody =
|
||||
serde_json::from_str(response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode Seedream response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let image_data = response_body
|
||||
.data
|
||||
.into_iter()
|
||||
.find_map(|item| {
|
||||
item.b64_json
|
||||
.map(ImageData::Base64)
|
||||
.or_else(|| item.url.map(ImageData::Url))
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
AppError::Provider("Seedream response did not include b64_json or url".to_string())
|
||||
})?;
|
||||
|
||||
Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: image_data,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AiProvider for SeedreamProvider {
|
||||
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
|
||||
let body = SeedreamRequestBody {
|
||||
model: &request.model,
|
||||
prompt: &request.prompt,
|
||||
size: request.size.as_deref(),
|
||||
response_format: "b64_json",
|
||||
stream: false,
|
||||
watermark: false,
|
||||
image: None,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/images/generations", self.base_url))
|
||||
.bearer_auth(&self.api_key)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let message = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!(
|
||||
"Seedream image generation failed ({status}): {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
parse_seedream_response(&response_text)
|
||||
}
|
||||
|
||||
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
|
||||
let images = request
|
||||
.image_paths
|
||||
.iter()
|
||||
.map(|path| path_to_data_url(path))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let body = SeedreamRequestBody {
|
||||
model: &request.model,
|
||||
prompt: &request.prompt,
|
||||
size: request.size.as_deref(),
|
||||
response_format: "b64_json",
|
||||
stream: false,
|
||||
watermark: false,
|
||||
image: Some(images),
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/images/generations", self.base_url))
|
||||
.bearer_auth(&self.api_key)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let message = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!(
|
||||
"Seedream image edit failed ({status}): {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
parse_seedream_response(&response_text)
|
||||
}
|
||||
}
|
||||
463
src-tauri/src/ai/tencent_hunyuan.rs
Normal file
463
src-tauri/src/ai/tencent_hunyuan.rs
Normal file
@@ -0,0 +1,463 @@
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use chrono::{TimeZone, Utc};
|
||||
use hmac::{Hmac, Mac};
|
||||
use reqwest::{Client, Url};
|
||||
use serde_json::{json, Value};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{fs, time::Duration};
|
||||
use tokio::time::sleep;
|
||||
|
||||
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
|
||||
use crate::AppError;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
pub struct TencentHunyuanProvider {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
secret_id: String,
|
||||
secret_key: String,
|
||||
}
|
||||
|
||||
struct TencentApiConfig {
|
||||
service: &'static str,
|
||||
version: &'static str,
|
||||
lite_action: &'static str,
|
||||
rapid_action: &'static str,
|
||||
submit_action: &'static str,
|
||||
query_action: &'static str,
|
||||
}
|
||||
|
||||
impl TencentHunyuanProvider {
|
||||
pub fn new(base_url: String, api_key: String) -> Result<Self, AppError> {
|
||||
let (secret_id, secret_key) = parse_secret_pair(&api_key)?;
|
||||
Ok(Self {
|
||||
client: Client::new(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
secret_id,
|
||||
secret_key,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_secret_pair(api_key: &str) -> Result<(String, String), AppError> {
|
||||
let (secret_id, secret_key) = api_key.split_once(':').ok_or_else(|| {
|
||||
AppError::Provider("腾讯云 API Key 需要填写为 SecretId:SecretKey".to_string())
|
||||
})?;
|
||||
let secret_id = secret_id.trim();
|
||||
let secret_key = secret_key.trim();
|
||||
if secret_id.is_empty() || secret_key.is_empty() {
|
||||
return Err(AppError::Provider(
|
||||
"腾讯云 API Key 需要填写为 SecretId:SecretKey".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok((secret_id.to_string(), secret_key.to_string()))
|
||||
}
|
||||
|
||||
fn hmac_sha256(key: &[u8], message: &str) -> Result<Vec<u8>, AppError> {
|
||||
let mut mac = HmacSha256::new_from_slice(key)
|
||||
.map_err(|error| AppError::Provider(format!("failed to create HMAC: {error}")))?;
|
||||
mac.update(message.as_bytes());
|
||||
Ok(mac.finalize().into_bytes().to_vec())
|
||||
}
|
||||
|
||||
fn sha256_hex(value: &str) -> String {
|
||||
hex::encode(Sha256::digest(value.as_bytes()))
|
||||
}
|
||||
|
||||
fn host_from_base_url(base_url: &str) -> Result<String, AppError> {
|
||||
let url = Url::parse(base_url)
|
||||
.map_err(|error| AppError::Provider(format!("腾讯云 Base URL 不是有效 URL: {error}")))?;
|
||||
url.host_str()
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| AppError::Provider("腾讯云 Base URL 缺少 host".to_string()))
|
||||
}
|
||||
|
||||
fn api_config(host: &str) -> TencentApiConfig {
|
||||
if host == "aiart.tencentcloudapi.com" {
|
||||
TencentApiConfig {
|
||||
service: "aiart",
|
||||
version: "2022-12-29",
|
||||
lite_action: "TextToImageLite",
|
||||
rapid_action: "TextToImageRapid",
|
||||
submit_action: "SubmitTextToImageJob",
|
||||
query_action: "QueryTextToImageJob",
|
||||
}
|
||||
} else {
|
||||
TencentApiConfig {
|
||||
service: "hunyuan",
|
||||
version: "2023-09-01",
|
||||
lite_action: "TextToImageLite",
|
||||
rapid_action: "TextToImageLite",
|
||||
submit_action: "SubmitHunyuanImageJob",
|
||||
query_action: "QueryHunyuanImageJob",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn hunyuan_resolution(size: Option<&str>, lite: bool, has_reference: bool) -> Option<String> {
|
||||
let size = size?;
|
||||
let (width, height) = size.split_once('x')?;
|
||||
let width = width.parse::<f32>().ok()?;
|
||||
let height = height.parse::<f32>().ok()?;
|
||||
if width <= 0.0 || height <= 0.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ratio = width / height;
|
||||
let value = if (ratio - 1.0).abs() < 0.12 {
|
||||
"1024:1024"
|
||||
} else if (ratio - 0.75).abs() < 0.12 {
|
||||
"768:1024"
|
||||
} else if (ratio - 1.333).abs() < 0.12 {
|
||||
"1024:768"
|
||||
} else if ratio < 0.65 {
|
||||
if lite && !has_reference {
|
||||
"1080:1920"
|
||||
} else {
|
||||
"720:1280"
|
||||
}
|
||||
} else if ratio > 1.55 {
|
||||
if lite && !has_reference {
|
||||
"1920:1080"
|
||||
} else {
|
||||
"1280:720"
|
||||
}
|
||||
} else if width < height {
|
||||
"768:1024"
|
||||
} else {
|
||||
"1024:768"
|
||||
};
|
||||
|
||||
if has_reference && !matches!(value, "1024:1024" | "768:1024" | "1024:768") {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(value.to_string())
|
||||
}
|
||||
|
||||
fn first_image_base64(image_paths: &[String]) -> Result<Option<String>, AppError> {
|
||||
image_paths
|
||||
.first()
|
||||
.map(|path| fs::read(path).map(|bytes| general_purpose::STANDARD.encode(bytes)))
|
||||
.transpose()
|
||||
.map_err(AppError::from)
|
||||
}
|
||||
|
||||
fn parse_tencent_error(response: &Value) -> Option<String> {
|
||||
let error = response.get("Error")?;
|
||||
let code = error
|
||||
.get("Code")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
let message = error
|
||||
.get("Message")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
Some(format!("{code}: {message}"))
|
||||
}
|
||||
|
||||
fn result_image_url(response: &Value) -> Option<String> {
|
||||
let result = response.get("ResultImage")?;
|
||||
if let Some(url) = result.as_str().filter(|value| value.starts_with("http")) {
|
||||
return Some(url.to_string());
|
||||
}
|
||||
|
||||
result.as_array()?.first().and_then(|image| {
|
||||
image
|
||||
.as_str()
|
||||
.filter(|value| value.starts_with("http"))
|
||||
.map(str::to_string)
|
||||
.or_else(|| {
|
||||
image
|
||||
.get("Url")
|
||||
.or_else(|| image.get("URL"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn result_image_base64(response: &Value) -> Option<String> {
|
||||
let result = response.get("ResultImage")?;
|
||||
if let Some(value) = result.as_str().filter(|value| !value.starts_with("http")) {
|
||||
return Some(value.to_string());
|
||||
}
|
||||
|
||||
result.as_array()?.first().and_then(|image| {
|
||||
image
|
||||
.as_str()
|
||||
.filter(|value| !value.starts_with("http"))
|
||||
.map(str::to_string)
|
||||
.or_else(|| {
|
||||
image
|
||||
.get("Base64")
|
||||
.or_else(|| image.get("ImageBase64"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AiProvider for TencentHunyuanProvider {
|
||||
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
|
||||
if request.model == "hunyuan-image-lite" {
|
||||
return self
|
||||
.run_text_to_image_lite(&request.prompt, request.size.as_deref())
|
||||
.await;
|
||||
}
|
||||
if request.model == "hunyuan-image-2.0" {
|
||||
return self
|
||||
.run_text_to_image_rapid(&request.prompt, request.size.as_deref(), &[])
|
||||
.await;
|
||||
}
|
||||
|
||||
self.run_async_image_job(&request.prompt, request.size.as_deref(), &[])
|
||||
.await
|
||||
}
|
||||
|
||||
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
|
||||
if request.model == "hunyuan-image-2.0" {
|
||||
return self
|
||||
.run_text_to_image_rapid(
|
||||
&request.prompt,
|
||||
request.size.as_deref(),
|
||||
&request.image_paths,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
self.run_async_image_job(
|
||||
&request.prompt,
|
||||
request.size.as_deref(),
|
||||
&request.image_paths,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl TencentHunyuanProvider {
|
||||
async fn call(&self, action: &str, payload: Value) -> Result<Value, AppError> {
|
||||
let region = "ap-guangzhou";
|
||||
let host = host_from_base_url(&self.base_url)?;
|
||||
let config = api_config(&host);
|
||||
let timestamp = Utc::now().timestamp();
|
||||
let date = Utc
|
||||
.timestamp_opt(timestamp, 0)
|
||||
.single()
|
||||
.ok_or_else(|| AppError::Provider("invalid Tencent Cloud timestamp".to_string()))?
|
||||
.format("%Y-%m-%d")
|
||||
.to_string();
|
||||
let payload_text = serde_json::to_string(&payload).map_err(|error| {
|
||||
AppError::Provider(format!("failed to encode Tencent Cloud request: {error}"))
|
||||
})?;
|
||||
let content_type = "application/json; charset=utf-8";
|
||||
let signed_headers = "content-type;host";
|
||||
let canonical_headers = format!("content-type:{content_type}\nhost:{host}\n");
|
||||
let canonical_request = format!(
|
||||
"POST\n/\n\n{canonical_headers}\n{signed_headers}\n{}",
|
||||
sha256_hex(&payload_text)
|
||||
);
|
||||
let credential_scope = format!("{}/{}/tc3_request", date, config.service);
|
||||
let string_to_sign = format!(
|
||||
"TC3-HMAC-SHA256\n{timestamp}\n{credential_scope}\n{}",
|
||||
sha256_hex(&canonical_request)
|
||||
);
|
||||
let secret_date = hmac_sha256(format!("TC3{}", self.secret_key).as_bytes(), &date)?;
|
||||
let secret_service = hmac_sha256(&secret_date, config.service)?;
|
||||
let secret_signing = hmac_sha256(&secret_service, "tc3_request")?;
|
||||
let signature = hex::encode(hmac_sha256(&secret_signing, &string_to_sign)?);
|
||||
let authorization = format!(
|
||||
"TC3-HMAC-SHA256 Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
|
||||
self.secret_id
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.base_url)
|
||||
.header("Authorization", authorization)
|
||||
.header("Content-Type", content_type)
|
||||
.header("Host", host)
|
||||
.header("X-TC-Action", action)
|
||||
.header("X-TC-Timestamp", timestamp.to_string())
|
||||
.header("X-TC-Version", config.version)
|
||||
.header("X-TC-Region", region)
|
||||
.body(payload_text)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let response_text = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(AppError::Provider(format!(
|
||||
"Tencent Hunyuan request failed ({status}): {response_text}"
|
||||
)));
|
||||
}
|
||||
|
||||
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode Tencent Hunyuan response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let inner = response_body.get("Response").ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"Tencent Hunyuan response missing Response: {response_text}"
|
||||
))
|
||||
})?;
|
||||
if let Some(message) = parse_tencent_error(inner) {
|
||||
return Err(AppError::Provider(format!(
|
||||
"Tencent Hunyuan {action} failed: {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(inner.clone())
|
||||
}
|
||||
|
||||
async fn run_text_to_image_lite(
|
||||
&self,
|
||||
prompt: &str,
|
||||
size: Option<&str>,
|
||||
) -> Result<ImageResult, AppError> {
|
||||
let mut payload = json!({
|
||||
"Prompt": prompt,
|
||||
"RspImgType": "url",
|
||||
"LogoAdd": 0,
|
||||
});
|
||||
if let Some(resolution) = hunyuan_resolution(size, true, false) {
|
||||
payload["Resolution"] = json!(resolution);
|
||||
}
|
||||
let host = host_from_base_url(&self.base_url)?;
|
||||
let action = api_config(&host).lite_action;
|
||||
let response = self.call(action, payload).await?;
|
||||
let image_url = result_image_url(&response).ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"Tencent Hunyuan TextToImageLite did not include image url: {response}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: ImageData::Url(image_url),
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_text_to_image_rapid(
|
||||
&self,
|
||||
prompt: &str,
|
||||
size: Option<&str>,
|
||||
image_paths: &[String],
|
||||
) -> Result<ImageResult, AppError> {
|
||||
let has_reference = !image_paths.is_empty();
|
||||
let mut payload = json!({
|
||||
"Prompt": prompt,
|
||||
"RspImgType": "url",
|
||||
"LogoAdd": 0,
|
||||
});
|
||||
if let Some(resolution) = hunyuan_resolution(size, false, has_reference) {
|
||||
payload["Resolution"] = json!(resolution);
|
||||
}
|
||||
if let Some(image_base64) = first_image_base64(image_paths)? {
|
||||
payload["Image"] = json!({
|
||||
"Base64": image_base64,
|
||||
});
|
||||
}
|
||||
let host = host_from_base_url(&self.base_url)?;
|
||||
let action = api_config(&host).rapid_action;
|
||||
let response = self.call(action, payload).await?;
|
||||
let image_url = result_image_url(&response).ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"Tencent Hunyuan TextToImageRapid did not include image url: {response}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: ImageData::Url(image_url),
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_async_image_job(
|
||||
&self,
|
||||
prompt: &str,
|
||||
size: Option<&str>,
|
||||
image_paths: &[String],
|
||||
) -> Result<ImageResult, AppError> {
|
||||
let has_reference = !image_paths.is_empty();
|
||||
let mut payload = json!({
|
||||
"Prompt": prompt,
|
||||
"Num": 1,
|
||||
"Revise": 1,
|
||||
"LogoAdd": 0,
|
||||
});
|
||||
if let Some(resolution) = hunyuan_resolution(size, false, has_reference) {
|
||||
payload["Resolution"] = json!(resolution);
|
||||
}
|
||||
let host = host_from_base_url(&self.base_url)?;
|
||||
let config = api_config(&host);
|
||||
if let Some(image_base64) = first_image_base64(image_paths)? {
|
||||
if config.service == "aiart" {
|
||||
payload["Images"] = json!([image_base64]);
|
||||
} else {
|
||||
payload["ContentImage"] = json!({
|
||||
"ImageBase64": image_base64,
|
||||
});
|
||||
}
|
||||
}
|
||||
let response = self.call(config.submit_action, payload).await?;
|
||||
let job_id = response
|
||||
.get("JobId")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"Tencent Hunyuan submit response did not include JobId: {response}"
|
||||
))
|
||||
})?;
|
||||
|
||||
self.wait_image_job(job_id).await
|
||||
}
|
||||
|
||||
async fn wait_image_job(&self, job_id: &str) -> Result<ImageResult, AppError> {
|
||||
for _ in 0..40 {
|
||||
sleep(Duration::from_secs(3)).await;
|
||||
let host = host_from_base_url(&self.base_url)?;
|
||||
let action = api_config(&host).query_action;
|
||||
let response = self.call(action, json!({ "JobId": job_id })).await?;
|
||||
let job_status = response
|
||||
.get("JobStatusCode")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
if job_status == "5" {
|
||||
if let Some(image_base64) = result_image_base64(&response) {
|
||||
return Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: ImageData::Base64(image_base64),
|
||||
});
|
||||
}
|
||||
let image_url = result_image_url(&response).ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"Tencent Hunyuan job succeeded but did not include image: {response}"
|
||||
))
|
||||
})?;
|
||||
return Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: ImageData::Url(image_url),
|
||||
});
|
||||
}
|
||||
if job_status == "4" {
|
||||
let message = response
|
||||
.get("JobErrorMsg")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("unknown error");
|
||||
return Err(AppError::Provider(format!(
|
||||
"Tencent Hunyuan image job failed: {message}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Err(AppError::Provider(
|
||||
"Tencent Hunyuan image job polling timed out".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,11 @@ pub async fn pick_material_images(app: tauri::AppHandle) -> Result<Vec<String>,
|
||||
let paths = files
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter_map(|file_path| file_path.as_path().map(|path| path.to_string_lossy().to_string()))
|
||||
.filter_map(|file_path| {
|
||||
file_path
|
||||
.as_path()
|
||||
.map(|path| path.to_string_lossy().to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(paths)
|
||||
|
||||
@@ -2,16 +2,21 @@ use tauri::{AppHandle, State};
|
||||
|
||||
use crate::{
|
||||
ai::{
|
||||
dashscope::DashScopeProvider,
|
||||
google_gemini::GoogleGeminiProvider,
|
||||
openai_compatible::OpenAiCompatibleProvider,
|
||||
provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest},
|
||||
seedream::SeedreamProvider,
|
||||
tencent_hunyuan::TencentHunyuanProvider,
|
||||
},
|
||||
db::{
|
||||
models::{CreateGenerationTaskInput, GenerateImageInput, GenerateImageOutput, GenerationTask},
|
||||
models::{
|
||||
CreateGenerationTaskInput, GenerateImageInput, GenerateImageOutput, GenerationTask,
|
||||
},
|
||||
repository,
|
||||
},
|
||||
state::AppState,
|
||||
storage,
|
||||
AppError,
|
||||
storage, AppError,
|
||||
};
|
||||
|
||||
#[tauri::command]
|
||||
@@ -30,21 +35,66 @@ pub async fn generate_image(
|
||||
) -> Result<GenerateImageOutput, AppError> {
|
||||
let provider = repository::get_provider_secret(&state.db, &input.provider_id).await?;
|
||||
if !provider.enabled {
|
||||
return Err(AppError::Provider(format!("provider {} is disabled", provider.name)));
|
||||
return Err(AppError::Provider(format!(
|
||||
"provider {} is disabled",
|
||||
provider.name
|
||||
)));
|
||||
}
|
||||
|
||||
if provider.kind != "openai-compatible" {
|
||||
if provider.kind != "openai"
|
||||
&& provider.kind != "openai-compatible"
|
||||
&& provider.kind != "volcengine-ark"
|
||||
&& provider.kind != "dashscope"
|
||||
&& provider.kind != "tencent-hunyuan"
|
||||
&& provider.kind != "google-gemini"
|
||||
{
|
||||
return Err(AppError::Provider(format!(
|
||||
"provider kind {} is not supported yet",
|
||||
provider.kind
|
||||
)));
|
||||
}
|
||||
|
||||
if !provider.base_url.trim_end_matches('/').ends_with("/v1") {
|
||||
if (provider.kind == "openai" || provider.kind == "openai-compatible")
|
||||
&& !provider.base_url.trim_end_matches('/').ends_with("/v1")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"Base URL 看起来不是 API 地址。OpenAI-compatible 地址通常需要以 /v1 结尾,例如 https://api.openai.com/v1 或 https://你的中转站域名/v1".to_string(),
|
||||
));
|
||||
}
|
||||
if provider.kind == "volcengine-ark"
|
||||
&& !provider.base_url.trim_end_matches('/').ends_with("/api/v3")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"火山方舟 Seedream 的 Base URL 通常需要以 /api/v3 结尾,例如 https://ark.cn-beijing.volces.com/api/v3".to_string(),
|
||||
));
|
||||
}
|
||||
if provider.kind == "dashscope" && !provider.base_url.trim_end_matches('/').ends_with("/api/v1")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"阿里云百炼 DashScope 的 Base URL 通常需要以 /api/v1 结尾,例如 https://dashscope.aliyuncs.com/api/v1".to_string(),
|
||||
));
|
||||
}
|
||||
if provider.kind == "google-gemini"
|
||||
&& !provider.base_url.trim_end_matches('/').ends_with("/v1beta")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"Google Gemini / Nano Banana 的 Base URL 通常填写 https://generativelanguage.googleapis.com/v1beta".to_string(),
|
||||
));
|
||||
}
|
||||
if provider.kind == "tencent-hunyuan"
|
||||
&& !provider
|
||||
.base_url
|
||||
.trim_end_matches('/')
|
||||
.ends_with("aiart.tencentcloudapi.com")
|
||||
&& !provider
|
||||
.base_url
|
||||
.trim_end_matches('/')
|
||||
.ends_with("hunyuan.tencentcloudapi.com")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"腾讯混元图像的 Base URL 通常填写 https://aiart.tencentcloudapi.com".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let api_key = provider
|
||||
.api_key_encrypted
|
||||
@@ -76,30 +126,121 @@ pub async fn generate_image(
|
||||
)
|
||||
.await?;
|
||||
|
||||
let ai_provider = OpenAiCompatibleProvider::new(provider.base_url, api_key);
|
||||
let image_result = match if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
let image_result = match if provider.kind == "volcengine-ark" {
|
||||
let ai_provider = SeedreamProvider::new(provider.base_url, api_key);
|
||||
if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
}
|
||||
} else if provider.kind == "dashscope" {
|
||||
let ai_provider = DashScopeProvider::new(provider.base_url, api_key);
|
||||
if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
}
|
||||
} else if provider.kind == "tencent-hunyuan" {
|
||||
let ai_provider = TencentHunyuanProvider::new(provider.base_url, api_key)?;
|
||||
if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
}
|
||||
} else if provider.kind == "google-gemini" {
|
||||
let ai_provider = GoogleGeminiProvider::new(provider.base_url, api_key);
|
||||
if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
}
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
let ai_provider = OpenAiCompatibleProvider::new(provider.base_url, api_key);
|
||||
if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
}
|
||||
} {
|
||||
Ok(result) => result,
|
||||
Err(error) => {
|
||||
repository::mark_generation_task_failed(&state.db, &task.id, &error.to_string()).await?;
|
||||
repository::mark_generation_task_failed(&state.db, &task.id, &error.to_string())
|
||||
.await?;
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
@@ -108,7 +249,8 @@ pub async fn generate_image(
|
||||
ImageData::Base64(data_base64) => storage::decode_base64_image(data_base64)?,
|
||||
ImageData::Url(url) => reqwest::get(url).await?.bytes().await?.to_vec(),
|
||||
};
|
||||
let stored_image = storage::save_generated_image_bytes(&app, &image_bytes, &image_result.mime_type)?;
|
||||
let stored_image =
|
||||
storage::save_generated_image_bytes(&app, &image_bytes, &image_result.mime_type)?;
|
||||
let file_path = stored_image.file_path.to_string_lossy().to_string();
|
||||
let asset = repository::create_image_asset(
|
||||
&state.db,
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tauri::State;
|
||||
|
||||
use crate::{
|
||||
@@ -7,12 +10,29 @@ use crate::{
|
||||
};
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_providers(state: State<'_, AppState>) -> Result<Vec<crate::db::models::ProviderConfig>, AppError> {
|
||||
pub async fn list_providers(
|
||||
state: State<'_, AppState>,
|
||||
) -> Result<Vec<crate::db::models::ProviderConfig>, AppError> {
|
||||
repository::list_providers(&state.db).await
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn upsert_provider(state: State<'_, AppState>, input: UpsertProviderInput) -> Result<(), AppError> {
|
||||
pub async fn upsert_provider(
|
||||
state: State<'_, AppState>,
|
||||
input: UpsertProviderInput,
|
||||
) -> Result<(), AppError> {
|
||||
if input
|
||||
.api_key
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.unwrap_or_default()
|
||||
.is_empty()
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"请先填写 API Key,再保存配置".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
repository::upsert_provider(&state.db, input).await
|
||||
}
|
||||
|
||||
@@ -20,3 +40,306 @@ pub async fn upsert_provider(state: State<'_, AppState>, input: UpsertProviderIn
|
||||
pub async fn delete_provider(state: State<'_, AppState>, id: String) -> Result<(), AppError> {
|
||||
repository::delete_provider(&state.db, &id).await
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ProviderModel {
|
||||
pub id: String,
|
||||
pub owned_by: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ModelsResponse {
|
||||
data: Vec<ModelItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ModelItem {
|
||||
id: String,
|
||||
owned_by: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ProviderCapabilities {
|
||||
image_models: Option<Vec<ProviderModel>>,
|
||||
selected_image_models: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn is_image_model(model_id: &str) -> bool {
|
||||
let id = model_id.to_ascii_lowercase();
|
||||
[
|
||||
"gpt-image",
|
||||
"image",
|
||||
"dall-e",
|
||||
"dalle",
|
||||
"imagen",
|
||||
"flux",
|
||||
"qwen-image",
|
||||
"wan",
|
||||
"z-image",
|
||||
"hunyuan-image",
|
||||
"gemini",
|
||||
"seedream",
|
||||
"seededit",
|
||||
"stable-diffusion",
|
||||
"sd-",
|
||||
"midjourney",
|
||||
"recraft",
|
||||
]
|
||||
.iter()
|
||||
.any(|marker| id.contains(marker))
|
||||
}
|
||||
|
||||
fn default_seedream_models() -> Vec<ProviderModel> {
|
||||
vec![
|
||||
ProviderModel {
|
||||
id: "doubao-seedream-4-5-251128".to_string(),
|
||||
owned_by: Some("volcengine".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "doubao-seedream-4-0-250828".to_string(),
|
||||
owned_by: Some("volcengine".to_string()),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn default_dashscope_models() -> Vec<ProviderModel> {
|
||||
vec![
|
||||
ProviderModel {
|
||||
id: "qwen-image-2.0-pro".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "qwen-image-2.0".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "qwen-image-plus".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "qwen-image".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "wan2.7-image-pro".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "wan2.7-image".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "z-image-turbo".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn default_tencent_hunyuan_models() -> Vec<ProviderModel> {
|
||||
vec![
|
||||
ProviderModel {
|
||||
id: "hunyuan-image-3.0".to_string(),
|
||||
owned_by: Some("tencent-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "hunyuan-image-2.0".to_string(),
|
||||
owned_by: Some("tencent-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "hunyuan-image-lite".to_string(),
|
||||
owned_by: Some("tencent-cloud".to_string()),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn default_google_gemini_models() -> Vec<ProviderModel> {
|
||||
vec![
|
||||
ProviderModel {
|
||||
id: "gemini-2.5-flash-image".to_string(),
|
||||
owned_by: Some("google".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "gemini-3.1-flash-image-preview".to_string(),
|
||||
owned_by: Some("google".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "gemini-3-pro-image-preview".to_string(),
|
||||
owned_by: Some("google".to_string()),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn fetch_provider_models(
|
||||
state: State<'_, AppState>,
|
||||
input: UpsertProviderInput,
|
||||
) -> Result<Vec<ProviderModel>, AppError> {
|
||||
if input.kind != "openai"
|
||||
&& input.kind != "openai-compatible"
|
||||
&& input.kind != "volcengine-ark"
|
||||
&& input.kind != "dashscope"
|
||||
&& input.kind != "tencent-hunyuan"
|
||||
&& input.kind != "google-gemini"
|
||||
{
|
||||
return Err(AppError::Provider(format!(
|
||||
"API 分类 {} 暂未接入模型列表获取",
|
||||
input.kind
|
||||
)));
|
||||
}
|
||||
|
||||
if (input.kind == "openai" || input.kind == "openai-compatible")
|
||||
&& !input.base_url.trim_end_matches('/').ends_with("/v1")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"Base URL 看起来不是 API 地址。OpenAI-compatible 地址通常需要以 /v1 结尾。".to_string(),
|
||||
));
|
||||
}
|
||||
if input.kind == "volcengine-ark" && !input.base_url.trim_end_matches('/').ends_with("/api/v3")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"火山方舟 Seedream 的 Base URL 通常需要以 /api/v3 结尾。".to_string(),
|
||||
));
|
||||
}
|
||||
if input.kind == "dashscope" && !input.base_url.trim_end_matches('/').ends_with("/api/v1") {
|
||||
return Err(AppError::Provider(
|
||||
"阿里云百炼 DashScope 的 Base URL 通常需要以 /api/v1 结尾。".to_string(),
|
||||
));
|
||||
}
|
||||
if input.kind == "google-gemini" && !input.base_url.trim_end_matches('/').ends_with("/v1beta") {
|
||||
return Err(AppError::Provider(
|
||||
"Google Gemini / Nano Banana 的 Base URL 通常填写 https://generativelanguage.googleapis.com/v1beta".to_string(),
|
||||
));
|
||||
}
|
||||
if input.kind == "tencent-hunyuan"
|
||||
&& !input
|
||||
.base_url
|
||||
.trim_end_matches('/')
|
||||
.ends_with("aiart.tencentcloudapi.com")
|
||||
&& !input
|
||||
.base_url
|
||||
.trim_end_matches('/')
|
||||
.ends_with("hunyuan.tencentcloudapi.com")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"腾讯混元图像的 Base URL 通常填写 https://aiart.tencentcloudapi.com".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let saved_provider = repository::get_provider_secret(&state.db, &input.id)
|
||||
.await
|
||||
.ok();
|
||||
let api_key = match input.api_key.clone() {
|
||||
Some(key) if !key.trim().is_empty() => Some(key),
|
||||
Some(_) => None,
|
||||
None => saved_provider.and_then(|provider| provider.api_key_encrypted),
|
||||
};
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(AppError::Provider(
|
||||
"API Key 为空,无法获取模型列表".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
if input.kind == "dashscope" {
|
||||
return save_provider_models(&state, input, default_dashscope_models()).await;
|
||||
}
|
||||
if input.kind == "tencent-hunyuan" {
|
||||
return save_provider_models(&state, input, default_tencent_hunyuan_models()).await;
|
||||
}
|
||||
if input.kind == "google-gemini" {
|
||||
return save_provider_models(&state, input, default_google_gemini_models()).await;
|
||||
}
|
||||
|
||||
let response = Client::new()
|
||||
.get(format!("{}/models", input.base_url.trim_end_matches('/')))
|
||||
.bearer_auth(api_key)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
if input.kind == "volcengine-ark" && matches!(status.as_u16(), 404 | 405 | 501) {
|
||||
let fetched_models = default_seedream_models();
|
||||
return save_provider_models(&state, input, fetched_models).await;
|
||||
}
|
||||
let message = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!(
|
||||
"获取模型列表失败 ({status}): {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
let mut fetched_models: Vec<ProviderModel> = response
|
||||
.json::<ModelsResponse>()
|
||||
.await?
|
||||
.data
|
||||
.into_iter()
|
||||
.filter(|model| is_image_model(&model.id))
|
||||
.map(|model| ProviderModel {
|
||||
id: model.id,
|
||||
owned_by: model.owned_by,
|
||||
})
|
||||
.collect();
|
||||
if input.kind == "volcengine-ark" && fetched_models.is_empty() {
|
||||
fetched_models = default_seedream_models();
|
||||
}
|
||||
|
||||
save_provider_models(&state, input, fetched_models).await
|
||||
}
|
||||
|
||||
async fn save_provider_models(
|
||||
state: &State<'_, AppState>,
|
||||
input: UpsertProviderInput,
|
||||
fetched_models: Vec<ProviderModel>,
|
||||
) -> Result<Vec<ProviderModel>, AppError> {
|
||||
let saved_providers = repository::list_providers(&state.db)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let saved_capabilities = saved_providers
|
||||
.iter()
|
||||
.find(|provider| provider.id == input.id)
|
||||
.and_then(|provider| provider.capabilities.as_deref())
|
||||
.and_then(|value| serde_json::from_str::<ProviderCapabilities>(value).ok());
|
||||
let mut models = saved_capabilities
|
||||
.as_ref()
|
||||
.and_then(|capabilities| capabilities.image_models.clone())
|
||||
.unwrap_or_default();
|
||||
models.extend(fetched_models);
|
||||
models.sort_by(|left, right| left.id.cmp(&right.id));
|
||||
models.dedup_by(|left, right| left.id == right.id);
|
||||
let model_ids = models
|
||||
.iter()
|
||||
.map(|model| model.id.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let mut selected_model_ids = saved_capabilities
|
||||
.and_then(|capabilities| capabilities.selected_image_models)
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter(|model_id| model_ids.contains(model_id))
|
||||
.collect::<Vec<_>>();
|
||||
for model_id in &model_ids {
|
||||
if !selected_model_ids.contains(model_id) {
|
||||
selected_model_ids.push(model_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let capabilities = json!({
|
||||
"responses_api": true,
|
||||
"images_api": true,
|
||||
"chat_completions": true,
|
||||
"image_edit": true,
|
||||
"image_models": models,
|
||||
"selected_image_models": selected_model_ids,
|
||||
});
|
||||
repository::upsert_provider(
|
||||
&state.db,
|
||||
UpsertProviderInput {
|
||||
capabilities: Some(capabilities.to_string()),
|
||||
..input
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::fs;
|
||||
|
||||
use sqlx::{sqlite::SqliteConnectOptions, SqlitePool};
|
||||
use sqlx::{sqlite::SqliteConnectOptions, Row, SqlitePool};
|
||||
use tauri::{AppHandle, Manager};
|
||||
|
||||
use crate::AppError;
|
||||
@@ -24,6 +24,38 @@ pub async fn init(app: &AppHandle) -> Result<SqlitePool, AppError> {
|
||||
sqlx::query(statement).execute(&pool).await?;
|
||||
}
|
||||
}
|
||||
ensure_column(
|
||||
&pool,
|
||||
"providers",
|
||||
"capabilities",
|
||||
"TEXT NOT NULL DEFAULT '{}'",
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
async fn ensure_column(
|
||||
pool: &SqlitePool,
|
||||
table: &str,
|
||||
column: &str,
|
||||
definition: &str,
|
||||
) -> Result<(), AppError> {
|
||||
let rows = sqlx::query(&format!("PRAGMA table_info({table})"))
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
let has_column = rows.iter().any(|row| {
|
||||
row.try_get::<String, _>("name")
|
||||
.map(|name| name == column)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
if !has_column {
|
||||
sqlx::query(&format!(
|
||||
"ALTER TABLE {table} ADD COLUMN {column} {definition}"
|
||||
))
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ pub struct ProviderConfig {
|
||||
pub api_key: Option<String>,
|
||||
pub text_model: Option<String>,
|
||||
pub image_model: Option<String>,
|
||||
pub capabilities: Option<String>,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
@@ -33,6 +34,7 @@ pub struct UpsertProviderInput {
|
||||
pub api_key: Option<String>,
|
||||
pub text_model: Option<String>,
|
||||
pub image_model: Option<String>,
|
||||
pub capabilities: Option<String>,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
|
||||
@@ -19,8 +19,10 @@ pub async fn list_providers(pool: &SqlitePool) -> Result<Vec<ProviderConfig>, Ap
|
||||
api_key_encrypted AS api_key,
|
||||
text_model,
|
||||
image_model,
|
||||
capabilities,
|
||||
enabled != 0 AS enabled
|
||||
FROM providers
|
||||
WHERE enabled != 0
|
||||
ORDER BY updated_at DESC
|
||||
"#,
|
||||
)
|
||||
@@ -30,24 +32,34 @@ pub async fn list_providers(pool: &SqlitePool) -> Result<Vec<ProviderConfig>, Ap
|
||||
Ok(providers)
|
||||
}
|
||||
|
||||
pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> Result<(), AppError> {
|
||||
pub async fn upsert_provider(
|
||||
pool: &SqlitePool,
|
||||
input: UpsertProviderInput,
|
||||
) -> Result<(), AppError> {
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
let existing_api_key: Option<String> = sqlx::query_scalar(
|
||||
let existing: Option<(Option<String>, Option<String>)> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT api_key_encrypted
|
||||
SELECT api_key_encrypted, capabilities
|
||||
FROM providers
|
||||
WHERE id = ?1
|
||||
"#,
|
||||
)
|
||||
.bind(&input.id)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
.flatten();
|
||||
let api_key = input
|
||||
.api_key
|
||||
.filter(|key| !key.trim().is_empty())
|
||||
.or(existing_api_key);
|
||||
.await?;
|
||||
let api_key = match input.api_key {
|
||||
Some(key) if key.trim().is_empty() => None,
|
||||
Some(key) => Some(key),
|
||||
None => existing.as_ref().and_then(|item| item.0.clone()),
|
||||
};
|
||||
let capabilities = input
|
||||
.capabilities
|
||||
.filter(|value| !value.trim().is_empty())
|
||||
.or_else(|| existing.and_then(|item| item.1))
|
||||
.unwrap_or_else(|| {
|
||||
r#"{"responses_api":true,"images_api":true,"chat_completions":true,"image_edit":true,"image_models":[],"selected_image_models":[]}"#.to_string()
|
||||
});
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
@@ -62,6 +74,7 @@ pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> R
|
||||
api_key_encrypted = excluded.api_key_encrypted,
|
||||
text_model = excluded.text_model,
|
||||
image_model = excluded.image_model,
|
||||
capabilities = excluded.capabilities,
|
||||
enabled = excluded.enabled,
|
||||
updated_at = excluded.updated_at
|
||||
"#,
|
||||
@@ -73,7 +86,7 @@ pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> R
|
||||
.bind(api_key)
|
||||
.bind(input.text_model)
|
||||
.bind(input.image_model)
|
||||
.bind(r#"{"responses_api":true,"images_api":true,"chat_completions":true,"image_edit":true}"#)
|
||||
.bind(capabilities)
|
||||
.bind(input.enabled)
|
||||
.bind(now)
|
||||
.execute(pool)
|
||||
@@ -98,6 +111,34 @@ pub async fn get_provider_secret(pool: &SqlitePool, id: &str) -> Result<Provider
|
||||
}
|
||||
|
||||
pub async fn delete_provider(pool: &SqlitePool, id: &str) -> Result<(), AppError> {
|
||||
let reference_count: i64 = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT
|
||||
(SELECT COUNT(*) FROM generation_tasks WHERE provider_id = ?1) +
|
||||
(SELECT COUNT(*) FROM conversations WHERE provider_id = ?1) +
|
||||
(SELECT COUNT(*) FROM ai_request_logs WHERE provider_id = ?1)
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if reference_count > 0 {
|
||||
let now = Utc::now().to_rfc3339();
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE providers
|
||||
SET enabled = 0, updated_at = ?2
|
||||
WHERE id = ?1
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.bind(now)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query("DELETE FROM providers WHERE id = ?1")
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
mod ai;
|
||||
mod commands;
|
||||
mod db;
|
||||
mod storage;
|
||||
mod state;
|
||||
mod storage;
|
||||
|
||||
use serde::Serialize;
|
||||
use state::AppState;
|
||||
@@ -53,6 +53,7 @@ pub fn run() {
|
||||
commands::provider::list_providers,
|
||||
commands::provider::upsert_provider,
|
||||
commands::provider::delete_provider,
|
||||
commands::provider::fetch_provider_models,
|
||||
commands::dialog::pick_material_images,
|
||||
commands::file::reveal_path,
|
||||
commands::file::open_generated_dir,
|
||||
|
||||
Reference in New Issue
Block a user