feat: add multi-provider image generation

This commit is contained in:
2026-04-29 15:47:18 +08:00
parent 20bd25e136
commit a4198f29d2
18 changed files with 3397 additions and 652 deletions

View File

@@ -26,6 +26,9 @@ reqwest = { version = "0.12", default-features = false, features = ["json", "mul
thiserror = "2"
async-trait = "0.1"
base64 = "0.22"
hex = "0.4"
hmac = "0.12"
sha2 = "0.10"
[features]
default = ["custom-protocol"]

View File

@@ -0,0 +1,354 @@
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use reqwest::Client;
use serde_json::{json, Value};
use std::{fs, path::Path, time::Duration};
use tokio::time::sleep;
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
use crate::AppError;
pub struct DashScopeProvider {
client: Client,
base_url: String,
api_key: String,
}
impl DashScopeProvider {
pub fn new(base_url: String, api_key: String) -> Self {
Self {
client: Client::new(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key,
}
}
}
fn is_qwen_image_model(model: &str) -> bool {
model.starts_with("qwen-image")
}
fn is_sync_multimodal_model(model: &str) -> bool {
is_qwen_image_model(model) || model.starts_with("z-image")
}
fn dashscope_size(size: Option<&str>) -> Option<String> {
size.map(|value| value.replace('x', "*"))
}
fn mime_for_path(path: &str) -> &'static str {
match Path::new(path)
.extension()
.and_then(|extension| extension.to_str())
.map(|extension| extension.to_ascii_lowercase())
.as_deref()
{
Some("jpg" | "jpeg") => "image/jpeg",
Some("webp") => "image/webp",
_ => "image/png",
}
}
fn path_to_data_url(path: &str) -> Result<String, AppError> {
let bytes = fs::read(path)?;
Ok(format!(
"data:{};base64,{}",
mime_for_path(path),
general_purpose::STANDARD.encode(bytes)
))
}
fn parse_dashscope_error(response_text: &str, fallback: &str) -> String {
serde_json::from_str::<Value>(response_text)
.ok()
.and_then(|value| {
let code = value
.get("code")
.and_then(Value::as_str)
.unwrap_or_default();
let message = value
.get("message")
.and_then(Value::as_str)
.unwrap_or_default();
if code.is_empty() && message.is_empty() {
None
} else {
Some(format!("{code}: {message}"))
}
})
.unwrap_or_else(|| fallback.to_string())
}
fn image_url_from_choices(value: &Value) -> Option<String> {
value
.get("output")?
.get("choices")?
.as_array()?
.iter()
.flat_map(|choice| {
choice
.get("message")
.and_then(|message| message.get("content"))
.and_then(Value::as_array)
.into_iter()
.flatten()
})
.find_map(|content| {
content
.get("image")
.and_then(Value::as_str)
.map(str::to_string)
})
}
fn image_url_from_results(value: &Value) -> Option<String> {
value
.get("output")?
.get("results")?
.as_array()?
.iter()
.find_map(|item| item.get("url").and_then(Value::as_str).map(str::to_string))
}
fn parse_image_url(value: &Value) -> Option<String> {
image_url_from_choices(value).or_else(|| image_url_from_results(value))
}
fn build_messages(prompt: &str, image_paths: &[String]) -> Result<Vec<Value>, AppError> {
let mut content = image_paths
.iter()
.map(|path| path_to_data_url(path).map(|image| json!({ "image": image })))
.collect::<Result<Vec<_>, _>>()?;
content.push(json!({ "text": prompt }));
Ok(vec![json!({
"role": "user",
"content": content,
})])
}
#[async_trait]
impl AiProvider for DashScopeProvider {
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
if is_sync_multimodal_model(&request.model) {
return self
.run_qwen_image(
&request.prompt,
&request.model,
request.size.as_deref(),
&[],
)
.await;
}
self.run_async_image_generation(
&request.prompt,
&request.model,
request.size.as_deref(),
&[],
)
.await
}
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
if is_sync_multimodal_model(&request.model) {
return self
.run_qwen_image(
&request.prompt,
&request.model,
request.size.as_deref(),
&request.image_paths,
)
.await;
}
self.run_async_image_generation(
&request.prompt,
&request.model,
request.size.as_deref(),
&request.image_paths,
)
.await
}
}
impl DashScopeProvider {
async fn run_qwen_image(
&self,
prompt: &str,
model: &str,
size: Option<&str>,
image_paths: &[String],
) -> Result<ImageResult, AppError> {
let mut parameters = json!({
"n": 1,
"watermark": false,
"prompt_extend": true,
});
if is_qwen_image_model(model) {
parameters["negative_prompt"] = json!(" ");
}
if let Some(size) = dashscope_size(size) {
parameters["size"] = json!(size);
}
let body = json!({
"model": model,
"input": {
"messages": build_messages(prompt, image_paths)?,
},
"parameters": parameters,
});
let response = self
.client
.post(format!(
"{}/services/aigc/multimodal-generation/generation",
self.base_url
))
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
return Err(AppError::Provider(format!(
"DashScope image generation failed ({status}): {}",
parse_dashscope_error(&response_text, &response_text)
)));
}
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode DashScope response: {error}; response body: {response_text}"
))
})?;
let image_url = parse_image_url(&response_body).ok_or_else(|| {
AppError::Provider(format!(
"DashScope response did not include image url: {response_text}"
))
})?;
Ok(ImageResult {
mime_type: "image/png".to_string(),
data: ImageData::Url(image_url),
})
}
async fn run_async_image_generation(
&self,
prompt: &str,
model: &str,
size: Option<&str>,
image_paths: &[String],
) -> Result<ImageResult, AppError> {
let mut parameters = json!({
"n": 1,
"watermark": false,
});
if let Some(size) = dashscope_size(size) {
parameters["size"] = json!(size);
}
if model.starts_with("wan2.7-image") {
parameters["thinking_mode"] = json!(true);
} else {
parameters["prompt_extend"] = json!(true);
}
let body = json!({
"model": model,
"input": {
"messages": build_messages(prompt, image_paths)?,
},
"parameters": parameters,
});
let response = self
.client
.post(format!(
"{}/services/aigc/image-generation/generation",
self.base_url
))
.bearer_auth(&self.api_key)
.header("X-DashScope-Async", "enable")
.json(&body)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
return Err(AppError::Provider(format!(
"DashScope async image task failed ({status}): {}",
parse_dashscope_error(&response_text, &response_text)
)));
}
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode DashScope task response: {error}; response body: {response_text}"
))
})?;
let task_id = response_body
.get("output")
.and_then(|output| output.get("task_id"))
.and_then(Value::as_str)
.ok_or_else(|| {
AppError::Provider(format!(
"DashScope task response did not include task_id: {response_text}"
))
})?;
self.wait_async_task(task_id).await
}
async fn wait_async_task(&self, task_id: &str) -> Result<ImageResult, AppError> {
for _ in 0..40 {
sleep(Duration::from_secs(3)).await;
let response = self
.client
.get(format!("{}/tasks/{task_id}", self.base_url))
.bearer_auth(&self.api_key)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
return Err(AppError::Provider(format!(
"DashScope task polling failed ({status}): {}",
parse_dashscope_error(&response_text, &response_text)
)));
}
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode DashScope task result: {error}; response body: {response_text}"
))
})?;
let task_status = response_body
.get("output")
.and_then(|output| output.get("task_status"))
.and_then(Value::as_str)
.unwrap_or_default();
if task_status == "SUCCEEDED" {
let image_url = parse_image_url(&response_body).ok_or_else(|| {
AppError::Provider(format!(
"DashScope task succeeded but did not include image url: {response_text}"
))
})?;
return Ok(ImageResult {
mime_type: "image/png".to_string(),
data: ImageData::Url(image_url),
});
}
if matches!(task_status, "FAILED" | "CANCELED" | "UNKNOWN") {
return Err(AppError::Provider(format!(
"DashScope task ended with status {task_status}: {}",
parse_dashscope_error(&response_text, &response_text)
)));
}
}
Err(AppError::Provider(
"DashScope image task polling timed out".to_string(),
))
}
}

View File

@@ -0,0 +1,189 @@
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use reqwest::Client;
use serde_json::{json, Value};
use std::{fs, path::Path};
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
use crate::AppError;
pub struct GoogleGeminiProvider {
client: Client,
base_url: String,
api_key: String,
}
impl GoogleGeminiProvider {
pub fn new(base_url: String, api_key: String) -> Self {
Self {
client: Client::new(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key,
}
}
}
fn mime_for_path(path: &str) -> &'static str {
match Path::new(path)
.extension()
.and_then(|extension| extension.to_str())
.map(|extension| extension.to_ascii_lowercase())
.as_deref()
{
Some("jpg" | "jpeg") => "image/jpeg",
Some("webp") => "image/webp",
_ => "image/png",
}
}
fn size_to_aspect_ratio(size: Option<&str>) -> Option<&'static str> {
let (width, height) = size?.split_once('x')?;
let width = width.parse::<f32>().ok()?;
let height = height.parse::<f32>().ok()?;
if width <= 0.0 || height <= 0.0 {
return None;
}
let ratio = width / height;
if (ratio - 1.0).abs() < 0.08 {
Some("1:1")
} else if (ratio - 16.0 / 9.0).abs() < 0.12 {
Some("16:9")
} else if (ratio - 9.0 / 16.0).abs() < 0.12 {
Some("9:16")
} else if (ratio - 4.0 / 3.0).abs() < 0.12 {
Some("4:3")
} else if (ratio - 3.0 / 4.0).abs() < 0.12 {
Some("3:4")
} else {
None
}
}
fn image_part(path: &str) -> Result<Value, AppError> {
let bytes = fs::read(path)?;
Ok(json!({
"inline_data": {
"mime_type": mime_for_path(path),
"data": general_purpose::STANDARD.encode(bytes),
}
}))
}
fn inline_data_from_part(part: &Value) -> Option<(&str, &str)> {
let inline_data = part.get("inlineData").or_else(|| part.get("inline_data"))?;
let data = inline_data.get("data").and_then(Value::as_str)?;
let mime_type = inline_data
.get("mimeType")
.or_else(|| inline_data.get("mime_type"))
.and_then(Value::as_str)
.unwrap_or("image/png");
Some((mime_type, data))
}
fn parse_gemini_response(response_text: &str) -> Result<ImageResult, AppError> {
let response_body = serde_json::from_str::<Value>(response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode Gemini image response: {error}; response body: {response_text}"
))
})?;
let parts = response_body
.get("candidates")
.and_then(Value::as_array)
.into_iter()
.flatten()
.flat_map(|candidate| {
candidate
.get("content")
.and_then(|content| content.get("parts"))
.and_then(Value::as_array)
.into_iter()
.flatten()
});
for part in parts {
if let Some((mime_type, data)) = inline_data_from_part(part) {
return Ok(ImageResult {
mime_type: mime_type.to_string(),
data: ImageData::Base64(data.to_string()),
});
}
}
Err(AppError::Provider(format!(
"Gemini response did not include image data: {response_text}"
)))
}
#[async_trait]
impl AiProvider for GoogleGeminiProvider {
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
self.generate_with_parts(
&request.prompt,
&request.model,
request.size.as_deref(),
&[],
)
.await
}
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
self.generate_with_parts(
&request.prompt,
&request.model,
request.size.as_deref(),
&request.image_paths,
)
.await
}
}
impl GoogleGeminiProvider {
async fn generate_with_parts(
&self,
prompt: &str,
model: &str,
size: Option<&str>,
image_paths: &[String],
) -> Result<ImageResult, AppError> {
let mut parts = vec![json!({ "text": prompt })];
parts.extend(
image_paths
.iter()
.map(|path| image_part(path))
.collect::<Result<Vec<_>, _>>()?,
);
let mut generation_config = json!({
"responseModalities": ["TEXT", "IMAGE"],
});
if let Some(aspect_ratio) = size_to_aspect_ratio(size) {
generation_config["imageConfig"] = json!({
"aspectRatio": aspect_ratio,
});
}
let body = json!({
"contents": [{
"role": "user",
"parts": parts,
}],
"generationConfig": generation_config,
});
let response = self
.client
.post(format!("{}/models/{model}:generateContent", self.base_url))
.header("x-goog-api-key", &self.api_key)
.json(&body)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
return Err(AppError::Provider(format!(
"Gemini image generation failed ({status}): {response_text}"
)));
}
parse_gemini_response(&response_text)
}
}

View File

@@ -1,2 +1,6 @@
pub mod dashscope;
pub mod google_gemini;
pub mod openai_compatible;
pub mod provider;
pub mod seedream;
pub mod tencent_hunyuan;

View File

@@ -32,11 +32,12 @@ fn parse_image_response(response_text: &str) -> Result<ImageResult, AppError> {
));
}
let response_body: ImageResponseBody = serde_json::from_str(response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode image response: {error}; response body: {response_text}"
))
})?;
let response_body: ImageResponseBody =
serde_json::from_str(response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode image response: {error}; response body: {response_text}"
))
})?;
let image_data = response_body
.data
.into_iter()
@@ -45,7 +46,9 @@ fn parse_image_response(response_text: &str) -> Result<ImageResult, AppError> {
.map(ImageData::Base64)
.or_else(|| item.url.map(ImageData::Url))
})
.ok_or_else(|| AppError::Provider("image response did not include b64_json or url".to_string()))?;
.ok_or_else(|| {
AppError::Provider("image response did not include b64_json or url".to_string())
})?;
Ok(ImageResult {
mime_type: "image/png".to_string(),
@@ -96,8 +99,13 @@ impl AiProvider for OpenAiCompatibleProvider {
let status = response.status();
if !status.is_success() {
let message = response.text().await.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!("image generation failed ({status}): {message}")));
let message = response
.text()
.await
.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!(
"image generation failed ({status}): {message}"
)));
}
let response_text = response.text().await?;
@@ -134,7 +142,9 @@ impl AiProvider for OpenAiCompatibleProvider {
Some("webp") => "image/webp",
_ => "image/png",
};
let part = multipart::Part::bytes(bytes).file_name(file_name).mime_str(mime)?;
let part = multipart::Part::bytes(bytes)
.file_name(file_name)
.mime_str(mime)?;
form = form.part("image[]", part);
}
@@ -148,8 +158,13 @@ impl AiProvider for OpenAiCompatibleProvider {
let status = response.status();
if !status.is_success() {
let message = response.text().await.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!("image edit failed ({status}): {message}")));
let message = response
.text()
.await
.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!(
"image edit failed ({status}): {message}"
)));
}
let response_text = response.text().await?;

View File

@@ -0,0 +1,180 @@
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::{fs, path::Path};
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
use crate::AppError;
pub struct SeedreamProvider {
client: Client,
base_url: String,
api_key: String,
}
impl SeedreamProvider {
pub fn new(base_url: String, api_key: String) -> Self {
Self {
client: Client::new(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key,
}
}
}
#[derive(Debug, Serialize)]
struct SeedreamRequestBody<'a> {
model: &'a str,
prompt: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
size: Option<&'a str>,
response_format: &'a str,
stream: bool,
watermark: bool,
#[serde(skip_serializing_if = "Option::is_none")]
image: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
struct SeedreamResponseBody {
data: Vec<SeedreamResponseItem>,
}
#[derive(Debug, Deserialize)]
struct SeedreamResponseItem {
b64_json: Option<String>,
url: Option<String>,
}
fn mime_for_path(path: &str) -> &'static str {
match Path::new(path)
.extension()
.and_then(|extension| extension.to_str())
.map(|extension| extension.to_ascii_lowercase())
.as_deref()
{
Some("jpg" | "jpeg") => "image/jpeg",
Some("webp") => "image/webp",
_ => "image/png",
}
}
fn path_to_data_url(path: &str) -> Result<String, AppError> {
let bytes = fs::read(path)?;
Ok(format!(
"data:{};base64,{}",
mime_for_path(path),
general_purpose::STANDARD.encode(bytes)
))
}
fn parse_seedream_response(response_text: &str) -> Result<ImageResult, AppError> {
if response_text.trim_start().starts_with("<!doctype html")
|| response_text.trim_start().starts_with("<html")
{
return Err(AppError::Provider(
"火山方舟返回了 HTML 页面,不是 API JSON 响应。请检查 Base URL 是否为 API 地址。"
.to_string(),
));
}
let response_body: SeedreamResponseBody =
serde_json::from_str(response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode Seedream response: {error}; response body: {response_text}"
))
})?;
let image_data = response_body
.data
.into_iter()
.find_map(|item| {
item.b64_json
.map(ImageData::Base64)
.or_else(|| item.url.map(ImageData::Url))
})
.ok_or_else(|| {
AppError::Provider("Seedream response did not include b64_json or url".to_string())
})?;
Ok(ImageResult {
mime_type: "image/png".to_string(),
data: image_data,
})
}
#[async_trait]
impl AiProvider for SeedreamProvider {
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
let body = SeedreamRequestBody {
model: &request.model,
prompt: &request.prompt,
size: request.size.as_deref(),
response_format: "b64_json",
stream: false,
watermark: false,
image: None,
};
let response = self
.client
.post(format!("{}/images/generations", self.base_url))
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let message = response
.text()
.await
.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!(
"Seedream image generation failed ({status}): {message}"
)));
}
let response_text = response.text().await?;
parse_seedream_response(&response_text)
}
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
let images = request
.image_paths
.iter()
.map(|path| path_to_data_url(path))
.collect::<Result<Vec<_>, _>>()?;
let body = SeedreamRequestBody {
model: &request.model,
prompt: &request.prompt,
size: request.size.as_deref(),
response_format: "b64_json",
stream: false,
watermark: false,
image: Some(images),
};
let response = self
.client
.post(format!("{}/images/generations", self.base_url))
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let message = response
.text()
.await
.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!(
"Seedream image edit failed ({status}): {message}"
)));
}
let response_text = response.text().await?;
parse_seedream_response(&response_text)
}
}

View File

@@ -0,0 +1,463 @@
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use chrono::{TimeZone, Utc};
use hmac::{Hmac, Mac};
use reqwest::{Client, Url};
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::{fs, time::Duration};
use tokio::time::sleep;
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
use crate::AppError;
type HmacSha256 = Hmac<Sha256>;
pub struct TencentHunyuanProvider {
client: Client,
base_url: String,
secret_id: String,
secret_key: String,
}
struct TencentApiConfig {
service: &'static str,
version: &'static str,
lite_action: &'static str,
rapid_action: &'static str,
submit_action: &'static str,
query_action: &'static str,
}
impl TencentHunyuanProvider {
pub fn new(base_url: String, api_key: String) -> Result<Self, AppError> {
let (secret_id, secret_key) = parse_secret_pair(&api_key)?;
Ok(Self {
client: Client::new(),
base_url: base_url.trim_end_matches('/').to_string(),
secret_id,
secret_key,
})
}
}
fn parse_secret_pair(api_key: &str) -> Result<(String, String), AppError> {
let (secret_id, secret_key) = api_key.split_once(':').ok_or_else(|| {
AppError::Provider("腾讯云 API Key 需要填写为 SecretId:SecretKey".to_string())
})?;
let secret_id = secret_id.trim();
let secret_key = secret_key.trim();
if secret_id.is_empty() || secret_key.is_empty() {
return Err(AppError::Provider(
"腾讯云 API Key 需要填写为 SecretId:SecretKey".to_string(),
));
}
Ok((secret_id.to_string(), secret_key.to_string()))
}
fn hmac_sha256(key: &[u8], message: &str) -> Result<Vec<u8>, AppError> {
let mut mac = HmacSha256::new_from_slice(key)
.map_err(|error| AppError::Provider(format!("failed to create HMAC: {error}")))?;
mac.update(message.as_bytes());
Ok(mac.finalize().into_bytes().to_vec())
}
fn sha256_hex(value: &str) -> String {
hex::encode(Sha256::digest(value.as_bytes()))
}
fn host_from_base_url(base_url: &str) -> Result<String, AppError> {
let url = Url::parse(base_url)
.map_err(|error| AppError::Provider(format!("腾讯云 Base URL 不是有效 URL: {error}")))?;
url.host_str()
.map(str::to_string)
.ok_or_else(|| AppError::Provider("腾讯云 Base URL 缺少 host".to_string()))
}
fn api_config(host: &str) -> TencentApiConfig {
if host == "aiart.tencentcloudapi.com" {
TencentApiConfig {
service: "aiart",
version: "2022-12-29",
lite_action: "TextToImageLite",
rapid_action: "TextToImageRapid",
submit_action: "SubmitTextToImageJob",
query_action: "QueryTextToImageJob",
}
} else {
TencentApiConfig {
service: "hunyuan",
version: "2023-09-01",
lite_action: "TextToImageLite",
rapid_action: "TextToImageLite",
submit_action: "SubmitHunyuanImageJob",
query_action: "QueryHunyuanImageJob",
}
}
}
fn hunyuan_resolution(size: Option<&str>, lite: bool, has_reference: bool) -> Option<String> {
let size = size?;
let (width, height) = size.split_once('x')?;
let width = width.parse::<f32>().ok()?;
let height = height.parse::<f32>().ok()?;
if width <= 0.0 || height <= 0.0 {
return None;
}
let ratio = width / height;
let value = if (ratio - 1.0).abs() < 0.12 {
"1024:1024"
} else if (ratio - 0.75).abs() < 0.12 {
"768:1024"
} else if (ratio - 1.333).abs() < 0.12 {
"1024:768"
} else if ratio < 0.65 {
if lite && !has_reference {
"1080:1920"
} else {
"720:1280"
}
} else if ratio > 1.55 {
if lite && !has_reference {
"1920:1080"
} else {
"1280:720"
}
} else if width < height {
"768:1024"
} else {
"1024:768"
};
if has_reference && !matches!(value, "1024:1024" | "768:1024" | "1024:768") {
return None;
}
Some(value.to_string())
}
fn first_image_base64(image_paths: &[String]) -> Result<Option<String>, AppError> {
image_paths
.first()
.map(|path| fs::read(path).map(|bytes| general_purpose::STANDARD.encode(bytes)))
.transpose()
.map_err(AppError::from)
}
fn parse_tencent_error(response: &Value) -> Option<String> {
let error = response.get("Error")?;
let code = error
.get("Code")
.and_then(Value::as_str)
.unwrap_or_default();
let message = error
.get("Message")
.and_then(Value::as_str)
.unwrap_or_default();
Some(format!("{code}: {message}"))
}
fn result_image_url(response: &Value) -> Option<String> {
let result = response.get("ResultImage")?;
if let Some(url) = result.as_str().filter(|value| value.starts_with("http")) {
return Some(url.to_string());
}
result.as_array()?.first().and_then(|image| {
image
.as_str()
.filter(|value| value.starts_with("http"))
.map(str::to_string)
.or_else(|| {
image
.get("Url")
.or_else(|| image.get("URL"))
.and_then(Value::as_str)
.map(str::to_string)
})
})
}
fn result_image_base64(response: &Value) -> Option<String> {
let result = response.get("ResultImage")?;
if let Some(value) = result.as_str().filter(|value| !value.starts_with("http")) {
return Some(value.to_string());
}
result.as_array()?.first().and_then(|image| {
image
.as_str()
.filter(|value| !value.starts_with("http"))
.map(str::to_string)
.or_else(|| {
image
.get("Base64")
.or_else(|| image.get("ImageBase64"))
.and_then(Value::as_str)
.map(str::to_string)
})
})
}
#[async_trait]
impl AiProvider for TencentHunyuanProvider {
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
if request.model == "hunyuan-image-lite" {
return self
.run_text_to_image_lite(&request.prompt, request.size.as_deref())
.await;
}
if request.model == "hunyuan-image-2.0" {
return self
.run_text_to_image_rapid(&request.prompt, request.size.as_deref(), &[])
.await;
}
self.run_async_image_job(&request.prompt, request.size.as_deref(), &[])
.await
}
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
if request.model == "hunyuan-image-2.0" {
return self
.run_text_to_image_rapid(
&request.prompt,
request.size.as_deref(),
&request.image_paths,
)
.await;
}
self.run_async_image_job(
&request.prompt,
request.size.as_deref(),
&request.image_paths,
)
.await
}
}
impl TencentHunyuanProvider {
async fn call(&self, action: &str, payload: Value) -> Result<Value, AppError> {
let region = "ap-guangzhou";
let host = host_from_base_url(&self.base_url)?;
let config = api_config(&host);
let timestamp = Utc::now().timestamp();
let date = Utc
.timestamp_opt(timestamp, 0)
.single()
.ok_or_else(|| AppError::Provider("invalid Tencent Cloud timestamp".to_string()))?
.format("%Y-%m-%d")
.to_string();
let payload_text = serde_json::to_string(&payload).map_err(|error| {
AppError::Provider(format!("failed to encode Tencent Cloud request: {error}"))
})?;
let content_type = "application/json; charset=utf-8";
let signed_headers = "content-type;host";
let canonical_headers = format!("content-type:{content_type}\nhost:{host}\n");
let canonical_request = format!(
"POST\n/\n\n{canonical_headers}\n{signed_headers}\n{}",
sha256_hex(&payload_text)
);
let credential_scope = format!("{}/{}/tc3_request", date, config.service);
let string_to_sign = format!(
"TC3-HMAC-SHA256\n{timestamp}\n{credential_scope}\n{}",
sha256_hex(&canonical_request)
);
let secret_date = hmac_sha256(format!("TC3{}", self.secret_key).as_bytes(), &date)?;
let secret_service = hmac_sha256(&secret_date, config.service)?;
let secret_signing = hmac_sha256(&secret_service, "tc3_request")?;
let signature = hex::encode(hmac_sha256(&secret_signing, &string_to_sign)?);
let authorization = format!(
"TC3-HMAC-SHA256 Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
self.secret_id
);
let response = self
.client
.post(&self.base_url)
.header("Authorization", authorization)
.header("Content-Type", content_type)
.header("Host", host)
.header("X-TC-Action", action)
.header("X-TC-Timestamp", timestamp.to_string())
.header("X-TC-Version", config.version)
.header("X-TC-Region", region)
.body(payload_text)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
return Err(AppError::Provider(format!(
"Tencent Hunyuan request failed ({status}): {response_text}"
)));
}
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode Tencent Hunyuan response: {error}; response body: {response_text}"
))
})?;
let inner = response_body.get("Response").ok_or_else(|| {
AppError::Provider(format!(
"Tencent Hunyuan response missing Response: {response_text}"
))
})?;
if let Some(message) = parse_tencent_error(inner) {
return Err(AppError::Provider(format!(
"Tencent Hunyuan {action} failed: {message}"
)));
}
Ok(inner.clone())
}
async fn run_text_to_image_lite(
&self,
prompt: &str,
size: Option<&str>,
) -> Result<ImageResult, AppError> {
let mut payload = json!({
"Prompt": prompt,
"RspImgType": "url",
"LogoAdd": 0,
});
if let Some(resolution) = hunyuan_resolution(size, true, false) {
payload["Resolution"] = json!(resolution);
}
let host = host_from_base_url(&self.base_url)?;
let action = api_config(&host).lite_action;
let response = self.call(action, payload).await?;
let image_url = result_image_url(&response).ok_or_else(|| {
AppError::Provider(format!(
"Tencent Hunyuan TextToImageLite did not include image url: {response}"
))
})?;
Ok(ImageResult {
mime_type: "image/png".to_string(),
data: ImageData::Url(image_url),
})
}
async fn run_text_to_image_rapid(
&self,
prompt: &str,
size: Option<&str>,
image_paths: &[String],
) -> Result<ImageResult, AppError> {
let has_reference = !image_paths.is_empty();
let mut payload = json!({
"Prompt": prompt,
"RspImgType": "url",
"LogoAdd": 0,
});
if let Some(resolution) = hunyuan_resolution(size, false, has_reference) {
payload["Resolution"] = json!(resolution);
}
if let Some(image_base64) = first_image_base64(image_paths)? {
payload["Image"] = json!({
"Base64": image_base64,
});
}
let host = host_from_base_url(&self.base_url)?;
let action = api_config(&host).rapid_action;
let response = self.call(action, payload).await?;
let image_url = result_image_url(&response).ok_or_else(|| {
AppError::Provider(format!(
"Tencent Hunyuan TextToImageRapid did not include image url: {response}"
))
})?;
Ok(ImageResult {
mime_type: "image/png".to_string(),
data: ImageData::Url(image_url),
})
}
async fn run_async_image_job(
&self,
prompt: &str,
size: Option<&str>,
image_paths: &[String],
) -> Result<ImageResult, AppError> {
let has_reference = !image_paths.is_empty();
let mut payload = json!({
"Prompt": prompt,
"Num": 1,
"Revise": 1,
"LogoAdd": 0,
});
if let Some(resolution) = hunyuan_resolution(size, false, has_reference) {
payload["Resolution"] = json!(resolution);
}
let host = host_from_base_url(&self.base_url)?;
let config = api_config(&host);
if let Some(image_base64) = first_image_base64(image_paths)? {
if config.service == "aiart" {
payload["Images"] = json!([image_base64]);
} else {
payload["ContentImage"] = json!({
"ImageBase64": image_base64,
});
}
}
let response = self.call(config.submit_action, payload).await?;
let job_id = response
.get("JobId")
.and_then(Value::as_str)
.ok_or_else(|| {
AppError::Provider(format!(
"Tencent Hunyuan submit response did not include JobId: {response}"
))
})?;
self.wait_image_job(job_id).await
}
async fn wait_image_job(&self, job_id: &str) -> Result<ImageResult, AppError> {
for _ in 0..40 {
sleep(Duration::from_secs(3)).await;
let host = host_from_base_url(&self.base_url)?;
let action = api_config(&host).query_action;
let response = self.call(action, json!({ "JobId": job_id })).await?;
let job_status = response
.get("JobStatusCode")
.and_then(Value::as_str)
.unwrap_or_default();
if job_status == "5" {
if let Some(image_base64) = result_image_base64(&response) {
return Ok(ImageResult {
mime_type: "image/png".to_string(),
data: ImageData::Base64(image_base64),
});
}
let image_url = result_image_url(&response).ok_or_else(|| {
AppError::Provider(format!(
"Tencent Hunyuan job succeeded but did not include image: {response}"
))
})?;
return Ok(ImageResult {
mime_type: "image/png".to_string(),
data: ImageData::Url(image_url),
});
}
if job_status == "4" {
let message = response
.get("JobErrorMsg")
.and_then(Value::as_str)
.unwrap_or("unknown error");
return Err(AppError::Provider(format!(
"Tencent Hunyuan image job failed: {message}"
)));
}
}
Err(AppError::Provider(
"Tencent Hunyuan image job polling timed out".to_string(),
))
}
}

View File

@@ -11,7 +11,11 @@ pub async fn pick_material_images(app: tauri::AppHandle) -> Result<Vec<String>,
let paths = files
.unwrap_or_default()
.into_iter()
.filter_map(|file_path| file_path.as_path().map(|path| path.to_string_lossy().to_string()))
.filter_map(|file_path| {
file_path
.as_path()
.map(|path| path.to_string_lossy().to_string())
})
.collect();
Ok(paths)

View File

@@ -2,16 +2,21 @@ use tauri::{AppHandle, State};
use crate::{
ai::{
dashscope::DashScopeProvider,
google_gemini::GoogleGeminiProvider,
openai_compatible::OpenAiCompatibleProvider,
provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest},
seedream::SeedreamProvider,
tencent_hunyuan::TencentHunyuanProvider,
},
db::{
models::{CreateGenerationTaskInput, GenerateImageInput, GenerateImageOutput, GenerationTask},
models::{
CreateGenerationTaskInput, GenerateImageInput, GenerateImageOutput, GenerationTask,
},
repository,
},
state::AppState,
storage,
AppError,
storage, AppError,
};
#[tauri::command]
@@ -30,21 +35,66 @@ pub async fn generate_image(
) -> Result<GenerateImageOutput, AppError> {
let provider = repository::get_provider_secret(&state.db, &input.provider_id).await?;
if !provider.enabled {
return Err(AppError::Provider(format!("provider {} is disabled", provider.name)));
return Err(AppError::Provider(format!(
"provider {} is disabled",
provider.name
)));
}
if provider.kind != "openai-compatible" {
if provider.kind != "openai"
&& provider.kind != "openai-compatible"
&& provider.kind != "volcengine-ark"
&& provider.kind != "dashscope"
&& provider.kind != "tencent-hunyuan"
&& provider.kind != "google-gemini"
{
return Err(AppError::Provider(format!(
"provider kind {} is not supported yet",
provider.kind
)));
}
if !provider.base_url.trim_end_matches('/').ends_with("/v1") {
if (provider.kind == "openai" || provider.kind == "openai-compatible")
&& !provider.base_url.trim_end_matches('/').ends_with("/v1")
{
return Err(AppError::Provider(
"Base URL 看起来不是 API 地址。OpenAI-compatible 地址通常需要以 /v1 结尾,例如 https://api.openai.com/v1 或 https://你的中转站域名/v1".to_string(),
));
}
if provider.kind == "volcengine-ark"
&& !provider.base_url.trim_end_matches('/').ends_with("/api/v3")
{
return Err(AppError::Provider(
"火山方舟 Seedream 的 Base URL 通常需要以 /api/v3 结尾,例如 https://ark.cn-beijing.volces.com/api/v3".to_string(),
));
}
if provider.kind == "dashscope" && !provider.base_url.trim_end_matches('/').ends_with("/api/v1")
{
return Err(AppError::Provider(
"阿里云百炼 DashScope 的 Base URL 通常需要以 /api/v1 结尾,例如 https://dashscope.aliyuncs.com/api/v1".to_string(),
));
}
if provider.kind == "google-gemini"
&& !provider.base_url.trim_end_matches('/').ends_with("/v1beta")
{
return Err(AppError::Provider(
"Google Gemini / Nano Banana 的 Base URL 通常填写 https://generativelanguage.googleapis.com/v1beta".to_string(),
));
}
if provider.kind == "tencent-hunyuan"
&& !provider
.base_url
.trim_end_matches('/')
.ends_with("aiart.tencentcloudapi.com")
&& !provider
.base_url
.trim_end_matches('/')
.ends_with("hunyuan.tencentcloudapi.com")
{
return Err(AppError::Provider(
"腾讯混元图像的 Base URL 通常填写 https://aiart.tencentcloudapi.com".to_string(),
));
}
let api_key = provider
.api_key_encrypted
@@ -76,30 +126,121 @@ pub async fn generate_image(
)
.await?;
let ai_provider = OpenAiCompatibleProvider::new(provider.base_url, api_key);
let image_result = match if input.image_paths.is_empty() {
ai_provider
.generate_image(ImageGenerateRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
})
.await
let image_result = match if provider.kind == "volcengine-ark" {
let ai_provider = SeedreamProvider::new(provider.base_url, api_key);
if input.image_paths.is_empty() {
ai_provider
.generate_image(ImageGenerateRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
})
.await
} else {
ai_provider
.edit_image(ImageEditRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
image_paths: input.image_paths,
})
.await
}
} else if provider.kind == "dashscope" {
let ai_provider = DashScopeProvider::new(provider.base_url, api_key);
if input.image_paths.is_empty() {
ai_provider
.generate_image(ImageGenerateRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
})
.await
} else {
ai_provider
.edit_image(ImageEditRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
image_paths: input.image_paths,
})
.await
}
} else if provider.kind == "tencent-hunyuan" {
let ai_provider = TencentHunyuanProvider::new(provider.base_url, api_key)?;
if input.image_paths.is_empty() {
ai_provider
.generate_image(ImageGenerateRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
})
.await
} else {
ai_provider
.edit_image(ImageEditRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
image_paths: input.image_paths,
})
.await
}
} else if provider.kind == "google-gemini" {
let ai_provider = GoogleGeminiProvider::new(provider.base_url, api_key);
if input.image_paths.is_empty() {
ai_provider
.generate_image(ImageGenerateRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
})
.await
} else {
ai_provider
.edit_image(ImageEditRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
image_paths: input.image_paths,
})
.await
}
} else {
ai_provider
.edit_image(ImageEditRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
image_paths: input.image_paths,
})
.await
let ai_provider = OpenAiCompatibleProvider::new(provider.base_url, api_key);
if input.image_paths.is_empty() {
ai_provider
.generate_image(ImageGenerateRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
})
.await
} else {
ai_provider
.edit_image(ImageEditRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
image_paths: input.image_paths,
})
.await
}
} {
Ok(result) => result,
Err(error) => {
repository::mark_generation_task_failed(&state.db, &task.id, &error.to_string()).await?;
repository::mark_generation_task_failed(&state.db, &task.id, &error.to_string())
.await?;
return Err(error);
}
};
@@ -108,7 +249,8 @@ pub async fn generate_image(
ImageData::Base64(data_base64) => storage::decode_base64_image(data_base64)?,
ImageData::Url(url) => reqwest::get(url).await?.bytes().await?.to_vec(),
};
let stored_image = storage::save_generated_image_bytes(&app, &image_bytes, &image_result.mime_type)?;
let stored_image =
storage::save_generated_image_bytes(&app, &image_bytes, &image_result.mime_type)?;
let file_path = stored_image.file_path.to_string_lossy().to_string();
let asset = repository::create_image_asset(
&state.db,

View File

@@ -1,3 +1,6 @@
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tauri::State;
use crate::{
@@ -7,12 +10,29 @@ use crate::{
};
#[tauri::command]
pub async fn list_providers(state: State<'_, AppState>) -> Result<Vec<crate::db::models::ProviderConfig>, AppError> {
pub async fn list_providers(
state: State<'_, AppState>,
) -> Result<Vec<crate::db::models::ProviderConfig>, AppError> {
repository::list_providers(&state.db).await
}
#[tauri::command]
pub async fn upsert_provider(state: State<'_, AppState>, input: UpsertProviderInput) -> Result<(), AppError> {
pub async fn upsert_provider(
state: State<'_, AppState>,
input: UpsertProviderInput,
) -> Result<(), AppError> {
if input
.api_key
.as_deref()
.map(str::trim)
.unwrap_or_default()
.is_empty()
{
return Err(AppError::Provider(
"请先填写 API Key再保存配置".to_string(),
));
}
repository::upsert_provider(&state.db, input).await
}
@@ -20,3 +40,306 @@ pub async fn upsert_provider(state: State<'_, AppState>, input: UpsertProviderIn
pub async fn delete_provider(state: State<'_, AppState>, id: String) -> Result<(), AppError> {
repository::delete_provider(&state.db, &id).await
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ProviderModel {
pub id: String,
pub owned_by: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ModelsResponse {
data: Vec<ModelItem>,
}
#[derive(Debug, Deserialize)]
struct ModelItem {
id: String,
owned_by: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ProviderCapabilities {
image_models: Option<Vec<ProviderModel>>,
selected_image_models: Option<Vec<String>>,
}
fn is_image_model(model_id: &str) -> bool {
let id = model_id.to_ascii_lowercase();
[
"gpt-image",
"image",
"dall-e",
"dalle",
"imagen",
"flux",
"qwen-image",
"wan",
"z-image",
"hunyuan-image",
"gemini",
"seedream",
"seededit",
"stable-diffusion",
"sd-",
"midjourney",
"recraft",
]
.iter()
.any(|marker| id.contains(marker))
}
fn default_seedream_models() -> Vec<ProviderModel> {
vec![
ProviderModel {
id: "doubao-seedream-4-5-251128".to_string(),
owned_by: Some("volcengine".to_string()),
},
ProviderModel {
id: "doubao-seedream-4-0-250828".to_string(),
owned_by: Some("volcengine".to_string()),
},
]
}
fn default_dashscope_models() -> Vec<ProviderModel> {
vec![
ProviderModel {
id: "qwen-image-2.0-pro".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
ProviderModel {
id: "qwen-image-2.0".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
ProviderModel {
id: "qwen-image-plus".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
ProviderModel {
id: "qwen-image".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
ProviderModel {
id: "wan2.7-image-pro".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
ProviderModel {
id: "wan2.7-image".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
ProviderModel {
id: "z-image-turbo".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
]
}
fn default_tencent_hunyuan_models() -> Vec<ProviderModel> {
vec![
ProviderModel {
id: "hunyuan-image-3.0".to_string(),
owned_by: Some("tencent-cloud".to_string()),
},
ProviderModel {
id: "hunyuan-image-2.0".to_string(),
owned_by: Some("tencent-cloud".to_string()),
},
ProviderModel {
id: "hunyuan-image-lite".to_string(),
owned_by: Some("tencent-cloud".to_string()),
},
]
}
fn default_google_gemini_models() -> Vec<ProviderModel> {
vec![
ProviderModel {
id: "gemini-2.5-flash-image".to_string(),
owned_by: Some("google".to_string()),
},
ProviderModel {
id: "gemini-3.1-flash-image-preview".to_string(),
owned_by: Some("google".to_string()),
},
ProviderModel {
id: "gemini-3-pro-image-preview".to_string(),
owned_by: Some("google".to_string()),
},
]
}
#[tauri::command]
pub async fn fetch_provider_models(
state: State<'_, AppState>,
input: UpsertProviderInput,
) -> Result<Vec<ProviderModel>, AppError> {
if input.kind != "openai"
&& input.kind != "openai-compatible"
&& input.kind != "volcengine-ark"
&& input.kind != "dashscope"
&& input.kind != "tencent-hunyuan"
&& input.kind != "google-gemini"
{
return Err(AppError::Provider(format!(
"API 分类 {} 暂未接入模型列表获取",
input.kind
)));
}
if (input.kind == "openai" || input.kind == "openai-compatible")
&& !input.base_url.trim_end_matches('/').ends_with("/v1")
{
return Err(AppError::Provider(
"Base URL 看起来不是 API 地址。OpenAI-compatible 地址通常需要以 /v1 结尾。".to_string(),
));
}
if input.kind == "volcengine-ark" && !input.base_url.trim_end_matches('/').ends_with("/api/v3")
{
return Err(AppError::Provider(
"火山方舟 Seedream 的 Base URL 通常需要以 /api/v3 结尾。".to_string(),
));
}
if input.kind == "dashscope" && !input.base_url.trim_end_matches('/').ends_with("/api/v1") {
return Err(AppError::Provider(
"阿里云百炼 DashScope 的 Base URL 通常需要以 /api/v1 结尾。".to_string(),
));
}
if input.kind == "google-gemini" && !input.base_url.trim_end_matches('/').ends_with("/v1beta") {
return Err(AppError::Provider(
"Google Gemini / Nano Banana 的 Base URL 通常填写 https://generativelanguage.googleapis.com/v1beta".to_string(),
));
}
if input.kind == "tencent-hunyuan"
&& !input
.base_url
.trim_end_matches('/')
.ends_with("aiart.tencentcloudapi.com")
&& !input
.base_url
.trim_end_matches('/')
.ends_with("hunyuan.tencentcloudapi.com")
{
return Err(AppError::Provider(
"腾讯混元图像的 Base URL 通常填写 https://aiart.tencentcloudapi.com".to_string(),
));
}
let saved_provider = repository::get_provider_secret(&state.db, &input.id)
.await
.ok();
let api_key = match input.api_key.clone() {
Some(key) if !key.trim().is_empty() => Some(key),
Some(_) => None,
None => saved_provider.and_then(|provider| provider.api_key_encrypted),
};
let Some(api_key) = api_key else {
return Err(AppError::Provider(
"API Key 为空,无法获取模型列表".to_string(),
));
};
if input.kind == "dashscope" {
return save_provider_models(&state, input, default_dashscope_models()).await;
}
if input.kind == "tencent-hunyuan" {
return save_provider_models(&state, input, default_tencent_hunyuan_models()).await;
}
if input.kind == "google-gemini" {
return save_provider_models(&state, input, default_google_gemini_models()).await;
}
let response = Client::new()
.get(format!("{}/models", input.base_url.trim_end_matches('/')))
.bearer_auth(api_key)
.send()
.await?;
let status = response.status();
if !status.is_success() {
if input.kind == "volcengine-ark" && matches!(status.as_u16(), 404 | 405 | 501) {
let fetched_models = default_seedream_models();
return save_provider_models(&state, input, fetched_models).await;
}
let message = response
.text()
.await
.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!(
"获取模型列表失败 ({status}): {message}"
)));
}
let mut fetched_models: Vec<ProviderModel> = response
.json::<ModelsResponse>()
.await?
.data
.into_iter()
.filter(|model| is_image_model(&model.id))
.map(|model| ProviderModel {
id: model.id,
owned_by: model.owned_by,
})
.collect();
if input.kind == "volcengine-ark" && fetched_models.is_empty() {
fetched_models = default_seedream_models();
}
save_provider_models(&state, input, fetched_models).await
}
async fn save_provider_models(
state: &State<'_, AppState>,
input: UpsertProviderInput,
fetched_models: Vec<ProviderModel>,
) -> Result<Vec<ProviderModel>, AppError> {
let saved_providers = repository::list_providers(&state.db)
.await
.unwrap_or_default();
let saved_capabilities = saved_providers
.iter()
.find(|provider| provider.id == input.id)
.and_then(|provider| provider.capabilities.as_deref())
.and_then(|value| serde_json::from_str::<ProviderCapabilities>(value).ok());
let mut models = saved_capabilities
.as_ref()
.and_then(|capabilities| capabilities.image_models.clone())
.unwrap_or_default();
models.extend(fetched_models);
models.sort_by(|left, right| left.id.cmp(&right.id));
models.dedup_by(|left, right| left.id == right.id);
let model_ids = models
.iter()
.map(|model| model.id.clone())
.collect::<Vec<_>>();
let mut selected_model_ids = saved_capabilities
.and_then(|capabilities| capabilities.selected_image_models)
.unwrap_or_default()
.into_iter()
.filter(|model_id| model_ids.contains(model_id))
.collect::<Vec<_>>();
for model_id in &model_ids {
if !selected_model_ids.contains(model_id) {
selected_model_ids.push(model_id.clone());
}
}
let capabilities = json!({
"responses_api": true,
"images_api": true,
"chat_completions": true,
"image_edit": true,
"image_models": models,
"selected_image_models": selected_model_ids,
});
repository::upsert_provider(
&state.db,
UpsertProviderInput {
capabilities: Some(capabilities.to_string()),
..input
},
)
.await?;
Ok(models)
}

View File

@@ -1,6 +1,6 @@
use std::fs;
use sqlx::{sqlite::SqliteConnectOptions, SqlitePool};
use sqlx::{sqlite::SqliteConnectOptions, Row, SqlitePool};
use tauri::{AppHandle, Manager};
use crate::AppError;
@@ -24,6 +24,38 @@ pub async fn init(app: &AppHandle) -> Result<SqlitePool, AppError> {
sqlx::query(statement).execute(&pool).await?;
}
}
ensure_column(
&pool,
"providers",
"capabilities",
"TEXT NOT NULL DEFAULT '{}'",
)
.await?;
Ok(pool)
}
async fn ensure_column(
pool: &SqlitePool,
table: &str,
column: &str,
definition: &str,
) -> Result<(), AppError> {
let rows = sqlx::query(&format!("PRAGMA table_info({table})"))
.fetch_all(pool)
.await?;
let has_column = rows.iter().any(|row| {
row.try_get::<String, _>("name")
.map(|name| name == column)
.unwrap_or(false)
});
if !has_column {
sqlx::query(&format!(
"ALTER TABLE {table} ADD COLUMN {column} {definition}"
))
.execute(pool)
.await?;
}
Ok(())
}

View File

@@ -10,6 +10,7 @@ pub struct ProviderConfig {
pub api_key: Option<String>,
pub text_model: Option<String>,
pub image_model: Option<String>,
pub capabilities: Option<String>,
pub enabled: bool,
}
@@ -33,6 +34,7 @@ pub struct UpsertProviderInput {
pub api_key: Option<String>,
pub text_model: Option<String>,
pub image_model: Option<String>,
pub capabilities: Option<String>,
pub enabled: bool,
}

View File

@@ -19,8 +19,10 @@ pub async fn list_providers(pool: &SqlitePool) -> Result<Vec<ProviderConfig>, Ap
api_key_encrypted AS api_key,
text_model,
image_model,
capabilities,
enabled != 0 AS enabled
FROM providers
WHERE enabled != 0
ORDER BY updated_at DESC
"#,
)
@@ -30,24 +32,34 @@ pub async fn list_providers(pool: &SqlitePool) -> Result<Vec<ProviderConfig>, Ap
Ok(providers)
}
pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> Result<(), AppError> {
pub async fn upsert_provider(
pool: &SqlitePool,
input: UpsertProviderInput,
) -> Result<(), AppError> {
let now = Utc::now().to_rfc3339();
let existing_api_key: Option<String> = sqlx::query_scalar(
let existing: Option<(Option<String>, Option<String>)> = sqlx::query_as(
r#"
SELECT api_key_encrypted
SELECT api_key_encrypted, capabilities
FROM providers
WHERE id = ?1
"#,
)
.bind(&input.id)
.fetch_optional(pool)
.await?
.flatten();
let api_key = input
.api_key
.filter(|key| !key.trim().is_empty())
.or(existing_api_key);
.await?;
let api_key = match input.api_key {
Some(key) if key.trim().is_empty() => None,
Some(key) => Some(key),
None => existing.as_ref().and_then(|item| item.0.clone()),
};
let capabilities = input
.capabilities
.filter(|value| !value.trim().is_empty())
.or_else(|| existing.and_then(|item| item.1))
.unwrap_or_else(|| {
r#"{"responses_api":true,"images_api":true,"chat_completions":true,"image_edit":true,"image_models":[],"selected_image_models":[]}"#.to_string()
});
sqlx::query(
r#"
@@ -62,6 +74,7 @@ pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> R
api_key_encrypted = excluded.api_key_encrypted,
text_model = excluded.text_model,
image_model = excluded.image_model,
capabilities = excluded.capabilities,
enabled = excluded.enabled,
updated_at = excluded.updated_at
"#,
@@ -73,7 +86,7 @@ pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> R
.bind(api_key)
.bind(input.text_model)
.bind(input.image_model)
.bind(r#"{"responses_api":true,"images_api":true,"chat_completions":true,"image_edit":true}"#)
.bind(capabilities)
.bind(input.enabled)
.bind(now)
.execute(pool)
@@ -98,6 +111,34 @@ pub async fn get_provider_secret(pool: &SqlitePool, id: &str) -> Result<Provider
}
pub async fn delete_provider(pool: &SqlitePool, id: &str) -> Result<(), AppError> {
let reference_count: i64 = sqlx::query_scalar(
r#"
SELECT
(SELECT COUNT(*) FROM generation_tasks WHERE provider_id = ?1) +
(SELECT COUNT(*) FROM conversations WHERE provider_id = ?1) +
(SELECT COUNT(*) FROM ai_request_logs WHERE provider_id = ?1)
"#,
)
.bind(id)
.fetch_one(pool)
.await?;
if reference_count > 0 {
let now = Utc::now().to_rfc3339();
sqlx::query(
r#"
UPDATE providers
SET enabled = 0, updated_at = ?2
WHERE id = ?1
"#,
)
.bind(id)
.bind(now)
.execute(pool)
.await?;
return Ok(());
}
sqlx::query("DELETE FROM providers WHERE id = ?1")
.bind(id)
.execute(pool)

View File

@@ -1,8 +1,8 @@
mod ai;
mod commands;
mod db;
mod storage;
mod state;
mod storage;
use serde::Serialize;
use state::AppState;
@@ -53,6 +53,7 @@ pub fn run() {
commands::provider::list_providers,
commands::provider::upsert_provider,
commands::provider::delete_provider,
commands::provider::fetch_provider_models,
commands::dialog::pick_material_images,
commands::file::reveal_path,
commands::file::open_generated_dir,