From a4198f29d26023fef5021a9526722b9df8ad876a Mon Sep 17 00:00:00 2001 From: ctexthuang Date: Wed, 29 Apr 2026 15:47:18 +0800 Subject: [PATCH] feat: add multi-provider image generation --- package.json | 11 +- pnpm-lock.yaml | 66 ++ src-tauri/Cargo.toml | 3 + src-tauri/src/ai/dashscope.rs | 354 +++++++ src-tauri/src/ai/google_gemini.rs | 189 ++++ src-tauri/src/ai/mod.rs | 4 + src-tauri/src/ai/openai_compatible.rs | 37 +- src-tauri/src/ai/seedream.rs | 180 ++++ src-tauri/src/ai/tencent_hunyuan.rs | 463 +++++++++ src-tauri/src/commands/dialog.rs | 6 +- src-tauri/src/commands/generation.rs | 196 +++- src-tauri/src/commands/provider.rs | 327 +++++- src-tauri/src/db/mod.rs | 34 +- src-tauri/src/db/models.rs | 2 + src-tauri/src/db/repository.rs | 61 +- src-tauri/src/lib.rs | 3 +- src/main.tsx | 790 +++++++++++---- src/styles.css | 1323 +++++++++++++++++-------- 18 files changed, 3397 insertions(+), 652 deletions(-) create mode 100644 src-tauri/src/ai/dashscope.rs create mode 100644 src-tauri/src/ai/google_gemini.rs create mode 100644 src-tauri/src/ai/seedream.rs create mode 100644 src-tauri/src/ai/tencent_hunyuan.rs diff --git a/package.json b/package.json index 6dee6a5..e5f22a5 100644 --- a/package.json +++ b/package.json @@ -16,17 +16,18 @@ "build:win": "tauri build" }, "dependencies": { + "@ant-design/icons": "^6.2.2", "@tauri-apps/api": "^2.0.0", "@tauri-apps/plugin-dialog": "^2.0.0", "@vitejs/plugin-react": "^5.0.0", - "vite": "^7.0.0", - "typescript": "^5.0.0", "react": "^19.0.0", - "react-dom": "^19.0.0" + "react-dom": "^19.0.0", + "typescript": "^5.0.0", + "vite": "^7.0.0" }, "devDependencies": { + "@tauri-apps/cli": "^2.0.0", "@types/react": "^19.0.0", - "@types/react-dom": "^19.0.0", - "@tauri-apps/cli": "^2.0.0" + "@types/react-dom": "^19.0.0" } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9abc16d..a5b0846 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -8,6 +8,9 @@ importers: .: dependencies: + '@ant-design/icons': + specifier: ^6.2.2 + version: 6.2.2(react-dom@19.2.5(react@19.2.5))(react@19.2.5) '@tauri-apps/api': specifier: ^2.0.0 version: 2.10.1 @@ -42,6 +45,23 @@ importers: packages: + '@ant-design/colors@8.0.1': + resolution: {integrity: sha512-foPVl0+SWIslGUtD/xBr1p9U4AKzPhNYEseXYRRo5QSzGACYZrQbe11AYJbYfAWnWSpGBx6JjBmSeugUsD9vqQ==} + + '@ant-design/fast-color@3.0.1': + resolution: {integrity: sha512-esKJegpW4nckh0o6kV3Tkb7NPIZYbPnnFxmQDUmL08ukXZAvV85TZBr70eGuke/CIArLaP6aw8lt9KILjnWuOw==} + engines: {node: '>=8.x'} + + '@ant-design/icons-svg@4.4.2': + resolution: {integrity: sha512-vHbT+zJEVzllwP+CM+ul7reTEfBR0vgxFe7+lREAsAA7YGsYpboiq2sQNeQeRvh09GfQgs/GyFEvZpJ9cLXpXA==} + + '@ant-design/icons@6.2.2': + resolution: {integrity: sha512-zlJtE7AMbG12TeYVPhtBXwNpFInNy8mjLzcIm+0BPw16/b8ODG87YJ1G37VIF5VFscdgfsf6EweAFPTobu/3iQ==} + engines: {node: '>=8'} + peerDependencies: + react: '>=16.0.0' + react-dom: '>=16.0.0' + '@babel/code-frame@7.29.0': resolution: {integrity: sha512-9NhCeYjq9+3uxgdtp20LSiJXJvN0FeCtNGpJxuMFZ1Kv3cWUNb6DOhJwUvcVCzKGR66cw4njwM6hrJLqgOwbcw==} engines: {node: '>=6.9.0'} @@ -297,6 +317,12 @@ packages: '@jridgewell/trace-mapping@0.3.31': resolution: {integrity: sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==} + '@rc-component/util@1.10.1': + resolution: {integrity: sha512-q++9S6rUa5Idb/xIBNz6jtvumw5+O5YV5V0g4iK9mn9jWs4oGJheE3ZN1kAnE723AXyaD8v95yeOASmdk8Jnng==} + peerDependencies: + react: '>=18.0.0' + react-dom: '>=18.0.0' + '@rolldown/pluginutils@1.0.0-rc.3': resolution: {integrity: sha512-eybk3TjzzzV97Dlj5c+XrBFW57eTNhzod66y9HrBlzJ6NsCrWCp/2kaPS3K9wJmurBC0Tdw4yPjXKZqlznim3Q==} @@ -544,6 +570,10 @@ packages: caniuse-lite@1.0.30001790: resolution: {integrity: sha512-bOoxfJPyYo+ds6W0YfptaCWbFnJYjh2Y1Eow5lRv+vI2u8ganPZqNm1JwNh0t2ELQCqIWg4B3dWEusgAmsoyOw==} + clsx@2.1.1: + resolution: {integrity: sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==} + engines: {node: '>=6'} + convert-source-map@2.0.0: resolution: {integrity: sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==} @@ -589,6 +619,9 @@ packages: resolution: {integrity: sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==} engines: {node: '>=6.9.0'} + is-mobile@5.0.0: + resolution: {integrity: sha512-Tz/yndySvLAEXh+Uk8liFCxOwVH6YutuR74utvOcu7I9Di+DwM0mtdPVZNaVvvBUM2OXxne/NhOs1zAO7riusQ==} + js-tokens@4.0.0: resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} @@ -632,6 +665,9 @@ packages: peerDependencies: react: ^19.2.5 + react-is@18.3.1: + resolution: {integrity: sha512-/LLMVyas0ljjAtoYiPqYiL8VWXzUUdThrmU5+n20DZv+a+ClRoevUzw5JxU+Ieh5/c87ytoTBV9G1FiKfNJdmg==} + react-refresh@0.18.0: resolution: {integrity: sha512-QgT5//D3jfjJb6Gsjxv0Slpj23ip+HtOpnNgnb2S5zU3CB26G/IDPGoy4RJB42wzFE46DRsstbW6tKHoKbhAxw==} engines: {node: '>=0.10.0'} @@ -716,6 +752,23 @@ packages: snapshots: + '@ant-design/colors@8.0.1': + dependencies: + '@ant-design/fast-color': 3.0.1 + + '@ant-design/fast-color@3.0.1': {} + + '@ant-design/icons-svg@4.4.2': {} + + '@ant-design/icons@6.2.2(react-dom@19.2.5(react@19.2.5))(react@19.2.5)': + dependencies: + '@ant-design/colors': 8.0.1 + '@ant-design/icons-svg': 4.4.2 + '@rc-component/util': 1.10.1(react-dom@19.2.5(react@19.2.5))(react@19.2.5) + clsx: 2.1.1 + react: 19.2.5 + react-dom: 19.2.5(react@19.2.5) + '@babel/code-frame@7.29.0': dependencies: '@babel/helper-validator-identifier': 7.28.5 @@ -925,6 +978,13 @@ snapshots: '@jridgewell/resolve-uri': 3.1.2 '@jridgewell/sourcemap-codec': 1.5.5 + '@rc-component/util@1.10.1(react-dom@19.2.5(react@19.2.5))(react@19.2.5)': + dependencies: + is-mobile: 5.0.0 + react: 19.2.5 + react-dom: 19.2.5(react@19.2.5) + react-is: 18.3.1 + '@rolldown/pluginutils@1.0.0-rc.3': {} '@rollup/rollup-android-arm-eabi@4.60.2': @@ -1110,6 +1170,8 @@ snapshots: caniuse-lite@1.0.30001790: {} + clsx@2.1.1: {} + convert-source-map@2.0.0: {} csstype@3.2.3: {} @@ -1160,6 +1222,8 @@ snapshots: gensync@1.0.0-beta.2: {} + is-mobile@5.0.0: {} + js-tokens@4.0.0: {} jsesc@3.1.0: {} @@ -1191,6 +1255,8 @@ snapshots: react: 19.2.5 scheduler: 0.27.0 + react-is@18.3.1: {} + react-refresh@0.18.0: {} react@19.2.5: {} diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 967cd75..96ff462 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -26,6 +26,9 @@ reqwest = { version = "0.12", default-features = false, features = ["json", "mul thiserror = "2" async-trait = "0.1" base64 = "0.22" +hex = "0.4" +hmac = "0.12" +sha2 = "0.10" [features] default = ["custom-protocol"] diff --git a/src-tauri/src/ai/dashscope.rs b/src-tauri/src/ai/dashscope.rs new file mode 100644 index 0000000..4ee2946 --- /dev/null +++ b/src-tauri/src/ai/dashscope.rs @@ -0,0 +1,354 @@ +use async_trait::async_trait; +use base64::{engine::general_purpose, Engine as _}; +use reqwest::Client; +use serde_json::{json, Value}; +use std::{fs, path::Path, time::Duration}; +use tokio::time::sleep; + +use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult}; +use crate::AppError; + +pub struct DashScopeProvider { + client: Client, + base_url: String, + api_key: String, +} + +impl DashScopeProvider { + pub fn new(base_url: String, api_key: String) -> Self { + Self { + client: Client::new(), + base_url: base_url.trim_end_matches('/').to_string(), + api_key, + } + } +} + +fn is_qwen_image_model(model: &str) -> bool { + model.starts_with("qwen-image") +} + +fn is_sync_multimodal_model(model: &str) -> bool { + is_qwen_image_model(model) || model.starts_with("z-image") +} + +fn dashscope_size(size: Option<&str>) -> Option { + size.map(|value| value.replace('x', "*")) +} + +fn mime_for_path(path: &str) -> &'static str { + match Path::new(path) + .extension() + .and_then(|extension| extension.to_str()) + .map(|extension| extension.to_ascii_lowercase()) + .as_deref() + { + Some("jpg" | "jpeg") => "image/jpeg", + Some("webp") => "image/webp", + _ => "image/png", + } +} + +fn path_to_data_url(path: &str) -> Result { + let bytes = fs::read(path)?; + Ok(format!( + "data:{};base64,{}", + mime_for_path(path), + general_purpose::STANDARD.encode(bytes) + )) +} + +fn parse_dashscope_error(response_text: &str, fallback: &str) -> String { + serde_json::from_str::(response_text) + .ok() + .and_then(|value| { + let code = value + .get("code") + .and_then(Value::as_str) + .unwrap_or_default(); + let message = value + .get("message") + .and_then(Value::as_str) + .unwrap_or_default(); + if code.is_empty() && message.is_empty() { + None + } else { + Some(format!("{code}: {message}")) + } + }) + .unwrap_or_else(|| fallback.to_string()) +} + +fn image_url_from_choices(value: &Value) -> Option { + value + .get("output")? + .get("choices")? + .as_array()? + .iter() + .flat_map(|choice| { + choice + .get("message") + .and_then(|message| message.get("content")) + .and_then(Value::as_array) + .into_iter() + .flatten() + }) + .find_map(|content| { + content + .get("image") + .and_then(Value::as_str) + .map(str::to_string) + }) +} + +fn image_url_from_results(value: &Value) -> Option { + value + .get("output")? + .get("results")? + .as_array()? + .iter() + .find_map(|item| item.get("url").and_then(Value::as_str).map(str::to_string)) +} + +fn parse_image_url(value: &Value) -> Option { + image_url_from_choices(value).or_else(|| image_url_from_results(value)) +} + +fn build_messages(prompt: &str, image_paths: &[String]) -> Result, AppError> { + let mut content = image_paths + .iter() + .map(|path| path_to_data_url(path).map(|image| json!({ "image": image }))) + .collect::, _>>()?; + content.push(json!({ "text": prompt })); + + Ok(vec![json!({ + "role": "user", + "content": content, + })]) +} + +#[async_trait] +impl AiProvider for DashScopeProvider { + async fn generate_image(&self, request: ImageGenerateRequest) -> Result { + if is_sync_multimodal_model(&request.model) { + return self + .run_qwen_image( + &request.prompt, + &request.model, + request.size.as_deref(), + &[], + ) + .await; + } + + self.run_async_image_generation( + &request.prompt, + &request.model, + request.size.as_deref(), + &[], + ) + .await + } + + async fn edit_image(&self, request: ImageEditRequest) -> Result { + if is_sync_multimodal_model(&request.model) { + return self + .run_qwen_image( + &request.prompt, + &request.model, + request.size.as_deref(), + &request.image_paths, + ) + .await; + } + + self.run_async_image_generation( + &request.prompt, + &request.model, + request.size.as_deref(), + &request.image_paths, + ) + .await + } +} + +impl DashScopeProvider { + async fn run_qwen_image( + &self, + prompt: &str, + model: &str, + size: Option<&str>, + image_paths: &[String], + ) -> Result { + let mut parameters = json!({ + "n": 1, + "watermark": false, + "prompt_extend": true, + }); + if is_qwen_image_model(model) { + parameters["negative_prompt"] = json!(" "); + } + if let Some(size) = dashscope_size(size) { + parameters["size"] = json!(size); + } + + let body = json!({ + "model": model, + "input": { + "messages": build_messages(prompt, image_paths)?, + }, + "parameters": parameters, + }); + let response = self + .client + .post(format!( + "{}/services/aigc/multimodal-generation/generation", + self.base_url + )) + .bearer_auth(&self.api_key) + .json(&body) + .send() + .await?; + let status = response.status(); + let response_text = response.text().await?; + if !status.is_success() { + return Err(AppError::Provider(format!( + "DashScope image generation failed ({status}): {}", + parse_dashscope_error(&response_text, &response_text) + ))); + } + + let response_body = serde_json::from_str::(&response_text).map_err(|error| { + AppError::Provider(format!( + "failed to decode DashScope response: {error}; response body: {response_text}" + )) + })?; + let image_url = parse_image_url(&response_body).ok_or_else(|| { + AppError::Provider(format!( + "DashScope response did not include image url: {response_text}" + )) + })?; + + Ok(ImageResult { + mime_type: "image/png".to_string(), + data: ImageData::Url(image_url), + }) + } + + async fn run_async_image_generation( + &self, + prompt: &str, + model: &str, + size: Option<&str>, + image_paths: &[String], + ) -> Result { + let mut parameters = json!({ + "n": 1, + "watermark": false, + }); + if let Some(size) = dashscope_size(size) { + parameters["size"] = json!(size); + } + if model.starts_with("wan2.7-image") { + parameters["thinking_mode"] = json!(true); + } else { + parameters["prompt_extend"] = json!(true); + } + + let body = json!({ + "model": model, + "input": { + "messages": build_messages(prompt, image_paths)?, + }, + "parameters": parameters, + }); + let response = self + .client + .post(format!( + "{}/services/aigc/image-generation/generation", + self.base_url + )) + .bearer_auth(&self.api_key) + .header("X-DashScope-Async", "enable") + .json(&body) + .send() + .await?; + let status = response.status(); + let response_text = response.text().await?; + if !status.is_success() { + return Err(AppError::Provider(format!( + "DashScope async image task failed ({status}): {}", + parse_dashscope_error(&response_text, &response_text) + ))); + } + + let response_body = serde_json::from_str::(&response_text).map_err(|error| { + AppError::Provider(format!( + "failed to decode DashScope task response: {error}; response body: {response_text}" + )) + })?; + let task_id = response_body + .get("output") + .and_then(|output| output.get("task_id")) + .and_then(Value::as_str) + .ok_or_else(|| { + AppError::Provider(format!( + "DashScope task response did not include task_id: {response_text}" + )) + })?; + + self.wait_async_task(task_id).await + } + + async fn wait_async_task(&self, task_id: &str) -> Result { + for _ in 0..40 { + sleep(Duration::from_secs(3)).await; + let response = self + .client + .get(format!("{}/tasks/{task_id}", self.base_url)) + .bearer_auth(&self.api_key) + .send() + .await?; + let status = response.status(); + let response_text = response.text().await?; + if !status.is_success() { + return Err(AppError::Provider(format!( + "DashScope task polling failed ({status}): {}", + parse_dashscope_error(&response_text, &response_text) + ))); + } + + let response_body = serde_json::from_str::(&response_text).map_err(|error| { + AppError::Provider(format!( + "failed to decode DashScope task result: {error}; response body: {response_text}" + )) + })?; + let task_status = response_body + .get("output") + .and_then(|output| output.get("task_status")) + .and_then(Value::as_str) + .unwrap_or_default(); + if task_status == "SUCCEEDED" { + let image_url = parse_image_url(&response_body).ok_or_else(|| { + AppError::Provider(format!( + "DashScope task succeeded but did not include image url: {response_text}" + )) + })?; + return Ok(ImageResult { + mime_type: "image/png".to_string(), + data: ImageData::Url(image_url), + }); + } + if matches!(task_status, "FAILED" | "CANCELED" | "UNKNOWN") { + return Err(AppError::Provider(format!( + "DashScope task ended with status {task_status}: {}", + parse_dashscope_error(&response_text, &response_text) + ))); + } + } + + Err(AppError::Provider( + "DashScope image task polling timed out".to_string(), + )) + } +} diff --git a/src-tauri/src/ai/google_gemini.rs b/src-tauri/src/ai/google_gemini.rs new file mode 100644 index 0000000..74842b6 --- /dev/null +++ b/src-tauri/src/ai/google_gemini.rs @@ -0,0 +1,189 @@ +use async_trait::async_trait; +use base64::{engine::general_purpose, Engine as _}; +use reqwest::Client; +use serde_json::{json, Value}; +use std::{fs, path::Path}; + +use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult}; +use crate::AppError; + +pub struct GoogleGeminiProvider { + client: Client, + base_url: String, + api_key: String, +} + +impl GoogleGeminiProvider { + pub fn new(base_url: String, api_key: String) -> Self { + Self { + client: Client::new(), + base_url: base_url.trim_end_matches('/').to_string(), + api_key, + } + } +} + +fn mime_for_path(path: &str) -> &'static str { + match Path::new(path) + .extension() + .and_then(|extension| extension.to_str()) + .map(|extension| extension.to_ascii_lowercase()) + .as_deref() + { + Some("jpg" | "jpeg") => "image/jpeg", + Some("webp") => "image/webp", + _ => "image/png", + } +} + +fn size_to_aspect_ratio(size: Option<&str>) -> Option<&'static str> { + let (width, height) = size?.split_once('x')?; + let width = width.parse::().ok()?; + let height = height.parse::().ok()?; + if width <= 0.0 || height <= 0.0 { + return None; + } + + let ratio = width / height; + if (ratio - 1.0).abs() < 0.08 { + Some("1:1") + } else if (ratio - 16.0 / 9.0).abs() < 0.12 { + Some("16:9") + } else if (ratio - 9.0 / 16.0).abs() < 0.12 { + Some("9:16") + } else if (ratio - 4.0 / 3.0).abs() < 0.12 { + Some("4:3") + } else if (ratio - 3.0 / 4.0).abs() < 0.12 { + Some("3:4") + } else { + None + } +} + +fn image_part(path: &str) -> Result { + let bytes = fs::read(path)?; + Ok(json!({ + "inline_data": { + "mime_type": mime_for_path(path), + "data": general_purpose::STANDARD.encode(bytes), + } + })) +} + +fn inline_data_from_part(part: &Value) -> Option<(&str, &str)> { + let inline_data = part.get("inlineData").or_else(|| part.get("inline_data"))?; + let data = inline_data.get("data").and_then(Value::as_str)?; + let mime_type = inline_data + .get("mimeType") + .or_else(|| inline_data.get("mime_type")) + .and_then(Value::as_str) + .unwrap_or("image/png"); + Some((mime_type, data)) +} + +fn parse_gemini_response(response_text: &str) -> Result { + let response_body = serde_json::from_str::(response_text).map_err(|error| { + AppError::Provider(format!( + "failed to decode Gemini image response: {error}; response body: {response_text}" + )) + })?; + let parts = response_body + .get("candidates") + .and_then(Value::as_array) + .into_iter() + .flatten() + .flat_map(|candidate| { + candidate + .get("content") + .and_then(|content| content.get("parts")) + .and_then(Value::as_array) + .into_iter() + .flatten() + }); + + for part in parts { + if let Some((mime_type, data)) = inline_data_from_part(part) { + return Ok(ImageResult { + mime_type: mime_type.to_string(), + data: ImageData::Base64(data.to_string()), + }); + } + } + + Err(AppError::Provider(format!( + "Gemini response did not include image data: {response_text}" + ))) +} + +#[async_trait] +impl AiProvider for GoogleGeminiProvider { + async fn generate_image(&self, request: ImageGenerateRequest) -> Result { + self.generate_with_parts( + &request.prompt, + &request.model, + request.size.as_deref(), + &[], + ) + .await + } + + async fn edit_image(&self, request: ImageEditRequest) -> Result { + self.generate_with_parts( + &request.prompt, + &request.model, + request.size.as_deref(), + &request.image_paths, + ) + .await + } +} + +impl GoogleGeminiProvider { + async fn generate_with_parts( + &self, + prompt: &str, + model: &str, + size: Option<&str>, + image_paths: &[String], + ) -> Result { + let mut parts = vec![json!({ "text": prompt })]; + parts.extend( + image_paths + .iter() + .map(|path| image_part(path)) + .collect::, _>>()?, + ); + let mut generation_config = json!({ + "responseModalities": ["TEXT", "IMAGE"], + }); + if let Some(aspect_ratio) = size_to_aspect_ratio(size) { + generation_config["imageConfig"] = json!({ + "aspectRatio": aspect_ratio, + }); + } + + let body = json!({ + "contents": [{ + "role": "user", + "parts": parts, + }], + "generationConfig": generation_config, + }); + let response = self + .client + .post(format!("{}/models/{model}:generateContent", self.base_url)) + .header("x-goog-api-key", &self.api_key) + .json(&body) + .send() + .await?; + let status = response.status(); + let response_text = response.text().await?; + if !status.is_success() { + return Err(AppError::Provider(format!( + "Gemini image generation failed ({status}): {response_text}" + ))); + } + + parse_gemini_response(&response_text) + } +} diff --git a/src-tauri/src/ai/mod.rs b/src-tauri/src/ai/mod.rs index b0c2621..d74eef2 100644 --- a/src-tauri/src/ai/mod.rs +++ b/src-tauri/src/ai/mod.rs @@ -1,2 +1,6 @@ +pub mod dashscope; +pub mod google_gemini; pub mod openai_compatible; pub mod provider; +pub mod seedream; +pub mod tencent_hunyuan; diff --git a/src-tauri/src/ai/openai_compatible.rs b/src-tauri/src/ai/openai_compatible.rs index 0768329..bfee0f6 100644 --- a/src-tauri/src/ai/openai_compatible.rs +++ b/src-tauri/src/ai/openai_compatible.rs @@ -32,11 +32,12 @@ fn parse_image_response(response_text: &str) -> Result { )); } - let response_body: ImageResponseBody = serde_json::from_str(response_text).map_err(|error| { - AppError::Provider(format!( - "failed to decode image response: {error}; response body: {response_text}" - )) - })?; + let response_body: ImageResponseBody = + serde_json::from_str(response_text).map_err(|error| { + AppError::Provider(format!( + "failed to decode image response: {error}; response body: {response_text}" + )) + })?; let image_data = response_body .data .into_iter() @@ -45,7 +46,9 @@ fn parse_image_response(response_text: &str) -> Result { .map(ImageData::Base64) .or_else(|| item.url.map(ImageData::Url)) }) - .ok_or_else(|| AppError::Provider("image response did not include b64_json or url".to_string()))?; + .ok_or_else(|| { + AppError::Provider("image response did not include b64_json or url".to_string()) + })?; Ok(ImageResult { mime_type: "image/png".to_string(), @@ -96,8 +99,13 @@ impl AiProvider for OpenAiCompatibleProvider { let status = response.status(); if !status.is_success() { - let message = response.text().await.unwrap_or_else(|_| "request failed".to_string()); - return Err(AppError::Provider(format!("image generation failed ({status}): {message}"))); + let message = response + .text() + .await + .unwrap_or_else(|_| "request failed".to_string()); + return Err(AppError::Provider(format!( + "image generation failed ({status}): {message}" + ))); } let response_text = response.text().await?; @@ -134,7 +142,9 @@ impl AiProvider for OpenAiCompatibleProvider { Some("webp") => "image/webp", _ => "image/png", }; - let part = multipart::Part::bytes(bytes).file_name(file_name).mime_str(mime)?; + let part = multipart::Part::bytes(bytes) + .file_name(file_name) + .mime_str(mime)?; form = form.part("image[]", part); } @@ -148,8 +158,13 @@ impl AiProvider for OpenAiCompatibleProvider { let status = response.status(); if !status.is_success() { - let message = response.text().await.unwrap_or_else(|_| "request failed".to_string()); - return Err(AppError::Provider(format!("image edit failed ({status}): {message}"))); + let message = response + .text() + .await + .unwrap_or_else(|_| "request failed".to_string()); + return Err(AppError::Provider(format!( + "image edit failed ({status}): {message}" + ))); } let response_text = response.text().await?; diff --git a/src-tauri/src/ai/seedream.rs b/src-tauri/src/ai/seedream.rs new file mode 100644 index 0000000..f054689 --- /dev/null +++ b/src-tauri/src/ai/seedream.rs @@ -0,0 +1,180 @@ +use async_trait::async_trait; +use base64::{engine::general_purpose, Engine as _}; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::{fs, path::Path}; + +use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult}; +use crate::AppError; + +pub struct SeedreamProvider { + client: Client, + base_url: String, + api_key: String, +} + +impl SeedreamProvider { + pub fn new(base_url: String, api_key: String) -> Self { + Self { + client: Client::new(), + base_url: base_url.trim_end_matches('/').to_string(), + api_key, + } + } +} + +#[derive(Debug, Serialize)] +struct SeedreamRequestBody<'a> { + model: &'a str, + prompt: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + size: Option<&'a str>, + response_format: &'a str, + stream: bool, + watermark: bool, + #[serde(skip_serializing_if = "Option::is_none")] + image: Option>, +} + +#[derive(Debug, Deserialize)] +struct SeedreamResponseBody { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct SeedreamResponseItem { + b64_json: Option, + url: Option, +} + +fn mime_for_path(path: &str) -> &'static str { + match Path::new(path) + .extension() + .and_then(|extension| extension.to_str()) + .map(|extension| extension.to_ascii_lowercase()) + .as_deref() + { + Some("jpg" | "jpeg") => "image/jpeg", + Some("webp") => "image/webp", + _ => "image/png", + } +} + +fn path_to_data_url(path: &str) -> Result { + let bytes = fs::read(path)?; + Ok(format!( + "data:{};base64,{}", + mime_for_path(path), + general_purpose::STANDARD.encode(bytes) + )) +} + +fn parse_seedream_response(response_text: &str) -> Result { + if response_text.trim_start().starts_with(" Result { + let body = SeedreamRequestBody { + model: &request.model, + prompt: &request.prompt, + size: request.size.as_deref(), + response_format: "b64_json", + stream: false, + watermark: false, + image: None, + }; + + let response = self + .client + .post(format!("{}/images/generations", self.base_url)) + .bearer_auth(&self.api_key) + .json(&body) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let message = response + .text() + .await + .unwrap_or_else(|_| "request failed".to_string()); + return Err(AppError::Provider(format!( + "Seedream image generation failed ({status}): {message}" + ))); + } + + let response_text = response.text().await?; + parse_seedream_response(&response_text) + } + + async fn edit_image(&self, request: ImageEditRequest) -> Result { + let images = request + .image_paths + .iter() + .map(|path| path_to_data_url(path)) + .collect::, _>>()?; + let body = SeedreamRequestBody { + model: &request.model, + prompt: &request.prompt, + size: request.size.as_deref(), + response_format: "b64_json", + stream: false, + watermark: false, + image: Some(images), + }; + + let response = self + .client + .post(format!("{}/images/generations", self.base_url)) + .bearer_auth(&self.api_key) + .json(&body) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let message = response + .text() + .await + .unwrap_or_else(|_| "request failed".to_string()); + return Err(AppError::Provider(format!( + "Seedream image edit failed ({status}): {message}" + ))); + } + + let response_text = response.text().await?; + parse_seedream_response(&response_text) + } +} diff --git a/src-tauri/src/ai/tencent_hunyuan.rs b/src-tauri/src/ai/tencent_hunyuan.rs new file mode 100644 index 0000000..6484c06 --- /dev/null +++ b/src-tauri/src/ai/tencent_hunyuan.rs @@ -0,0 +1,463 @@ +use async_trait::async_trait; +use base64::{engine::general_purpose, Engine as _}; +use chrono::{TimeZone, Utc}; +use hmac::{Hmac, Mac}; +use reqwest::{Client, Url}; +use serde_json::{json, Value}; +use sha2::{Digest, Sha256}; +use std::{fs, time::Duration}; +use tokio::time::sleep; + +use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult}; +use crate::AppError; + +type HmacSha256 = Hmac; + +pub struct TencentHunyuanProvider { + client: Client, + base_url: String, + secret_id: String, + secret_key: String, +} + +struct TencentApiConfig { + service: &'static str, + version: &'static str, + lite_action: &'static str, + rapid_action: &'static str, + submit_action: &'static str, + query_action: &'static str, +} + +impl TencentHunyuanProvider { + pub fn new(base_url: String, api_key: String) -> Result { + let (secret_id, secret_key) = parse_secret_pair(&api_key)?; + Ok(Self { + client: Client::new(), + base_url: base_url.trim_end_matches('/').to_string(), + secret_id, + secret_key, + }) + } +} + +fn parse_secret_pair(api_key: &str) -> Result<(String, String), AppError> { + let (secret_id, secret_key) = api_key.split_once(':').ok_or_else(|| { + AppError::Provider("腾讯云 API Key 需要填写为 SecretId:SecretKey".to_string()) + })?; + let secret_id = secret_id.trim(); + let secret_key = secret_key.trim(); + if secret_id.is_empty() || secret_key.is_empty() { + return Err(AppError::Provider( + "腾讯云 API Key 需要填写为 SecretId:SecretKey".to_string(), + )); + } + + Ok((secret_id.to_string(), secret_key.to_string())) +} + +fn hmac_sha256(key: &[u8], message: &str) -> Result, AppError> { + let mut mac = HmacSha256::new_from_slice(key) + .map_err(|error| AppError::Provider(format!("failed to create HMAC: {error}")))?; + mac.update(message.as_bytes()); + Ok(mac.finalize().into_bytes().to_vec()) +} + +fn sha256_hex(value: &str) -> String { + hex::encode(Sha256::digest(value.as_bytes())) +} + +fn host_from_base_url(base_url: &str) -> Result { + let url = Url::parse(base_url) + .map_err(|error| AppError::Provider(format!("腾讯云 Base URL 不是有效 URL: {error}")))?; + url.host_str() + .map(str::to_string) + .ok_or_else(|| AppError::Provider("腾讯云 Base URL 缺少 host".to_string())) +} + +fn api_config(host: &str) -> TencentApiConfig { + if host == "aiart.tencentcloudapi.com" { + TencentApiConfig { + service: "aiart", + version: "2022-12-29", + lite_action: "TextToImageLite", + rapid_action: "TextToImageRapid", + submit_action: "SubmitTextToImageJob", + query_action: "QueryTextToImageJob", + } + } else { + TencentApiConfig { + service: "hunyuan", + version: "2023-09-01", + lite_action: "TextToImageLite", + rapid_action: "TextToImageLite", + submit_action: "SubmitHunyuanImageJob", + query_action: "QueryHunyuanImageJob", + } + } +} + +fn hunyuan_resolution(size: Option<&str>, lite: bool, has_reference: bool) -> Option { + let size = size?; + let (width, height) = size.split_once('x')?; + let width = width.parse::().ok()?; + let height = height.parse::().ok()?; + if width <= 0.0 || height <= 0.0 { + return None; + } + + let ratio = width / height; + let value = if (ratio - 1.0).abs() < 0.12 { + "1024:1024" + } else if (ratio - 0.75).abs() < 0.12 { + "768:1024" + } else if (ratio - 1.333).abs() < 0.12 { + "1024:768" + } else if ratio < 0.65 { + if lite && !has_reference { + "1080:1920" + } else { + "720:1280" + } + } else if ratio > 1.55 { + if lite && !has_reference { + "1920:1080" + } else { + "1280:720" + } + } else if width < height { + "768:1024" + } else { + "1024:768" + }; + + if has_reference && !matches!(value, "1024:1024" | "768:1024" | "1024:768") { + return None; + } + + Some(value.to_string()) +} + +fn first_image_base64(image_paths: &[String]) -> Result, AppError> { + image_paths + .first() + .map(|path| fs::read(path).map(|bytes| general_purpose::STANDARD.encode(bytes))) + .transpose() + .map_err(AppError::from) +} + +fn parse_tencent_error(response: &Value) -> Option { + let error = response.get("Error")?; + let code = error + .get("Code") + .and_then(Value::as_str) + .unwrap_or_default(); + let message = error + .get("Message") + .and_then(Value::as_str) + .unwrap_or_default(); + Some(format!("{code}: {message}")) +} + +fn result_image_url(response: &Value) -> Option { + let result = response.get("ResultImage")?; + if let Some(url) = result.as_str().filter(|value| value.starts_with("http")) { + return Some(url.to_string()); + } + + result.as_array()?.first().and_then(|image| { + image + .as_str() + .filter(|value| value.starts_with("http")) + .map(str::to_string) + .or_else(|| { + image + .get("Url") + .or_else(|| image.get("URL")) + .and_then(Value::as_str) + .map(str::to_string) + }) + }) +} + +fn result_image_base64(response: &Value) -> Option { + let result = response.get("ResultImage")?; + if let Some(value) = result.as_str().filter(|value| !value.starts_with("http")) { + return Some(value.to_string()); + } + + result.as_array()?.first().and_then(|image| { + image + .as_str() + .filter(|value| !value.starts_with("http")) + .map(str::to_string) + .or_else(|| { + image + .get("Base64") + .or_else(|| image.get("ImageBase64")) + .and_then(Value::as_str) + .map(str::to_string) + }) + }) +} + +#[async_trait] +impl AiProvider for TencentHunyuanProvider { + async fn generate_image(&self, request: ImageGenerateRequest) -> Result { + if request.model == "hunyuan-image-lite" { + return self + .run_text_to_image_lite(&request.prompt, request.size.as_deref()) + .await; + } + if request.model == "hunyuan-image-2.0" { + return self + .run_text_to_image_rapid(&request.prompt, request.size.as_deref(), &[]) + .await; + } + + self.run_async_image_job(&request.prompt, request.size.as_deref(), &[]) + .await + } + + async fn edit_image(&self, request: ImageEditRequest) -> Result { + if request.model == "hunyuan-image-2.0" { + return self + .run_text_to_image_rapid( + &request.prompt, + request.size.as_deref(), + &request.image_paths, + ) + .await; + } + + self.run_async_image_job( + &request.prompt, + request.size.as_deref(), + &request.image_paths, + ) + .await + } +} + +impl TencentHunyuanProvider { + async fn call(&self, action: &str, payload: Value) -> Result { + let region = "ap-guangzhou"; + let host = host_from_base_url(&self.base_url)?; + let config = api_config(&host); + let timestamp = Utc::now().timestamp(); + let date = Utc + .timestamp_opt(timestamp, 0) + .single() + .ok_or_else(|| AppError::Provider("invalid Tencent Cloud timestamp".to_string()))? + .format("%Y-%m-%d") + .to_string(); + let payload_text = serde_json::to_string(&payload).map_err(|error| { + AppError::Provider(format!("failed to encode Tencent Cloud request: {error}")) + })?; + let content_type = "application/json; charset=utf-8"; + let signed_headers = "content-type;host"; + let canonical_headers = format!("content-type:{content_type}\nhost:{host}\n"); + let canonical_request = format!( + "POST\n/\n\n{canonical_headers}\n{signed_headers}\n{}", + sha256_hex(&payload_text) + ); + let credential_scope = format!("{}/{}/tc3_request", date, config.service); + let string_to_sign = format!( + "TC3-HMAC-SHA256\n{timestamp}\n{credential_scope}\n{}", + sha256_hex(&canonical_request) + ); + let secret_date = hmac_sha256(format!("TC3{}", self.secret_key).as_bytes(), &date)?; + let secret_service = hmac_sha256(&secret_date, config.service)?; + let secret_signing = hmac_sha256(&secret_service, "tc3_request")?; + let signature = hex::encode(hmac_sha256(&secret_signing, &string_to_sign)?); + let authorization = format!( + "TC3-HMAC-SHA256 Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}", + self.secret_id + ); + + let response = self + .client + .post(&self.base_url) + .header("Authorization", authorization) + .header("Content-Type", content_type) + .header("Host", host) + .header("X-TC-Action", action) + .header("X-TC-Timestamp", timestamp.to_string()) + .header("X-TC-Version", config.version) + .header("X-TC-Region", region) + .body(payload_text) + .send() + .await?; + let status = response.status(); + let response_text = response.text().await?; + if !status.is_success() { + return Err(AppError::Provider(format!( + "Tencent Hunyuan request failed ({status}): {response_text}" + ))); + } + + let response_body = serde_json::from_str::(&response_text).map_err(|error| { + AppError::Provider(format!( + "failed to decode Tencent Hunyuan response: {error}; response body: {response_text}" + )) + })?; + let inner = response_body.get("Response").ok_or_else(|| { + AppError::Provider(format!( + "Tencent Hunyuan response missing Response: {response_text}" + )) + })?; + if let Some(message) = parse_tencent_error(inner) { + return Err(AppError::Provider(format!( + "Tencent Hunyuan {action} failed: {message}" + ))); + } + + Ok(inner.clone()) + } + + async fn run_text_to_image_lite( + &self, + prompt: &str, + size: Option<&str>, + ) -> Result { + let mut payload = json!({ + "Prompt": prompt, + "RspImgType": "url", + "LogoAdd": 0, + }); + if let Some(resolution) = hunyuan_resolution(size, true, false) { + payload["Resolution"] = json!(resolution); + } + let host = host_from_base_url(&self.base_url)?; + let action = api_config(&host).lite_action; + let response = self.call(action, payload).await?; + let image_url = result_image_url(&response).ok_or_else(|| { + AppError::Provider(format!( + "Tencent Hunyuan TextToImageLite did not include image url: {response}" + )) + })?; + + Ok(ImageResult { + mime_type: "image/png".to_string(), + data: ImageData::Url(image_url), + }) + } + + async fn run_text_to_image_rapid( + &self, + prompt: &str, + size: Option<&str>, + image_paths: &[String], + ) -> Result { + let has_reference = !image_paths.is_empty(); + let mut payload = json!({ + "Prompt": prompt, + "RspImgType": "url", + "LogoAdd": 0, + }); + if let Some(resolution) = hunyuan_resolution(size, false, has_reference) { + payload["Resolution"] = json!(resolution); + } + if let Some(image_base64) = first_image_base64(image_paths)? { + payload["Image"] = json!({ + "Base64": image_base64, + }); + } + let host = host_from_base_url(&self.base_url)?; + let action = api_config(&host).rapid_action; + let response = self.call(action, payload).await?; + let image_url = result_image_url(&response).ok_or_else(|| { + AppError::Provider(format!( + "Tencent Hunyuan TextToImageRapid did not include image url: {response}" + )) + })?; + + Ok(ImageResult { + mime_type: "image/png".to_string(), + data: ImageData::Url(image_url), + }) + } + + async fn run_async_image_job( + &self, + prompt: &str, + size: Option<&str>, + image_paths: &[String], + ) -> Result { + let has_reference = !image_paths.is_empty(); + let mut payload = json!({ + "Prompt": prompt, + "Num": 1, + "Revise": 1, + "LogoAdd": 0, + }); + if let Some(resolution) = hunyuan_resolution(size, false, has_reference) { + payload["Resolution"] = json!(resolution); + } + let host = host_from_base_url(&self.base_url)?; + let config = api_config(&host); + if let Some(image_base64) = first_image_base64(image_paths)? { + if config.service == "aiart" { + payload["Images"] = json!([image_base64]); + } else { + payload["ContentImage"] = json!({ + "ImageBase64": image_base64, + }); + } + } + let response = self.call(config.submit_action, payload).await?; + let job_id = response + .get("JobId") + .and_then(Value::as_str) + .ok_or_else(|| { + AppError::Provider(format!( + "Tencent Hunyuan submit response did not include JobId: {response}" + )) + })?; + + self.wait_image_job(job_id).await + } + + async fn wait_image_job(&self, job_id: &str) -> Result { + for _ in 0..40 { + sleep(Duration::from_secs(3)).await; + let host = host_from_base_url(&self.base_url)?; + let action = api_config(&host).query_action; + let response = self.call(action, json!({ "JobId": job_id })).await?; + let job_status = response + .get("JobStatusCode") + .and_then(Value::as_str) + .unwrap_or_default(); + if job_status == "5" { + if let Some(image_base64) = result_image_base64(&response) { + return Ok(ImageResult { + mime_type: "image/png".to_string(), + data: ImageData::Base64(image_base64), + }); + } + let image_url = result_image_url(&response).ok_or_else(|| { + AppError::Provider(format!( + "Tencent Hunyuan job succeeded but did not include image: {response}" + )) + })?; + return Ok(ImageResult { + mime_type: "image/png".to_string(), + data: ImageData::Url(image_url), + }); + } + if job_status == "4" { + let message = response + .get("JobErrorMsg") + .and_then(Value::as_str) + .unwrap_or("unknown error"); + return Err(AppError::Provider(format!( + "Tencent Hunyuan image job failed: {message}" + ))); + } + } + + Err(AppError::Provider( + "Tencent Hunyuan image job polling timed out".to_string(), + )) + } +} diff --git a/src-tauri/src/commands/dialog.rs b/src-tauri/src/commands/dialog.rs index 558d07f..7edc63d 100644 --- a/src-tauri/src/commands/dialog.rs +++ b/src-tauri/src/commands/dialog.rs @@ -11,7 +11,11 @@ pub async fn pick_material_images(app: tauri::AppHandle) -> Result, let paths = files .unwrap_or_default() .into_iter() - .filter_map(|file_path| file_path.as_path().map(|path| path.to_string_lossy().to_string())) + .filter_map(|file_path| { + file_path + .as_path() + .map(|path| path.to_string_lossy().to_string()) + }) .collect(); Ok(paths) diff --git a/src-tauri/src/commands/generation.rs b/src-tauri/src/commands/generation.rs index 64c9680..7e5f5e5 100644 --- a/src-tauri/src/commands/generation.rs +++ b/src-tauri/src/commands/generation.rs @@ -2,16 +2,21 @@ use tauri::{AppHandle, State}; use crate::{ ai::{ + dashscope::DashScopeProvider, + google_gemini::GoogleGeminiProvider, openai_compatible::OpenAiCompatibleProvider, provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest}, + seedream::SeedreamProvider, + tencent_hunyuan::TencentHunyuanProvider, }, db::{ - models::{CreateGenerationTaskInput, GenerateImageInput, GenerateImageOutput, GenerationTask}, + models::{ + CreateGenerationTaskInput, GenerateImageInput, GenerateImageOutput, GenerationTask, + }, repository, }, state::AppState, - storage, - AppError, + storage, AppError, }; #[tauri::command] @@ -30,21 +35,66 @@ pub async fn generate_image( ) -> Result { let provider = repository::get_provider_secret(&state.db, &input.provider_id).await?; if !provider.enabled { - return Err(AppError::Provider(format!("provider {} is disabled", provider.name))); + return Err(AppError::Provider(format!( + "provider {} is disabled", + provider.name + ))); } - if provider.kind != "openai-compatible" { + if provider.kind != "openai" + && provider.kind != "openai-compatible" + && provider.kind != "volcengine-ark" + && provider.kind != "dashscope" + && provider.kind != "tencent-hunyuan" + && provider.kind != "google-gemini" + { return Err(AppError::Provider(format!( "provider kind {} is not supported yet", provider.kind ))); } - if !provider.base_url.trim_end_matches('/').ends_with("/v1") { + if (provider.kind == "openai" || provider.kind == "openai-compatible") + && !provider.base_url.trim_end_matches('/').ends_with("/v1") + { return Err(AppError::Provider( "Base URL 看起来不是 API 地址。OpenAI-compatible 地址通常需要以 /v1 结尾,例如 https://api.openai.com/v1 或 https://你的中转站域名/v1".to_string(), )); } + if provider.kind == "volcengine-ark" + && !provider.base_url.trim_end_matches('/').ends_with("/api/v3") + { + return Err(AppError::Provider( + "火山方舟 Seedream 的 Base URL 通常需要以 /api/v3 结尾,例如 https://ark.cn-beijing.volces.com/api/v3".to_string(), + )); + } + if provider.kind == "dashscope" && !provider.base_url.trim_end_matches('/').ends_with("/api/v1") + { + return Err(AppError::Provider( + "阿里云百炼 DashScope 的 Base URL 通常需要以 /api/v1 结尾,例如 https://dashscope.aliyuncs.com/api/v1".to_string(), + )); + } + if provider.kind == "google-gemini" + && !provider.base_url.trim_end_matches('/').ends_with("/v1beta") + { + return Err(AppError::Provider( + "Google Gemini / Nano Banana 的 Base URL 通常填写 https://generativelanguage.googleapis.com/v1beta".to_string(), + )); + } + if provider.kind == "tencent-hunyuan" + && !provider + .base_url + .trim_end_matches('/') + .ends_with("aiart.tencentcloudapi.com") + && !provider + .base_url + .trim_end_matches('/') + .ends_with("hunyuan.tencentcloudapi.com") + { + return Err(AppError::Provider( + "腾讯混元图像的 Base URL 通常填写 https://aiart.tencentcloudapi.com".to_string(), + )); + } let api_key = provider .api_key_encrypted @@ -76,30 +126,121 @@ pub async fn generate_image( ) .await?; - let ai_provider = OpenAiCompatibleProvider::new(provider.base_url, api_key); - let image_result = match if input.image_paths.is_empty() { - ai_provider - .generate_image(ImageGenerateRequest { - prompt: input.prompt, - model, - size: input.size, - quality: input.quality, - }) - .await + let image_result = match if provider.kind == "volcengine-ark" { + let ai_provider = SeedreamProvider::new(provider.base_url, api_key); + if input.image_paths.is_empty() { + ai_provider + .generate_image(ImageGenerateRequest { + prompt: input.prompt, + model, + size: input.size, + quality: input.quality, + }) + .await + } else { + ai_provider + .edit_image(ImageEditRequest { + prompt: input.prompt, + model, + size: input.size, + quality: input.quality, + image_paths: input.image_paths, + }) + .await + } + } else if provider.kind == "dashscope" { + let ai_provider = DashScopeProvider::new(provider.base_url, api_key); + if input.image_paths.is_empty() { + ai_provider + .generate_image(ImageGenerateRequest { + prompt: input.prompt, + model, + size: input.size, + quality: input.quality, + }) + .await + } else { + ai_provider + .edit_image(ImageEditRequest { + prompt: input.prompt, + model, + size: input.size, + quality: input.quality, + image_paths: input.image_paths, + }) + .await + } + } else if provider.kind == "tencent-hunyuan" { + let ai_provider = TencentHunyuanProvider::new(provider.base_url, api_key)?; + if input.image_paths.is_empty() { + ai_provider + .generate_image(ImageGenerateRequest { + prompt: input.prompt, + model, + size: input.size, + quality: input.quality, + }) + .await + } else { + ai_provider + .edit_image(ImageEditRequest { + prompt: input.prompt, + model, + size: input.size, + quality: input.quality, + image_paths: input.image_paths, + }) + .await + } + } else if provider.kind == "google-gemini" { + let ai_provider = GoogleGeminiProvider::new(provider.base_url, api_key); + if input.image_paths.is_empty() { + ai_provider + .generate_image(ImageGenerateRequest { + prompt: input.prompt, + model, + size: input.size, + quality: input.quality, + }) + .await + } else { + ai_provider + .edit_image(ImageEditRequest { + prompt: input.prompt, + model, + size: input.size, + quality: input.quality, + image_paths: input.image_paths, + }) + .await + } } else { - ai_provider - .edit_image(ImageEditRequest { - prompt: input.prompt, - model, - size: input.size, - quality: input.quality, - image_paths: input.image_paths, - }) - .await + let ai_provider = OpenAiCompatibleProvider::new(provider.base_url, api_key); + if input.image_paths.is_empty() { + ai_provider + .generate_image(ImageGenerateRequest { + prompt: input.prompt, + model, + size: input.size, + quality: input.quality, + }) + .await + } else { + ai_provider + .edit_image(ImageEditRequest { + prompt: input.prompt, + model, + size: input.size, + quality: input.quality, + image_paths: input.image_paths, + }) + .await + } } { Ok(result) => result, Err(error) => { - repository::mark_generation_task_failed(&state.db, &task.id, &error.to_string()).await?; + repository::mark_generation_task_failed(&state.db, &task.id, &error.to_string()) + .await?; return Err(error); } }; @@ -108,7 +249,8 @@ pub async fn generate_image( ImageData::Base64(data_base64) => storage::decode_base64_image(data_base64)?, ImageData::Url(url) => reqwest::get(url).await?.bytes().await?.to_vec(), }; - let stored_image = storage::save_generated_image_bytes(&app, &image_bytes, &image_result.mime_type)?; + let stored_image = + storage::save_generated_image_bytes(&app, &image_bytes, &image_result.mime_type)?; let file_path = stored_image.file_path.to_string_lossy().to_string(); let asset = repository::create_image_asset( &state.db, diff --git a/src-tauri/src/commands/provider.rs b/src-tauri/src/commands/provider.rs index f02e814..4475124 100644 --- a/src-tauri/src/commands/provider.rs +++ b/src-tauri/src/commands/provider.rs @@ -1,3 +1,6 @@ +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::json; use tauri::State; use crate::{ @@ -7,12 +10,29 @@ use crate::{ }; #[tauri::command] -pub async fn list_providers(state: State<'_, AppState>) -> Result, AppError> { +pub async fn list_providers( + state: State<'_, AppState>, +) -> Result, AppError> { repository::list_providers(&state.db).await } #[tauri::command] -pub async fn upsert_provider(state: State<'_, AppState>, input: UpsertProviderInput) -> Result<(), AppError> { +pub async fn upsert_provider( + state: State<'_, AppState>, + input: UpsertProviderInput, +) -> Result<(), AppError> { + if input + .api_key + .as_deref() + .map(str::trim) + .unwrap_or_default() + .is_empty() + { + return Err(AppError::Provider( + "请先填写 API Key,再保存配置".to_string(), + )); + } + repository::upsert_provider(&state.db, input).await } @@ -20,3 +40,306 @@ pub async fn upsert_provider(state: State<'_, AppState>, input: UpsertProviderIn pub async fn delete_provider(state: State<'_, AppState>, id: String) -> Result<(), AppError> { repository::delete_provider(&state.db, &id).await } + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ProviderModel { + pub id: String, + pub owned_by: Option, +} + +#[derive(Debug, Deserialize)] +struct ModelsResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct ModelItem { + id: String, + owned_by: Option, +} + +#[derive(Debug, Deserialize)] +struct ProviderCapabilities { + image_models: Option>, + selected_image_models: Option>, +} + +fn is_image_model(model_id: &str) -> bool { + let id = model_id.to_ascii_lowercase(); + [ + "gpt-image", + "image", + "dall-e", + "dalle", + "imagen", + "flux", + "qwen-image", + "wan", + "z-image", + "hunyuan-image", + "gemini", + "seedream", + "seededit", + "stable-diffusion", + "sd-", + "midjourney", + "recraft", + ] + .iter() + .any(|marker| id.contains(marker)) +} + +fn default_seedream_models() -> Vec { + vec![ + ProviderModel { + id: "doubao-seedream-4-5-251128".to_string(), + owned_by: Some("volcengine".to_string()), + }, + ProviderModel { + id: "doubao-seedream-4-0-250828".to_string(), + owned_by: Some("volcengine".to_string()), + }, + ] +} + +fn default_dashscope_models() -> Vec { + vec![ + ProviderModel { + id: "qwen-image-2.0-pro".to_string(), + owned_by: Some("alibaba-cloud".to_string()), + }, + ProviderModel { + id: "qwen-image-2.0".to_string(), + owned_by: Some("alibaba-cloud".to_string()), + }, + ProviderModel { + id: "qwen-image-plus".to_string(), + owned_by: Some("alibaba-cloud".to_string()), + }, + ProviderModel { + id: "qwen-image".to_string(), + owned_by: Some("alibaba-cloud".to_string()), + }, + ProviderModel { + id: "wan2.7-image-pro".to_string(), + owned_by: Some("alibaba-cloud".to_string()), + }, + ProviderModel { + id: "wan2.7-image".to_string(), + owned_by: Some("alibaba-cloud".to_string()), + }, + ProviderModel { + id: "z-image-turbo".to_string(), + owned_by: Some("alibaba-cloud".to_string()), + }, + ] +} + +fn default_tencent_hunyuan_models() -> Vec { + vec![ + ProviderModel { + id: "hunyuan-image-3.0".to_string(), + owned_by: Some("tencent-cloud".to_string()), + }, + ProviderModel { + id: "hunyuan-image-2.0".to_string(), + owned_by: Some("tencent-cloud".to_string()), + }, + ProviderModel { + id: "hunyuan-image-lite".to_string(), + owned_by: Some("tencent-cloud".to_string()), + }, + ] +} + +fn default_google_gemini_models() -> Vec { + vec![ + ProviderModel { + id: "gemini-2.5-flash-image".to_string(), + owned_by: Some("google".to_string()), + }, + ProviderModel { + id: "gemini-3.1-flash-image-preview".to_string(), + owned_by: Some("google".to_string()), + }, + ProviderModel { + id: "gemini-3-pro-image-preview".to_string(), + owned_by: Some("google".to_string()), + }, + ] +} + +#[tauri::command] +pub async fn fetch_provider_models( + state: State<'_, AppState>, + input: UpsertProviderInput, +) -> Result, AppError> { + if input.kind != "openai" + && input.kind != "openai-compatible" + && input.kind != "volcengine-ark" + && input.kind != "dashscope" + && input.kind != "tencent-hunyuan" + && input.kind != "google-gemini" + { + return Err(AppError::Provider(format!( + "API 分类 {} 暂未接入模型列表获取", + input.kind + ))); + } + + if (input.kind == "openai" || input.kind == "openai-compatible") + && !input.base_url.trim_end_matches('/').ends_with("/v1") + { + return Err(AppError::Provider( + "Base URL 看起来不是 API 地址。OpenAI-compatible 地址通常需要以 /v1 结尾。".to_string(), + )); + } + if input.kind == "volcengine-ark" && !input.base_url.trim_end_matches('/').ends_with("/api/v3") + { + return Err(AppError::Provider( + "火山方舟 Seedream 的 Base URL 通常需要以 /api/v3 结尾。".to_string(), + )); + } + if input.kind == "dashscope" && !input.base_url.trim_end_matches('/').ends_with("/api/v1") { + return Err(AppError::Provider( + "阿里云百炼 DashScope 的 Base URL 通常需要以 /api/v1 结尾。".to_string(), + )); + } + if input.kind == "google-gemini" && !input.base_url.trim_end_matches('/').ends_with("/v1beta") { + return Err(AppError::Provider( + "Google Gemini / Nano Banana 的 Base URL 通常填写 https://generativelanguage.googleapis.com/v1beta".to_string(), + )); + } + if input.kind == "tencent-hunyuan" + && !input + .base_url + .trim_end_matches('/') + .ends_with("aiart.tencentcloudapi.com") + && !input + .base_url + .trim_end_matches('/') + .ends_with("hunyuan.tencentcloudapi.com") + { + return Err(AppError::Provider( + "腾讯混元图像的 Base URL 通常填写 https://aiart.tencentcloudapi.com".to_string(), + )); + } + + let saved_provider = repository::get_provider_secret(&state.db, &input.id) + .await + .ok(); + let api_key = match input.api_key.clone() { + Some(key) if !key.trim().is_empty() => Some(key), + Some(_) => None, + None => saved_provider.and_then(|provider| provider.api_key_encrypted), + }; + let Some(api_key) = api_key else { + return Err(AppError::Provider( + "API Key 为空,无法获取模型列表".to_string(), + )); + }; + + if input.kind == "dashscope" { + return save_provider_models(&state, input, default_dashscope_models()).await; + } + if input.kind == "tencent-hunyuan" { + return save_provider_models(&state, input, default_tencent_hunyuan_models()).await; + } + if input.kind == "google-gemini" { + return save_provider_models(&state, input, default_google_gemini_models()).await; + } + + let response = Client::new() + .get(format!("{}/models", input.base_url.trim_end_matches('/'))) + .bearer_auth(api_key) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + if input.kind == "volcengine-ark" && matches!(status.as_u16(), 404 | 405 | 501) { + let fetched_models = default_seedream_models(); + return save_provider_models(&state, input, fetched_models).await; + } + let message = response + .text() + .await + .unwrap_or_else(|_| "request failed".to_string()); + return Err(AppError::Provider(format!( + "获取模型列表失败 ({status}): {message}" + ))); + } + + let mut fetched_models: Vec = response + .json::() + .await? + .data + .into_iter() + .filter(|model| is_image_model(&model.id)) + .map(|model| ProviderModel { + id: model.id, + owned_by: model.owned_by, + }) + .collect(); + if input.kind == "volcengine-ark" && fetched_models.is_empty() { + fetched_models = default_seedream_models(); + } + + save_provider_models(&state, input, fetched_models).await +} + +async fn save_provider_models( + state: &State<'_, AppState>, + input: UpsertProviderInput, + fetched_models: Vec, +) -> Result, AppError> { + let saved_providers = repository::list_providers(&state.db) + .await + .unwrap_or_default(); + let saved_capabilities = saved_providers + .iter() + .find(|provider| provider.id == input.id) + .and_then(|provider| provider.capabilities.as_deref()) + .and_then(|value| serde_json::from_str::(value).ok()); + let mut models = saved_capabilities + .as_ref() + .and_then(|capabilities| capabilities.image_models.clone()) + .unwrap_or_default(); + models.extend(fetched_models); + models.sort_by(|left, right| left.id.cmp(&right.id)); + models.dedup_by(|left, right| left.id == right.id); + let model_ids = models + .iter() + .map(|model| model.id.clone()) + .collect::>(); + let mut selected_model_ids = saved_capabilities + .and_then(|capabilities| capabilities.selected_image_models) + .unwrap_or_default() + .into_iter() + .filter(|model_id| model_ids.contains(model_id)) + .collect::>(); + for model_id in &model_ids { + if !selected_model_ids.contains(model_id) { + selected_model_ids.push(model_id.clone()); + } + } + + let capabilities = json!({ + "responses_api": true, + "images_api": true, + "chat_completions": true, + "image_edit": true, + "image_models": models, + "selected_image_models": selected_model_ids, + }); + repository::upsert_provider( + &state.db, + UpsertProviderInput { + capabilities: Some(capabilities.to_string()), + ..input + }, + ) + .await?; + + Ok(models) +} diff --git a/src-tauri/src/db/mod.rs b/src-tauri/src/db/mod.rs index c63fbd1..b24e8d3 100644 --- a/src-tauri/src/db/mod.rs +++ b/src-tauri/src/db/mod.rs @@ -1,6 +1,6 @@ use std::fs; -use sqlx::{sqlite::SqliteConnectOptions, SqlitePool}; +use sqlx::{sqlite::SqliteConnectOptions, Row, SqlitePool}; use tauri::{AppHandle, Manager}; use crate::AppError; @@ -24,6 +24,38 @@ pub async fn init(app: &AppHandle) -> Result { sqlx::query(statement).execute(&pool).await?; } } + ensure_column( + &pool, + "providers", + "capabilities", + "TEXT NOT NULL DEFAULT '{}'", + ) + .await?; Ok(pool) } + +async fn ensure_column( + pool: &SqlitePool, + table: &str, + column: &str, + definition: &str, +) -> Result<(), AppError> { + let rows = sqlx::query(&format!("PRAGMA table_info({table})")) + .fetch_all(pool) + .await?; + let has_column = rows.iter().any(|row| { + row.try_get::("name") + .map(|name| name == column) + .unwrap_or(false) + }); + if !has_column { + sqlx::query(&format!( + "ALTER TABLE {table} ADD COLUMN {column} {definition}" + )) + .execute(pool) + .await?; + } + + Ok(()) +} diff --git a/src-tauri/src/db/models.rs b/src-tauri/src/db/models.rs index 3d5607f..e075b0b 100644 --- a/src-tauri/src/db/models.rs +++ b/src-tauri/src/db/models.rs @@ -10,6 +10,7 @@ pub struct ProviderConfig { pub api_key: Option, pub text_model: Option, pub image_model: Option, + pub capabilities: Option, pub enabled: bool, } @@ -33,6 +34,7 @@ pub struct UpsertProviderInput { pub api_key: Option, pub text_model: Option, pub image_model: Option, + pub capabilities: Option, pub enabled: bool, } diff --git a/src-tauri/src/db/repository.rs b/src-tauri/src/db/repository.rs index e57b84f..fe0051e 100644 --- a/src-tauri/src/db/repository.rs +++ b/src-tauri/src/db/repository.rs @@ -19,8 +19,10 @@ pub async fn list_providers(pool: &SqlitePool) -> Result, Ap api_key_encrypted AS api_key, text_model, image_model, + capabilities, enabled != 0 AS enabled FROM providers + WHERE enabled != 0 ORDER BY updated_at DESC "#, ) @@ -30,24 +32,34 @@ pub async fn list_providers(pool: &SqlitePool) -> Result, Ap Ok(providers) } -pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> Result<(), AppError> { +pub async fn upsert_provider( + pool: &SqlitePool, + input: UpsertProviderInput, +) -> Result<(), AppError> { let now = Utc::now().to_rfc3339(); - let existing_api_key: Option = sqlx::query_scalar( + let existing: Option<(Option, Option)> = sqlx::query_as( r#" - SELECT api_key_encrypted + SELECT api_key_encrypted, capabilities FROM providers WHERE id = ?1 "#, ) .bind(&input.id) .fetch_optional(pool) - .await? - .flatten(); - let api_key = input - .api_key - .filter(|key| !key.trim().is_empty()) - .or(existing_api_key); + .await?; + let api_key = match input.api_key { + Some(key) if key.trim().is_empty() => None, + Some(key) => Some(key), + None => existing.as_ref().and_then(|item| item.0.clone()), + }; + let capabilities = input + .capabilities + .filter(|value| !value.trim().is_empty()) + .or_else(|| existing.and_then(|item| item.1)) + .unwrap_or_else(|| { + r#"{"responses_api":true,"images_api":true,"chat_completions":true,"image_edit":true,"image_models":[],"selected_image_models":[]}"#.to_string() + }); sqlx::query( r#" @@ -62,6 +74,7 @@ pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> R api_key_encrypted = excluded.api_key_encrypted, text_model = excluded.text_model, image_model = excluded.image_model, + capabilities = excluded.capabilities, enabled = excluded.enabled, updated_at = excluded.updated_at "#, @@ -73,7 +86,7 @@ pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> R .bind(api_key) .bind(input.text_model) .bind(input.image_model) - .bind(r#"{"responses_api":true,"images_api":true,"chat_completions":true,"image_edit":true}"#) + .bind(capabilities) .bind(input.enabled) .bind(now) .execute(pool) @@ -98,6 +111,34 @@ pub async fn get_provider_secret(pool: &SqlitePool, id: &str) -> Result Result<(), AppError> { + let reference_count: i64 = sqlx::query_scalar( + r#" + SELECT + (SELECT COUNT(*) FROM generation_tasks WHERE provider_id = ?1) + + (SELECT COUNT(*) FROM conversations WHERE provider_id = ?1) + + (SELECT COUNT(*) FROM ai_request_logs WHERE provider_id = ?1) + "#, + ) + .bind(id) + .fetch_one(pool) + .await?; + + if reference_count > 0 { + let now = Utc::now().to_rfc3339(); + sqlx::query( + r#" + UPDATE providers + SET enabled = 0, updated_at = ?2 + WHERE id = ?1 + "#, + ) + .bind(id) + .bind(now) + .execute(pool) + .await?; + return Ok(()); + } + sqlx::query("DELETE FROM providers WHERE id = ?1") .bind(id) .execute(pool) diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index b7f163e..5cafa70 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -1,8 +1,8 @@ mod ai; mod commands; mod db; -mod storage; mod state; +mod storage; use serde::Serialize; use state::AppState; @@ -53,6 +53,7 @@ pub fn run() { commands::provider::list_providers, commands::provider::upsert_provider, commands::provider::delete_provider, + commands::provider::fetch_provider_models, commands::dialog::pick_material_images, commands::file::reveal_path, commands::file::open_generated_dir, diff --git a/src/main.tsx b/src/main.tsx index 245d9cd..b1a5e9f 100644 --- a/src/main.tsx +++ b/src/main.tsx @@ -1,5 +1,6 @@ import React, { useEffect, useState } from 'react'; import ReactDOM from 'react-dom/client'; +import { FolderOpenOutlined, PictureOutlined, RobotOutlined, SettingOutlined } from '@ant-design/icons'; import { convertFileSrc, invoke } from '@tauri-apps/api/core'; import appLogo from './assets/logo.svg'; import './styles.css'; @@ -12,6 +13,7 @@ type ProviderConfig = { api_key?: string | null; text_model?: string | null; image_model?: string | null; + capabilities?: string | null; enabled: boolean; }; @@ -44,15 +46,55 @@ type SessionImage = { created_at: string; }; +type ProviderModel = { + id: string; + owned_by?: string | null; +}; + +type ProviderCapabilities = { + image_models?: ProviderModel[]; + selected_image_models?: string[]; +}; + type GenerationStep = { label: string; status: 'pending' | 'active' | 'done' | 'error'; }; function formatError(error: unknown) { - if (typeof error === 'string') return error; - if (error instanceof Error) return error.message; - return JSON.stringify(error); + const message = + typeof error === 'string' ? error : error instanceof Error ? error.message : JSON.stringify(error); + if (message.includes('502 Bad Gateway') || message.includes('upstream_error')) { + return '上游模型服务返回 502。通常是中转站或模型供应商临时失败,不是本地程序错误;可以换模型、降低分辨率,或稍后/换供应商重试。'; + } + if (message.includes('503') || message.includes('504')) { + return '上游模型服务暂时不可用或超时。可以稍后重试,或切换模型/供应商。'; + } + return message; +} + +function isErrorStatus(message: string) { + return message.includes('失败') || message.includes('为空') || message.startsWith('请先'); +} + +function parseProviderCapabilities(value?: string | null): ProviderCapabilities { + if (!value) return {}; + try { + return JSON.parse(value) as ProviderCapabilities; + } catch { + return {}; + } +} + +function buildProviderCapabilities(models: ProviderModel[], selectedModels: string[]) { + return JSON.stringify({ + responses_api: true, + images_api: true, + chat_completions: true, + image_edit: true, + image_models: models, + selected_image_models: selectedModels, + }); } const defaultProviderForm: ProviderForm = { @@ -66,32 +108,207 @@ const defaultProviderForm: ProviderForm = { enabled: true, }; +const apiKindOptions = [ + { + value: 'openai', + label: 'OpenAI 官方', + sampleId: 'openai', + sampleName: 'OpenAI 官方', + baseUrl: 'https://api.openai.com/v1', + supported: true, + }, + { + value: 'openai-compatible', + label: 'OpenAI-compatible / 中转站', + sampleId: 'openai-compatible', + sampleName: 'OpenAI-compatible / 中转站', + baseUrl: 'https://api.openai.com/v1', + supported: true, + }, + { + value: 'volcengine-ark', + label: '火山方舟 / Seedream', + sampleId: 'volcengine-seedream', + sampleName: '火山方舟 / Seedream', + baseUrl: 'https://ark.cn-beijing.volces.com/api/v3', + supported: true, + }, + { + value: 'dashscope', + label: '阿里云百炼 / 通义万相', + sampleId: 'dashscope-image', + sampleName: '阿里云百炼 / 通义万相', + baseUrl: 'https://dashscope.aliyuncs.com/api/v1', + supported: true, + }, + { + value: 'tencent-hunyuan', + label: '腾讯混元图像', + sampleId: 'tencent-hunyuan-image', + sampleName: '腾讯混元图像', + baseUrl: 'https://aiart.tencentcloudapi.com', + supported: true, + }, + { + value: 'google-gemini', + label: 'Google Gemini / Nano Banana', + sampleId: 'google-nano-banana', + sampleName: 'Google Gemini / Nano Banana', + baseUrl: 'https://generativelanguage.googleapis.com/v1beta', + supported: true, + }, + { value: 'stability-ai', label: 'Stability AI(待接入)', supported: false }, + { value: 'replicate', label: 'Replicate(待接入)', supported: false }, + { value: 'fal-ai', label: 'fal.ai(待接入)', supported: false }, +]; + const initialGenerationSteps: GenerationStep[] = [ - { label: '保存配置', status: 'pending' }, + { label: '检查配置', status: 'pending' }, { label: '提交任务', status: 'pending' }, { label: '等待模型返回', status: 'pending' }, { label: '保存到应用文件夹', status: 'pending' }, { label: '更新结果列表', status: 'pending' }, ]; -const imageModelOptions = ['gpt-image-1', 'gpt-image-1.5', 'gpt-image-2']; -const imageSizeOptions = ['1024x1024', '1024x1536', '1536x1024']; +const defaultImageModelOptions: string[] = []; const imageQualityOptions = ['auto', 'high', 'medium', 'low']; +const imageAspectRatioOptions = [ + { + value: '1:1', + defaultSize: '1024x1024', + sizes: [ + { value: '1024x1024', label: '1024x1024' }, + { value: '2048x2048', label: '2048x2048' }, + { value: '4096x4096', label: '4096x4096 4K' }, + ], + }, + { + value: '1:2', + defaultSize: '1024x2048', + sizes: [ + { value: '1024x2048', label: '1024x2048' }, + { value: '1536x3072', label: '1536x3072' }, + { value: '2048x4096', label: '2048x4096 4K' }, + ], + }, + { + value: '2:1', + defaultSize: '2048x1024', + sizes: [ + { value: '2048x1024', label: '2048x1024' }, + { value: '3072x1536', label: '3072x1536' }, + { value: '4096x2048', label: '4096x2048 4K' }, + ], + }, + { + value: '9:16', + defaultSize: '1080x1920', + sizes: [ + { value: '1080x1920', label: '1080x1920' }, + { value: '1440x2560', label: '1440x2560' }, + { value: '2160x3840', label: '2160x3840 4K' }, + ], + }, + { + value: '16:9', + defaultSize: '1920x1080', + sizes: [ + { value: '1920x1080', label: '1920x1080' }, + { value: '2560x1440', label: '2560x1440' }, + { value: '3840x2160', label: '3840x2160 4K' }, + ], + }, + { + value: '3:4', + defaultSize: '1536x2048', + sizes: [ + { value: '1536x2048', label: '1536x2048' }, + { value: '2304x3072', label: '2304x3072' }, + { value: '3072x4096', label: '3072x4096 4K' }, + ], + }, + { + value: '4:3', + defaultSize: '2048x1536', + sizes: [ + { value: '2048x1536', label: '2048x1536' }, + { value: '3072x2304', label: '3072x2304' }, + { value: '4096x3072', label: '4096x3072 4K' }, + ], + }, + { + value: '名片横版', + defaultSize: '1050x600', + sizes: [ + { value: '1050x600', label: '1050x600' }, + { value: '2100x1200', label: '2100x1200' }, + { value: '3500x2000', label: '3500x2000' }, + ], + }, + { + value: '名片竖版', + defaultSize: '600x1050', + sizes: [ + { value: '600x1050', label: '600x1050' }, + { value: '1200x2100', label: '1200x2100' }, + { value: '2000x3500', label: '2000x3500' }, + ], + }, +]; + +function getAspectRatioOption(value: string) { + return imageAspectRatioOptions.find((option) => option.value === value) ?? imageAspectRatioOptions[0]; +} + +function getAspectRatioForSize(value: string) { + return imageAspectRatioOptions.find((option) => option.sizes.some((size) => size.value === value)); +} + +function apiKeyPlaceholder(kind: string) { + if (kind === 'tencent-hunyuan') return 'SecretId:SecretKey'; + if (kind === 'google-gemini') return 'Google AI Studio API Key'; + if (kind === 'dashscope') return 'sk-... 或阿里云百炼 API Key'; + return 'sk-... 或中转站 key'; +} + +function providerSettingsTip(kind: string) { + if (kind === 'tencent-hunyuan') { + return '腾讯云使用 API 3.0 签名,API Key 填 SecretId:SecretKey。'; + } + if (kind === 'dashscope') { + return '阿里云百炼 Base URL 默认即可,API Key 填 DashScope Key。'; + } + if (kind === 'google-gemini') { + return 'Google Gemini 图像模型也叫 Nano Banana,API Key 填 Google AI Studio Key。'; + } + return 'Base URL 填 API 地址,通常以 /v1 结尾。'; +} function App() { const [providers, setProviders] = useState([]); const [providerForm, setProviderForm] = useState(defaultProviderForm); const [prompt, setPrompt] = useState('一只赛博朋克风格的橘猫坐在霓虹灯下'); - const [selectedImageModel, setSelectedImageModel] = useState('gpt-image-2'); + const [selectedImageModel, setSelectedImageModel] = useState(''); + const [fetchedImageModels, setFetchedImageModels] = useState([]); + const [selectedImageModels, setSelectedImageModels] = useState(defaultImageModelOptions); + const [imageAspectRatio, setImageAspectRatio] = useState('1:1'); const [imageSize, setImageSize] = useState('1024x1024'); const [imageQuality, setImageQuality] = useState('auto'); const [status, setStatus] = useState('准备就绪'); + const [settingsStatus, setSettingsStatus] = useState(''); const [isBusy, setIsBusy] = useState(false); + const [isFetchingModels, setIsFetchingModels] = useState(false); const [isSettingsOpen, setIsSettingsOpen] = useState(false); const [previewImage, setPreviewImage] = useState(null); const [sessionImages, setSessionImages] = useState([]); const [materialPaths, setMaterialPaths] = useState([]); const [generationSteps, setGenerationSteps] = useState(initialGenerationSteps); + const selectedAspectRatioOption = getAspectRatioOption(imageAspectRatio); + const visibleImageModelOptions = + selectedImageModels.length > 0 ? selectedImageModels : defaultImageModelOptions; + const activeProviderName = + providers.find((provider) => provider.id === providerForm.id)?.name ?? providerForm.name; + const activeMode = materialPaths.length > 0 ? '图像编辑' : '文字生成'; function setStep(index: number, status: GenerationStep['status']) { setGenerationSteps((steps) => @@ -113,18 +330,114 @@ function App() { setProviderForm((current) => ({ ...current, [key]: value })); } + function updateProviderKind(kind: string) { + const option = apiKindOptions.find((item) => item.value === kind); + setFetchedImageModels([]); + setSelectedImageModels([]); + setSelectedImageModel(''); + setSettingsStatus(''); + setProviderForm((current) => ({ + ...defaultProviderForm, + api_key: '', + base_url: option?.baseUrl || '', + id: option?.sampleId || `${kind}-provider`, + kind, + name: option?.sampleName || option?.label || current.name, + })); + } + + function updateAspectRatio(value: string) { + const option = getAspectRatioOption(value); + setImageAspectRatio(option.value); + setImageSize(option.defaultSize); + } + + function updateImageSize(value: string) { + setImageSize(value); + const option = getAspectRatioForSize(value); + if (option) { + setImageAspectRatio(option.value); + } + } + + function updateSelectedModelRecords(modelId: string) { + setSelectedImageModels((current) => { + if (current.includes(modelId)) { + if (current.length === 1) return current; + const next = current.filter((id) => id !== modelId); + if (selectedImageModel === modelId) { + setSelectedImageModel(next[0] ?? ''); + } + return next; + } + + return [...current, modelId]; + }); + } + + async function fetchProviderModels() { + setIsFetchingModels(true); + setStatus('正在获取图片模型列表...'); + setSettingsStatus('正在获取图片模型列表...'); + try { + const models = await invoke('fetch_provider_models', { + input: { ...providerForm, image_model: null }, + }); + const modelIds = models.map((model) => model.id); + setFetchedImageModels(models); + if (modelIds.length === 0) { + setSelectedImageModels([]); + setSelectedImageModel(''); + setStatus('未从模型列表中识别到图片模型'); + setSettingsStatus('接口可访问,但没有识别到图片模型'); + return; + } + + setSelectedImageModels((current) => { + const kept = current.filter((model) => modelIds.includes(model)); + const next = [...kept, ...modelIds.filter((model) => !kept.includes(model))]; + setSelectedImageModel((selected) => (next.includes(selected) ? selected : (next[0] ?? ''))); + return next; + }); + setStatus(`已获取 ${modelIds.length} 个图片模型`); + setSettingsStatus(`已获取 ${modelIds.length} 个图片模型`); + } catch (error) { + setStatus(`获取模型失败:${formatError(error)}`); + setSettingsStatus(`获取模型失败:${formatError(error)}`); + } finally { + setIsFetchingModels(false); + } + } + async function refreshProviders() { const result = await invoke('list_providers'); setProviders(result); const current = result.find((provider) => provider.id === providerForm.id) ?? result[0]; if (current) { + const capabilities = parseProviderCapabilities(current.capabilities); + const storedModels = capabilities.image_models ?? []; + const storedSelectedModels = + capabilities.selected_image_models?.filter((model) => + storedModels.some((storedModel) => storedModel.id === model), + ) ?? []; + const nextSelectedModels = + storedSelectedModels.length > 0 + ? storedSelectedModels + : storedModels.length > 0 + ? storedModels.map((model) => model.id) + : defaultImageModelOptions; + setFetchedImageModels(storedModels); + setSelectedImageModels(nextSelectedModels); + setSelectedImageModel((model) => + nextSelectedModels.includes(model) ? model : (nextSelectedModels[0] ?? ''), + ); setProviderForm((form) => ({ ...form, id: current.id, name: current.name, kind: current.kind, base_url: current.base_url, - api_key: current.api_key ?? form.api_key, + api_key: current.api_key ?? '', text_model: current.text_model ?? null, image_model: null, enabled: current.enabled, @@ -133,14 +446,29 @@ function App() { } async function saveProvider() { + if (!providerForm.api_key.trim()) { + setStatus('请先填写 API Key,再保存配置'); + setSettingsStatus('请先填写 API Key,再保存配置'); + return; + } + setIsBusy(true); setStatus('正在保存配置...'); + setSettingsStatus('正在保存配置...'); try { - await invoke('upsert_provider', { input: { ...providerForm, image_model: null } }); + await invoke('upsert_provider', { + input: { + ...providerForm, + image_model: null, + capabilities: buildProviderCapabilities(fetchedImageModels, selectedImageModels), + }, + }); await refreshProviders(); setStatus('配置已保存'); + setSettingsStatus('配置已保存'); } catch (error) { setStatus(`保存失败:${formatError(error)}`); + setSettingsStatus(`保存失败:${formatError(error)}`); } finally { setIsBusy(false); } @@ -149,15 +477,21 @@ function App() { async function deleteProvider(id: string) { setIsBusy(true); setStatus('正在删除配置...'); + setSettingsStatus('正在删除配置...'); try { await invoke('delete_provider', { id }); await refreshProviders(); if (providerForm.id === id) { setProviderForm(defaultProviderForm); + setFetchedImageModels([]); + setSelectedImageModels([]); + setSelectedImageModel(''); } setStatus('配置已删除'); + setSettingsStatus('配置已删除'); } catch (error) { setStatus(`删除失败:${formatError(error)}`); + setSettingsStatus(`删除失败:${formatError(error)}`); } finally { setIsBusy(false); } @@ -170,21 +504,51 @@ function App() { name: provider.name, kind: provider.kind, base_url: provider.base_url, - api_key: provider.api_key ?? form.api_key, + api_key: provider.api_key ?? '', text_model: provider.text_model ?? null, image_model: null, enabled: provider.enabled, })); + const capabilities = parseProviderCapabilities(provider.capabilities); + const storedModels = capabilities.image_models ?? []; + const storedSelectedModels = + capabilities.selected_image_models?.filter((model) => + storedModels.some((storedModel) => storedModel.id === model), + ) ?? []; + const nextSelectedModels = + storedSelectedModels.length > 0 + ? storedSelectedModels + : storedModels.length > 0 + ? storedModels.map((model) => model.id) + : defaultImageModelOptions; + setFetchedImageModels(storedModels); + setSelectedImageModels(nextSelectedModels); + setSelectedImageModel((model) => + nextSelectedModels.includes(model) ? model : (nextSelectedModels[0] ?? ''), + ); setStatus('已切换模型配置'); + setSettingsStatus(''); } async function generateImage() { + if (!selectedImageModel) { + setStatus('请先在设置中获取并选择图像模型'); + return; + } + if (!providerForm.api_key.trim()) { + setStatus('请先在设置中填写并保存 API Key'); + return; + } + if (!providers.some((provider) => provider.id === providerForm.id)) { + setStatus('请先保存当前模型供应商配置,再开始生成'); + return; + } setIsBusy(true); setGenerationSteps(initialGenerationSteps); setStatus('正在生成图片...'); try { startStep(0); - await invoke('upsert_provider', { input: { ...providerForm, image_model: null } }); + await new Promise((resolve) => window.setTimeout(resolve, 80)); startStep(1); await new Promise((resolve) => window.setTimeout(resolve, 120)); startStep(2); @@ -264,148 +628,192 @@ function App() { return (
-
-
- Image Draw AI -
-

Image Draw AI

-

图片默认保存到应用数据文件夹

-
+
- -
-