feat: add multi-provider image generation

This commit is contained in:
2026-04-29 15:47:18 +08:00
parent 20bd25e136
commit a4198f29d2
18 changed files with 3397 additions and 652 deletions

View File

@@ -16,17 +16,18 @@
"build:win": "tauri build"
},
"dependencies": {
"@ant-design/icons": "^6.2.2",
"@tauri-apps/api": "^2.0.0",
"@tauri-apps/plugin-dialog": "^2.0.0",
"@vitejs/plugin-react": "^5.0.0",
"vite": "^7.0.0",
"typescript": "^5.0.0",
"react": "^19.0.0",
"react-dom": "^19.0.0"
"react-dom": "^19.0.0",
"typescript": "^5.0.0",
"vite": "^7.0.0"
},
"devDependencies": {
"@tauri-apps/cli": "^2.0.0",
"@types/react": "^19.0.0",
"@types/react-dom": "^19.0.0",
"@tauri-apps/cli": "^2.0.0"
"@types/react-dom": "^19.0.0"
}
}

66
pnpm-lock.yaml generated
View File

@@ -8,6 +8,9 @@ importers:
.:
dependencies:
'@ant-design/icons':
specifier: ^6.2.2
version: 6.2.2(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
'@tauri-apps/api':
specifier: ^2.0.0
version: 2.10.1
@@ -42,6 +45,23 @@ importers:
packages:
'@ant-design/colors@8.0.1':
resolution: {integrity: sha512-foPVl0+SWIslGUtD/xBr1p9U4AKzPhNYEseXYRRo5QSzGACYZrQbe11AYJbYfAWnWSpGBx6JjBmSeugUsD9vqQ==}
'@ant-design/fast-color@3.0.1':
resolution: {integrity: sha512-esKJegpW4nckh0o6kV3Tkb7NPIZYbPnnFxmQDUmL08ukXZAvV85TZBr70eGuke/CIArLaP6aw8lt9KILjnWuOw==}
engines: {node: '>=8.x'}
'@ant-design/icons-svg@4.4.2':
resolution: {integrity: sha512-vHbT+zJEVzllwP+CM+ul7reTEfBR0vgxFe7+lREAsAA7YGsYpboiq2sQNeQeRvh09GfQgs/GyFEvZpJ9cLXpXA==}
'@ant-design/icons@6.2.2':
resolution: {integrity: sha512-zlJtE7AMbG12TeYVPhtBXwNpFInNy8mjLzcIm+0BPw16/b8ODG87YJ1G37VIF5VFscdgfsf6EweAFPTobu/3iQ==}
engines: {node: '>=8'}
peerDependencies:
react: '>=16.0.0'
react-dom: '>=16.0.0'
'@babel/code-frame@7.29.0':
resolution: {integrity: sha512-9NhCeYjq9+3uxgdtp20LSiJXJvN0FeCtNGpJxuMFZ1Kv3cWUNb6DOhJwUvcVCzKGR66cw4njwM6hrJLqgOwbcw==}
engines: {node: '>=6.9.0'}
@@ -297,6 +317,12 @@ packages:
'@jridgewell/trace-mapping@0.3.31':
resolution: {integrity: sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==}
'@rc-component/util@1.10.1':
resolution: {integrity: sha512-q++9S6rUa5Idb/xIBNz6jtvumw5+O5YV5V0g4iK9mn9jWs4oGJheE3ZN1kAnE723AXyaD8v95yeOASmdk8Jnng==}
peerDependencies:
react: '>=18.0.0'
react-dom: '>=18.0.0'
'@rolldown/pluginutils@1.0.0-rc.3':
resolution: {integrity: sha512-eybk3TjzzzV97Dlj5c+XrBFW57eTNhzod66y9HrBlzJ6NsCrWCp/2kaPS3K9wJmurBC0Tdw4yPjXKZqlznim3Q==}
@@ -544,6 +570,10 @@ packages:
caniuse-lite@1.0.30001790:
resolution: {integrity: sha512-bOoxfJPyYo+ds6W0YfptaCWbFnJYjh2Y1Eow5lRv+vI2u8ganPZqNm1JwNh0t2ELQCqIWg4B3dWEusgAmsoyOw==}
clsx@2.1.1:
resolution: {integrity: sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==}
engines: {node: '>=6'}
convert-source-map@2.0.0:
resolution: {integrity: sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==}
@@ -589,6 +619,9 @@ packages:
resolution: {integrity: sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==}
engines: {node: '>=6.9.0'}
is-mobile@5.0.0:
resolution: {integrity: sha512-Tz/yndySvLAEXh+Uk8liFCxOwVH6YutuR74utvOcu7I9Di+DwM0mtdPVZNaVvvBUM2OXxne/NhOs1zAO7riusQ==}
js-tokens@4.0.0:
resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==}
@@ -632,6 +665,9 @@ packages:
peerDependencies:
react: ^19.2.5
react-is@18.3.1:
resolution: {integrity: sha512-/LLMVyas0ljjAtoYiPqYiL8VWXzUUdThrmU5+n20DZv+a+ClRoevUzw5JxU+Ieh5/c87ytoTBV9G1FiKfNJdmg==}
react-refresh@0.18.0:
resolution: {integrity: sha512-QgT5//D3jfjJb6Gsjxv0Slpj23ip+HtOpnNgnb2S5zU3CB26G/IDPGoy4RJB42wzFE46DRsstbW6tKHoKbhAxw==}
engines: {node: '>=0.10.0'}
@@ -716,6 +752,23 @@ packages:
snapshots:
'@ant-design/colors@8.0.1':
dependencies:
'@ant-design/fast-color': 3.0.1
'@ant-design/fast-color@3.0.1': {}
'@ant-design/icons-svg@4.4.2': {}
'@ant-design/icons@6.2.2(react-dom@19.2.5(react@19.2.5))(react@19.2.5)':
dependencies:
'@ant-design/colors': 8.0.1
'@ant-design/icons-svg': 4.4.2
'@rc-component/util': 1.10.1(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
clsx: 2.1.1
react: 19.2.5
react-dom: 19.2.5(react@19.2.5)
'@babel/code-frame@7.29.0':
dependencies:
'@babel/helper-validator-identifier': 7.28.5
@@ -925,6 +978,13 @@ snapshots:
'@jridgewell/resolve-uri': 3.1.2
'@jridgewell/sourcemap-codec': 1.5.5
'@rc-component/util@1.10.1(react-dom@19.2.5(react@19.2.5))(react@19.2.5)':
dependencies:
is-mobile: 5.0.0
react: 19.2.5
react-dom: 19.2.5(react@19.2.5)
react-is: 18.3.1
'@rolldown/pluginutils@1.0.0-rc.3': {}
'@rollup/rollup-android-arm-eabi@4.60.2':
@@ -1110,6 +1170,8 @@ snapshots:
caniuse-lite@1.0.30001790: {}
clsx@2.1.1: {}
convert-source-map@2.0.0: {}
csstype@3.2.3: {}
@@ -1160,6 +1222,8 @@ snapshots:
gensync@1.0.0-beta.2: {}
is-mobile@5.0.0: {}
js-tokens@4.0.0: {}
jsesc@3.1.0: {}
@@ -1191,6 +1255,8 @@ snapshots:
react: 19.2.5
scheduler: 0.27.0
react-is@18.3.1: {}
react-refresh@0.18.0: {}
react@19.2.5: {}

View File

@@ -26,6 +26,9 @@ reqwest = { version = "0.12", default-features = false, features = ["json", "mul
thiserror = "2"
async-trait = "0.1"
base64 = "0.22"
hex = "0.4"
hmac = "0.12"
sha2 = "0.10"
[features]
default = ["custom-protocol"]

View File

@@ -0,0 +1,354 @@
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use reqwest::Client;
use serde_json::{json, Value};
use std::{fs, path::Path, time::Duration};
use tokio::time::sleep;
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
use crate::AppError;
pub struct DashScopeProvider {
client: Client,
base_url: String,
api_key: String,
}
impl DashScopeProvider {
pub fn new(base_url: String, api_key: String) -> Self {
Self {
client: Client::new(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key,
}
}
}
fn is_qwen_image_model(model: &str) -> bool {
model.starts_with("qwen-image")
}
fn is_sync_multimodal_model(model: &str) -> bool {
is_qwen_image_model(model) || model.starts_with("z-image")
}
fn dashscope_size(size: Option<&str>) -> Option<String> {
size.map(|value| value.replace('x', "*"))
}
fn mime_for_path(path: &str) -> &'static str {
match Path::new(path)
.extension()
.and_then(|extension| extension.to_str())
.map(|extension| extension.to_ascii_lowercase())
.as_deref()
{
Some("jpg" | "jpeg") => "image/jpeg",
Some("webp") => "image/webp",
_ => "image/png",
}
}
fn path_to_data_url(path: &str) -> Result<String, AppError> {
let bytes = fs::read(path)?;
Ok(format!(
"data:{};base64,{}",
mime_for_path(path),
general_purpose::STANDARD.encode(bytes)
))
}
fn parse_dashscope_error(response_text: &str, fallback: &str) -> String {
serde_json::from_str::<Value>(response_text)
.ok()
.and_then(|value| {
let code = value
.get("code")
.and_then(Value::as_str)
.unwrap_or_default();
let message = value
.get("message")
.and_then(Value::as_str)
.unwrap_or_default();
if code.is_empty() && message.is_empty() {
None
} else {
Some(format!("{code}: {message}"))
}
})
.unwrap_or_else(|| fallback.to_string())
}
fn image_url_from_choices(value: &Value) -> Option<String> {
value
.get("output")?
.get("choices")?
.as_array()?
.iter()
.flat_map(|choice| {
choice
.get("message")
.and_then(|message| message.get("content"))
.and_then(Value::as_array)
.into_iter()
.flatten()
})
.find_map(|content| {
content
.get("image")
.and_then(Value::as_str)
.map(str::to_string)
})
}
fn image_url_from_results(value: &Value) -> Option<String> {
value
.get("output")?
.get("results")?
.as_array()?
.iter()
.find_map(|item| item.get("url").and_then(Value::as_str).map(str::to_string))
}
fn parse_image_url(value: &Value) -> Option<String> {
image_url_from_choices(value).or_else(|| image_url_from_results(value))
}
fn build_messages(prompt: &str, image_paths: &[String]) -> Result<Vec<Value>, AppError> {
let mut content = image_paths
.iter()
.map(|path| path_to_data_url(path).map(|image| json!({ "image": image })))
.collect::<Result<Vec<_>, _>>()?;
content.push(json!({ "text": prompt }));
Ok(vec![json!({
"role": "user",
"content": content,
})])
}
#[async_trait]
impl AiProvider for DashScopeProvider {
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
if is_sync_multimodal_model(&request.model) {
return self
.run_qwen_image(
&request.prompt,
&request.model,
request.size.as_deref(),
&[],
)
.await;
}
self.run_async_image_generation(
&request.prompt,
&request.model,
request.size.as_deref(),
&[],
)
.await
}
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
if is_sync_multimodal_model(&request.model) {
return self
.run_qwen_image(
&request.prompt,
&request.model,
request.size.as_deref(),
&request.image_paths,
)
.await;
}
self.run_async_image_generation(
&request.prompt,
&request.model,
request.size.as_deref(),
&request.image_paths,
)
.await
}
}
impl DashScopeProvider {
async fn run_qwen_image(
&self,
prompt: &str,
model: &str,
size: Option<&str>,
image_paths: &[String],
) -> Result<ImageResult, AppError> {
let mut parameters = json!({
"n": 1,
"watermark": false,
"prompt_extend": true,
});
if is_qwen_image_model(model) {
parameters["negative_prompt"] = json!(" ");
}
if let Some(size) = dashscope_size(size) {
parameters["size"] = json!(size);
}
let body = json!({
"model": model,
"input": {
"messages": build_messages(prompt, image_paths)?,
},
"parameters": parameters,
});
let response = self
.client
.post(format!(
"{}/services/aigc/multimodal-generation/generation",
self.base_url
))
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
return Err(AppError::Provider(format!(
"DashScope image generation failed ({status}): {}",
parse_dashscope_error(&response_text, &response_text)
)));
}
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode DashScope response: {error}; response body: {response_text}"
))
})?;
let image_url = parse_image_url(&response_body).ok_or_else(|| {
AppError::Provider(format!(
"DashScope response did not include image url: {response_text}"
))
})?;
Ok(ImageResult {
mime_type: "image/png".to_string(),
data: ImageData::Url(image_url),
})
}
async fn run_async_image_generation(
&self,
prompt: &str,
model: &str,
size: Option<&str>,
image_paths: &[String],
) -> Result<ImageResult, AppError> {
let mut parameters = json!({
"n": 1,
"watermark": false,
});
if let Some(size) = dashscope_size(size) {
parameters["size"] = json!(size);
}
if model.starts_with("wan2.7-image") {
parameters["thinking_mode"] = json!(true);
} else {
parameters["prompt_extend"] = json!(true);
}
let body = json!({
"model": model,
"input": {
"messages": build_messages(prompt, image_paths)?,
},
"parameters": parameters,
});
let response = self
.client
.post(format!(
"{}/services/aigc/image-generation/generation",
self.base_url
))
.bearer_auth(&self.api_key)
.header("X-DashScope-Async", "enable")
.json(&body)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
return Err(AppError::Provider(format!(
"DashScope async image task failed ({status}): {}",
parse_dashscope_error(&response_text, &response_text)
)));
}
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode DashScope task response: {error}; response body: {response_text}"
))
})?;
let task_id = response_body
.get("output")
.and_then(|output| output.get("task_id"))
.and_then(Value::as_str)
.ok_or_else(|| {
AppError::Provider(format!(
"DashScope task response did not include task_id: {response_text}"
))
})?;
self.wait_async_task(task_id).await
}
async fn wait_async_task(&self, task_id: &str) -> Result<ImageResult, AppError> {
for _ in 0..40 {
sleep(Duration::from_secs(3)).await;
let response = self
.client
.get(format!("{}/tasks/{task_id}", self.base_url))
.bearer_auth(&self.api_key)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
return Err(AppError::Provider(format!(
"DashScope task polling failed ({status}): {}",
parse_dashscope_error(&response_text, &response_text)
)));
}
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode DashScope task result: {error}; response body: {response_text}"
))
})?;
let task_status = response_body
.get("output")
.and_then(|output| output.get("task_status"))
.and_then(Value::as_str)
.unwrap_or_default();
if task_status == "SUCCEEDED" {
let image_url = parse_image_url(&response_body).ok_or_else(|| {
AppError::Provider(format!(
"DashScope task succeeded but did not include image url: {response_text}"
))
})?;
return Ok(ImageResult {
mime_type: "image/png".to_string(),
data: ImageData::Url(image_url),
});
}
if matches!(task_status, "FAILED" | "CANCELED" | "UNKNOWN") {
return Err(AppError::Provider(format!(
"DashScope task ended with status {task_status}: {}",
parse_dashscope_error(&response_text, &response_text)
)));
}
}
Err(AppError::Provider(
"DashScope image task polling timed out".to_string(),
))
}
}

View File

@@ -0,0 +1,189 @@
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use reqwest::Client;
use serde_json::{json, Value};
use std::{fs, path::Path};
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
use crate::AppError;
pub struct GoogleGeminiProvider {
client: Client,
base_url: String,
api_key: String,
}
impl GoogleGeminiProvider {
pub fn new(base_url: String, api_key: String) -> Self {
Self {
client: Client::new(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key,
}
}
}
fn mime_for_path(path: &str) -> &'static str {
match Path::new(path)
.extension()
.and_then(|extension| extension.to_str())
.map(|extension| extension.to_ascii_lowercase())
.as_deref()
{
Some("jpg" | "jpeg") => "image/jpeg",
Some("webp") => "image/webp",
_ => "image/png",
}
}
fn size_to_aspect_ratio(size: Option<&str>) -> Option<&'static str> {
let (width, height) = size?.split_once('x')?;
let width = width.parse::<f32>().ok()?;
let height = height.parse::<f32>().ok()?;
if width <= 0.0 || height <= 0.0 {
return None;
}
let ratio = width / height;
if (ratio - 1.0).abs() < 0.08 {
Some("1:1")
} else if (ratio - 16.0 / 9.0).abs() < 0.12 {
Some("16:9")
} else if (ratio - 9.0 / 16.0).abs() < 0.12 {
Some("9:16")
} else if (ratio - 4.0 / 3.0).abs() < 0.12 {
Some("4:3")
} else if (ratio - 3.0 / 4.0).abs() < 0.12 {
Some("3:4")
} else {
None
}
}
fn image_part(path: &str) -> Result<Value, AppError> {
let bytes = fs::read(path)?;
Ok(json!({
"inline_data": {
"mime_type": mime_for_path(path),
"data": general_purpose::STANDARD.encode(bytes),
}
}))
}
fn inline_data_from_part(part: &Value) -> Option<(&str, &str)> {
let inline_data = part.get("inlineData").or_else(|| part.get("inline_data"))?;
let data = inline_data.get("data").and_then(Value::as_str)?;
let mime_type = inline_data
.get("mimeType")
.or_else(|| inline_data.get("mime_type"))
.and_then(Value::as_str)
.unwrap_or("image/png");
Some((mime_type, data))
}
fn parse_gemini_response(response_text: &str) -> Result<ImageResult, AppError> {
let response_body = serde_json::from_str::<Value>(response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode Gemini image response: {error}; response body: {response_text}"
))
})?;
let parts = response_body
.get("candidates")
.and_then(Value::as_array)
.into_iter()
.flatten()
.flat_map(|candidate| {
candidate
.get("content")
.and_then(|content| content.get("parts"))
.and_then(Value::as_array)
.into_iter()
.flatten()
});
for part in parts {
if let Some((mime_type, data)) = inline_data_from_part(part) {
return Ok(ImageResult {
mime_type: mime_type.to_string(),
data: ImageData::Base64(data.to_string()),
});
}
}
Err(AppError::Provider(format!(
"Gemini response did not include image data: {response_text}"
)))
}
#[async_trait]
impl AiProvider for GoogleGeminiProvider {
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
self.generate_with_parts(
&request.prompt,
&request.model,
request.size.as_deref(),
&[],
)
.await
}
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
self.generate_with_parts(
&request.prompt,
&request.model,
request.size.as_deref(),
&request.image_paths,
)
.await
}
}
impl GoogleGeminiProvider {
async fn generate_with_parts(
&self,
prompt: &str,
model: &str,
size: Option<&str>,
image_paths: &[String],
) -> Result<ImageResult, AppError> {
let mut parts = vec![json!({ "text": prompt })];
parts.extend(
image_paths
.iter()
.map(|path| image_part(path))
.collect::<Result<Vec<_>, _>>()?,
);
let mut generation_config = json!({
"responseModalities": ["TEXT", "IMAGE"],
});
if let Some(aspect_ratio) = size_to_aspect_ratio(size) {
generation_config["imageConfig"] = json!({
"aspectRatio": aspect_ratio,
});
}
let body = json!({
"contents": [{
"role": "user",
"parts": parts,
}],
"generationConfig": generation_config,
});
let response = self
.client
.post(format!("{}/models/{model}:generateContent", self.base_url))
.header("x-goog-api-key", &self.api_key)
.json(&body)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
return Err(AppError::Provider(format!(
"Gemini image generation failed ({status}): {response_text}"
)));
}
parse_gemini_response(&response_text)
}
}

View File

@@ -1,2 +1,6 @@
pub mod dashscope;
pub mod google_gemini;
pub mod openai_compatible;
pub mod provider;
pub mod seedream;
pub mod tencent_hunyuan;

View File

@@ -32,7 +32,8 @@ fn parse_image_response(response_text: &str) -> Result<ImageResult, AppError> {
));
}
let response_body: ImageResponseBody = serde_json::from_str(response_text).map_err(|error| {
let response_body: ImageResponseBody =
serde_json::from_str(response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode image response: {error}; response body: {response_text}"
))
@@ -45,7 +46,9 @@ fn parse_image_response(response_text: &str) -> Result<ImageResult, AppError> {
.map(ImageData::Base64)
.or_else(|| item.url.map(ImageData::Url))
})
.ok_or_else(|| AppError::Provider("image response did not include b64_json or url".to_string()))?;
.ok_or_else(|| {
AppError::Provider("image response did not include b64_json or url".to_string())
})?;
Ok(ImageResult {
mime_type: "image/png".to_string(),
@@ -96,8 +99,13 @@ impl AiProvider for OpenAiCompatibleProvider {
let status = response.status();
if !status.is_success() {
let message = response.text().await.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!("image generation failed ({status}): {message}")));
let message = response
.text()
.await
.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!(
"image generation failed ({status}): {message}"
)));
}
let response_text = response.text().await?;
@@ -134,7 +142,9 @@ impl AiProvider for OpenAiCompatibleProvider {
Some("webp") => "image/webp",
_ => "image/png",
};
let part = multipart::Part::bytes(bytes).file_name(file_name).mime_str(mime)?;
let part = multipart::Part::bytes(bytes)
.file_name(file_name)
.mime_str(mime)?;
form = form.part("image[]", part);
}
@@ -148,8 +158,13 @@ impl AiProvider for OpenAiCompatibleProvider {
let status = response.status();
if !status.is_success() {
let message = response.text().await.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!("image edit failed ({status}): {message}")));
let message = response
.text()
.await
.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!(
"image edit failed ({status}): {message}"
)));
}
let response_text = response.text().await?;

View File

@@ -0,0 +1,180 @@
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::{fs, path::Path};
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
use crate::AppError;
pub struct SeedreamProvider {
client: Client,
base_url: String,
api_key: String,
}
impl SeedreamProvider {
pub fn new(base_url: String, api_key: String) -> Self {
Self {
client: Client::new(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key,
}
}
}
#[derive(Debug, Serialize)]
struct SeedreamRequestBody<'a> {
model: &'a str,
prompt: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
size: Option<&'a str>,
response_format: &'a str,
stream: bool,
watermark: bool,
#[serde(skip_serializing_if = "Option::is_none")]
image: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
struct SeedreamResponseBody {
data: Vec<SeedreamResponseItem>,
}
#[derive(Debug, Deserialize)]
struct SeedreamResponseItem {
b64_json: Option<String>,
url: Option<String>,
}
fn mime_for_path(path: &str) -> &'static str {
match Path::new(path)
.extension()
.and_then(|extension| extension.to_str())
.map(|extension| extension.to_ascii_lowercase())
.as_deref()
{
Some("jpg" | "jpeg") => "image/jpeg",
Some("webp") => "image/webp",
_ => "image/png",
}
}
fn path_to_data_url(path: &str) -> Result<String, AppError> {
let bytes = fs::read(path)?;
Ok(format!(
"data:{};base64,{}",
mime_for_path(path),
general_purpose::STANDARD.encode(bytes)
))
}
fn parse_seedream_response(response_text: &str) -> Result<ImageResult, AppError> {
if response_text.trim_start().starts_with("<!doctype html")
|| response_text.trim_start().starts_with("<html")
{
return Err(AppError::Provider(
"火山方舟返回了 HTML 页面,不是 API JSON 响应。请检查 Base URL 是否为 API 地址。"
.to_string(),
));
}
let response_body: SeedreamResponseBody =
serde_json::from_str(response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode Seedream response: {error}; response body: {response_text}"
))
})?;
let image_data = response_body
.data
.into_iter()
.find_map(|item| {
item.b64_json
.map(ImageData::Base64)
.or_else(|| item.url.map(ImageData::Url))
})
.ok_or_else(|| {
AppError::Provider("Seedream response did not include b64_json or url".to_string())
})?;
Ok(ImageResult {
mime_type: "image/png".to_string(),
data: image_data,
})
}
#[async_trait]
impl AiProvider for SeedreamProvider {
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
let body = SeedreamRequestBody {
model: &request.model,
prompt: &request.prompt,
size: request.size.as_deref(),
response_format: "b64_json",
stream: false,
watermark: false,
image: None,
};
let response = self
.client
.post(format!("{}/images/generations", self.base_url))
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let message = response
.text()
.await
.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!(
"Seedream image generation failed ({status}): {message}"
)));
}
let response_text = response.text().await?;
parse_seedream_response(&response_text)
}
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
let images = request
.image_paths
.iter()
.map(|path| path_to_data_url(path))
.collect::<Result<Vec<_>, _>>()?;
let body = SeedreamRequestBody {
model: &request.model,
prompt: &request.prompt,
size: request.size.as_deref(),
response_format: "b64_json",
stream: false,
watermark: false,
image: Some(images),
};
let response = self
.client
.post(format!("{}/images/generations", self.base_url))
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let message = response
.text()
.await
.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!(
"Seedream image edit failed ({status}): {message}"
)));
}
let response_text = response.text().await?;
parse_seedream_response(&response_text)
}
}

View File

@@ -0,0 +1,463 @@
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use chrono::{TimeZone, Utc};
use hmac::{Hmac, Mac};
use reqwest::{Client, Url};
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::{fs, time::Duration};
use tokio::time::sleep;
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
use crate::AppError;
type HmacSha256 = Hmac<Sha256>;
pub struct TencentHunyuanProvider {
client: Client,
base_url: String,
secret_id: String,
secret_key: String,
}
struct TencentApiConfig {
service: &'static str,
version: &'static str,
lite_action: &'static str,
rapid_action: &'static str,
submit_action: &'static str,
query_action: &'static str,
}
impl TencentHunyuanProvider {
pub fn new(base_url: String, api_key: String) -> Result<Self, AppError> {
let (secret_id, secret_key) = parse_secret_pair(&api_key)?;
Ok(Self {
client: Client::new(),
base_url: base_url.trim_end_matches('/').to_string(),
secret_id,
secret_key,
})
}
}
fn parse_secret_pair(api_key: &str) -> Result<(String, String), AppError> {
let (secret_id, secret_key) = api_key.split_once(':').ok_or_else(|| {
AppError::Provider("腾讯云 API Key 需要填写为 SecretId:SecretKey".to_string())
})?;
let secret_id = secret_id.trim();
let secret_key = secret_key.trim();
if secret_id.is_empty() || secret_key.is_empty() {
return Err(AppError::Provider(
"腾讯云 API Key 需要填写为 SecretId:SecretKey".to_string(),
));
}
Ok((secret_id.to_string(), secret_key.to_string()))
}
fn hmac_sha256(key: &[u8], message: &str) -> Result<Vec<u8>, AppError> {
let mut mac = HmacSha256::new_from_slice(key)
.map_err(|error| AppError::Provider(format!("failed to create HMAC: {error}")))?;
mac.update(message.as_bytes());
Ok(mac.finalize().into_bytes().to_vec())
}
fn sha256_hex(value: &str) -> String {
hex::encode(Sha256::digest(value.as_bytes()))
}
fn host_from_base_url(base_url: &str) -> Result<String, AppError> {
let url = Url::parse(base_url)
.map_err(|error| AppError::Provider(format!("腾讯云 Base URL 不是有效 URL: {error}")))?;
url.host_str()
.map(str::to_string)
.ok_or_else(|| AppError::Provider("腾讯云 Base URL 缺少 host".to_string()))
}
fn api_config(host: &str) -> TencentApiConfig {
if host == "aiart.tencentcloudapi.com" {
TencentApiConfig {
service: "aiart",
version: "2022-12-29",
lite_action: "TextToImageLite",
rapid_action: "TextToImageRapid",
submit_action: "SubmitTextToImageJob",
query_action: "QueryTextToImageJob",
}
} else {
TencentApiConfig {
service: "hunyuan",
version: "2023-09-01",
lite_action: "TextToImageLite",
rapid_action: "TextToImageLite",
submit_action: "SubmitHunyuanImageJob",
query_action: "QueryHunyuanImageJob",
}
}
}
fn hunyuan_resolution(size: Option<&str>, lite: bool, has_reference: bool) -> Option<String> {
let size = size?;
let (width, height) = size.split_once('x')?;
let width = width.parse::<f32>().ok()?;
let height = height.parse::<f32>().ok()?;
if width <= 0.0 || height <= 0.0 {
return None;
}
let ratio = width / height;
let value = if (ratio - 1.0).abs() < 0.12 {
"1024:1024"
} else if (ratio - 0.75).abs() < 0.12 {
"768:1024"
} else if (ratio - 1.333).abs() < 0.12 {
"1024:768"
} else if ratio < 0.65 {
if lite && !has_reference {
"1080:1920"
} else {
"720:1280"
}
} else if ratio > 1.55 {
if lite && !has_reference {
"1920:1080"
} else {
"1280:720"
}
} else if width < height {
"768:1024"
} else {
"1024:768"
};
if has_reference && !matches!(value, "1024:1024" | "768:1024" | "1024:768") {
return None;
}
Some(value.to_string())
}
fn first_image_base64(image_paths: &[String]) -> Result<Option<String>, AppError> {
image_paths
.first()
.map(|path| fs::read(path).map(|bytes| general_purpose::STANDARD.encode(bytes)))
.transpose()
.map_err(AppError::from)
}
fn parse_tencent_error(response: &Value) -> Option<String> {
let error = response.get("Error")?;
let code = error
.get("Code")
.and_then(Value::as_str)
.unwrap_or_default();
let message = error
.get("Message")
.and_then(Value::as_str)
.unwrap_or_default();
Some(format!("{code}: {message}"))
}
fn result_image_url(response: &Value) -> Option<String> {
let result = response.get("ResultImage")?;
if let Some(url) = result.as_str().filter(|value| value.starts_with("http")) {
return Some(url.to_string());
}
result.as_array()?.first().and_then(|image| {
image
.as_str()
.filter(|value| value.starts_with("http"))
.map(str::to_string)
.or_else(|| {
image
.get("Url")
.or_else(|| image.get("URL"))
.and_then(Value::as_str)
.map(str::to_string)
})
})
}
fn result_image_base64(response: &Value) -> Option<String> {
let result = response.get("ResultImage")?;
if let Some(value) = result.as_str().filter(|value| !value.starts_with("http")) {
return Some(value.to_string());
}
result.as_array()?.first().and_then(|image| {
image
.as_str()
.filter(|value| !value.starts_with("http"))
.map(str::to_string)
.or_else(|| {
image
.get("Base64")
.or_else(|| image.get("ImageBase64"))
.and_then(Value::as_str)
.map(str::to_string)
})
})
}
#[async_trait]
impl AiProvider for TencentHunyuanProvider {
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
if request.model == "hunyuan-image-lite" {
return self
.run_text_to_image_lite(&request.prompt, request.size.as_deref())
.await;
}
if request.model == "hunyuan-image-2.0" {
return self
.run_text_to_image_rapid(&request.prompt, request.size.as_deref(), &[])
.await;
}
self.run_async_image_job(&request.prompt, request.size.as_deref(), &[])
.await
}
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
if request.model == "hunyuan-image-2.0" {
return self
.run_text_to_image_rapid(
&request.prompt,
request.size.as_deref(),
&request.image_paths,
)
.await;
}
self.run_async_image_job(
&request.prompt,
request.size.as_deref(),
&request.image_paths,
)
.await
}
}
impl TencentHunyuanProvider {
async fn call(&self, action: &str, payload: Value) -> Result<Value, AppError> {
let region = "ap-guangzhou";
let host = host_from_base_url(&self.base_url)?;
let config = api_config(&host);
let timestamp = Utc::now().timestamp();
let date = Utc
.timestamp_opt(timestamp, 0)
.single()
.ok_or_else(|| AppError::Provider("invalid Tencent Cloud timestamp".to_string()))?
.format("%Y-%m-%d")
.to_string();
let payload_text = serde_json::to_string(&payload).map_err(|error| {
AppError::Provider(format!("failed to encode Tencent Cloud request: {error}"))
})?;
let content_type = "application/json; charset=utf-8";
let signed_headers = "content-type;host";
let canonical_headers = format!("content-type:{content_type}\nhost:{host}\n");
let canonical_request = format!(
"POST\n/\n\n{canonical_headers}\n{signed_headers}\n{}",
sha256_hex(&payload_text)
);
let credential_scope = format!("{}/{}/tc3_request", date, config.service);
let string_to_sign = format!(
"TC3-HMAC-SHA256\n{timestamp}\n{credential_scope}\n{}",
sha256_hex(&canonical_request)
);
let secret_date = hmac_sha256(format!("TC3{}", self.secret_key).as_bytes(), &date)?;
let secret_service = hmac_sha256(&secret_date, config.service)?;
let secret_signing = hmac_sha256(&secret_service, "tc3_request")?;
let signature = hex::encode(hmac_sha256(&secret_signing, &string_to_sign)?);
let authorization = format!(
"TC3-HMAC-SHA256 Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
self.secret_id
);
let response = self
.client
.post(&self.base_url)
.header("Authorization", authorization)
.header("Content-Type", content_type)
.header("Host", host)
.header("X-TC-Action", action)
.header("X-TC-Timestamp", timestamp.to_string())
.header("X-TC-Version", config.version)
.header("X-TC-Region", region)
.body(payload_text)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
return Err(AppError::Provider(format!(
"Tencent Hunyuan request failed ({status}): {response_text}"
)));
}
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
AppError::Provider(format!(
"failed to decode Tencent Hunyuan response: {error}; response body: {response_text}"
))
})?;
let inner = response_body.get("Response").ok_or_else(|| {
AppError::Provider(format!(
"Tencent Hunyuan response missing Response: {response_text}"
))
})?;
if let Some(message) = parse_tencent_error(inner) {
return Err(AppError::Provider(format!(
"Tencent Hunyuan {action} failed: {message}"
)));
}
Ok(inner.clone())
}
async fn run_text_to_image_lite(
&self,
prompt: &str,
size: Option<&str>,
) -> Result<ImageResult, AppError> {
let mut payload = json!({
"Prompt": prompt,
"RspImgType": "url",
"LogoAdd": 0,
});
if let Some(resolution) = hunyuan_resolution(size, true, false) {
payload["Resolution"] = json!(resolution);
}
let host = host_from_base_url(&self.base_url)?;
let action = api_config(&host).lite_action;
let response = self.call(action, payload).await?;
let image_url = result_image_url(&response).ok_or_else(|| {
AppError::Provider(format!(
"Tencent Hunyuan TextToImageLite did not include image url: {response}"
))
})?;
Ok(ImageResult {
mime_type: "image/png".to_string(),
data: ImageData::Url(image_url),
})
}
async fn run_text_to_image_rapid(
&self,
prompt: &str,
size: Option<&str>,
image_paths: &[String],
) -> Result<ImageResult, AppError> {
let has_reference = !image_paths.is_empty();
let mut payload = json!({
"Prompt": prompt,
"RspImgType": "url",
"LogoAdd": 0,
});
if let Some(resolution) = hunyuan_resolution(size, false, has_reference) {
payload["Resolution"] = json!(resolution);
}
if let Some(image_base64) = first_image_base64(image_paths)? {
payload["Image"] = json!({
"Base64": image_base64,
});
}
let host = host_from_base_url(&self.base_url)?;
let action = api_config(&host).rapid_action;
let response = self.call(action, payload).await?;
let image_url = result_image_url(&response).ok_or_else(|| {
AppError::Provider(format!(
"Tencent Hunyuan TextToImageRapid did not include image url: {response}"
))
})?;
Ok(ImageResult {
mime_type: "image/png".to_string(),
data: ImageData::Url(image_url),
})
}
async fn run_async_image_job(
&self,
prompt: &str,
size: Option<&str>,
image_paths: &[String],
) -> Result<ImageResult, AppError> {
let has_reference = !image_paths.is_empty();
let mut payload = json!({
"Prompt": prompt,
"Num": 1,
"Revise": 1,
"LogoAdd": 0,
});
if let Some(resolution) = hunyuan_resolution(size, false, has_reference) {
payload["Resolution"] = json!(resolution);
}
let host = host_from_base_url(&self.base_url)?;
let config = api_config(&host);
if let Some(image_base64) = first_image_base64(image_paths)? {
if config.service == "aiart" {
payload["Images"] = json!([image_base64]);
} else {
payload["ContentImage"] = json!({
"ImageBase64": image_base64,
});
}
}
let response = self.call(config.submit_action, payload).await?;
let job_id = response
.get("JobId")
.and_then(Value::as_str)
.ok_or_else(|| {
AppError::Provider(format!(
"Tencent Hunyuan submit response did not include JobId: {response}"
))
})?;
self.wait_image_job(job_id).await
}
async fn wait_image_job(&self, job_id: &str) -> Result<ImageResult, AppError> {
for _ in 0..40 {
sleep(Duration::from_secs(3)).await;
let host = host_from_base_url(&self.base_url)?;
let action = api_config(&host).query_action;
let response = self.call(action, json!({ "JobId": job_id })).await?;
let job_status = response
.get("JobStatusCode")
.and_then(Value::as_str)
.unwrap_or_default();
if job_status == "5" {
if let Some(image_base64) = result_image_base64(&response) {
return Ok(ImageResult {
mime_type: "image/png".to_string(),
data: ImageData::Base64(image_base64),
});
}
let image_url = result_image_url(&response).ok_or_else(|| {
AppError::Provider(format!(
"Tencent Hunyuan job succeeded but did not include image: {response}"
))
})?;
return Ok(ImageResult {
mime_type: "image/png".to_string(),
data: ImageData::Url(image_url),
});
}
if job_status == "4" {
let message = response
.get("JobErrorMsg")
.and_then(Value::as_str)
.unwrap_or("unknown error");
return Err(AppError::Provider(format!(
"Tencent Hunyuan image job failed: {message}"
)));
}
}
Err(AppError::Provider(
"Tencent Hunyuan image job polling timed out".to_string(),
))
}
}

View File

@@ -11,7 +11,11 @@ pub async fn pick_material_images(app: tauri::AppHandle) -> Result<Vec<String>,
let paths = files
.unwrap_or_default()
.into_iter()
.filter_map(|file_path| file_path.as_path().map(|path| path.to_string_lossy().to_string()))
.filter_map(|file_path| {
file_path
.as_path()
.map(|path| path.to_string_lossy().to_string())
})
.collect();
Ok(paths)

View File

@@ -2,16 +2,21 @@ use tauri::{AppHandle, State};
use crate::{
ai::{
dashscope::DashScopeProvider,
google_gemini::GoogleGeminiProvider,
openai_compatible::OpenAiCompatibleProvider,
provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest},
seedream::SeedreamProvider,
tencent_hunyuan::TencentHunyuanProvider,
},
db::{
models::{CreateGenerationTaskInput, GenerateImageInput, GenerateImageOutput, GenerationTask},
models::{
CreateGenerationTaskInput, GenerateImageInput, GenerateImageOutput, GenerationTask,
},
repository,
},
state::AppState,
storage,
AppError,
storage, AppError,
};
#[tauri::command]
@@ -30,21 +35,66 @@ pub async fn generate_image(
) -> Result<GenerateImageOutput, AppError> {
let provider = repository::get_provider_secret(&state.db, &input.provider_id).await?;
if !provider.enabled {
return Err(AppError::Provider(format!("provider {} is disabled", provider.name)));
return Err(AppError::Provider(format!(
"provider {} is disabled",
provider.name
)));
}
if provider.kind != "openai-compatible" {
if provider.kind != "openai"
&& provider.kind != "openai-compatible"
&& provider.kind != "volcengine-ark"
&& provider.kind != "dashscope"
&& provider.kind != "tencent-hunyuan"
&& provider.kind != "google-gemini"
{
return Err(AppError::Provider(format!(
"provider kind {} is not supported yet",
provider.kind
)));
}
if !provider.base_url.trim_end_matches('/').ends_with("/v1") {
if (provider.kind == "openai" || provider.kind == "openai-compatible")
&& !provider.base_url.trim_end_matches('/').ends_with("/v1")
{
return Err(AppError::Provider(
"Base URL 看起来不是 API 地址。OpenAI-compatible 地址通常需要以 /v1 结尾,例如 https://api.openai.com/v1 或 https://你的中转站域名/v1".to_string(),
));
}
if provider.kind == "volcengine-ark"
&& !provider.base_url.trim_end_matches('/').ends_with("/api/v3")
{
return Err(AppError::Provider(
"火山方舟 Seedream 的 Base URL 通常需要以 /api/v3 结尾,例如 https://ark.cn-beijing.volces.com/api/v3".to_string(),
));
}
if provider.kind == "dashscope" && !provider.base_url.trim_end_matches('/').ends_with("/api/v1")
{
return Err(AppError::Provider(
"阿里云百炼 DashScope 的 Base URL 通常需要以 /api/v1 结尾,例如 https://dashscope.aliyuncs.com/api/v1".to_string(),
));
}
if provider.kind == "google-gemini"
&& !provider.base_url.trim_end_matches('/').ends_with("/v1beta")
{
return Err(AppError::Provider(
"Google Gemini / Nano Banana 的 Base URL 通常填写 https://generativelanguage.googleapis.com/v1beta".to_string(),
));
}
if provider.kind == "tencent-hunyuan"
&& !provider
.base_url
.trim_end_matches('/')
.ends_with("aiart.tencentcloudapi.com")
&& !provider
.base_url
.trim_end_matches('/')
.ends_with("hunyuan.tencentcloudapi.com")
{
return Err(AppError::Provider(
"腾讯混元图像的 Base URL 通常填写 https://aiart.tencentcloudapi.com".to_string(),
));
}
let api_key = provider
.api_key_encrypted
@@ -76,8 +126,9 @@ pub async fn generate_image(
)
.await?;
let ai_provider = OpenAiCompatibleProvider::new(provider.base_url, api_key);
let image_result = match if input.image_paths.is_empty() {
let image_result = match if provider.kind == "volcengine-ark" {
let ai_provider = SeedreamProvider::new(provider.base_url, api_key);
if input.image_paths.is_empty() {
ai_provider
.generate_image(ImageGenerateRequest {
prompt: input.prompt,
@@ -96,10 +147,100 @@ pub async fn generate_image(
image_paths: input.image_paths,
})
.await
}
} else if provider.kind == "dashscope" {
let ai_provider = DashScopeProvider::new(provider.base_url, api_key);
if input.image_paths.is_empty() {
ai_provider
.generate_image(ImageGenerateRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
})
.await
} else {
ai_provider
.edit_image(ImageEditRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
image_paths: input.image_paths,
})
.await
}
} else if provider.kind == "tencent-hunyuan" {
let ai_provider = TencentHunyuanProvider::new(provider.base_url, api_key)?;
if input.image_paths.is_empty() {
ai_provider
.generate_image(ImageGenerateRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
})
.await
} else {
ai_provider
.edit_image(ImageEditRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
image_paths: input.image_paths,
})
.await
}
} else if provider.kind == "google-gemini" {
let ai_provider = GoogleGeminiProvider::new(provider.base_url, api_key);
if input.image_paths.is_empty() {
ai_provider
.generate_image(ImageGenerateRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
})
.await
} else {
ai_provider
.edit_image(ImageEditRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
image_paths: input.image_paths,
})
.await
}
} else {
let ai_provider = OpenAiCompatibleProvider::new(provider.base_url, api_key);
if input.image_paths.is_empty() {
ai_provider
.generate_image(ImageGenerateRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
})
.await
} else {
ai_provider
.edit_image(ImageEditRequest {
prompt: input.prompt,
model,
size: input.size,
quality: input.quality,
image_paths: input.image_paths,
})
.await
}
} {
Ok(result) => result,
Err(error) => {
repository::mark_generation_task_failed(&state.db, &task.id, &error.to_string()).await?;
repository::mark_generation_task_failed(&state.db, &task.id, &error.to_string())
.await?;
return Err(error);
}
};
@@ -108,7 +249,8 @@ pub async fn generate_image(
ImageData::Base64(data_base64) => storage::decode_base64_image(data_base64)?,
ImageData::Url(url) => reqwest::get(url).await?.bytes().await?.to_vec(),
};
let stored_image = storage::save_generated_image_bytes(&app, &image_bytes, &image_result.mime_type)?;
let stored_image =
storage::save_generated_image_bytes(&app, &image_bytes, &image_result.mime_type)?;
let file_path = stored_image.file_path.to_string_lossy().to_string();
let asset = repository::create_image_asset(
&state.db,

View File

@@ -1,3 +1,6 @@
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tauri::State;
use crate::{
@@ -7,12 +10,29 @@ use crate::{
};
#[tauri::command]
pub async fn list_providers(state: State<'_, AppState>) -> Result<Vec<crate::db::models::ProviderConfig>, AppError> {
pub async fn list_providers(
state: State<'_, AppState>,
) -> Result<Vec<crate::db::models::ProviderConfig>, AppError> {
repository::list_providers(&state.db).await
}
#[tauri::command]
pub async fn upsert_provider(state: State<'_, AppState>, input: UpsertProviderInput) -> Result<(), AppError> {
pub async fn upsert_provider(
state: State<'_, AppState>,
input: UpsertProviderInput,
) -> Result<(), AppError> {
if input
.api_key
.as_deref()
.map(str::trim)
.unwrap_or_default()
.is_empty()
{
return Err(AppError::Provider(
"请先填写 API Key再保存配置".to_string(),
));
}
repository::upsert_provider(&state.db, input).await
}
@@ -20,3 +40,306 @@ pub async fn upsert_provider(state: State<'_, AppState>, input: UpsertProviderIn
pub async fn delete_provider(state: State<'_, AppState>, id: String) -> Result<(), AppError> {
repository::delete_provider(&state.db, &id).await
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ProviderModel {
pub id: String,
pub owned_by: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ModelsResponse {
data: Vec<ModelItem>,
}
#[derive(Debug, Deserialize)]
struct ModelItem {
id: String,
owned_by: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ProviderCapabilities {
image_models: Option<Vec<ProviderModel>>,
selected_image_models: Option<Vec<String>>,
}
fn is_image_model(model_id: &str) -> bool {
let id = model_id.to_ascii_lowercase();
[
"gpt-image",
"image",
"dall-e",
"dalle",
"imagen",
"flux",
"qwen-image",
"wan",
"z-image",
"hunyuan-image",
"gemini",
"seedream",
"seededit",
"stable-diffusion",
"sd-",
"midjourney",
"recraft",
]
.iter()
.any(|marker| id.contains(marker))
}
fn default_seedream_models() -> Vec<ProviderModel> {
vec![
ProviderModel {
id: "doubao-seedream-4-5-251128".to_string(),
owned_by: Some("volcengine".to_string()),
},
ProviderModel {
id: "doubao-seedream-4-0-250828".to_string(),
owned_by: Some("volcengine".to_string()),
},
]
}
fn default_dashscope_models() -> Vec<ProviderModel> {
vec![
ProviderModel {
id: "qwen-image-2.0-pro".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
ProviderModel {
id: "qwen-image-2.0".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
ProviderModel {
id: "qwen-image-plus".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
ProviderModel {
id: "qwen-image".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
ProviderModel {
id: "wan2.7-image-pro".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
ProviderModel {
id: "wan2.7-image".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
ProviderModel {
id: "z-image-turbo".to_string(),
owned_by: Some("alibaba-cloud".to_string()),
},
]
}
fn default_tencent_hunyuan_models() -> Vec<ProviderModel> {
vec![
ProviderModel {
id: "hunyuan-image-3.0".to_string(),
owned_by: Some("tencent-cloud".to_string()),
},
ProviderModel {
id: "hunyuan-image-2.0".to_string(),
owned_by: Some("tencent-cloud".to_string()),
},
ProviderModel {
id: "hunyuan-image-lite".to_string(),
owned_by: Some("tencent-cloud".to_string()),
},
]
}
fn default_google_gemini_models() -> Vec<ProviderModel> {
vec![
ProviderModel {
id: "gemini-2.5-flash-image".to_string(),
owned_by: Some("google".to_string()),
},
ProviderModel {
id: "gemini-3.1-flash-image-preview".to_string(),
owned_by: Some("google".to_string()),
},
ProviderModel {
id: "gemini-3-pro-image-preview".to_string(),
owned_by: Some("google".to_string()),
},
]
}
#[tauri::command]
pub async fn fetch_provider_models(
state: State<'_, AppState>,
input: UpsertProviderInput,
) -> Result<Vec<ProviderModel>, AppError> {
if input.kind != "openai"
&& input.kind != "openai-compatible"
&& input.kind != "volcengine-ark"
&& input.kind != "dashscope"
&& input.kind != "tencent-hunyuan"
&& input.kind != "google-gemini"
{
return Err(AppError::Provider(format!(
"API 分类 {} 暂未接入模型列表获取",
input.kind
)));
}
if (input.kind == "openai" || input.kind == "openai-compatible")
&& !input.base_url.trim_end_matches('/').ends_with("/v1")
{
return Err(AppError::Provider(
"Base URL 看起来不是 API 地址。OpenAI-compatible 地址通常需要以 /v1 结尾。".to_string(),
));
}
if input.kind == "volcengine-ark" && !input.base_url.trim_end_matches('/').ends_with("/api/v3")
{
return Err(AppError::Provider(
"火山方舟 Seedream 的 Base URL 通常需要以 /api/v3 结尾。".to_string(),
));
}
if input.kind == "dashscope" && !input.base_url.trim_end_matches('/').ends_with("/api/v1") {
return Err(AppError::Provider(
"阿里云百炼 DashScope 的 Base URL 通常需要以 /api/v1 结尾。".to_string(),
));
}
if input.kind == "google-gemini" && !input.base_url.trim_end_matches('/').ends_with("/v1beta") {
return Err(AppError::Provider(
"Google Gemini / Nano Banana 的 Base URL 通常填写 https://generativelanguage.googleapis.com/v1beta".to_string(),
));
}
if input.kind == "tencent-hunyuan"
&& !input
.base_url
.trim_end_matches('/')
.ends_with("aiart.tencentcloudapi.com")
&& !input
.base_url
.trim_end_matches('/')
.ends_with("hunyuan.tencentcloudapi.com")
{
return Err(AppError::Provider(
"腾讯混元图像的 Base URL 通常填写 https://aiart.tencentcloudapi.com".to_string(),
));
}
let saved_provider = repository::get_provider_secret(&state.db, &input.id)
.await
.ok();
let api_key = match input.api_key.clone() {
Some(key) if !key.trim().is_empty() => Some(key),
Some(_) => None,
None => saved_provider.and_then(|provider| provider.api_key_encrypted),
};
let Some(api_key) = api_key else {
return Err(AppError::Provider(
"API Key 为空,无法获取模型列表".to_string(),
));
};
if input.kind == "dashscope" {
return save_provider_models(&state, input, default_dashscope_models()).await;
}
if input.kind == "tencent-hunyuan" {
return save_provider_models(&state, input, default_tencent_hunyuan_models()).await;
}
if input.kind == "google-gemini" {
return save_provider_models(&state, input, default_google_gemini_models()).await;
}
let response = Client::new()
.get(format!("{}/models", input.base_url.trim_end_matches('/')))
.bearer_auth(api_key)
.send()
.await?;
let status = response.status();
if !status.is_success() {
if input.kind == "volcengine-ark" && matches!(status.as_u16(), 404 | 405 | 501) {
let fetched_models = default_seedream_models();
return save_provider_models(&state, input, fetched_models).await;
}
let message = response
.text()
.await
.unwrap_or_else(|_| "request failed".to_string());
return Err(AppError::Provider(format!(
"获取模型列表失败 ({status}): {message}"
)));
}
let mut fetched_models: Vec<ProviderModel> = response
.json::<ModelsResponse>()
.await?
.data
.into_iter()
.filter(|model| is_image_model(&model.id))
.map(|model| ProviderModel {
id: model.id,
owned_by: model.owned_by,
})
.collect();
if input.kind == "volcengine-ark" && fetched_models.is_empty() {
fetched_models = default_seedream_models();
}
save_provider_models(&state, input, fetched_models).await
}
async fn save_provider_models(
state: &State<'_, AppState>,
input: UpsertProviderInput,
fetched_models: Vec<ProviderModel>,
) -> Result<Vec<ProviderModel>, AppError> {
let saved_providers = repository::list_providers(&state.db)
.await
.unwrap_or_default();
let saved_capabilities = saved_providers
.iter()
.find(|provider| provider.id == input.id)
.and_then(|provider| provider.capabilities.as_deref())
.and_then(|value| serde_json::from_str::<ProviderCapabilities>(value).ok());
let mut models = saved_capabilities
.as_ref()
.and_then(|capabilities| capabilities.image_models.clone())
.unwrap_or_default();
models.extend(fetched_models);
models.sort_by(|left, right| left.id.cmp(&right.id));
models.dedup_by(|left, right| left.id == right.id);
let model_ids = models
.iter()
.map(|model| model.id.clone())
.collect::<Vec<_>>();
let mut selected_model_ids = saved_capabilities
.and_then(|capabilities| capabilities.selected_image_models)
.unwrap_or_default()
.into_iter()
.filter(|model_id| model_ids.contains(model_id))
.collect::<Vec<_>>();
for model_id in &model_ids {
if !selected_model_ids.contains(model_id) {
selected_model_ids.push(model_id.clone());
}
}
let capabilities = json!({
"responses_api": true,
"images_api": true,
"chat_completions": true,
"image_edit": true,
"image_models": models,
"selected_image_models": selected_model_ids,
});
repository::upsert_provider(
&state.db,
UpsertProviderInput {
capabilities: Some(capabilities.to_string()),
..input
},
)
.await?;
Ok(models)
}

View File

@@ -1,6 +1,6 @@
use std::fs;
use sqlx::{sqlite::SqliteConnectOptions, SqlitePool};
use sqlx::{sqlite::SqliteConnectOptions, Row, SqlitePool};
use tauri::{AppHandle, Manager};
use crate::AppError;
@@ -24,6 +24,38 @@ pub async fn init(app: &AppHandle) -> Result<SqlitePool, AppError> {
sqlx::query(statement).execute(&pool).await?;
}
}
ensure_column(
&pool,
"providers",
"capabilities",
"TEXT NOT NULL DEFAULT '{}'",
)
.await?;
Ok(pool)
}
async fn ensure_column(
pool: &SqlitePool,
table: &str,
column: &str,
definition: &str,
) -> Result<(), AppError> {
let rows = sqlx::query(&format!("PRAGMA table_info({table})"))
.fetch_all(pool)
.await?;
let has_column = rows.iter().any(|row| {
row.try_get::<String, _>("name")
.map(|name| name == column)
.unwrap_or(false)
});
if !has_column {
sqlx::query(&format!(
"ALTER TABLE {table} ADD COLUMN {column} {definition}"
))
.execute(pool)
.await?;
}
Ok(())
}

View File

@@ -10,6 +10,7 @@ pub struct ProviderConfig {
pub api_key: Option<String>,
pub text_model: Option<String>,
pub image_model: Option<String>,
pub capabilities: Option<String>,
pub enabled: bool,
}
@@ -33,6 +34,7 @@ pub struct UpsertProviderInput {
pub api_key: Option<String>,
pub text_model: Option<String>,
pub image_model: Option<String>,
pub capabilities: Option<String>,
pub enabled: bool,
}

View File

@@ -19,8 +19,10 @@ pub async fn list_providers(pool: &SqlitePool) -> Result<Vec<ProviderConfig>, Ap
api_key_encrypted AS api_key,
text_model,
image_model,
capabilities,
enabled != 0 AS enabled
FROM providers
WHERE enabled != 0
ORDER BY updated_at DESC
"#,
)
@@ -30,24 +32,34 @@ pub async fn list_providers(pool: &SqlitePool) -> Result<Vec<ProviderConfig>, Ap
Ok(providers)
}
pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> Result<(), AppError> {
pub async fn upsert_provider(
pool: &SqlitePool,
input: UpsertProviderInput,
) -> Result<(), AppError> {
let now = Utc::now().to_rfc3339();
let existing_api_key: Option<String> = sqlx::query_scalar(
let existing: Option<(Option<String>, Option<String>)> = sqlx::query_as(
r#"
SELECT api_key_encrypted
SELECT api_key_encrypted, capabilities
FROM providers
WHERE id = ?1
"#,
)
.bind(&input.id)
.fetch_optional(pool)
.await?
.flatten();
let api_key = input
.api_key
.filter(|key| !key.trim().is_empty())
.or(existing_api_key);
.await?;
let api_key = match input.api_key {
Some(key) if key.trim().is_empty() => None,
Some(key) => Some(key),
None => existing.as_ref().and_then(|item| item.0.clone()),
};
let capabilities = input
.capabilities
.filter(|value| !value.trim().is_empty())
.or_else(|| existing.and_then(|item| item.1))
.unwrap_or_else(|| {
r#"{"responses_api":true,"images_api":true,"chat_completions":true,"image_edit":true,"image_models":[],"selected_image_models":[]}"#.to_string()
});
sqlx::query(
r#"
@@ -62,6 +74,7 @@ pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> R
api_key_encrypted = excluded.api_key_encrypted,
text_model = excluded.text_model,
image_model = excluded.image_model,
capabilities = excluded.capabilities,
enabled = excluded.enabled,
updated_at = excluded.updated_at
"#,
@@ -73,7 +86,7 @@ pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> R
.bind(api_key)
.bind(input.text_model)
.bind(input.image_model)
.bind(r#"{"responses_api":true,"images_api":true,"chat_completions":true,"image_edit":true}"#)
.bind(capabilities)
.bind(input.enabled)
.bind(now)
.execute(pool)
@@ -98,6 +111,34 @@ pub async fn get_provider_secret(pool: &SqlitePool, id: &str) -> Result<Provider
}
pub async fn delete_provider(pool: &SqlitePool, id: &str) -> Result<(), AppError> {
let reference_count: i64 = sqlx::query_scalar(
r#"
SELECT
(SELECT COUNT(*) FROM generation_tasks WHERE provider_id = ?1) +
(SELECT COUNT(*) FROM conversations WHERE provider_id = ?1) +
(SELECT COUNT(*) FROM ai_request_logs WHERE provider_id = ?1)
"#,
)
.bind(id)
.fetch_one(pool)
.await?;
if reference_count > 0 {
let now = Utc::now().to_rfc3339();
sqlx::query(
r#"
UPDATE providers
SET enabled = 0, updated_at = ?2
WHERE id = ?1
"#,
)
.bind(id)
.bind(now)
.execute(pool)
.await?;
return Ok(());
}
sqlx::query("DELETE FROM providers WHERE id = ?1")
.bind(id)
.execute(pool)

View File

@@ -1,8 +1,8 @@
mod ai;
mod commands;
mod db;
mod storage;
mod state;
mod storage;
use serde::Serialize;
use state::AppState;
@@ -53,6 +53,7 @@ pub fn run() {
commands::provider::list_providers,
commands::provider::upsert_provider,
commands::provider::delete_provider,
commands::provider::fetch_provider_models,
commands::dialog::pick_material_images,
commands::file::reveal_path,
commands::file::open_generated_dir,

View File

@@ -1,5 +1,6 @@
import React, { useEffect, useState } from 'react';
import ReactDOM from 'react-dom/client';
import { FolderOpenOutlined, PictureOutlined, RobotOutlined, SettingOutlined } from '@ant-design/icons';
import { convertFileSrc, invoke } from '@tauri-apps/api/core';
import appLogo from './assets/logo.svg';
import './styles.css';
@@ -12,6 +13,7 @@ type ProviderConfig = {
api_key?: string | null;
text_model?: string | null;
image_model?: string | null;
capabilities?: string | null;
enabled: boolean;
};
@@ -44,15 +46,55 @@ type SessionImage = {
created_at: string;
};
type ProviderModel = {
id: string;
owned_by?: string | null;
};
type ProviderCapabilities = {
image_models?: ProviderModel[];
selected_image_models?: string[];
};
type GenerationStep = {
label: string;
status: 'pending' | 'active' | 'done' | 'error';
};
function formatError(error: unknown) {
if (typeof error === 'string') return error;
if (error instanceof Error) return error.message;
return JSON.stringify(error);
const message =
typeof error === 'string' ? error : error instanceof Error ? error.message : JSON.stringify(error);
if (message.includes('502 Bad Gateway') || message.includes('upstream_error')) {
return '上游模型服务返回 502。通常是中转站或模型供应商临时失败不是本地程序错误可以换模型、降低分辨率或稍后/换供应商重试。';
}
if (message.includes('503') || message.includes('504')) {
return '上游模型服务暂时不可用或超时。可以稍后重试,或切换模型/供应商。';
}
return message;
}
function isErrorStatus(message: string) {
return message.includes('失败') || message.includes('为空') || message.startsWith('请先');
}
function parseProviderCapabilities(value?: string | null): ProviderCapabilities {
if (!value) return {};
try {
return JSON.parse(value) as ProviderCapabilities;
} catch {
return {};
}
}
function buildProviderCapabilities(models: ProviderModel[], selectedModels: string[]) {
return JSON.stringify({
responses_api: true,
images_api: true,
chat_completions: true,
image_edit: true,
image_models: models,
selected_image_models: selectedModels,
});
}
const defaultProviderForm: ProviderForm = {
@@ -66,32 +108,207 @@ const defaultProviderForm: ProviderForm = {
enabled: true,
};
const apiKindOptions = [
{
value: 'openai',
label: 'OpenAI 官方',
sampleId: 'openai',
sampleName: 'OpenAI 官方',
baseUrl: 'https://api.openai.com/v1',
supported: true,
},
{
value: 'openai-compatible',
label: 'OpenAI-compatible / 中转站',
sampleId: 'openai-compatible',
sampleName: 'OpenAI-compatible / 中转站',
baseUrl: 'https://api.openai.com/v1',
supported: true,
},
{
value: 'volcengine-ark',
label: '火山方舟 / Seedream',
sampleId: 'volcengine-seedream',
sampleName: '火山方舟 / Seedream',
baseUrl: 'https://ark.cn-beijing.volces.com/api/v3',
supported: true,
},
{
value: 'dashscope',
label: '阿里云百炼 / 通义万相',
sampleId: 'dashscope-image',
sampleName: '阿里云百炼 / 通义万相',
baseUrl: 'https://dashscope.aliyuncs.com/api/v1',
supported: true,
},
{
value: 'tencent-hunyuan',
label: '腾讯混元图像',
sampleId: 'tencent-hunyuan-image',
sampleName: '腾讯混元图像',
baseUrl: 'https://aiart.tencentcloudapi.com',
supported: true,
},
{
value: 'google-gemini',
label: 'Google Gemini / Nano Banana',
sampleId: 'google-nano-banana',
sampleName: 'Google Gemini / Nano Banana',
baseUrl: 'https://generativelanguage.googleapis.com/v1beta',
supported: true,
},
{ value: 'stability-ai', label: 'Stability AI待接入', supported: false },
{ value: 'replicate', label: 'Replicate待接入', supported: false },
{ value: 'fal-ai', label: 'fal.ai待接入', supported: false },
];
const initialGenerationSteps: GenerationStep[] = [
{ label: '保存配置', status: 'pending' },
{ label: '检查配置', status: 'pending' },
{ label: '提交任务', status: 'pending' },
{ label: '等待模型返回', status: 'pending' },
{ label: '保存到应用文件夹', status: 'pending' },
{ label: '更新结果列表', status: 'pending' },
];
const imageModelOptions = ['gpt-image-1', 'gpt-image-1.5', 'gpt-image-2'];
const imageSizeOptions = ['1024x1024', '1024x1536', '1536x1024'];
const defaultImageModelOptions: string[] = [];
const imageQualityOptions = ['auto', 'high', 'medium', 'low'];
const imageAspectRatioOptions = [
{
value: '1:1',
defaultSize: '1024x1024',
sizes: [
{ value: '1024x1024', label: '1024x1024' },
{ value: '2048x2048', label: '2048x2048' },
{ value: '4096x4096', label: '4096x4096 4K' },
],
},
{
value: '1:2',
defaultSize: '1024x2048',
sizes: [
{ value: '1024x2048', label: '1024x2048' },
{ value: '1536x3072', label: '1536x3072' },
{ value: '2048x4096', label: '2048x4096 4K' },
],
},
{
value: '2:1',
defaultSize: '2048x1024',
sizes: [
{ value: '2048x1024', label: '2048x1024' },
{ value: '3072x1536', label: '3072x1536' },
{ value: '4096x2048', label: '4096x2048 4K' },
],
},
{
value: '9:16',
defaultSize: '1080x1920',
sizes: [
{ value: '1080x1920', label: '1080x1920' },
{ value: '1440x2560', label: '1440x2560' },
{ value: '2160x3840', label: '2160x3840 4K' },
],
},
{
value: '16:9',
defaultSize: '1920x1080',
sizes: [
{ value: '1920x1080', label: '1920x1080' },
{ value: '2560x1440', label: '2560x1440' },
{ value: '3840x2160', label: '3840x2160 4K' },
],
},
{
value: '3:4',
defaultSize: '1536x2048',
sizes: [
{ value: '1536x2048', label: '1536x2048' },
{ value: '2304x3072', label: '2304x3072' },
{ value: '3072x4096', label: '3072x4096 4K' },
],
},
{
value: '4:3',
defaultSize: '2048x1536',
sizes: [
{ value: '2048x1536', label: '2048x1536' },
{ value: '3072x2304', label: '3072x2304' },
{ value: '4096x3072', label: '4096x3072 4K' },
],
},
{
value: '名片横版',
defaultSize: '1050x600',
sizes: [
{ value: '1050x600', label: '1050x600' },
{ value: '2100x1200', label: '2100x1200' },
{ value: '3500x2000', label: '3500x2000' },
],
},
{
value: '名片竖版',
defaultSize: '600x1050',
sizes: [
{ value: '600x1050', label: '600x1050' },
{ value: '1200x2100', label: '1200x2100' },
{ value: '2000x3500', label: '2000x3500' },
],
},
];
function getAspectRatioOption(value: string) {
return imageAspectRatioOptions.find((option) => option.value === value) ?? imageAspectRatioOptions[0];
}
function getAspectRatioForSize(value: string) {
return imageAspectRatioOptions.find((option) => option.sizes.some((size) => size.value === value));
}
function apiKeyPlaceholder(kind: string) {
if (kind === 'tencent-hunyuan') return 'SecretId:SecretKey';
if (kind === 'google-gemini') return 'Google AI Studio API Key';
if (kind === 'dashscope') return 'sk-... 或阿里云百炼 API Key';
return 'sk-... 或中转站 key';
}
function providerSettingsTip(kind: string) {
if (kind === 'tencent-hunyuan') {
return '腾讯云使用 API 3.0 签名API Key 填 SecretId:SecretKey。';
}
if (kind === 'dashscope') {
return '阿里云百炼 Base URL 默认即可API Key 填 DashScope Key。';
}
if (kind === 'google-gemini') {
return 'Google Gemini 图像模型也叫 Nano BananaAPI Key 填 Google AI Studio Key。';
}
return 'Base URL 填 API 地址,通常以 /v1 结尾。';
}
function App() {
const [providers, setProviders] = useState<ProviderConfig[]>([]);
const [providerForm, setProviderForm] = useState<ProviderForm>(defaultProviderForm);
const [prompt, setPrompt] = useState('一只赛博朋克风格的橘猫坐在霓虹灯下');
const [selectedImageModel, setSelectedImageModel] = useState('gpt-image-2');
const [selectedImageModel, setSelectedImageModel] = useState('');
const [fetchedImageModels, setFetchedImageModels] = useState<ProviderModel[]>([]);
const [selectedImageModels, setSelectedImageModels] = useState<string[]>(defaultImageModelOptions);
const [imageAspectRatio, setImageAspectRatio] = useState('1:1');
const [imageSize, setImageSize] = useState('1024x1024');
const [imageQuality, setImageQuality] = useState('auto');
const [status, setStatus] = useState('准备就绪');
const [settingsStatus, setSettingsStatus] = useState('');
const [isBusy, setIsBusy] = useState(false);
const [isFetchingModels, setIsFetchingModels] = useState(false);
const [isSettingsOpen, setIsSettingsOpen] = useState(false);
const [previewImage, setPreviewImage] = useState<SessionImage | null>(null);
const [sessionImages, setSessionImages] = useState<SessionImage[]>([]);
const [materialPaths, setMaterialPaths] = useState<string[]>([]);
const [generationSteps, setGenerationSteps] = useState<GenerationStep[]>(initialGenerationSteps);
const selectedAspectRatioOption = getAspectRatioOption(imageAspectRatio);
const visibleImageModelOptions =
selectedImageModels.length > 0 ? selectedImageModels : defaultImageModelOptions;
const activeProviderName =
providers.find((provider) => provider.id === providerForm.id)?.name ?? providerForm.name;
const activeMode = materialPaths.length > 0 ? '图像编辑' : '文字生成';
function setStep(index: number, status: GenerationStep['status']) {
setGenerationSteps((steps) =>
@@ -113,18 +330,114 @@ function App() {
setProviderForm((current) => ({ ...current, [key]: value }));
}
function updateProviderKind(kind: string) {
const option = apiKindOptions.find((item) => item.value === kind);
setFetchedImageModels([]);
setSelectedImageModels([]);
setSelectedImageModel('');
setSettingsStatus('');
setProviderForm((current) => ({
...defaultProviderForm,
api_key: '',
base_url: option?.baseUrl || '',
id: option?.sampleId || `${kind}-provider`,
kind,
name: option?.sampleName || option?.label || current.name,
}));
}
function updateAspectRatio(value: string) {
const option = getAspectRatioOption(value);
setImageAspectRatio(option.value);
setImageSize(option.defaultSize);
}
function updateImageSize(value: string) {
setImageSize(value);
const option = getAspectRatioForSize(value);
if (option) {
setImageAspectRatio(option.value);
}
}
function updateSelectedModelRecords(modelId: string) {
setSelectedImageModels((current) => {
if (current.includes(modelId)) {
if (current.length === 1) return current;
const next = current.filter((id) => id !== modelId);
if (selectedImageModel === modelId) {
setSelectedImageModel(next[0] ?? '');
}
return next;
}
return [...current, modelId];
});
}
async function fetchProviderModels() {
setIsFetchingModels(true);
setStatus('正在获取图片模型列表...');
setSettingsStatus('正在获取图片模型列表...');
try {
const models = await invoke<ProviderModel[]>('fetch_provider_models', {
input: { ...providerForm, image_model: null },
});
const modelIds = models.map((model) => model.id);
setFetchedImageModels(models);
if (modelIds.length === 0) {
setSelectedImageModels([]);
setSelectedImageModel('');
setStatus('未从模型列表中识别到图片模型');
setSettingsStatus('接口可访问,但没有识别到图片模型');
return;
}
setSelectedImageModels((current) => {
const kept = current.filter((model) => modelIds.includes(model));
const next = [...kept, ...modelIds.filter((model) => !kept.includes(model))];
setSelectedImageModel((selected) => (next.includes(selected) ? selected : (next[0] ?? '')));
return next;
});
setStatus(`已获取 ${modelIds.length} 个图片模型`);
setSettingsStatus(`已获取 ${modelIds.length} 个图片模型`);
} catch (error) {
setStatus(`获取模型失败:${formatError(error)}`);
setSettingsStatus(`获取模型失败:${formatError(error)}`);
} finally {
setIsFetchingModels(false);
}
}
async function refreshProviders() {
const result = await invoke<ProviderConfig[]>('list_providers');
setProviders(result);
const current = result.find((provider) => provider.id === providerForm.id) ?? result[0];
if (current) {
const capabilities = parseProviderCapabilities(current.capabilities);
const storedModels = capabilities.image_models ?? [];
const storedSelectedModels =
capabilities.selected_image_models?.filter((model) =>
storedModels.some((storedModel) => storedModel.id === model),
) ?? [];
const nextSelectedModels =
storedSelectedModels.length > 0
? storedSelectedModels
: storedModels.length > 0
? storedModels.map((model) => model.id)
: defaultImageModelOptions;
setFetchedImageModels(storedModels);
setSelectedImageModels(nextSelectedModels);
setSelectedImageModel((model) =>
nextSelectedModels.includes(model) ? model : (nextSelectedModels[0] ?? ''),
);
setProviderForm((form) => ({
...form,
id: current.id,
name: current.name,
kind: current.kind,
base_url: current.base_url,
api_key: current.api_key ?? form.api_key,
api_key: current.api_key ?? '',
text_model: current.text_model ?? null,
image_model: null,
enabled: current.enabled,
@@ -133,14 +446,29 @@ function App() {
}
async function saveProvider() {
if (!providerForm.api_key.trim()) {
setStatus('请先填写 API Key再保存配置');
setSettingsStatus('请先填写 API Key再保存配置');
return;
}
setIsBusy(true);
setStatus('正在保存配置...');
setSettingsStatus('正在保存配置...');
try {
await invoke('upsert_provider', { input: { ...providerForm, image_model: null } });
await invoke('upsert_provider', {
input: {
...providerForm,
image_model: null,
capabilities: buildProviderCapabilities(fetchedImageModels, selectedImageModels),
},
});
await refreshProviders();
setStatus('配置已保存');
setSettingsStatus('配置已保存');
} catch (error) {
setStatus(`保存失败:${formatError(error)}`);
setSettingsStatus(`保存失败:${formatError(error)}`);
} finally {
setIsBusy(false);
}
@@ -149,15 +477,21 @@ function App() {
async function deleteProvider(id: string) {
setIsBusy(true);
setStatus('正在删除配置...');
setSettingsStatus('正在删除配置...');
try {
await invoke('delete_provider', { id });
await refreshProviders();
if (providerForm.id === id) {
setProviderForm(defaultProviderForm);
setFetchedImageModels([]);
setSelectedImageModels([]);
setSelectedImageModel('');
}
setStatus('配置已删除');
setSettingsStatus('配置已删除');
} catch (error) {
setStatus(`删除失败:${formatError(error)}`);
setSettingsStatus(`删除失败:${formatError(error)}`);
} finally {
setIsBusy(false);
}
@@ -170,21 +504,51 @@ function App() {
name: provider.name,
kind: provider.kind,
base_url: provider.base_url,
api_key: provider.api_key ?? form.api_key,
api_key: provider.api_key ?? '',
text_model: provider.text_model ?? null,
image_model: null,
enabled: provider.enabled,
}));
const capabilities = parseProviderCapabilities(provider.capabilities);
const storedModels = capabilities.image_models ?? [];
const storedSelectedModels =
capabilities.selected_image_models?.filter((model) =>
storedModels.some((storedModel) => storedModel.id === model),
) ?? [];
const nextSelectedModels =
storedSelectedModels.length > 0
? storedSelectedModels
: storedModels.length > 0
? storedModels.map((model) => model.id)
: defaultImageModelOptions;
setFetchedImageModels(storedModels);
setSelectedImageModels(nextSelectedModels);
setSelectedImageModel((model) =>
nextSelectedModels.includes(model) ? model : (nextSelectedModels[0] ?? ''),
);
setStatus('已切换模型配置');
setSettingsStatus('');
}
async function generateImage() {
if (!selectedImageModel) {
setStatus('请先在设置中获取并选择图像模型');
return;
}
if (!providerForm.api_key.trim()) {
setStatus('请先在设置中填写并保存 API Key');
return;
}
if (!providers.some((provider) => provider.id === providerForm.id)) {
setStatus('请先保存当前模型供应商配置,再开始生成');
return;
}
setIsBusy(true);
setGenerationSteps(initialGenerationSteps);
setStatus('正在生成图片...');
try {
startStep(0);
await invoke('upsert_provider', { input: { ...providerForm, image_model: null } });
await new Promise((resolve) => window.setTimeout(resolve, 80));
startStep(1);
await new Promise((resolve) => window.setTimeout(resolve, 120));
startStep(2);
@@ -264,18 +628,42 @@ function App() {
return (
<main className="app-shell">
<aside className="side-rail">
<div className="rail-logo">
<img src={appLogo} alt="Image Draw AI" />
</div>
<nav className="rail-nav" aria-label="主导航">
<button className="rail-button active" title="生成">
<span className="rail-icon"><RobotOutlined /></span>
<strong></strong>
</button>
<button className="rail-button" title="素材" onClick={pickMaterialImages} disabled={isBusy}>
<span className="rail-icon"><PictureOutlined /></span>
<strong></strong>
</button>
<button className="rail-button" title="图库" onClick={openGeneratedDir}>
<span className="rail-icon"><FolderOpenOutlined /></span>
<strong></strong>
</button>
</nav>
<button className="rail-button rail-settings" title="设置" onClick={() => setIsSettingsOpen(true)}>
<span className="rail-icon"><SettingOutlined /></span>
<strong></strong>
</button>
</aside>
<section className="app-stage">
<header className="topbar">
<div className="brand">
<img className="brand-mark" src={appLogo} alt="Image Draw AI" />
<div>
<h1>Image Draw AI</h1>
<p></p>
<p>Image Draw AI</p>
<h1></h1>
</div>
</div>
<div className="topbar-actions">
<div className="current-provider">
<span></span>
<strong>{selectedImageModel}</strong>
<span></span>
<strong>{activeProviderName}</strong>
</div>
<button className="ghost" onClick={() => setIsSettingsOpen(true)}></button>
</div>
@@ -284,8 +672,11 @@ function App() {
<section className="workspace">
<aside className="compose-card">
<div className="section-heading">
<div>
<span></span>
<strong>{materialPaths.length > 0 ? '素材生成' : '文字生成'}</strong>
<strong>{activeMode}</strong>
</div>
<small>{imageAspectRatio} / {imageSize} / {imageQuality}</small>
</div>
<label className="field prompt-field">
@@ -297,7 +688,7 @@ function App() {
<div className="material-header">
<div>
<strong></strong>
<span>{materialPaths.length > 0 ? `${materialPaths.length},图像编辑模式` : '可选,支持多张'}</span>
<span>{materialPaths.length > 0 ? `${materialPaths.length}素材` : '未导入'}</span>
</div>
{materialPaths.length > 0 && (
<button className="ghost mini" onClick={() => setMaterialPaths([])} disabled={isBusy}></button>
@@ -307,7 +698,7 @@ function App() {
<div className="reference-strip">
<button className="add-reference-card" onClick={pickMaterialImages} disabled={isBusy}>
<span>+</span>
<strong></strong>
<strong></strong>
<small>PNG/JPG/WEBP</small>
</button>
{materialPaths.map((path, index) => (
@@ -322,26 +713,21 @@ function App() {
<div className="params-card">
<div className="section-heading">
<div>
<span></span>
<strong></strong>
</div>
</div>
<div className="params-grid">
<label className="field compact-field">
<span></span>
<select value={selectedImageModel} onChange={(event) => setSelectedImageModel(event.target.value)} disabled={isBusy}>
{imageModelOptions.map((model) => (
{visibleImageModelOptions.length === 0 && <option value=""></option>}
{visibleImageModelOptions.map((model) => (
<option key={model} value={model}>{model}</option>
))}
</select>
</label>
<div className="grid two">
<label className="field compact-field">
<span></span>
<select value={imageSize} onChange={(event) => setImageSize(event.target.value)} disabled={isBusy}>
{imageSizeOptions.map((size) => (
<option key={size} value={size}>{size}</option>
))}
</select>
</label>
<label className="field compact-field">
<span></span>
<select value={imageQuality} onChange={(event) => setImageQuality(event.target.value)} disabled={isBusy}>
@@ -350,6 +736,22 @@ function App() {
))}
</select>
</label>
<label className="field compact-field">
<span></span>
<select value={imageAspectRatio} onChange={(event) => updateAspectRatio(event.target.value)} disabled={isBusy}>
{imageAspectRatioOptions.map((option) => (
<option key={option.value} value={option.value}>{option.value}</option>
))}
</select>
</label>
<label className="field compact-field">
<span></span>
<select value={imageSize} onChange={(event) => updateImageSize(event.target.value)} disabled={isBusy}>
{selectedAspectRatioOption.sizes.map((size) => (
<option key={size.value} value={size.value}>{size.label}</option>
))}
</select>
</label>
</div>
</div>
@@ -357,22 +759,27 @@ function App() {
{isBusy ? '正在生成...' : '开始生成'}
</button>
{status !== '准备就绪' && <p className="status">{status}</p>}
{status !== '准备就绪' && (
<p className={`status ${isErrorStatus(status) ? 'error' : ''}`}>{status}</p>
)}
</aside>
<section className="result-card">
<div className="section-heading result-heading">
<div>
<span></span>
<div className="heading-actions">
<strong> {sessionImages.length} </strong>
</div>
<div className="heading-actions">
<button className="ghost mini" onClick={openGeneratedDir}></button>
</div>
</div>
{sessionImages.length === 0 ? (
<div className="empty-state">
<div></div>
<p></p>
<img src={appLogo} alt="" />
<div></div>
<p>{selectedImageModel || '未获取模型'} / {imageSize} / {imageQuality}</p>
</div>
) : (
<div className="image-grid">
@@ -408,6 +815,7 @@ function App() {
</div>
</section>
</section>
</section>
{isSettingsOpen && (
<div className="drawer-layer">
@@ -422,12 +830,7 @@ function App() {
</div>
<div className="settings-content">
<section className="settings-group">
<div className="section-heading">
<span></span>
<strong></strong>
</div>
<div className="grid two">
<div className="settings-form">
<label className="field">
<span> ID</span>
<input value={providerForm.id} onChange={(event) => updateProviderForm('id', event.target.value)} />
@@ -436,14 +839,16 @@ function App() {
<span></span>
<input value={providerForm.name} onChange={(event) => updateProviderForm('name', event.target.value)} />
</label>
</div>
</section>
<section className="settings-group">
<div className="section-heading">
<span></span>
<strong> / OpenAI</strong>
</div>
<label className="field">
<span>API </span>
<select value={providerForm.kind} onChange={(event) => updateProviderKind(event.target.value)}>
{apiKindOptions.map((option) => (
<option key={option.value} value={option.value} disabled={!option.supported}>
{option.label}
</option>
))}
</select>
</label>
<label className="field">
<span>Base URL</span>
<input
@@ -451,30 +856,69 @@ function App() {
onChange={(event) => updateProviderForm('base_url', event.target.value)}
placeholder="https://api.openai.com/v1"
/>
<small> API /v1 </small>
</label>
<label className="field">
<span>API Key</span>
<input
value={providerForm.api_key}
onChange={(event) => updateProviderForm('api_key', event.target.value)}
placeholder="sk-... 或中转站 key"
placeholder={apiKeyPlaceholder(providerForm.kind)}
type="password"
/>
</label>
</section>
<p className="settings-tip">{providerSettingsTip(providerForm.kind)}</p>
</div>
<div className="drawer-actions">
<button onClick={saveProvider} disabled={isBusy}></button>
<button className="ghost" onClick={refreshProviders} disabled={isBusy}></button>
<button className="ghost" onClick={fetchProviderModels} disabled={isBusy || isFetchingModels}>
{isFetchingModels ? '获取中...' : '获取模型'}
</button>
</div>
{settingsStatus && (
<p className={`settings-status ${isErrorStatus(settingsStatus) ? 'error' : ''}`}>
{settingsStatus}
</p>
)}
<div className="model-list-panel">
<div className="section-heading">
<span></span>
<strong>{fetchedImageModels.length} </strong>
</div>
{fetchedImageModels.length === 0 ? (
<p className="muted model-empty"></p>
) : (
<ul className="model-list">
{fetchedImageModels.map((model) => (
<li key={model.id}>
<label className="model-record">
<input
checked={selectedImageModels.includes(model.id)}
disabled={isBusy}
onChange={() => updateSelectedModelRecords(model.id)}
type="checkbox"
/>
<span>
<strong>{model.id}</strong>
{model.owned_by && <small>{model.owned_by}</small>}
</span>
</label>
</li>
))}
</ul>
)}
</div>
<div className="saved-providers">
<div className="section-heading">
<div>
<span></span>
<strong>{providers.length} </strong>
</div>
<button className="ghost mini" onClick={refreshProviders} disabled={isBusy}></button>
</div>
{providers.length === 0 ? (
<p className="muted"></p>
) : (

File diff suppressed because it is too large Load Diff