feat: add multi-provider image generation
This commit is contained in:
11
package.json
11
package.json
@@ -16,17 +16,18 @@
|
||||
"build:win": "tauri build"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ant-design/icons": "^6.2.2",
|
||||
"@tauri-apps/api": "^2.0.0",
|
||||
"@tauri-apps/plugin-dialog": "^2.0.0",
|
||||
"@vitejs/plugin-react": "^5.0.0",
|
||||
"vite": "^7.0.0",
|
||||
"typescript": "^5.0.0",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0"
|
||||
"react-dom": "^19.0.0",
|
||||
"typescript": "^5.0.0",
|
||||
"vite": "^7.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@tauri-apps/cli": "^2.0.0",
|
||||
"@types/react": "^19.0.0",
|
||||
"@types/react-dom": "^19.0.0",
|
||||
"@tauri-apps/cli": "^2.0.0"
|
||||
"@types/react-dom": "^19.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
66
pnpm-lock.yaml
generated
66
pnpm-lock.yaml
generated
@@ -8,6 +8,9 @@ importers:
|
||||
|
||||
.:
|
||||
dependencies:
|
||||
'@ant-design/icons':
|
||||
specifier: ^6.2.2
|
||||
version: 6.2.2(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
|
||||
'@tauri-apps/api':
|
||||
specifier: ^2.0.0
|
||||
version: 2.10.1
|
||||
@@ -42,6 +45,23 @@ importers:
|
||||
|
||||
packages:
|
||||
|
||||
'@ant-design/colors@8.0.1':
|
||||
resolution: {integrity: sha512-foPVl0+SWIslGUtD/xBr1p9U4AKzPhNYEseXYRRo5QSzGACYZrQbe11AYJbYfAWnWSpGBx6JjBmSeugUsD9vqQ==}
|
||||
|
||||
'@ant-design/fast-color@3.0.1':
|
||||
resolution: {integrity: sha512-esKJegpW4nckh0o6kV3Tkb7NPIZYbPnnFxmQDUmL08ukXZAvV85TZBr70eGuke/CIArLaP6aw8lt9KILjnWuOw==}
|
||||
engines: {node: '>=8.x'}
|
||||
|
||||
'@ant-design/icons-svg@4.4.2':
|
||||
resolution: {integrity: sha512-vHbT+zJEVzllwP+CM+ul7reTEfBR0vgxFe7+lREAsAA7YGsYpboiq2sQNeQeRvh09GfQgs/GyFEvZpJ9cLXpXA==}
|
||||
|
||||
'@ant-design/icons@6.2.2':
|
||||
resolution: {integrity: sha512-zlJtE7AMbG12TeYVPhtBXwNpFInNy8mjLzcIm+0BPw16/b8ODG87YJ1G37VIF5VFscdgfsf6EweAFPTobu/3iQ==}
|
||||
engines: {node: '>=8'}
|
||||
peerDependencies:
|
||||
react: '>=16.0.0'
|
||||
react-dom: '>=16.0.0'
|
||||
|
||||
'@babel/code-frame@7.29.0':
|
||||
resolution: {integrity: sha512-9NhCeYjq9+3uxgdtp20LSiJXJvN0FeCtNGpJxuMFZ1Kv3cWUNb6DOhJwUvcVCzKGR66cw4njwM6hrJLqgOwbcw==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
@@ -297,6 +317,12 @@ packages:
|
||||
'@jridgewell/trace-mapping@0.3.31':
|
||||
resolution: {integrity: sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==}
|
||||
|
||||
'@rc-component/util@1.10.1':
|
||||
resolution: {integrity: sha512-q++9S6rUa5Idb/xIBNz6jtvumw5+O5YV5V0g4iK9mn9jWs4oGJheE3ZN1kAnE723AXyaD8v95yeOASmdk8Jnng==}
|
||||
peerDependencies:
|
||||
react: '>=18.0.0'
|
||||
react-dom: '>=18.0.0'
|
||||
|
||||
'@rolldown/pluginutils@1.0.0-rc.3':
|
||||
resolution: {integrity: sha512-eybk3TjzzzV97Dlj5c+XrBFW57eTNhzod66y9HrBlzJ6NsCrWCp/2kaPS3K9wJmurBC0Tdw4yPjXKZqlznim3Q==}
|
||||
|
||||
@@ -544,6 +570,10 @@ packages:
|
||||
caniuse-lite@1.0.30001790:
|
||||
resolution: {integrity: sha512-bOoxfJPyYo+ds6W0YfptaCWbFnJYjh2Y1Eow5lRv+vI2u8ganPZqNm1JwNh0t2ELQCqIWg4B3dWEusgAmsoyOw==}
|
||||
|
||||
clsx@2.1.1:
|
||||
resolution: {integrity: sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==}
|
||||
engines: {node: '>=6'}
|
||||
|
||||
convert-source-map@2.0.0:
|
||||
resolution: {integrity: sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==}
|
||||
|
||||
@@ -589,6 +619,9 @@ packages:
|
||||
resolution: {integrity: sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
|
||||
is-mobile@5.0.0:
|
||||
resolution: {integrity: sha512-Tz/yndySvLAEXh+Uk8liFCxOwVH6YutuR74utvOcu7I9Di+DwM0mtdPVZNaVvvBUM2OXxne/NhOs1zAO7riusQ==}
|
||||
|
||||
js-tokens@4.0.0:
|
||||
resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==}
|
||||
|
||||
@@ -632,6 +665,9 @@ packages:
|
||||
peerDependencies:
|
||||
react: ^19.2.5
|
||||
|
||||
react-is@18.3.1:
|
||||
resolution: {integrity: sha512-/LLMVyas0ljjAtoYiPqYiL8VWXzUUdThrmU5+n20DZv+a+ClRoevUzw5JxU+Ieh5/c87ytoTBV9G1FiKfNJdmg==}
|
||||
|
||||
react-refresh@0.18.0:
|
||||
resolution: {integrity: sha512-QgT5//D3jfjJb6Gsjxv0Slpj23ip+HtOpnNgnb2S5zU3CB26G/IDPGoy4RJB42wzFE46DRsstbW6tKHoKbhAxw==}
|
||||
engines: {node: '>=0.10.0'}
|
||||
@@ -716,6 +752,23 @@ packages:
|
||||
|
||||
snapshots:
|
||||
|
||||
'@ant-design/colors@8.0.1':
|
||||
dependencies:
|
||||
'@ant-design/fast-color': 3.0.1
|
||||
|
||||
'@ant-design/fast-color@3.0.1': {}
|
||||
|
||||
'@ant-design/icons-svg@4.4.2': {}
|
||||
|
||||
'@ant-design/icons@6.2.2(react-dom@19.2.5(react@19.2.5))(react@19.2.5)':
|
||||
dependencies:
|
||||
'@ant-design/colors': 8.0.1
|
||||
'@ant-design/icons-svg': 4.4.2
|
||||
'@rc-component/util': 1.10.1(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
|
||||
clsx: 2.1.1
|
||||
react: 19.2.5
|
||||
react-dom: 19.2.5(react@19.2.5)
|
||||
|
||||
'@babel/code-frame@7.29.0':
|
||||
dependencies:
|
||||
'@babel/helper-validator-identifier': 7.28.5
|
||||
@@ -925,6 +978,13 @@ snapshots:
|
||||
'@jridgewell/resolve-uri': 3.1.2
|
||||
'@jridgewell/sourcemap-codec': 1.5.5
|
||||
|
||||
'@rc-component/util@1.10.1(react-dom@19.2.5(react@19.2.5))(react@19.2.5)':
|
||||
dependencies:
|
||||
is-mobile: 5.0.0
|
||||
react: 19.2.5
|
||||
react-dom: 19.2.5(react@19.2.5)
|
||||
react-is: 18.3.1
|
||||
|
||||
'@rolldown/pluginutils@1.0.0-rc.3': {}
|
||||
|
||||
'@rollup/rollup-android-arm-eabi@4.60.2':
|
||||
@@ -1110,6 +1170,8 @@ snapshots:
|
||||
|
||||
caniuse-lite@1.0.30001790: {}
|
||||
|
||||
clsx@2.1.1: {}
|
||||
|
||||
convert-source-map@2.0.0: {}
|
||||
|
||||
csstype@3.2.3: {}
|
||||
@@ -1160,6 +1222,8 @@ snapshots:
|
||||
|
||||
gensync@1.0.0-beta.2: {}
|
||||
|
||||
is-mobile@5.0.0: {}
|
||||
|
||||
js-tokens@4.0.0: {}
|
||||
|
||||
jsesc@3.1.0: {}
|
||||
@@ -1191,6 +1255,8 @@ snapshots:
|
||||
react: 19.2.5
|
||||
scheduler: 0.27.0
|
||||
|
||||
react-is@18.3.1: {}
|
||||
|
||||
react-refresh@0.18.0: {}
|
||||
|
||||
react@19.2.5: {}
|
||||
|
||||
@@ -26,6 +26,9 @@ reqwest = { version = "0.12", default-features = false, features = ["json", "mul
|
||||
thiserror = "2"
|
||||
async-trait = "0.1"
|
||||
base64 = "0.22"
|
||||
hex = "0.4"
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
|
||||
[features]
|
||||
default = ["custom-protocol"]
|
||||
|
||||
354
src-tauri/src/ai/dashscope.rs
Normal file
354
src-tauri/src/ai/dashscope.rs
Normal file
@@ -0,0 +1,354 @@
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use reqwest::Client;
|
||||
use serde_json::{json, Value};
|
||||
use std::{fs, path::Path, time::Duration};
|
||||
use tokio::time::sleep;
|
||||
|
||||
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
|
||||
use crate::AppError;
|
||||
|
||||
pub struct DashScopeProvider {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl DashScopeProvider {
|
||||
pub fn new(base_url: String, api_key: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_qwen_image_model(model: &str) -> bool {
|
||||
model.starts_with("qwen-image")
|
||||
}
|
||||
|
||||
fn is_sync_multimodal_model(model: &str) -> bool {
|
||||
is_qwen_image_model(model) || model.starts_with("z-image")
|
||||
}
|
||||
|
||||
fn dashscope_size(size: Option<&str>) -> Option<String> {
|
||||
size.map(|value| value.replace('x', "*"))
|
||||
}
|
||||
|
||||
fn mime_for_path(path: &str) -> &'static str {
|
||||
match Path::new(path)
|
||||
.extension()
|
||||
.and_then(|extension| extension.to_str())
|
||||
.map(|extension| extension.to_ascii_lowercase())
|
||||
.as_deref()
|
||||
{
|
||||
Some("jpg" | "jpeg") => "image/jpeg",
|
||||
Some("webp") => "image/webp",
|
||||
_ => "image/png",
|
||||
}
|
||||
}
|
||||
|
||||
fn path_to_data_url(path: &str) -> Result<String, AppError> {
|
||||
let bytes = fs::read(path)?;
|
||||
Ok(format!(
|
||||
"data:{};base64,{}",
|
||||
mime_for_path(path),
|
||||
general_purpose::STANDARD.encode(bytes)
|
||||
))
|
||||
}
|
||||
|
||||
fn parse_dashscope_error(response_text: &str, fallback: &str) -> String {
|
||||
serde_json::from_str::<Value>(response_text)
|
||||
.ok()
|
||||
.and_then(|value| {
|
||||
let code = value
|
||||
.get("code")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
let message = value
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
if code.is_empty() && message.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(format!("{code}: {message}"))
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| fallback.to_string())
|
||||
}
|
||||
|
||||
fn image_url_from_choices(value: &Value) -> Option<String> {
|
||||
value
|
||||
.get("output")?
|
||||
.get("choices")?
|
||||
.as_array()?
|
||||
.iter()
|
||||
.flat_map(|choice| {
|
||||
choice
|
||||
.get("message")
|
||||
.and_then(|message| message.get("content"))
|
||||
.and_then(Value::as_array)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
})
|
||||
.find_map(|content| {
|
||||
content
|
||||
.get("image")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
}
|
||||
|
||||
fn image_url_from_results(value: &Value) -> Option<String> {
|
||||
value
|
||||
.get("output")?
|
||||
.get("results")?
|
||||
.as_array()?
|
||||
.iter()
|
||||
.find_map(|item| item.get("url").and_then(Value::as_str).map(str::to_string))
|
||||
}
|
||||
|
||||
fn parse_image_url(value: &Value) -> Option<String> {
|
||||
image_url_from_choices(value).or_else(|| image_url_from_results(value))
|
||||
}
|
||||
|
||||
fn build_messages(prompt: &str, image_paths: &[String]) -> Result<Vec<Value>, AppError> {
|
||||
let mut content = image_paths
|
||||
.iter()
|
||||
.map(|path| path_to_data_url(path).map(|image| json!({ "image": image })))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
content.push(json!({ "text": prompt }));
|
||||
|
||||
Ok(vec![json!({
|
||||
"role": "user",
|
||||
"content": content,
|
||||
})])
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AiProvider for DashScopeProvider {
|
||||
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
|
||||
if is_sync_multimodal_model(&request.model) {
|
||||
return self
|
||||
.run_qwen_image(
|
||||
&request.prompt,
|
||||
&request.model,
|
||||
request.size.as_deref(),
|
||||
&[],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
self.run_async_image_generation(
|
||||
&request.prompt,
|
||||
&request.model,
|
||||
request.size.as_deref(),
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
|
||||
if is_sync_multimodal_model(&request.model) {
|
||||
return self
|
||||
.run_qwen_image(
|
||||
&request.prompt,
|
||||
&request.model,
|
||||
request.size.as_deref(),
|
||||
&request.image_paths,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
self.run_async_image_generation(
|
||||
&request.prompt,
|
||||
&request.model,
|
||||
request.size.as_deref(),
|
||||
&request.image_paths,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl DashScopeProvider {
|
||||
async fn run_qwen_image(
|
||||
&self,
|
||||
prompt: &str,
|
||||
model: &str,
|
||||
size: Option<&str>,
|
||||
image_paths: &[String],
|
||||
) -> Result<ImageResult, AppError> {
|
||||
let mut parameters = json!({
|
||||
"n": 1,
|
||||
"watermark": false,
|
||||
"prompt_extend": true,
|
||||
});
|
||||
if is_qwen_image_model(model) {
|
||||
parameters["negative_prompt"] = json!(" ");
|
||||
}
|
||||
if let Some(size) = dashscope_size(size) {
|
||||
parameters["size"] = json!(size);
|
||||
}
|
||||
|
||||
let body = json!({
|
||||
"model": model,
|
||||
"input": {
|
||||
"messages": build_messages(prompt, image_paths)?,
|
||||
},
|
||||
"parameters": parameters,
|
||||
});
|
||||
let response = self
|
||||
.client
|
||||
.post(format!(
|
||||
"{}/services/aigc/multimodal-generation/generation",
|
||||
self.base_url
|
||||
))
|
||||
.bearer_auth(&self.api_key)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let response_text = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(AppError::Provider(format!(
|
||||
"DashScope image generation failed ({status}): {}",
|
||||
parse_dashscope_error(&response_text, &response_text)
|
||||
)));
|
||||
}
|
||||
|
||||
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode DashScope response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let image_url = parse_image_url(&response_body).ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"DashScope response did not include image url: {response_text}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: ImageData::Url(image_url),
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_async_image_generation(
|
||||
&self,
|
||||
prompt: &str,
|
||||
model: &str,
|
||||
size: Option<&str>,
|
||||
image_paths: &[String],
|
||||
) -> Result<ImageResult, AppError> {
|
||||
let mut parameters = json!({
|
||||
"n": 1,
|
||||
"watermark": false,
|
||||
});
|
||||
if let Some(size) = dashscope_size(size) {
|
||||
parameters["size"] = json!(size);
|
||||
}
|
||||
if model.starts_with("wan2.7-image") {
|
||||
parameters["thinking_mode"] = json!(true);
|
||||
} else {
|
||||
parameters["prompt_extend"] = json!(true);
|
||||
}
|
||||
|
||||
let body = json!({
|
||||
"model": model,
|
||||
"input": {
|
||||
"messages": build_messages(prompt, image_paths)?,
|
||||
},
|
||||
"parameters": parameters,
|
||||
});
|
||||
let response = self
|
||||
.client
|
||||
.post(format!(
|
||||
"{}/services/aigc/image-generation/generation",
|
||||
self.base_url
|
||||
))
|
||||
.bearer_auth(&self.api_key)
|
||||
.header("X-DashScope-Async", "enable")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let response_text = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(AppError::Provider(format!(
|
||||
"DashScope async image task failed ({status}): {}",
|
||||
parse_dashscope_error(&response_text, &response_text)
|
||||
)));
|
||||
}
|
||||
|
||||
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode DashScope task response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let task_id = response_body
|
||||
.get("output")
|
||||
.and_then(|output| output.get("task_id"))
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"DashScope task response did not include task_id: {response_text}"
|
||||
))
|
||||
})?;
|
||||
|
||||
self.wait_async_task(task_id).await
|
||||
}
|
||||
|
||||
async fn wait_async_task(&self, task_id: &str) -> Result<ImageResult, AppError> {
|
||||
for _ in 0..40 {
|
||||
sleep(Duration::from_secs(3)).await;
|
||||
let response = self
|
||||
.client
|
||||
.get(format!("{}/tasks/{task_id}", self.base_url))
|
||||
.bearer_auth(&self.api_key)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let response_text = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(AppError::Provider(format!(
|
||||
"DashScope task polling failed ({status}): {}",
|
||||
parse_dashscope_error(&response_text, &response_text)
|
||||
)));
|
||||
}
|
||||
|
||||
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode DashScope task result: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let task_status = response_body
|
||||
.get("output")
|
||||
.and_then(|output| output.get("task_status"))
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
if task_status == "SUCCEEDED" {
|
||||
let image_url = parse_image_url(&response_body).ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"DashScope task succeeded but did not include image url: {response_text}"
|
||||
))
|
||||
})?;
|
||||
return Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: ImageData::Url(image_url),
|
||||
});
|
||||
}
|
||||
if matches!(task_status, "FAILED" | "CANCELED" | "UNKNOWN") {
|
||||
return Err(AppError::Provider(format!(
|
||||
"DashScope task ended with status {task_status}: {}",
|
||||
parse_dashscope_error(&response_text, &response_text)
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Err(AppError::Provider(
|
||||
"DashScope image task polling timed out".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
189
src-tauri/src/ai/google_gemini.rs
Normal file
189
src-tauri/src/ai/google_gemini.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use reqwest::Client;
|
||||
use serde_json::{json, Value};
|
||||
use std::{fs, path::Path};
|
||||
|
||||
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
|
||||
use crate::AppError;
|
||||
|
||||
pub struct GoogleGeminiProvider {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl GoogleGeminiProvider {
|
||||
pub fn new(base_url: String, api_key: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn mime_for_path(path: &str) -> &'static str {
|
||||
match Path::new(path)
|
||||
.extension()
|
||||
.and_then(|extension| extension.to_str())
|
||||
.map(|extension| extension.to_ascii_lowercase())
|
||||
.as_deref()
|
||||
{
|
||||
Some("jpg" | "jpeg") => "image/jpeg",
|
||||
Some("webp") => "image/webp",
|
||||
_ => "image/png",
|
||||
}
|
||||
}
|
||||
|
||||
fn size_to_aspect_ratio(size: Option<&str>) -> Option<&'static str> {
|
||||
let (width, height) = size?.split_once('x')?;
|
||||
let width = width.parse::<f32>().ok()?;
|
||||
let height = height.parse::<f32>().ok()?;
|
||||
if width <= 0.0 || height <= 0.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ratio = width / height;
|
||||
if (ratio - 1.0).abs() < 0.08 {
|
||||
Some("1:1")
|
||||
} else if (ratio - 16.0 / 9.0).abs() < 0.12 {
|
||||
Some("16:9")
|
||||
} else if (ratio - 9.0 / 16.0).abs() < 0.12 {
|
||||
Some("9:16")
|
||||
} else if (ratio - 4.0 / 3.0).abs() < 0.12 {
|
||||
Some("4:3")
|
||||
} else if (ratio - 3.0 / 4.0).abs() < 0.12 {
|
||||
Some("3:4")
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn image_part(path: &str) -> Result<Value, AppError> {
|
||||
let bytes = fs::read(path)?;
|
||||
Ok(json!({
|
||||
"inline_data": {
|
||||
"mime_type": mime_for_path(path),
|
||||
"data": general_purpose::STANDARD.encode(bytes),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn inline_data_from_part(part: &Value) -> Option<(&str, &str)> {
|
||||
let inline_data = part.get("inlineData").or_else(|| part.get("inline_data"))?;
|
||||
let data = inline_data.get("data").and_then(Value::as_str)?;
|
||||
let mime_type = inline_data
|
||||
.get("mimeType")
|
||||
.or_else(|| inline_data.get("mime_type"))
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("image/png");
|
||||
Some((mime_type, data))
|
||||
}
|
||||
|
||||
fn parse_gemini_response(response_text: &str) -> Result<ImageResult, AppError> {
|
||||
let response_body = serde_json::from_str::<Value>(response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode Gemini image response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let parts = response_body
|
||||
.get("candidates")
|
||||
.and_then(Value::as_array)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.flat_map(|candidate| {
|
||||
candidate
|
||||
.get("content")
|
||||
.and_then(|content| content.get("parts"))
|
||||
.and_then(Value::as_array)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
});
|
||||
|
||||
for part in parts {
|
||||
if let Some((mime_type, data)) = inline_data_from_part(part) {
|
||||
return Ok(ImageResult {
|
||||
mime_type: mime_type.to_string(),
|
||||
data: ImageData::Base64(data.to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Err(AppError::Provider(format!(
|
||||
"Gemini response did not include image data: {response_text}"
|
||||
)))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AiProvider for GoogleGeminiProvider {
|
||||
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
|
||||
self.generate_with_parts(
|
||||
&request.prompt,
|
||||
&request.model,
|
||||
request.size.as_deref(),
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
|
||||
self.generate_with_parts(
|
||||
&request.prompt,
|
||||
&request.model,
|
||||
request.size.as_deref(),
|
||||
&request.image_paths,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl GoogleGeminiProvider {
|
||||
async fn generate_with_parts(
|
||||
&self,
|
||||
prompt: &str,
|
||||
model: &str,
|
||||
size: Option<&str>,
|
||||
image_paths: &[String],
|
||||
) -> Result<ImageResult, AppError> {
|
||||
let mut parts = vec![json!({ "text": prompt })];
|
||||
parts.extend(
|
||||
image_paths
|
||||
.iter()
|
||||
.map(|path| image_part(path))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
);
|
||||
let mut generation_config = json!({
|
||||
"responseModalities": ["TEXT", "IMAGE"],
|
||||
});
|
||||
if let Some(aspect_ratio) = size_to_aspect_ratio(size) {
|
||||
generation_config["imageConfig"] = json!({
|
||||
"aspectRatio": aspect_ratio,
|
||||
});
|
||||
}
|
||||
|
||||
let body = json!({
|
||||
"contents": [{
|
||||
"role": "user",
|
||||
"parts": parts,
|
||||
}],
|
||||
"generationConfig": generation_config,
|
||||
});
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/models/{model}:generateContent", self.base_url))
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let response_text = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(AppError::Provider(format!(
|
||||
"Gemini image generation failed ({status}): {response_text}"
|
||||
)));
|
||||
}
|
||||
|
||||
parse_gemini_response(&response_text)
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,6 @@
|
||||
pub mod dashscope;
|
||||
pub mod google_gemini;
|
||||
pub mod openai_compatible;
|
||||
pub mod provider;
|
||||
pub mod seedream;
|
||||
pub mod tencent_hunyuan;
|
||||
|
||||
@@ -32,11 +32,12 @@ fn parse_image_response(response_text: &str) -> Result<ImageResult, AppError> {
|
||||
));
|
||||
}
|
||||
|
||||
let response_body: ImageResponseBody = serde_json::from_str(response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode image response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let response_body: ImageResponseBody =
|
||||
serde_json::from_str(response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode image response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let image_data = response_body
|
||||
.data
|
||||
.into_iter()
|
||||
@@ -45,7 +46,9 @@ fn parse_image_response(response_text: &str) -> Result<ImageResult, AppError> {
|
||||
.map(ImageData::Base64)
|
||||
.or_else(|| item.url.map(ImageData::Url))
|
||||
})
|
||||
.ok_or_else(|| AppError::Provider("image response did not include b64_json or url".to_string()))?;
|
||||
.ok_or_else(|| {
|
||||
AppError::Provider("image response did not include b64_json or url".to_string())
|
||||
})?;
|
||||
|
||||
Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
@@ -96,8 +99,13 @@ impl AiProvider for OpenAiCompatibleProvider {
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let message = response.text().await.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!("image generation failed ({status}): {message}")));
|
||||
let message = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!(
|
||||
"image generation failed ({status}): {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
@@ -134,7 +142,9 @@ impl AiProvider for OpenAiCompatibleProvider {
|
||||
Some("webp") => "image/webp",
|
||||
_ => "image/png",
|
||||
};
|
||||
let part = multipart::Part::bytes(bytes).file_name(file_name).mime_str(mime)?;
|
||||
let part = multipart::Part::bytes(bytes)
|
||||
.file_name(file_name)
|
||||
.mime_str(mime)?;
|
||||
form = form.part("image[]", part);
|
||||
}
|
||||
|
||||
@@ -148,8 +158,13 @@ impl AiProvider for OpenAiCompatibleProvider {
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let message = response.text().await.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!("image edit failed ({status}): {message}")));
|
||||
let message = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!(
|
||||
"image edit failed ({status}): {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
|
||||
180
src-tauri/src/ai/seedream.rs
Normal file
180
src-tauri/src/ai/seedream.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fs, path::Path};
|
||||
|
||||
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
|
||||
use crate::AppError;
|
||||
|
||||
pub struct SeedreamProvider {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl SeedreamProvider {
|
||||
pub fn new(base_url: String, api_key: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct SeedreamRequestBody<'a> {
|
||||
model: &'a str,
|
||||
prompt: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
size: Option<&'a str>,
|
||||
response_format: &'a str,
|
||||
stream: bool,
|
||||
watermark: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
image: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SeedreamResponseBody {
|
||||
data: Vec<SeedreamResponseItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SeedreamResponseItem {
|
||||
b64_json: Option<String>,
|
||||
url: Option<String>,
|
||||
}
|
||||
|
||||
fn mime_for_path(path: &str) -> &'static str {
|
||||
match Path::new(path)
|
||||
.extension()
|
||||
.and_then(|extension| extension.to_str())
|
||||
.map(|extension| extension.to_ascii_lowercase())
|
||||
.as_deref()
|
||||
{
|
||||
Some("jpg" | "jpeg") => "image/jpeg",
|
||||
Some("webp") => "image/webp",
|
||||
_ => "image/png",
|
||||
}
|
||||
}
|
||||
|
||||
fn path_to_data_url(path: &str) -> Result<String, AppError> {
|
||||
let bytes = fs::read(path)?;
|
||||
Ok(format!(
|
||||
"data:{};base64,{}",
|
||||
mime_for_path(path),
|
||||
general_purpose::STANDARD.encode(bytes)
|
||||
))
|
||||
}
|
||||
|
||||
fn parse_seedream_response(response_text: &str) -> Result<ImageResult, AppError> {
|
||||
if response_text.trim_start().starts_with("<!doctype html")
|
||||
|| response_text.trim_start().starts_with("<html")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"火山方舟返回了 HTML 页面,不是 API JSON 响应。请检查 Base URL 是否为 API 地址。"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let response_body: SeedreamResponseBody =
|
||||
serde_json::from_str(response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode Seedream response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let image_data = response_body
|
||||
.data
|
||||
.into_iter()
|
||||
.find_map(|item| {
|
||||
item.b64_json
|
||||
.map(ImageData::Base64)
|
||||
.or_else(|| item.url.map(ImageData::Url))
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
AppError::Provider("Seedream response did not include b64_json or url".to_string())
|
||||
})?;
|
||||
|
||||
Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: image_data,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AiProvider for SeedreamProvider {
|
||||
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
|
||||
let body = SeedreamRequestBody {
|
||||
model: &request.model,
|
||||
prompt: &request.prompt,
|
||||
size: request.size.as_deref(),
|
||||
response_format: "b64_json",
|
||||
stream: false,
|
||||
watermark: false,
|
||||
image: None,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/images/generations", self.base_url))
|
||||
.bearer_auth(&self.api_key)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let message = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!(
|
||||
"Seedream image generation failed ({status}): {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
parse_seedream_response(&response_text)
|
||||
}
|
||||
|
||||
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
|
||||
let images = request
|
||||
.image_paths
|
||||
.iter()
|
||||
.map(|path| path_to_data_url(path))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let body = SeedreamRequestBody {
|
||||
model: &request.model,
|
||||
prompt: &request.prompt,
|
||||
size: request.size.as_deref(),
|
||||
response_format: "b64_json",
|
||||
stream: false,
|
||||
watermark: false,
|
||||
image: Some(images),
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/images/generations", self.base_url))
|
||||
.bearer_auth(&self.api_key)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let message = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!(
|
||||
"Seedream image edit failed ({status}): {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
parse_seedream_response(&response_text)
|
||||
}
|
||||
}
|
||||
463
src-tauri/src/ai/tencent_hunyuan.rs
Normal file
463
src-tauri/src/ai/tencent_hunyuan.rs
Normal file
@@ -0,0 +1,463 @@
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use chrono::{TimeZone, Utc};
|
||||
use hmac::{Hmac, Mac};
|
||||
use reqwest::{Client, Url};
|
||||
use serde_json::{json, Value};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{fs, time::Duration};
|
||||
use tokio::time::sleep;
|
||||
|
||||
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
|
||||
use crate::AppError;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
pub struct TencentHunyuanProvider {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
secret_id: String,
|
||||
secret_key: String,
|
||||
}
|
||||
|
||||
struct TencentApiConfig {
|
||||
service: &'static str,
|
||||
version: &'static str,
|
||||
lite_action: &'static str,
|
||||
rapid_action: &'static str,
|
||||
submit_action: &'static str,
|
||||
query_action: &'static str,
|
||||
}
|
||||
|
||||
impl TencentHunyuanProvider {
|
||||
pub fn new(base_url: String, api_key: String) -> Result<Self, AppError> {
|
||||
let (secret_id, secret_key) = parse_secret_pair(&api_key)?;
|
||||
Ok(Self {
|
||||
client: Client::new(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
secret_id,
|
||||
secret_key,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_secret_pair(api_key: &str) -> Result<(String, String), AppError> {
|
||||
let (secret_id, secret_key) = api_key.split_once(':').ok_or_else(|| {
|
||||
AppError::Provider("腾讯云 API Key 需要填写为 SecretId:SecretKey".to_string())
|
||||
})?;
|
||||
let secret_id = secret_id.trim();
|
||||
let secret_key = secret_key.trim();
|
||||
if secret_id.is_empty() || secret_key.is_empty() {
|
||||
return Err(AppError::Provider(
|
||||
"腾讯云 API Key 需要填写为 SecretId:SecretKey".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok((secret_id.to_string(), secret_key.to_string()))
|
||||
}
|
||||
|
||||
fn hmac_sha256(key: &[u8], message: &str) -> Result<Vec<u8>, AppError> {
|
||||
let mut mac = HmacSha256::new_from_slice(key)
|
||||
.map_err(|error| AppError::Provider(format!("failed to create HMAC: {error}")))?;
|
||||
mac.update(message.as_bytes());
|
||||
Ok(mac.finalize().into_bytes().to_vec())
|
||||
}
|
||||
|
||||
fn sha256_hex(value: &str) -> String {
|
||||
hex::encode(Sha256::digest(value.as_bytes()))
|
||||
}
|
||||
|
||||
fn host_from_base_url(base_url: &str) -> Result<String, AppError> {
|
||||
let url = Url::parse(base_url)
|
||||
.map_err(|error| AppError::Provider(format!("腾讯云 Base URL 不是有效 URL: {error}")))?;
|
||||
url.host_str()
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| AppError::Provider("腾讯云 Base URL 缺少 host".to_string()))
|
||||
}
|
||||
|
||||
fn api_config(host: &str) -> TencentApiConfig {
|
||||
if host == "aiart.tencentcloudapi.com" {
|
||||
TencentApiConfig {
|
||||
service: "aiart",
|
||||
version: "2022-12-29",
|
||||
lite_action: "TextToImageLite",
|
||||
rapid_action: "TextToImageRapid",
|
||||
submit_action: "SubmitTextToImageJob",
|
||||
query_action: "QueryTextToImageJob",
|
||||
}
|
||||
} else {
|
||||
TencentApiConfig {
|
||||
service: "hunyuan",
|
||||
version: "2023-09-01",
|
||||
lite_action: "TextToImageLite",
|
||||
rapid_action: "TextToImageLite",
|
||||
submit_action: "SubmitHunyuanImageJob",
|
||||
query_action: "QueryHunyuanImageJob",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn hunyuan_resolution(size: Option<&str>, lite: bool, has_reference: bool) -> Option<String> {
|
||||
let size = size?;
|
||||
let (width, height) = size.split_once('x')?;
|
||||
let width = width.parse::<f32>().ok()?;
|
||||
let height = height.parse::<f32>().ok()?;
|
||||
if width <= 0.0 || height <= 0.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ratio = width / height;
|
||||
let value = if (ratio - 1.0).abs() < 0.12 {
|
||||
"1024:1024"
|
||||
} else if (ratio - 0.75).abs() < 0.12 {
|
||||
"768:1024"
|
||||
} else if (ratio - 1.333).abs() < 0.12 {
|
||||
"1024:768"
|
||||
} else if ratio < 0.65 {
|
||||
if lite && !has_reference {
|
||||
"1080:1920"
|
||||
} else {
|
||||
"720:1280"
|
||||
}
|
||||
} else if ratio > 1.55 {
|
||||
if lite && !has_reference {
|
||||
"1920:1080"
|
||||
} else {
|
||||
"1280:720"
|
||||
}
|
||||
} else if width < height {
|
||||
"768:1024"
|
||||
} else {
|
||||
"1024:768"
|
||||
};
|
||||
|
||||
if has_reference && !matches!(value, "1024:1024" | "768:1024" | "1024:768") {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(value.to_string())
|
||||
}
|
||||
|
||||
fn first_image_base64(image_paths: &[String]) -> Result<Option<String>, AppError> {
|
||||
image_paths
|
||||
.first()
|
||||
.map(|path| fs::read(path).map(|bytes| general_purpose::STANDARD.encode(bytes)))
|
||||
.transpose()
|
||||
.map_err(AppError::from)
|
||||
}
|
||||
|
||||
fn parse_tencent_error(response: &Value) -> Option<String> {
|
||||
let error = response.get("Error")?;
|
||||
let code = error
|
||||
.get("Code")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
let message = error
|
||||
.get("Message")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
Some(format!("{code}: {message}"))
|
||||
}
|
||||
|
||||
fn result_image_url(response: &Value) -> Option<String> {
|
||||
let result = response.get("ResultImage")?;
|
||||
if let Some(url) = result.as_str().filter(|value| value.starts_with("http")) {
|
||||
return Some(url.to_string());
|
||||
}
|
||||
|
||||
result.as_array()?.first().and_then(|image| {
|
||||
image
|
||||
.as_str()
|
||||
.filter(|value| value.starts_with("http"))
|
||||
.map(str::to_string)
|
||||
.or_else(|| {
|
||||
image
|
||||
.get("Url")
|
||||
.or_else(|| image.get("URL"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn result_image_base64(response: &Value) -> Option<String> {
|
||||
let result = response.get("ResultImage")?;
|
||||
if let Some(value) = result.as_str().filter(|value| !value.starts_with("http")) {
|
||||
return Some(value.to_string());
|
||||
}
|
||||
|
||||
result.as_array()?.first().and_then(|image| {
|
||||
image
|
||||
.as_str()
|
||||
.filter(|value| !value.starts_with("http"))
|
||||
.map(str::to_string)
|
||||
.or_else(|| {
|
||||
image
|
||||
.get("Base64")
|
||||
.or_else(|| image.get("ImageBase64"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AiProvider for TencentHunyuanProvider {
|
||||
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
|
||||
if request.model == "hunyuan-image-lite" {
|
||||
return self
|
||||
.run_text_to_image_lite(&request.prompt, request.size.as_deref())
|
||||
.await;
|
||||
}
|
||||
if request.model == "hunyuan-image-2.0" {
|
||||
return self
|
||||
.run_text_to_image_rapid(&request.prompt, request.size.as_deref(), &[])
|
||||
.await;
|
||||
}
|
||||
|
||||
self.run_async_image_job(&request.prompt, request.size.as_deref(), &[])
|
||||
.await
|
||||
}
|
||||
|
||||
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
|
||||
if request.model == "hunyuan-image-2.0" {
|
||||
return self
|
||||
.run_text_to_image_rapid(
|
||||
&request.prompt,
|
||||
request.size.as_deref(),
|
||||
&request.image_paths,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
self.run_async_image_job(
|
||||
&request.prompt,
|
||||
request.size.as_deref(),
|
||||
&request.image_paths,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl TencentHunyuanProvider {
|
||||
async fn call(&self, action: &str, payload: Value) -> Result<Value, AppError> {
|
||||
let region = "ap-guangzhou";
|
||||
let host = host_from_base_url(&self.base_url)?;
|
||||
let config = api_config(&host);
|
||||
let timestamp = Utc::now().timestamp();
|
||||
let date = Utc
|
||||
.timestamp_opt(timestamp, 0)
|
||||
.single()
|
||||
.ok_or_else(|| AppError::Provider("invalid Tencent Cloud timestamp".to_string()))?
|
||||
.format("%Y-%m-%d")
|
||||
.to_string();
|
||||
let payload_text = serde_json::to_string(&payload).map_err(|error| {
|
||||
AppError::Provider(format!("failed to encode Tencent Cloud request: {error}"))
|
||||
})?;
|
||||
let content_type = "application/json; charset=utf-8";
|
||||
let signed_headers = "content-type;host";
|
||||
let canonical_headers = format!("content-type:{content_type}\nhost:{host}\n");
|
||||
let canonical_request = format!(
|
||||
"POST\n/\n\n{canonical_headers}\n{signed_headers}\n{}",
|
||||
sha256_hex(&payload_text)
|
||||
);
|
||||
let credential_scope = format!("{}/{}/tc3_request", date, config.service);
|
||||
let string_to_sign = format!(
|
||||
"TC3-HMAC-SHA256\n{timestamp}\n{credential_scope}\n{}",
|
||||
sha256_hex(&canonical_request)
|
||||
);
|
||||
let secret_date = hmac_sha256(format!("TC3{}", self.secret_key).as_bytes(), &date)?;
|
||||
let secret_service = hmac_sha256(&secret_date, config.service)?;
|
||||
let secret_signing = hmac_sha256(&secret_service, "tc3_request")?;
|
||||
let signature = hex::encode(hmac_sha256(&secret_signing, &string_to_sign)?);
|
||||
let authorization = format!(
|
||||
"TC3-HMAC-SHA256 Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
|
||||
self.secret_id
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.base_url)
|
||||
.header("Authorization", authorization)
|
||||
.header("Content-Type", content_type)
|
||||
.header("Host", host)
|
||||
.header("X-TC-Action", action)
|
||||
.header("X-TC-Timestamp", timestamp.to_string())
|
||||
.header("X-TC-Version", config.version)
|
||||
.header("X-TC-Region", region)
|
||||
.body(payload_text)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let response_text = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(AppError::Provider(format!(
|
||||
"Tencent Hunyuan request failed ({status}): {response_text}"
|
||||
)));
|
||||
}
|
||||
|
||||
let response_body = serde_json::from_str::<Value>(&response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode Tencent Hunyuan response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let inner = response_body.get("Response").ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"Tencent Hunyuan response missing Response: {response_text}"
|
||||
))
|
||||
})?;
|
||||
if let Some(message) = parse_tencent_error(inner) {
|
||||
return Err(AppError::Provider(format!(
|
||||
"Tencent Hunyuan {action} failed: {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(inner.clone())
|
||||
}
|
||||
|
||||
async fn run_text_to_image_lite(
|
||||
&self,
|
||||
prompt: &str,
|
||||
size: Option<&str>,
|
||||
) -> Result<ImageResult, AppError> {
|
||||
let mut payload = json!({
|
||||
"Prompt": prompt,
|
||||
"RspImgType": "url",
|
||||
"LogoAdd": 0,
|
||||
});
|
||||
if let Some(resolution) = hunyuan_resolution(size, true, false) {
|
||||
payload["Resolution"] = json!(resolution);
|
||||
}
|
||||
let host = host_from_base_url(&self.base_url)?;
|
||||
let action = api_config(&host).lite_action;
|
||||
let response = self.call(action, payload).await?;
|
||||
let image_url = result_image_url(&response).ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"Tencent Hunyuan TextToImageLite did not include image url: {response}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: ImageData::Url(image_url),
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_text_to_image_rapid(
|
||||
&self,
|
||||
prompt: &str,
|
||||
size: Option<&str>,
|
||||
image_paths: &[String],
|
||||
) -> Result<ImageResult, AppError> {
|
||||
let has_reference = !image_paths.is_empty();
|
||||
let mut payload = json!({
|
||||
"Prompt": prompt,
|
||||
"RspImgType": "url",
|
||||
"LogoAdd": 0,
|
||||
});
|
||||
if let Some(resolution) = hunyuan_resolution(size, false, has_reference) {
|
||||
payload["Resolution"] = json!(resolution);
|
||||
}
|
||||
if let Some(image_base64) = first_image_base64(image_paths)? {
|
||||
payload["Image"] = json!({
|
||||
"Base64": image_base64,
|
||||
});
|
||||
}
|
||||
let host = host_from_base_url(&self.base_url)?;
|
||||
let action = api_config(&host).rapid_action;
|
||||
let response = self.call(action, payload).await?;
|
||||
let image_url = result_image_url(&response).ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"Tencent Hunyuan TextToImageRapid did not include image url: {response}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: ImageData::Url(image_url),
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_async_image_job(
|
||||
&self,
|
||||
prompt: &str,
|
||||
size: Option<&str>,
|
||||
image_paths: &[String],
|
||||
) -> Result<ImageResult, AppError> {
|
||||
let has_reference = !image_paths.is_empty();
|
||||
let mut payload = json!({
|
||||
"Prompt": prompt,
|
||||
"Num": 1,
|
||||
"Revise": 1,
|
||||
"LogoAdd": 0,
|
||||
});
|
||||
if let Some(resolution) = hunyuan_resolution(size, false, has_reference) {
|
||||
payload["Resolution"] = json!(resolution);
|
||||
}
|
||||
let host = host_from_base_url(&self.base_url)?;
|
||||
let config = api_config(&host);
|
||||
if let Some(image_base64) = first_image_base64(image_paths)? {
|
||||
if config.service == "aiart" {
|
||||
payload["Images"] = json!([image_base64]);
|
||||
} else {
|
||||
payload["ContentImage"] = json!({
|
||||
"ImageBase64": image_base64,
|
||||
});
|
||||
}
|
||||
}
|
||||
let response = self.call(config.submit_action, payload).await?;
|
||||
let job_id = response
|
||||
.get("JobId")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"Tencent Hunyuan submit response did not include JobId: {response}"
|
||||
))
|
||||
})?;
|
||||
|
||||
self.wait_image_job(job_id).await
|
||||
}
|
||||
|
||||
async fn wait_image_job(&self, job_id: &str) -> Result<ImageResult, AppError> {
|
||||
for _ in 0..40 {
|
||||
sleep(Duration::from_secs(3)).await;
|
||||
let host = host_from_base_url(&self.base_url)?;
|
||||
let action = api_config(&host).query_action;
|
||||
let response = self.call(action, json!({ "JobId": job_id })).await?;
|
||||
let job_status = response
|
||||
.get("JobStatusCode")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
if job_status == "5" {
|
||||
if let Some(image_base64) = result_image_base64(&response) {
|
||||
return Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: ImageData::Base64(image_base64),
|
||||
});
|
||||
}
|
||||
let image_url = result_image_url(&response).ok_or_else(|| {
|
||||
AppError::Provider(format!(
|
||||
"Tencent Hunyuan job succeeded but did not include image: {response}"
|
||||
))
|
||||
})?;
|
||||
return Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: ImageData::Url(image_url),
|
||||
});
|
||||
}
|
||||
if job_status == "4" {
|
||||
let message = response
|
||||
.get("JobErrorMsg")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("unknown error");
|
||||
return Err(AppError::Provider(format!(
|
||||
"Tencent Hunyuan image job failed: {message}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Err(AppError::Provider(
|
||||
"Tencent Hunyuan image job polling timed out".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,11 @@ pub async fn pick_material_images(app: tauri::AppHandle) -> Result<Vec<String>,
|
||||
let paths = files
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter_map(|file_path| file_path.as_path().map(|path| path.to_string_lossy().to_string()))
|
||||
.filter_map(|file_path| {
|
||||
file_path
|
||||
.as_path()
|
||||
.map(|path| path.to_string_lossy().to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(paths)
|
||||
|
||||
@@ -2,16 +2,21 @@ use tauri::{AppHandle, State};
|
||||
|
||||
use crate::{
|
||||
ai::{
|
||||
dashscope::DashScopeProvider,
|
||||
google_gemini::GoogleGeminiProvider,
|
||||
openai_compatible::OpenAiCompatibleProvider,
|
||||
provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest},
|
||||
seedream::SeedreamProvider,
|
||||
tencent_hunyuan::TencentHunyuanProvider,
|
||||
},
|
||||
db::{
|
||||
models::{CreateGenerationTaskInput, GenerateImageInput, GenerateImageOutput, GenerationTask},
|
||||
models::{
|
||||
CreateGenerationTaskInput, GenerateImageInput, GenerateImageOutput, GenerationTask,
|
||||
},
|
||||
repository,
|
||||
},
|
||||
state::AppState,
|
||||
storage,
|
||||
AppError,
|
||||
storage, AppError,
|
||||
};
|
||||
|
||||
#[tauri::command]
|
||||
@@ -30,21 +35,66 @@ pub async fn generate_image(
|
||||
) -> Result<GenerateImageOutput, AppError> {
|
||||
let provider = repository::get_provider_secret(&state.db, &input.provider_id).await?;
|
||||
if !provider.enabled {
|
||||
return Err(AppError::Provider(format!("provider {} is disabled", provider.name)));
|
||||
return Err(AppError::Provider(format!(
|
||||
"provider {} is disabled",
|
||||
provider.name
|
||||
)));
|
||||
}
|
||||
|
||||
if provider.kind != "openai-compatible" {
|
||||
if provider.kind != "openai"
|
||||
&& provider.kind != "openai-compatible"
|
||||
&& provider.kind != "volcengine-ark"
|
||||
&& provider.kind != "dashscope"
|
||||
&& provider.kind != "tencent-hunyuan"
|
||||
&& provider.kind != "google-gemini"
|
||||
{
|
||||
return Err(AppError::Provider(format!(
|
||||
"provider kind {} is not supported yet",
|
||||
provider.kind
|
||||
)));
|
||||
}
|
||||
|
||||
if !provider.base_url.trim_end_matches('/').ends_with("/v1") {
|
||||
if (provider.kind == "openai" || provider.kind == "openai-compatible")
|
||||
&& !provider.base_url.trim_end_matches('/').ends_with("/v1")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"Base URL 看起来不是 API 地址。OpenAI-compatible 地址通常需要以 /v1 结尾,例如 https://api.openai.com/v1 或 https://你的中转站域名/v1".to_string(),
|
||||
));
|
||||
}
|
||||
if provider.kind == "volcengine-ark"
|
||||
&& !provider.base_url.trim_end_matches('/').ends_with("/api/v3")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"火山方舟 Seedream 的 Base URL 通常需要以 /api/v3 结尾,例如 https://ark.cn-beijing.volces.com/api/v3".to_string(),
|
||||
));
|
||||
}
|
||||
if provider.kind == "dashscope" && !provider.base_url.trim_end_matches('/').ends_with("/api/v1")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"阿里云百炼 DashScope 的 Base URL 通常需要以 /api/v1 结尾,例如 https://dashscope.aliyuncs.com/api/v1".to_string(),
|
||||
));
|
||||
}
|
||||
if provider.kind == "google-gemini"
|
||||
&& !provider.base_url.trim_end_matches('/').ends_with("/v1beta")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"Google Gemini / Nano Banana 的 Base URL 通常填写 https://generativelanguage.googleapis.com/v1beta".to_string(),
|
||||
));
|
||||
}
|
||||
if provider.kind == "tencent-hunyuan"
|
||||
&& !provider
|
||||
.base_url
|
||||
.trim_end_matches('/')
|
||||
.ends_with("aiart.tencentcloudapi.com")
|
||||
&& !provider
|
||||
.base_url
|
||||
.trim_end_matches('/')
|
||||
.ends_with("hunyuan.tencentcloudapi.com")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"腾讯混元图像的 Base URL 通常填写 https://aiart.tencentcloudapi.com".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let api_key = provider
|
||||
.api_key_encrypted
|
||||
@@ -76,30 +126,121 @@ pub async fn generate_image(
|
||||
)
|
||||
.await?;
|
||||
|
||||
let ai_provider = OpenAiCompatibleProvider::new(provider.base_url, api_key);
|
||||
let image_result = match if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
let image_result = match if provider.kind == "volcengine-ark" {
|
||||
let ai_provider = SeedreamProvider::new(provider.base_url, api_key);
|
||||
if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
}
|
||||
} else if provider.kind == "dashscope" {
|
||||
let ai_provider = DashScopeProvider::new(provider.base_url, api_key);
|
||||
if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
}
|
||||
} else if provider.kind == "tencent-hunyuan" {
|
||||
let ai_provider = TencentHunyuanProvider::new(provider.base_url, api_key)?;
|
||||
if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
}
|
||||
} else if provider.kind == "google-gemini" {
|
||||
let ai_provider = GoogleGeminiProvider::new(provider.base_url, api_key);
|
||||
if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
}
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
let ai_provider = OpenAiCompatibleProvider::new(provider.base_url, api_key);
|
||||
if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
}
|
||||
} {
|
||||
Ok(result) => result,
|
||||
Err(error) => {
|
||||
repository::mark_generation_task_failed(&state.db, &task.id, &error.to_string()).await?;
|
||||
repository::mark_generation_task_failed(&state.db, &task.id, &error.to_string())
|
||||
.await?;
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
@@ -108,7 +249,8 @@ pub async fn generate_image(
|
||||
ImageData::Base64(data_base64) => storage::decode_base64_image(data_base64)?,
|
||||
ImageData::Url(url) => reqwest::get(url).await?.bytes().await?.to_vec(),
|
||||
};
|
||||
let stored_image = storage::save_generated_image_bytes(&app, &image_bytes, &image_result.mime_type)?;
|
||||
let stored_image =
|
||||
storage::save_generated_image_bytes(&app, &image_bytes, &image_result.mime_type)?;
|
||||
let file_path = stored_image.file_path.to_string_lossy().to_string();
|
||||
let asset = repository::create_image_asset(
|
||||
&state.db,
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tauri::State;
|
||||
|
||||
use crate::{
|
||||
@@ -7,12 +10,29 @@ use crate::{
|
||||
};
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_providers(state: State<'_, AppState>) -> Result<Vec<crate::db::models::ProviderConfig>, AppError> {
|
||||
pub async fn list_providers(
|
||||
state: State<'_, AppState>,
|
||||
) -> Result<Vec<crate::db::models::ProviderConfig>, AppError> {
|
||||
repository::list_providers(&state.db).await
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn upsert_provider(state: State<'_, AppState>, input: UpsertProviderInput) -> Result<(), AppError> {
|
||||
pub async fn upsert_provider(
|
||||
state: State<'_, AppState>,
|
||||
input: UpsertProviderInput,
|
||||
) -> Result<(), AppError> {
|
||||
if input
|
||||
.api_key
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.unwrap_or_default()
|
||||
.is_empty()
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"请先填写 API Key,再保存配置".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
repository::upsert_provider(&state.db, input).await
|
||||
}
|
||||
|
||||
@@ -20,3 +40,306 @@ pub async fn upsert_provider(state: State<'_, AppState>, input: UpsertProviderIn
|
||||
pub async fn delete_provider(state: State<'_, AppState>, id: String) -> Result<(), AppError> {
|
||||
repository::delete_provider(&state.db, &id).await
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ProviderModel {
|
||||
pub id: String,
|
||||
pub owned_by: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ModelsResponse {
|
||||
data: Vec<ModelItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ModelItem {
|
||||
id: String,
|
||||
owned_by: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ProviderCapabilities {
|
||||
image_models: Option<Vec<ProviderModel>>,
|
||||
selected_image_models: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn is_image_model(model_id: &str) -> bool {
|
||||
let id = model_id.to_ascii_lowercase();
|
||||
[
|
||||
"gpt-image",
|
||||
"image",
|
||||
"dall-e",
|
||||
"dalle",
|
||||
"imagen",
|
||||
"flux",
|
||||
"qwen-image",
|
||||
"wan",
|
||||
"z-image",
|
||||
"hunyuan-image",
|
||||
"gemini",
|
||||
"seedream",
|
||||
"seededit",
|
||||
"stable-diffusion",
|
||||
"sd-",
|
||||
"midjourney",
|
||||
"recraft",
|
||||
]
|
||||
.iter()
|
||||
.any(|marker| id.contains(marker))
|
||||
}
|
||||
|
||||
fn default_seedream_models() -> Vec<ProviderModel> {
|
||||
vec![
|
||||
ProviderModel {
|
||||
id: "doubao-seedream-4-5-251128".to_string(),
|
||||
owned_by: Some("volcengine".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "doubao-seedream-4-0-250828".to_string(),
|
||||
owned_by: Some("volcengine".to_string()),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn default_dashscope_models() -> Vec<ProviderModel> {
|
||||
vec![
|
||||
ProviderModel {
|
||||
id: "qwen-image-2.0-pro".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "qwen-image-2.0".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "qwen-image-plus".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "qwen-image".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "wan2.7-image-pro".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "wan2.7-image".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "z-image-turbo".to_string(),
|
||||
owned_by: Some("alibaba-cloud".to_string()),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn default_tencent_hunyuan_models() -> Vec<ProviderModel> {
|
||||
vec![
|
||||
ProviderModel {
|
||||
id: "hunyuan-image-3.0".to_string(),
|
||||
owned_by: Some("tencent-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "hunyuan-image-2.0".to_string(),
|
||||
owned_by: Some("tencent-cloud".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "hunyuan-image-lite".to_string(),
|
||||
owned_by: Some("tencent-cloud".to_string()),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn default_google_gemini_models() -> Vec<ProviderModel> {
|
||||
vec![
|
||||
ProviderModel {
|
||||
id: "gemini-2.5-flash-image".to_string(),
|
||||
owned_by: Some("google".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "gemini-3.1-flash-image-preview".to_string(),
|
||||
owned_by: Some("google".to_string()),
|
||||
},
|
||||
ProviderModel {
|
||||
id: "gemini-3-pro-image-preview".to_string(),
|
||||
owned_by: Some("google".to_string()),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn fetch_provider_models(
|
||||
state: State<'_, AppState>,
|
||||
input: UpsertProviderInput,
|
||||
) -> Result<Vec<ProviderModel>, AppError> {
|
||||
if input.kind != "openai"
|
||||
&& input.kind != "openai-compatible"
|
||||
&& input.kind != "volcengine-ark"
|
||||
&& input.kind != "dashscope"
|
||||
&& input.kind != "tencent-hunyuan"
|
||||
&& input.kind != "google-gemini"
|
||||
{
|
||||
return Err(AppError::Provider(format!(
|
||||
"API 分类 {} 暂未接入模型列表获取",
|
||||
input.kind
|
||||
)));
|
||||
}
|
||||
|
||||
if (input.kind == "openai" || input.kind == "openai-compatible")
|
||||
&& !input.base_url.trim_end_matches('/').ends_with("/v1")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"Base URL 看起来不是 API 地址。OpenAI-compatible 地址通常需要以 /v1 结尾。".to_string(),
|
||||
));
|
||||
}
|
||||
if input.kind == "volcengine-ark" && !input.base_url.trim_end_matches('/').ends_with("/api/v3")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"火山方舟 Seedream 的 Base URL 通常需要以 /api/v3 结尾。".to_string(),
|
||||
));
|
||||
}
|
||||
if input.kind == "dashscope" && !input.base_url.trim_end_matches('/').ends_with("/api/v1") {
|
||||
return Err(AppError::Provider(
|
||||
"阿里云百炼 DashScope 的 Base URL 通常需要以 /api/v1 结尾。".to_string(),
|
||||
));
|
||||
}
|
||||
if input.kind == "google-gemini" && !input.base_url.trim_end_matches('/').ends_with("/v1beta") {
|
||||
return Err(AppError::Provider(
|
||||
"Google Gemini / Nano Banana 的 Base URL 通常填写 https://generativelanguage.googleapis.com/v1beta".to_string(),
|
||||
));
|
||||
}
|
||||
if input.kind == "tencent-hunyuan"
|
||||
&& !input
|
||||
.base_url
|
||||
.trim_end_matches('/')
|
||||
.ends_with("aiart.tencentcloudapi.com")
|
||||
&& !input
|
||||
.base_url
|
||||
.trim_end_matches('/')
|
||||
.ends_with("hunyuan.tencentcloudapi.com")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"腾讯混元图像的 Base URL 通常填写 https://aiart.tencentcloudapi.com".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let saved_provider = repository::get_provider_secret(&state.db, &input.id)
|
||||
.await
|
||||
.ok();
|
||||
let api_key = match input.api_key.clone() {
|
||||
Some(key) if !key.trim().is_empty() => Some(key),
|
||||
Some(_) => None,
|
||||
None => saved_provider.and_then(|provider| provider.api_key_encrypted),
|
||||
};
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(AppError::Provider(
|
||||
"API Key 为空,无法获取模型列表".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
if input.kind == "dashscope" {
|
||||
return save_provider_models(&state, input, default_dashscope_models()).await;
|
||||
}
|
||||
if input.kind == "tencent-hunyuan" {
|
||||
return save_provider_models(&state, input, default_tencent_hunyuan_models()).await;
|
||||
}
|
||||
if input.kind == "google-gemini" {
|
||||
return save_provider_models(&state, input, default_google_gemini_models()).await;
|
||||
}
|
||||
|
||||
let response = Client::new()
|
||||
.get(format!("{}/models", input.base_url.trim_end_matches('/')))
|
||||
.bearer_auth(api_key)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
if input.kind == "volcengine-ark" && matches!(status.as_u16(), 404 | 405 | 501) {
|
||||
let fetched_models = default_seedream_models();
|
||||
return save_provider_models(&state, input, fetched_models).await;
|
||||
}
|
||||
let message = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!(
|
||||
"获取模型列表失败 ({status}): {message}"
|
||||
)));
|
||||
}
|
||||
|
||||
let mut fetched_models: Vec<ProviderModel> = response
|
||||
.json::<ModelsResponse>()
|
||||
.await?
|
||||
.data
|
||||
.into_iter()
|
||||
.filter(|model| is_image_model(&model.id))
|
||||
.map(|model| ProviderModel {
|
||||
id: model.id,
|
||||
owned_by: model.owned_by,
|
||||
})
|
||||
.collect();
|
||||
if input.kind == "volcengine-ark" && fetched_models.is_empty() {
|
||||
fetched_models = default_seedream_models();
|
||||
}
|
||||
|
||||
save_provider_models(&state, input, fetched_models).await
|
||||
}
|
||||
|
||||
async fn save_provider_models(
|
||||
state: &State<'_, AppState>,
|
||||
input: UpsertProviderInput,
|
||||
fetched_models: Vec<ProviderModel>,
|
||||
) -> Result<Vec<ProviderModel>, AppError> {
|
||||
let saved_providers = repository::list_providers(&state.db)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let saved_capabilities = saved_providers
|
||||
.iter()
|
||||
.find(|provider| provider.id == input.id)
|
||||
.and_then(|provider| provider.capabilities.as_deref())
|
||||
.and_then(|value| serde_json::from_str::<ProviderCapabilities>(value).ok());
|
||||
let mut models = saved_capabilities
|
||||
.as_ref()
|
||||
.and_then(|capabilities| capabilities.image_models.clone())
|
||||
.unwrap_or_default();
|
||||
models.extend(fetched_models);
|
||||
models.sort_by(|left, right| left.id.cmp(&right.id));
|
||||
models.dedup_by(|left, right| left.id == right.id);
|
||||
let model_ids = models
|
||||
.iter()
|
||||
.map(|model| model.id.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let mut selected_model_ids = saved_capabilities
|
||||
.and_then(|capabilities| capabilities.selected_image_models)
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter(|model_id| model_ids.contains(model_id))
|
||||
.collect::<Vec<_>>();
|
||||
for model_id in &model_ids {
|
||||
if !selected_model_ids.contains(model_id) {
|
||||
selected_model_ids.push(model_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let capabilities = json!({
|
||||
"responses_api": true,
|
||||
"images_api": true,
|
||||
"chat_completions": true,
|
||||
"image_edit": true,
|
||||
"image_models": models,
|
||||
"selected_image_models": selected_model_ids,
|
||||
});
|
||||
repository::upsert_provider(
|
||||
&state.db,
|
||||
UpsertProviderInput {
|
||||
capabilities: Some(capabilities.to_string()),
|
||||
..input
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::fs;
|
||||
|
||||
use sqlx::{sqlite::SqliteConnectOptions, SqlitePool};
|
||||
use sqlx::{sqlite::SqliteConnectOptions, Row, SqlitePool};
|
||||
use tauri::{AppHandle, Manager};
|
||||
|
||||
use crate::AppError;
|
||||
@@ -24,6 +24,38 @@ pub async fn init(app: &AppHandle) -> Result<SqlitePool, AppError> {
|
||||
sqlx::query(statement).execute(&pool).await?;
|
||||
}
|
||||
}
|
||||
ensure_column(
|
||||
&pool,
|
||||
"providers",
|
||||
"capabilities",
|
||||
"TEXT NOT NULL DEFAULT '{}'",
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
async fn ensure_column(
|
||||
pool: &SqlitePool,
|
||||
table: &str,
|
||||
column: &str,
|
||||
definition: &str,
|
||||
) -> Result<(), AppError> {
|
||||
let rows = sqlx::query(&format!("PRAGMA table_info({table})"))
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
let has_column = rows.iter().any(|row| {
|
||||
row.try_get::<String, _>("name")
|
||||
.map(|name| name == column)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
if !has_column {
|
||||
sqlx::query(&format!(
|
||||
"ALTER TABLE {table} ADD COLUMN {column} {definition}"
|
||||
))
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ pub struct ProviderConfig {
|
||||
pub api_key: Option<String>,
|
||||
pub text_model: Option<String>,
|
||||
pub image_model: Option<String>,
|
||||
pub capabilities: Option<String>,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
@@ -33,6 +34,7 @@ pub struct UpsertProviderInput {
|
||||
pub api_key: Option<String>,
|
||||
pub text_model: Option<String>,
|
||||
pub image_model: Option<String>,
|
||||
pub capabilities: Option<String>,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
|
||||
@@ -19,8 +19,10 @@ pub async fn list_providers(pool: &SqlitePool) -> Result<Vec<ProviderConfig>, Ap
|
||||
api_key_encrypted AS api_key,
|
||||
text_model,
|
||||
image_model,
|
||||
capabilities,
|
||||
enabled != 0 AS enabled
|
||||
FROM providers
|
||||
WHERE enabled != 0
|
||||
ORDER BY updated_at DESC
|
||||
"#,
|
||||
)
|
||||
@@ -30,24 +32,34 @@ pub async fn list_providers(pool: &SqlitePool) -> Result<Vec<ProviderConfig>, Ap
|
||||
Ok(providers)
|
||||
}
|
||||
|
||||
pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> Result<(), AppError> {
|
||||
pub async fn upsert_provider(
|
||||
pool: &SqlitePool,
|
||||
input: UpsertProviderInput,
|
||||
) -> Result<(), AppError> {
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
let existing_api_key: Option<String> = sqlx::query_scalar(
|
||||
let existing: Option<(Option<String>, Option<String>)> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT api_key_encrypted
|
||||
SELECT api_key_encrypted, capabilities
|
||||
FROM providers
|
||||
WHERE id = ?1
|
||||
"#,
|
||||
)
|
||||
.bind(&input.id)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
.flatten();
|
||||
let api_key = input
|
||||
.api_key
|
||||
.filter(|key| !key.trim().is_empty())
|
||||
.or(existing_api_key);
|
||||
.await?;
|
||||
let api_key = match input.api_key {
|
||||
Some(key) if key.trim().is_empty() => None,
|
||||
Some(key) => Some(key),
|
||||
None => existing.as_ref().and_then(|item| item.0.clone()),
|
||||
};
|
||||
let capabilities = input
|
||||
.capabilities
|
||||
.filter(|value| !value.trim().is_empty())
|
||||
.or_else(|| existing.and_then(|item| item.1))
|
||||
.unwrap_or_else(|| {
|
||||
r#"{"responses_api":true,"images_api":true,"chat_completions":true,"image_edit":true,"image_models":[],"selected_image_models":[]}"#.to_string()
|
||||
});
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
@@ -62,6 +74,7 @@ pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> R
|
||||
api_key_encrypted = excluded.api_key_encrypted,
|
||||
text_model = excluded.text_model,
|
||||
image_model = excluded.image_model,
|
||||
capabilities = excluded.capabilities,
|
||||
enabled = excluded.enabled,
|
||||
updated_at = excluded.updated_at
|
||||
"#,
|
||||
@@ -73,7 +86,7 @@ pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> R
|
||||
.bind(api_key)
|
||||
.bind(input.text_model)
|
||||
.bind(input.image_model)
|
||||
.bind(r#"{"responses_api":true,"images_api":true,"chat_completions":true,"image_edit":true}"#)
|
||||
.bind(capabilities)
|
||||
.bind(input.enabled)
|
||||
.bind(now)
|
||||
.execute(pool)
|
||||
@@ -98,6 +111,34 @@ pub async fn get_provider_secret(pool: &SqlitePool, id: &str) -> Result<Provider
|
||||
}
|
||||
|
||||
pub async fn delete_provider(pool: &SqlitePool, id: &str) -> Result<(), AppError> {
|
||||
let reference_count: i64 = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT
|
||||
(SELECT COUNT(*) FROM generation_tasks WHERE provider_id = ?1) +
|
||||
(SELECT COUNT(*) FROM conversations WHERE provider_id = ?1) +
|
||||
(SELECT COUNT(*) FROM ai_request_logs WHERE provider_id = ?1)
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if reference_count > 0 {
|
||||
let now = Utc::now().to_rfc3339();
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE providers
|
||||
SET enabled = 0, updated_at = ?2
|
||||
WHERE id = ?1
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.bind(now)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query("DELETE FROM providers WHERE id = ?1")
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
mod ai;
|
||||
mod commands;
|
||||
mod db;
|
||||
mod storage;
|
||||
mod state;
|
||||
mod storage;
|
||||
|
||||
use serde::Serialize;
|
||||
use state::AppState;
|
||||
@@ -53,6 +53,7 @@ pub fn run() {
|
||||
commands::provider::list_providers,
|
||||
commands::provider::upsert_provider,
|
||||
commands::provider::delete_provider,
|
||||
commands::provider::fetch_provider_models,
|
||||
commands::dialog::pick_material_images,
|
||||
commands::file::reveal_path,
|
||||
commands::file::open_generated_dir,
|
||||
|
||||
790
src/main.tsx
790
src/main.tsx
@@ -1,5 +1,6 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import ReactDOM from 'react-dom/client';
|
||||
import { FolderOpenOutlined, PictureOutlined, RobotOutlined, SettingOutlined } from '@ant-design/icons';
|
||||
import { convertFileSrc, invoke } from '@tauri-apps/api/core';
|
||||
import appLogo from './assets/logo.svg';
|
||||
import './styles.css';
|
||||
@@ -12,6 +13,7 @@ type ProviderConfig = {
|
||||
api_key?: string | null;
|
||||
text_model?: string | null;
|
||||
image_model?: string | null;
|
||||
capabilities?: string | null;
|
||||
enabled: boolean;
|
||||
};
|
||||
|
||||
@@ -44,15 +46,55 @@ type SessionImage = {
|
||||
created_at: string;
|
||||
};
|
||||
|
||||
type ProviderModel = {
|
||||
id: string;
|
||||
owned_by?: string | null;
|
||||
};
|
||||
|
||||
type ProviderCapabilities = {
|
||||
image_models?: ProviderModel[];
|
||||
selected_image_models?: string[];
|
||||
};
|
||||
|
||||
type GenerationStep = {
|
||||
label: string;
|
||||
status: 'pending' | 'active' | 'done' | 'error';
|
||||
};
|
||||
|
||||
function formatError(error: unknown) {
|
||||
if (typeof error === 'string') return error;
|
||||
if (error instanceof Error) return error.message;
|
||||
return JSON.stringify(error);
|
||||
const message =
|
||||
typeof error === 'string' ? error : error instanceof Error ? error.message : JSON.stringify(error);
|
||||
if (message.includes('502 Bad Gateway') || message.includes('upstream_error')) {
|
||||
return '上游模型服务返回 502。通常是中转站或模型供应商临时失败,不是本地程序错误;可以换模型、降低分辨率,或稍后/换供应商重试。';
|
||||
}
|
||||
if (message.includes('503') || message.includes('504')) {
|
||||
return '上游模型服务暂时不可用或超时。可以稍后重试,或切换模型/供应商。';
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
function isErrorStatus(message: string) {
|
||||
return message.includes('失败') || message.includes('为空') || message.startsWith('请先');
|
||||
}
|
||||
|
||||
function parseProviderCapabilities(value?: string | null): ProviderCapabilities {
|
||||
if (!value) return {};
|
||||
try {
|
||||
return JSON.parse(value) as ProviderCapabilities;
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function buildProviderCapabilities(models: ProviderModel[], selectedModels: string[]) {
|
||||
return JSON.stringify({
|
||||
responses_api: true,
|
||||
images_api: true,
|
||||
chat_completions: true,
|
||||
image_edit: true,
|
||||
image_models: models,
|
||||
selected_image_models: selectedModels,
|
||||
});
|
||||
}
|
||||
|
||||
const defaultProviderForm: ProviderForm = {
|
||||
@@ -66,32 +108,207 @@ const defaultProviderForm: ProviderForm = {
|
||||
enabled: true,
|
||||
};
|
||||
|
||||
const apiKindOptions = [
|
||||
{
|
||||
value: 'openai',
|
||||
label: 'OpenAI 官方',
|
||||
sampleId: 'openai',
|
||||
sampleName: 'OpenAI 官方',
|
||||
baseUrl: 'https://api.openai.com/v1',
|
||||
supported: true,
|
||||
},
|
||||
{
|
||||
value: 'openai-compatible',
|
||||
label: 'OpenAI-compatible / 中转站',
|
||||
sampleId: 'openai-compatible',
|
||||
sampleName: 'OpenAI-compatible / 中转站',
|
||||
baseUrl: 'https://api.openai.com/v1',
|
||||
supported: true,
|
||||
},
|
||||
{
|
||||
value: 'volcengine-ark',
|
||||
label: '火山方舟 / Seedream',
|
||||
sampleId: 'volcengine-seedream',
|
||||
sampleName: '火山方舟 / Seedream',
|
||||
baseUrl: 'https://ark.cn-beijing.volces.com/api/v3',
|
||||
supported: true,
|
||||
},
|
||||
{
|
||||
value: 'dashscope',
|
||||
label: '阿里云百炼 / 通义万相',
|
||||
sampleId: 'dashscope-image',
|
||||
sampleName: '阿里云百炼 / 通义万相',
|
||||
baseUrl: 'https://dashscope.aliyuncs.com/api/v1',
|
||||
supported: true,
|
||||
},
|
||||
{
|
||||
value: 'tencent-hunyuan',
|
||||
label: '腾讯混元图像',
|
||||
sampleId: 'tencent-hunyuan-image',
|
||||
sampleName: '腾讯混元图像',
|
||||
baseUrl: 'https://aiart.tencentcloudapi.com',
|
||||
supported: true,
|
||||
},
|
||||
{
|
||||
value: 'google-gemini',
|
||||
label: 'Google Gemini / Nano Banana',
|
||||
sampleId: 'google-nano-banana',
|
||||
sampleName: 'Google Gemini / Nano Banana',
|
||||
baseUrl: 'https://generativelanguage.googleapis.com/v1beta',
|
||||
supported: true,
|
||||
},
|
||||
{ value: 'stability-ai', label: 'Stability AI(待接入)', supported: false },
|
||||
{ value: 'replicate', label: 'Replicate(待接入)', supported: false },
|
||||
{ value: 'fal-ai', label: 'fal.ai(待接入)', supported: false },
|
||||
];
|
||||
|
||||
const initialGenerationSteps: GenerationStep[] = [
|
||||
{ label: '保存配置', status: 'pending' },
|
||||
{ label: '检查配置', status: 'pending' },
|
||||
{ label: '提交任务', status: 'pending' },
|
||||
{ label: '等待模型返回', status: 'pending' },
|
||||
{ label: '保存到应用文件夹', status: 'pending' },
|
||||
{ label: '更新结果列表', status: 'pending' },
|
||||
];
|
||||
|
||||
const imageModelOptions = ['gpt-image-1', 'gpt-image-1.5', 'gpt-image-2'];
|
||||
const imageSizeOptions = ['1024x1024', '1024x1536', '1536x1024'];
|
||||
const defaultImageModelOptions: string[] = [];
|
||||
const imageQualityOptions = ['auto', 'high', 'medium', 'low'];
|
||||
const imageAspectRatioOptions = [
|
||||
{
|
||||
value: '1:1',
|
||||
defaultSize: '1024x1024',
|
||||
sizes: [
|
||||
{ value: '1024x1024', label: '1024x1024' },
|
||||
{ value: '2048x2048', label: '2048x2048' },
|
||||
{ value: '4096x4096', label: '4096x4096 4K' },
|
||||
],
|
||||
},
|
||||
{
|
||||
value: '1:2',
|
||||
defaultSize: '1024x2048',
|
||||
sizes: [
|
||||
{ value: '1024x2048', label: '1024x2048' },
|
||||
{ value: '1536x3072', label: '1536x3072' },
|
||||
{ value: '2048x4096', label: '2048x4096 4K' },
|
||||
],
|
||||
},
|
||||
{
|
||||
value: '2:1',
|
||||
defaultSize: '2048x1024',
|
||||
sizes: [
|
||||
{ value: '2048x1024', label: '2048x1024' },
|
||||
{ value: '3072x1536', label: '3072x1536' },
|
||||
{ value: '4096x2048', label: '4096x2048 4K' },
|
||||
],
|
||||
},
|
||||
{
|
||||
value: '9:16',
|
||||
defaultSize: '1080x1920',
|
||||
sizes: [
|
||||
{ value: '1080x1920', label: '1080x1920' },
|
||||
{ value: '1440x2560', label: '1440x2560' },
|
||||
{ value: '2160x3840', label: '2160x3840 4K' },
|
||||
],
|
||||
},
|
||||
{
|
||||
value: '16:9',
|
||||
defaultSize: '1920x1080',
|
||||
sizes: [
|
||||
{ value: '1920x1080', label: '1920x1080' },
|
||||
{ value: '2560x1440', label: '2560x1440' },
|
||||
{ value: '3840x2160', label: '3840x2160 4K' },
|
||||
],
|
||||
},
|
||||
{
|
||||
value: '3:4',
|
||||
defaultSize: '1536x2048',
|
||||
sizes: [
|
||||
{ value: '1536x2048', label: '1536x2048' },
|
||||
{ value: '2304x3072', label: '2304x3072' },
|
||||
{ value: '3072x4096', label: '3072x4096 4K' },
|
||||
],
|
||||
},
|
||||
{
|
||||
value: '4:3',
|
||||
defaultSize: '2048x1536',
|
||||
sizes: [
|
||||
{ value: '2048x1536', label: '2048x1536' },
|
||||
{ value: '3072x2304', label: '3072x2304' },
|
||||
{ value: '4096x3072', label: '4096x3072 4K' },
|
||||
],
|
||||
},
|
||||
{
|
||||
value: '名片横版',
|
||||
defaultSize: '1050x600',
|
||||
sizes: [
|
||||
{ value: '1050x600', label: '1050x600' },
|
||||
{ value: '2100x1200', label: '2100x1200' },
|
||||
{ value: '3500x2000', label: '3500x2000' },
|
||||
],
|
||||
},
|
||||
{
|
||||
value: '名片竖版',
|
||||
defaultSize: '600x1050',
|
||||
sizes: [
|
||||
{ value: '600x1050', label: '600x1050' },
|
||||
{ value: '1200x2100', label: '1200x2100' },
|
||||
{ value: '2000x3500', label: '2000x3500' },
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
function getAspectRatioOption(value: string) {
|
||||
return imageAspectRatioOptions.find((option) => option.value === value) ?? imageAspectRatioOptions[0];
|
||||
}
|
||||
|
||||
function getAspectRatioForSize(value: string) {
|
||||
return imageAspectRatioOptions.find((option) => option.sizes.some((size) => size.value === value));
|
||||
}
|
||||
|
||||
function apiKeyPlaceholder(kind: string) {
|
||||
if (kind === 'tencent-hunyuan') return 'SecretId:SecretKey';
|
||||
if (kind === 'google-gemini') return 'Google AI Studio API Key';
|
||||
if (kind === 'dashscope') return 'sk-... 或阿里云百炼 API Key';
|
||||
return 'sk-... 或中转站 key';
|
||||
}
|
||||
|
||||
function providerSettingsTip(kind: string) {
|
||||
if (kind === 'tencent-hunyuan') {
|
||||
return '腾讯云使用 API 3.0 签名,API Key 填 SecretId:SecretKey。';
|
||||
}
|
||||
if (kind === 'dashscope') {
|
||||
return '阿里云百炼 Base URL 默认即可,API Key 填 DashScope Key。';
|
||||
}
|
||||
if (kind === 'google-gemini') {
|
||||
return 'Google Gemini 图像模型也叫 Nano Banana,API Key 填 Google AI Studio Key。';
|
||||
}
|
||||
return 'Base URL 填 API 地址,通常以 /v1 结尾。';
|
||||
}
|
||||
|
||||
function App() {
|
||||
const [providers, setProviders] = useState<ProviderConfig[]>([]);
|
||||
const [providerForm, setProviderForm] = useState<ProviderForm>(defaultProviderForm);
|
||||
const [prompt, setPrompt] = useState('一只赛博朋克风格的橘猫坐在霓虹灯下');
|
||||
const [selectedImageModel, setSelectedImageModel] = useState('gpt-image-2');
|
||||
const [selectedImageModel, setSelectedImageModel] = useState('');
|
||||
const [fetchedImageModels, setFetchedImageModels] = useState<ProviderModel[]>([]);
|
||||
const [selectedImageModels, setSelectedImageModels] = useState<string[]>(defaultImageModelOptions);
|
||||
const [imageAspectRatio, setImageAspectRatio] = useState('1:1');
|
||||
const [imageSize, setImageSize] = useState('1024x1024');
|
||||
const [imageQuality, setImageQuality] = useState('auto');
|
||||
const [status, setStatus] = useState('准备就绪');
|
||||
const [settingsStatus, setSettingsStatus] = useState('');
|
||||
const [isBusy, setIsBusy] = useState(false);
|
||||
const [isFetchingModels, setIsFetchingModels] = useState(false);
|
||||
const [isSettingsOpen, setIsSettingsOpen] = useState(false);
|
||||
const [previewImage, setPreviewImage] = useState<SessionImage | null>(null);
|
||||
const [sessionImages, setSessionImages] = useState<SessionImage[]>([]);
|
||||
const [materialPaths, setMaterialPaths] = useState<string[]>([]);
|
||||
const [generationSteps, setGenerationSteps] = useState<GenerationStep[]>(initialGenerationSteps);
|
||||
const selectedAspectRatioOption = getAspectRatioOption(imageAspectRatio);
|
||||
const visibleImageModelOptions =
|
||||
selectedImageModels.length > 0 ? selectedImageModels : defaultImageModelOptions;
|
||||
const activeProviderName =
|
||||
providers.find((provider) => provider.id === providerForm.id)?.name ?? providerForm.name;
|
||||
const activeMode = materialPaths.length > 0 ? '图像编辑' : '文字生成';
|
||||
|
||||
function setStep(index: number, status: GenerationStep['status']) {
|
||||
setGenerationSteps((steps) =>
|
||||
@@ -113,18 +330,114 @@ function App() {
|
||||
setProviderForm((current) => ({ ...current, [key]: value }));
|
||||
}
|
||||
|
||||
function updateProviderKind(kind: string) {
|
||||
const option = apiKindOptions.find((item) => item.value === kind);
|
||||
setFetchedImageModels([]);
|
||||
setSelectedImageModels([]);
|
||||
setSelectedImageModel('');
|
||||
setSettingsStatus('');
|
||||
setProviderForm((current) => ({
|
||||
...defaultProviderForm,
|
||||
api_key: '',
|
||||
base_url: option?.baseUrl || '',
|
||||
id: option?.sampleId || `${kind}-provider`,
|
||||
kind,
|
||||
name: option?.sampleName || option?.label || current.name,
|
||||
}));
|
||||
}
|
||||
|
||||
function updateAspectRatio(value: string) {
|
||||
const option = getAspectRatioOption(value);
|
||||
setImageAspectRatio(option.value);
|
||||
setImageSize(option.defaultSize);
|
||||
}
|
||||
|
||||
function updateImageSize(value: string) {
|
||||
setImageSize(value);
|
||||
const option = getAspectRatioForSize(value);
|
||||
if (option) {
|
||||
setImageAspectRatio(option.value);
|
||||
}
|
||||
}
|
||||
|
||||
function updateSelectedModelRecords(modelId: string) {
|
||||
setSelectedImageModels((current) => {
|
||||
if (current.includes(modelId)) {
|
||||
if (current.length === 1) return current;
|
||||
const next = current.filter((id) => id !== modelId);
|
||||
if (selectedImageModel === modelId) {
|
||||
setSelectedImageModel(next[0] ?? '');
|
||||
}
|
||||
return next;
|
||||
}
|
||||
|
||||
return [...current, modelId];
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchProviderModels() {
|
||||
setIsFetchingModels(true);
|
||||
setStatus('正在获取图片模型列表...');
|
||||
setSettingsStatus('正在获取图片模型列表...');
|
||||
try {
|
||||
const models = await invoke<ProviderModel[]>('fetch_provider_models', {
|
||||
input: { ...providerForm, image_model: null },
|
||||
});
|
||||
const modelIds = models.map((model) => model.id);
|
||||
setFetchedImageModels(models);
|
||||
if (modelIds.length === 0) {
|
||||
setSelectedImageModels([]);
|
||||
setSelectedImageModel('');
|
||||
setStatus('未从模型列表中识别到图片模型');
|
||||
setSettingsStatus('接口可访问,但没有识别到图片模型');
|
||||
return;
|
||||
}
|
||||
|
||||
setSelectedImageModels((current) => {
|
||||
const kept = current.filter((model) => modelIds.includes(model));
|
||||
const next = [...kept, ...modelIds.filter((model) => !kept.includes(model))];
|
||||
setSelectedImageModel((selected) => (next.includes(selected) ? selected : (next[0] ?? '')));
|
||||
return next;
|
||||
});
|
||||
setStatus(`已获取 ${modelIds.length} 个图片模型`);
|
||||
setSettingsStatus(`已获取 ${modelIds.length} 个图片模型`);
|
||||
} catch (error) {
|
||||
setStatus(`获取模型失败:${formatError(error)}`);
|
||||
setSettingsStatus(`获取模型失败:${formatError(error)}`);
|
||||
} finally {
|
||||
setIsFetchingModels(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function refreshProviders() {
|
||||
const result = await invoke<ProviderConfig[]>('list_providers');
|
||||
setProviders(result);
|
||||
const current = result.find((provider) => provider.id === providerForm.id) ?? result[0];
|
||||
if (current) {
|
||||
const capabilities = parseProviderCapabilities(current.capabilities);
|
||||
const storedModels = capabilities.image_models ?? [];
|
||||
const storedSelectedModels =
|
||||
capabilities.selected_image_models?.filter((model) =>
|
||||
storedModels.some((storedModel) => storedModel.id === model),
|
||||
) ?? [];
|
||||
const nextSelectedModels =
|
||||
storedSelectedModels.length > 0
|
||||
? storedSelectedModels
|
||||
: storedModels.length > 0
|
||||
? storedModels.map((model) => model.id)
|
||||
: defaultImageModelOptions;
|
||||
setFetchedImageModels(storedModels);
|
||||
setSelectedImageModels(nextSelectedModels);
|
||||
setSelectedImageModel((model) =>
|
||||
nextSelectedModels.includes(model) ? model : (nextSelectedModels[0] ?? ''),
|
||||
);
|
||||
setProviderForm((form) => ({
|
||||
...form,
|
||||
id: current.id,
|
||||
name: current.name,
|
||||
kind: current.kind,
|
||||
base_url: current.base_url,
|
||||
api_key: current.api_key ?? form.api_key,
|
||||
api_key: current.api_key ?? '',
|
||||
text_model: current.text_model ?? null,
|
||||
image_model: null,
|
||||
enabled: current.enabled,
|
||||
@@ -133,14 +446,29 @@ function App() {
|
||||
}
|
||||
|
||||
async function saveProvider() {
|
||||
if (!providerForm.api_key.trim()) {
|
||||
setStatus('请先填写 API Key,再保存配置');
|
||||
setSettingsStatus('请先填写 API Key,再保存配置');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsBusy(true);
|
||||
setStatus('正在保存配置...');
|
||||
setSettingsStatus('正在保存配置...');
|
||||
try {
|
||||
await invoke('upsert_provider', { input: { ...providerForm, image_model: null } });
|
||||
await invoke('upsert_provider', {
|
||||
input: {
|
||||
...providerForm,
|
||||
image_model: null,
|
||||
capabilities: buildProviderCapabilities(fetchedImageModels, selectedImageModels),
|
||||
},
|
||||
});
|
||||
await refreshProviders();
|
||||
setStatus('配置已保存');
|
||||
setSettingsStatus('配置已保存');
|
||||
} catch (error) {
|
||||
setStatus(`保存失败:${formatError(error)}`);
|
||||
setSettingsStatus(`保存失败:${formatError(error)}`);
|
||||
} finally {
|
||||
setIsBusy(false);
|
||||
}
|
||||
@@ -149,15 +477,21 @@ function App() {
|
||||
async function deleteProvider(id: string) {
|
||||
setIsBusy(true);
|
||||
setStatus('正在删除配置...');
|
||||
setSettingsStatus('正在删除配置...');
|
||||
try {
|
||||
await invoke('delete_provider', { id });
|
||||
await refreshProviders();
|
||||
if (providerForm.id === id) {
|
||||
setProviderForm(defaultProviderForm);
|
||||
setFetchedImageModels([]);
|
||||
setSelectedImageModels([]);
|
||||
setSelectedImageModel('');
|
||||
}
|
||||
setStatus('配置已删除');
|
||||
setSettingsStatus('配置已删除');
|
||||
} catch (error) {
|
||||
setStatus(`删除失败:${formatError(error)}`);
|
||||
setSettingsStatus(`删除失败:${formatError(error)}`);
|
||||
} finally {
|
||||
setIsBusy(false);
|
||||
}
|
||||
@@ -170,21 +504,51 @@ function App() {
|
||||
name: provider.name,
|
||||
kind: provider.kind,
|
||||
base_url: provider.base_url,
|
||||
api_key: provider.api_key ?? form.api_key,
|
||||
api_key: provider.api_key ?? '',
|
||||
text_model: provider.text_model ?? null,
|
||||
image_model: null,
|
||||
enabled: provider.enabled,
|
||||
}));
|
||||
const capabilities = parseProviderCapabilities(provider.capabilities);
|
||||
const storedModels = capabilities.image_models ?? [];
|
||||
const storedSelectedModels =
|
||||
capabilities.selected_image_models?.filter((model) =>
|
||||
storedModels.some((storedModel) => storedModel.id === model),
|
||||
) ?? [];
|
||||
const nextSelectedModels =
|
||||
storedSelectedModels.length > 0
|
||||
? storedSelectedModels
|
||||
: storedModels.length > 0
|
||||
? storedModels.map((model) => model.id)
|
||||
: defaultImageModelOptions;
|
||||
setFetchedImageModels(storedModels);
|
||||
setSelectedImageModels(nextSelectedModels);
|
||||
setSelectedImageModel((model) =>
|
||||
nextSelectedModels.includes(model) ? model : (nextSelectedModels[0] ?? ''),
|
||||
);
|
||||
setStatus('已切换模型配置');
|
||||
setSettingsStatus('');
|
||||
}
|
||||
|
||||
async function generateImage() {
|
||||
if (!selectedImageModel) {
|
||||
setStatus('请先在设置中获取并选择图像模型');
|
||||
return;
|
||||
}
|
||||
if (!providerForm.api_key.trim()) {
|
||||
setStatus('请先在设置中填写并保存 API Key');
|
||||
return;
|
||||
}
|
||||
if (!providers.some((provider) => provider.id === providerForm.id)) {
|
||||
setStatus('请先保存当前模型供应商配置,再开始生成');
|
||||
return;
|
||||
}
|
||||
setIsBusy(true);
|
||||
setGenerationSteps(initialGenerationSteps);
|
||||
setStatus('正在生成图片...');
|
||||
try {
|
||||
startStep(0);
|
||||
await invoke('upsert_provider', { input: { ...providerForm, image_model: null } });
|
||||
await new Promise((resolve) => window.setTimeout(resolve, 80));
|
||||
startStep(1);
|
||||
await new Promise((resolve) => window.setTimeout(resolve, 120));
|
||||
startStep(2);
|
||||
@@ -264,148 +628,192 @@ function App() {
|
||||
|
||||
return (
|
||||
<main className="app-shell">
|
||||
<header className="topbar">
|
||||
<div className="brand">
|
||||
<img className="brand-mark" src={appLogo} alt="Image Draw AI" />
|
||||
<div>
|
||||
<h1>Image Draw AI</h1>
|
||||
<p>图片默认保存到应用数据文件夹</p>
|
||||
</div>
|
||||
<aside className="side-rail">
|
||||
<div className="rail-logo">
|
||||
<img src={appLogo} alt="Image Draw AI" />
|
||||
</div>
|
||||
<div className="topbar-actions">
|
||||
<div className="current-provider">
|
||||
<span>当前模型</span>
|
||||
<strong>{selectedImageModel}</strong>
|
||||
</div>
|
||||
<button className="ghost" onClick={() => setIsSettingsOpen(true)}>设置</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<section className="workspace">
|
||||
<aside className="compose-card">
|
||||
<div className="section-heading">
|
||||
<span>创作区</span>
|
||||
<strong>{materialPaths.length > 0 ? '素材生成' : '文字生成'}</strong>
|
||||
</div>
|
||||
|
||||
<label className="field prompt-field">
|
||||
<span>提示词</span>
|
||||
<textarea value={prompt} onChange={(event) => setPrompt(event.target.value)} />
|
||||
</label>
|
||||
|
||||
<div className="material-panel">
|
||||
<div className="material-header">
|
||||
<div>
|
||||
<strong>参考图</strong>
|
||||
<span>{materialPaths.length > 0 ? `${materialPaths.length} 张,图像编辑模式` : '可选,支持多张'}</span>
|
||||
</div>
|
||||
{materialPaths.length > 0 && (
|
||||
<button className="ghost mini" onClick={() => setMaterialPaths([])} disabled={isBusy}>清空</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="reference-strip">
|
||||
<button className="add-reference-card" onClick={pickMaterialImages} disabled={isBusy}>
|
||||
<span>+</span>
|
||||
<strong>添加参考图</strong>
|
||||
<small>PNG/JPG/WEBP</small>
|
||||
</button>
|
||||
{materialPaths.map((path, index) => (
|
||||
<article className="reference-card" key={path}>
|
||||
<img src={convertFileSrc(path)} alt="素材图片" />
|
||||
<span>{index + 1}</span>
|
||||
<button onClick={() => removeMaterialImage(path)} disabled={isBusy}>×</button>
|
||||
</article>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="params-card">
|
||||
<div className="section-heading">
|
||||
<span>生成参数</span>
|
||||
<strong>基础</strong>
|
||||
</div>
|
||||
<label className="field compact-field">
|
||||
<span>图像模型</span>
|
||||
<select value={selectedImageModel} onChange={(event) => setSelectedImageModel(event.target.value)} disabled={isBusy}>
|
||||
{imageModelOptions.map((model) => (
|
||||
<option key={model} value={model}>{model}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
<div className="grid two">
|
||||
<label className="field compact-field">
|
||||
<span>分辨率</span>
|
||||
<select value={imageSize} onChange={(event) => setImageSize(event.target.value)} disabled={isBusy}>
|
||||
{imageSizeOptions.map((size) => (
|
||||
<option key={size} value={size}>{size}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
<label className="field compact-field">
|
||||
<span>质量</span>
|
||||
<select value={imageQuality} onChange={(event) => setImageQuality(event.target.value)} disabled={isBusy}>
|
||||
{imageQualityOptions.map((quality) => (
|
||||
<option key={quality} value={quality}>{quality}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button className="generate-button" onClick={generateImage} disabled={isBusy}>
|
||||
{isBusy ? '正在生成...' : '开始生成'}
|
||||
<nav className="rail-nav" aria-label="主导航">
|
||||
<button className="rail-button active" title="生成">
|
||||
<span className="rail-icon"><RobotOutlined /></span>
|
||||
<strong>生成</strong>
|
||||
</button>
|
||||
<button className="rail-button" title="素材" onClick={pickMaterialImages} disabled={isBusy}>
|
||||
<span className="rail-icon"><PictureOutlined /></span>
|
||||
<strong>素材</strong>
|
||||
</button>
|
||||
<button className="rail-button" title="图库" onClick={openGeneratedDir}>
|
||||
<span className="rail-icon"><FolderOpenOutlined /></span>
|
||||
<strong>图库</strong>
|
||||
</button>
|
||||
</nav>
|
||||
<button className="rail-button rail-settings" title="设置" onClick={() => setIsSettingsOpen(true)}>
|
||||
<span className="rail-icon"><SettingOutlined /></span>
|
||||
<strong>设置</strong>
|
||||
</button>
|
||||
</aside>
|
||||
|
||||
{status !== '准备就绪' && <p className="status">{status}</p>}
|
||||
</aside>
|
||||
|
||||
<section className="result-card">
|
||||
<div className="section-heading result-heading">
|
||||
<span>结果区</span>
|
||||
<div className="heading-actions">
|
||||
<strong>本次生成 {sessionImages.length} 张</strong>
|
||||
<button className="ghost mini" onClick={openGeneratedDir}>打开目录</button>
|
||||
<section className="app-stage">
|
||||
<header className="topbar">
|
||||
<div className="brand">
|
||||
<div>
|
||||
<p>Image Draw AI</p>
|
||||
<h1>图像生成工作台</h1>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{sessionImages.length === 0 ? (
|
||||
<div className="empty-state">
|
||||
<div>暂无图片</div>
|
||||
<p>输入提示词,点击开始生成后会显示在这里。</p>
|
||||
<div className="topbar-actions">
|
||||
<div className="current-provider">
|
||||
<span>供应商</span>
|
||||
<strong>{activeProviderName}</strong>
|
||||
</div>
|
||||
) : (
|
||||
<div className="image-grid">
|
||||
{sessionImages.map((image) => (
|
||||
<article className="image-card" key={image.id}>
|
||||
<button className="image-preview-button" onClick={() => setPreviewImage(image)}>
|
||||
<img src={convertFileSrc(image.file_path)} alt={image.prompt} />
|
||||
</button>
|
||||
<div>
|
||||
<strong>{image.created_at}</strong>
|
||||
<p>{image.prompt}</p>
|
||||
<button className="ghost mini" onClick={() => revealImage(image.file_path)}>定位文件</button>
|
||||
<span>{image.file_path}</span>
|
||||
</div>
|
||||
</article>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
<button className="ghost" onClick={() => setIsSettingsOpen(true)}>设置</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<div className={`progress-card result-progress ${isBusy ? 'is-loading' : ''}`}>
|
||||
<div className="spinner" aria-hidden="true" />
|
||||
<div className="progress-content">
|
||||
<strong>{isBusy ? '生成中' : '生成流程'}</strong>
|
||||
<ol className="step-list">
|
||||
{generationSteps.map((step) => (
|
||||
<li className={`step ${step.status}`} key={step.label}>
|
||||
<span />
|
||||
{step.label}
|
||||
</li>
|
||||
<section className="workspace">
|
||||
<aside className="compose-card">
|
||||
<div className="section-heading">
|
||||
<div>
|
||||
<span>创作区</span>
|
||||
<strong>{activeMode}</strong>
|
||||
</div>
|
||||
<small>{imageAspectRatio} / {imageSize} / {imageQuality}</small>
|
||||
</div>
|
||||
|
||||
<label className="field prompt-field">
|
||||
<span>提示词</span>
|
||||
<textarea value={prompt} onChange={(event) => setPrompt(event.target.value)} />
|
||||
</label>
|
||||
|
||||
<div className="material-panel">
|
||||
<div className="material-header">
|
||||
<div>
|
||||
<strong>参考图</strong>
|
||||
<span>{materialPaths.length > 0 ? `${materialPaths.length} 张素材` : '未导入'}</span>
|
||||
</div>
|
||||
{materialPaths.length > 0 && (
|
||||
<button className="ghost mini" onClick={() => setMaterialPaths([])} disabled={isBusy}>清空</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="reference-strip">
|
||||
<button className="add-reference-card" onClick={pickMaterialImages} disabled={isBusy}>
|
||||
<span>+</span>
|
||||
<strong>参考图</strong>
|
||||
<small>PNG/JPG/WEBP</small>
|
||||
</button>
|
||||
{materialPaths.map((path, index) => (
|
||||
<article className="reference-card" key={path}>
|
||||
<img src={convertFileSrc(path)} alt="素材图片" />
|
||||
<span>{index + 1}</span>
|
||||
<button onClick={() => removeMaterialImage(path)} disabled={isBusy}>×</button>
|
||||
</article>
|
||||
))}
|
||||
</ol>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="params-card">
|
||||
<div className="section-heading">
|
||||
<div>
|
||||
<span>生成参数</span>
|
||||
<strong>基础</strong>
|
||||
</div>
|
||||
</div>
|
||||
<div className="params-grid">
|
||||
<label className="field compact-field">
|
||||
<span>图像模型</span>
|
||||
<select value={selectedImageModel} onChange={(event) => setSelectedImageModel(event.target.value)} disabled={isBusy}>
|
||||
{visibleImageModelOptions.length === 0 && <option value="">未获取模型</option>}
|
||||
{visibleImageModelOptions.map((model) => (
|
||||
<option key={model} value={model}>{model}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
<label className="field compact-field">
|
||||
<span>质量</span>
|
||||
<select value={imageQuality} onChange={(event) => setImageQuality(event.target.value)} disabled={isBusy}>
|
||||
{imageQualityOptions.map((quality) => (
|
||||
<option key={quality} value={quality}>{quality}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
<label className="field compact-field">
|
||||
<span>比例</span>
|
||||
<select value={imageAspectRatio} onChange={(event) => updateAspectRatio(event.target.value)} disabled={isBusy}>
|
||||
{imageAspectRatioOptions.map((option) => (
|
||||
<option key={option.value} value={option.value}>{option.value}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
<label className="field compact-field">
|
||||
<span>分辨率</span>
|
||||
<select value={imageSize} onChange={(event) => updateImageSize(event.target.value)} disabled={isBusy}>
|
||||
{selectedAspectRatioOption.sizes.map((size) => (
|
||||
<option key={size.value} value={size.value}>{size.label}</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button className="generate-button" onClick={generateImage} disabled={isBusy}>
|
||||
{isBusy ? '正在生成...' : '开始生成'}
|
||||
</button>
|
||||
|
||||
{status !== '准备就绪' && (
|
||||
<p className={`status ${isErrorStatus(status) ? 'error' : ''}`}>{status}</p>
|
||||
)}
|
||||
</aside>
|
||||
|
||||
<section className="result-card">
|
||||
<div className="section-heading result-heading">
|
||||
<div>
|
||||
<span>结果区</span>
|
||||
<strong>本次生成 {sessionImages.length} 张</strong>
|
||||
</div>
|
||||
<div className="heading-actions">
|
||||
<button className="ghost mini" onClick={openGeneratedDir}>打开目录</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{sessionImages.length === 0 ? (
|
||||
<div className="empty-state">
|
||||
<img src={appLogo} alt="" />
|
||||
<div>等待首张作品</div>
|
||||
<p>{selectedImageModel || '未获取模型'} / {imageSize} / {imageQuality}</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="image-grid">
|
||||
{sessionImages.map((image) => (
|
||||
<article className="image-card" key={image.id}>
|
||||
<button className="image-preview-button" onClick={() => setPreviewImage(image)}>
|
||||
<img src={convertFileSrc(image.file_path)} alt={image.prompt} />
|
||||
</button>
|
||||
<div>
|
||||
<strong>{image.created_at}</strong>
|
||||
<p>{image.prompt}</p>
|
||||
<button className="ghost mini" onClick={() => revealImage(image.file_path)}>定位文件</button>
|
||||
<span>{image.file_path}</span>
|
||||
</div>
|
||||
</article>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className={`progress-card result-progress ${isBusy ? 'is-loading' : ''}`}>
|
||||
<div className="spinner" aria-hidden="true" />
|
||||
<div className="progress-content">
|
||||
<strong>{isBusy ? '生成中' : '生成流程'}</strong>
|
||||
<ol className="step-list">
|
||||
{generationSteps.map((step) => (
|
||||
<li className={`step ${step.status}`} key={step.label}>
|
||||
<span />
|
||||
{step.label}
|
||||
</li>
|
||||
))}
|
||||
</ol>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</section>
|
||||
</section>
|
||||
|
||||
@@ -422,28 +830,25 @@ function App() {
|
||||
</div>
|
||||
|
||||
<div className="settings-content">
|
||||
<section className="settings-group">
|
||||
<div className="section-heading">
|
||||
<span>基础信息</span>
|
||||
<strong>配置名称</strong>
|
||||
</div>
|
||||
<div className="grid two">
|
||||
<label className="field">
|
||||
<span>配置 ID</span>
|
||||
<input value={providerForm.id} onChange={(event) => updateProviderForm('id', event.target.value)} />
|
||||
</label>
|
||||
<label className="field">
|
||||
<span>名称</span>
|
||||
<input value={providerForm.name} onChange={(event) => updateProviderForm('name', event.target.value)} />
|
||||
</label>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section className="settings-group">
|
||||
<div className="section-heading">
|
||||
<span>接口信息</span>
|
||||
<strong>中转站 / OpenAI</strong>
|
||||
</div>
|
||||
<div className="settings-form">
|
||||
<label className="field">
|
||||
<span>配置 ID</span>
|
||||
<input value={providerForm.id} onChange={(event) => updateProviderForm('id', event.target.value)} />
|
||||
</label>
|
||||
<label className="field">
|
||||
<span>名称</span>
|
||||
<input value={providerForm.name} onChange={(event) => updateProviderForm('name', event.target.value)} />
|
||||
</label>
|
||||
<label className="field">
|
||||
<span>API 分类</span>
|
||||
<select value={providerForm.kind} onChange={(event) => updateProviderKind(event.target.value)}>
|
||||
{apiKindOptions.map((option) => (
|
||||
<option key={option.value} value={option.value} disabled={!option.supported}>
|
||||
{option.label}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
<label className="field">
|
||||
<span>Base URL</span>
|
||||
<input
|
||||
@@ -451,29 +856,68 @@ function App() {
|
||||
onChange={(event) => updateProviderForm('base_url', event.target.value)}
|
||||
placeholder="https://api.openai.com/v1"
|
||||
/>
|
||||
<small>填写 API 地址,不是中转站网页地址;通常以 /v1 结尾。</small>
|
||||
</label>
|
||||
|
||||
<label className="field">
|
||||
<span>API Key</span>
|
||||
<input
|
||||
value={providerForm.api_key}
|
||||
onChange={(event) => updateProviderForm('api_key', event.target.value)}
|
||||
placeholder="sk-... 或中转站 key"
|
||||
placeholder={apiKeyPlaceholder(providerForm.kind)}
|
||||
type="password"
|
||||
/>
|
||||
</label>
|
||||
</section>
|
||||
<p className="settings-tip">{providerSettingsTip(providerForm.kind)}</p>
|
||||
</div>
|
||||
|
||||
<div className="drawer-actions">
|
||||
<button onClick={saveProvider} disabled={isBusy}>保存配置</button>
|
||||
<button className="ghost" onClick={refreshProviders} disabled={isBusy}>刷新</button>
|
||||
<button className="ghost" onClick={fetchProviderModels} disabled={isBusy || isFetchingModels}>
|
||||
{isFetchingModels ? '获取中...' : '获取模型'}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{settingsStatus && (
|
||||
<p className={`settings-status ${isErrorStatus(settingsStatus) ? 'error' : ''}`}>
|
||||
{settingsStatus}
|
||||
</p>
|
||||
)}
|
||||
|
||||
<div className="model-list-panel">
|
||||
<div className="section-heading">
|
||||
<span>模型列表</span>
|
||||
<strong>{fetchedImageModels.length} 个图片模型</strong>
|
||||
</div>
|
||||
{fetchedImageModels.length === 0 ? (
|
||||
<p className="muted model-empty">暂无图片模型</p>
|
||||
) : (
|
||||
<ul className="model-list">
|
||||
{fetchedImageModels.map((model) => (
|
||||
<li key={model.id}>
|
||||
<label className="model-record">
|
||||
<input
|
||||
checked={selectedImageModels.includes(model.id)}
|
||||
disabled={isBusy}
|
||||
onChange={() => updateSelectedModelRecords(model.id)}
|
||||
type="checkbox"
|
||||
/>
|
||||
<span>
|
||||
<strong>{model.id}</strong>
|
||||
{model.owned_by && <small>{model.owned_by}</small>}
|
||||
</span>
|
||||
</label>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="saved-providers">
|
||||
<div className="section-heading">
|
||||
<span>已保存</span>
|
||||
<strong>{providers.length} 个配置</strong>
|
||||
<div>
|
||||
<span>已保存</span>
|
||||
<strong>{providers.length} 个配置</strong>
|
||||
</div>
|
||||
<button className="ghost mini" onClick={refreshProviders} disabled={isBusy}>刷新</button>
|
||||
</div>
|
||||
{providers.length === 0 ? (
|
||||
<p className="muted">暂无配置,保存后会出现在这里。</p>
|
||||
|
||||
1323
src/styles.css
1323
src/styles.css
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user