feat: scaffold AI image desktop MVP
This commit is contained in:
32
src-tauri/Cargo.toml
Normal file
32
src-tauri/Cargo.toml
Normal file
@@ -0,0 +1,32 @@
|
||||
[package]
|
||||
name = "image-draw-ai"
|
||||
version = "0.1.0"
|
||||
description = "Cross-platform desktop AI image tool"
|
||||
authors = ["Image Draw AI"]
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "image_draw_ai_lib"
|
||||
crate-type = ["staticlib", "cdylib", "rlib"]
|
||||
|
||||
[build-dependencies]
|
||||
tauri-build = { version = "2", features = [] }
|
||||
|
||||
[dependencies]
|
||||
tauri = { version = "2", features = ["protocol-asset"] }
|
||||
tauri-plugin-opener = "2"
|
||||
tauri-plugin-dialog = "2"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
sqlx = { version = "0.8", features = ["sqlite", "runtime-tokio-rustls", "chrono", "uuid"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
uuid = { version = "1", features = ["v4", "serde"] }
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "rustls-tls"] }
|
||||
thiserror = "2"
|
||||
async-trait = "0.1"
|
||||
base64 = "0.22"
|
||||
|
||||
[features]
|
||||
default = ["custom-protocol"]
|
||||
custom-protocol = ["tauri/custom-protocol"]
|
||||
3
src-tauri/build.rs
Normal file
3
src-tauri/build.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
tauri_build::build()
|
||||
}
|
||||
11
src-tauri/capabilities/default.json
Normal file
11
src-tauri/capabilities/default.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"$schema": "../gen/schemas/desktop-schema.json",
|
||||
"identifier": "default",
|
||||
"description": "Default desktop permissions",
|
||||
"windows": ["main"],
|
||||
"permissions": [
|
||||
"core:default",
|
||||
"dialog:allow-open",
|
||||
"opener:default"
|
||||
]
|
||||
}
|
||||
1
src-tauri/gen/schemas/acl-manifests.json
Normal file
1
src-tauri/gen/schemas/acl-manifests.json
Normal file
File diff suppressed because one or more lines are too long
1
src-tauri/gen/schemas/capabilities.json
Normal file
1
src-tauri/gen/schemas/capabilities.json
Normal file
@@ -0,0 +1 @@
|
||||
{"default":{"identifier":"default","description":"Default desktop permissions","local":true,"windows":["main"],"permissions":["core:default","dialog:allow-open","opener:default"]}}
|
||||
2543
src-tauri/gen/schemas/desktop-schema.json
Normal file
2543
src-tauri/gen/schemas/desktop-schema.json
Normal file
File diff suppressed because it is too large
Load Diff
2543
src-tauri/gen/schemas/macOS-schema.json
Normal file
2543
src-tauri/gen/schemas/macOS-schema.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
src-tauri/icons/icon.png
Normal file
BIN
src-tauri/icons/icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 800 B |
68
src-tauri/migrations/001_init.sql
Normal file
68
src-tauri/migrations/001_init.sql
Normal file
@@ -0,0 +1,68 @@
|
||||
CREATE TABLE IF NOT EXISTS providers (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
base_url TEXT NOT NULL,
|
||||
api_key_encrypted TEXT,
|
||||
text_model TEXT,
|
||||
image_model TEXT,
|
||||
capabilities TEXT NOT NULL DEFAULT '{}',
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS generation_tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
provider_id TEXT NOT NULL,
|
||||
task_type TEXT NOT NULL,
|
||||
prompt TEXT NOT NULL,
|
||||
negative_prompt TEXT,
|
||||
model TEXT NOT NULL,
|
||||
size TEXT,
|
||||
quality TEXT,
|
||||
status TEXT NOT NULL,
|
||||
error_message TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
finished_at TEXT,
|
||||
FOREIGN KEY(provider_id) REFERENCES providers(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS image_assets (
|
||||
id TEXT PRIMARY KEY,
|
||||
task_id TEXT,
|
||||
file_path TEXT NOT NULL,
|
||||
thumbnail_path TEXT,
|
||||
mime_type TEXT,
|
||||
width INTEGER,
|
||||
height INTEGER,
|
||||
file_size INTEGER,
|
||||
source_type TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY(task_id) REFERENCES generation_tasks(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT,
|
||||
provider_id TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
FOREIGN KEY(provider_id) REFERENCES providers(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ai_request_logs (
|
||||
id TEXT PRIMARY KEY,
|
||||
task_id TEXT,
|
||||
provider_id TEXT NOT NULL,
|
||||
endpoint TEXT NOT NULL,
|
||||
request_summary TEXT,
|
||||
response_summary TEXT,
|
||||
status_code INTEGER,
|
||||
latency_ms INTEGER,
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY(task_id) REFERENCES generation_tasks(id),
|
||||
FOREIGN KEY(provider_id) REFERENCES providers(id)
|
||||
);
|
||||
2
src-tauri/src/ai/mod.rs
Normal file
2
src-tauri/src/ai/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod openai_compatible;
|
||||
pub mod provider;
|
||||
158
src-tauri/src/ai/openai_compatible.rs
Normal file
158
src-tauri/src/ai/openai_compatible.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use async_trait::async_trait;
|
||||
use std::{fs, path::Path};
|
||||
|
||||
use reqwest::{multipart, Client};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest, ImageResult};
|
||||
use crate::AppError;
|
||||
|
||||
pub struct OpenAiCompatibleProvider {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl OpenAiCompatibleProvider {
|
||||
pub fn new(base_url: String, api_key: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_image_response(response_text: &str) -> Result<ImageResult, AppError> {
|
||||
if response_text.trim_start().starts_with("<!doctype html")
|
||||
|| response_text.trim_start().starts_with("<html")
|
||||
{
|
||||
return Err(AppError::Provider(
|
||||
"provider returned an HTML page, not an API JSON response. Please use the API base URL, usually ending with /v1, not the gateway website URL.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let response_body: ImageResponseBody = serde_json::from_str(response_text).map_err(|error| {
|
||||
AppError::Provider(format!(
|
||||
"failed to decode image response: {error}; response body: {response_text}"
|
||||
))
|
||||
})?;
|
||||
let image_data = response_body
|
||||
.data
|
||||
.into_iter()
|
||||
.find_map(|item| {
|
||||
item.b64_json
|
||||
.map(ImageData::Base64)
|
||||
.or_else(|| item.url.map(ImageData::Url))
|
||||
})
|
||||
.ok_or_else(|| AppError::Provider("image response did not include b64_json or url".to_string()))?;
|
||||
|
||||
Ok(ImageResult {
|
||||
mime_type: "image/png".to_string(),
|
||||
data: image_data,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ImageRequestBody<'a> {
|
||||
model: &'a str,
|
||||
prompt: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
size: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
quality: Option<&'a str>,
|
||||
response_format: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ImageResponseBody {
|
||||
data: Vec<ImageResponseItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ImageResponseItem {
|
||||
b64_json: Option<String>,
|
||||
url: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AiProvider for OpenAiCompatibleProvider {
|
||||
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError> {
|
||||
let body = ImageRequestBody {
|
||||
model: &request.model,
|
||||
prompt: &request.prompt,
|
||||
size: request.size.as_deref(),
|
||||
quality: request.quality.as_deref(),
|
||||
response_format: "b64_json",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/images/generations", self.base_url))
|
||||
.bearer_auth(&self.api_key)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let message = response.text().await.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!("image generation failed ({status}): {message}")));
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
parse_image_response(&response_text)
|
||||
}
|
||||
|
||||
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError> {
|
||||
let mut form = multipart::Form::new()
|
||||
.text("model", request.model)
|
||||
.text("prompt", request.prompt)
|
||||
.text("response_format", "b64_json");
|
||||
|
||||
if let Some(size) = request.size {
|
||||
form = form.text("size", size);
|
||||
}
|
||||
if let Some(quality) = request.quality {
|
||||
form = form.text("quality", quality);
|
||||
}
|
||||
|
||||
for image_path in request.image_paths {
|
||||
let bytes = fs::read(&image_path)?;
|
||||
let file_name = Path::new(&image_path)
|
||||
.file_name()
|
||||
.and_then(|name| name.to_str())
|
||||
.unwrap_or("image.png")
|
||||
.to_string();
|
||||
let mime = match Path::new(&image_path)
|
||||
.extension()
|
||||
.and_then(|extension| extension.to_str())
|
||||
.map(|extension| extension.to_ascii_lowercase())
|
||||
.as_deref()
|
||||
{
|
||||
Some("jpg" | "jpeg") => "image/jpeg",
|
||||
Some("webp") => "image/webp",
|
||||
_ => "image/png",
|
||||
};
|
||||
let part = multipart::Part::bytes(bytes).file_name(file_name).mime_str(mime)?;
|
||||
form = form.part("image[]", part);
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/images/edits", self.base_url))
|
||||
.bearer_auth(&self.api_key)
|
||||
.multipart(form)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let message = response.text().await.unwrap_or_else(|_| "request failed".to_string());
|
||||
return Err(AppError::Provider(format!("image edit failed ({status}): {message}")));
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
parse_image_response(&response_text)
|
||||
}
|
||||
}
|
||||
39
src-tauri/src/ai/provider.rs
Normal file
39
src-tauri/src/ai/provider.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::AppError;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageGenerateRequest {
|
||||
pub prompt: String,
|
||||
pub model: String,
|
||||
pub size: Option<String>,
|
||||
pub quality: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageEditRequest {
|
||||
pub prompt: String,
|
||||
pub model: String,
|
||||
pub size: Option<String>,
|
||||
pub quality: Option<String>,
|
||||
pub image_paths: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageResult {
|
||||
pub mime_type: String,
|
||||
pub data: ImageData,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ImageData {
|
||||
Base64(String),
|
||||
Url(String),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait AiProvider: Send + Sync {
|
||||
async fn generate_image(&self, request: ImageGenerateRequest) -> Result<ImageResult, AppError>;
|
||||
async fn edit_image(&self, request: ImageEditRequest) -> Result<ImageResult, AppError>;
|
||||
}
|
||||
18
src-tauri/src/commands/dialog.rs
Normal file
18
src-tauri/src/commands/dialog.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use tauri_plugin_dialog::DialogExt;
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn pick_material_images(app: tauri::AppHandle) -> Result<Vec<String>, String> {
|
||||
let files = app
|
||||
.dialog()
|
||||
.file()
|
||||
.add_filter("Images", &["png", "jpg", "jpeg", "webp"])
|
||||
.blocking_pick_files();
|
||||
|
||||
let paths = files
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter_map(|file_path| file_path.as_path().map(|path| path.to_string_lossy().to_string()))
|
||||
.collect();
|
||||
|
||||
Ok(paths)
|
||||
}
|
||||
129
src-tauri/src/commands/generation.rs
Normal file
129
src-tauri/src/commands/generation.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
use tauri::{AppHandle, State};
|
||||
|
||||
use crate::{
|
||||
ai::{
|
||||
openai_compatible::OpenAiCompatibleProvider,
|
||||
provider::{AiProvider, ImageData, ImageEditRequest, ImageGenerateRequest},
|
||||
},
|
||||
db::{
|
||||
models::{CreateGenerationTaskInput, GenerateImageInput, GenerateImageOutput, GenerationTask},
|
||||
repository,
|
||||
},
|
||||
state::AppState,
|
||||
storage,
|
||||
AppError,
|
||||
};
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn create_generation_task(
|
||||
state: State<'_, AppState>,
|
||||
input: CreateGenerationTaskInput,
|
||||
) -> Result<GenerationTask, AppError> {
|
||||
repository::create_generation_task(&state.db, input).await
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn generate_image(
|
||||
app: AppHandle,
|
||||
state: State<'_, AppState>,
|
||||
input: GenerateImageInput,
|
||||
) -> Result<GenerateImageOutput, AppError> {
|
||||
let provider = repository::get_provider_secret(&state.db, &input.provider_id).await?;
|
||||
if !provider.enabled {
|
||||
return Err(AppError::Provider(format!("provider {} is disabled", provider.name)));
|
||||
}
|
||||
|
||||
if provider.kind != "openai-compatible" {
|
||||
return Err(AppError::Provider(format!(
|
||||
"provider kind {} is not supported yet",
|
||||
provider.kind
|
||||
)));
|
||||
}
|
||||
|
||||
if !provider.base_url.trim_end_matches('/').ends_with("/v1") {
|
||||
return Err(AppError::Provider(
|
||||
"Base URL 看起来不是 API 地址。OpenAI-compatible 地址通常需要以 /v1 结尾,例如 https://api.openai.com/v1 或 https://你的中转站域名/v1".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let api_key = provider
|
||||
.api_key_encrypted
|
||||
.clone()
|
||||
.filter(|key| !key.trim().is_empty())
|
||||
.ok_or_else(|| AppError::Provider("provider api_key is empty".to_string()))?;
|
||||
let model = input
|
||||
.model
|
||||
.clone()
|
||||
.or_else(|| provider.image_model.clone())
|
||||
.ok_or_else(|| AppError::Provider("image model is empty".to_string()))?;
|
||||
|
||||
let task = repository::create_generation_task(
|
||||
&state.db,
|
||||
CreateGenerationTaskInput {
|
||||
provider_id: provider.id.clone(),
|
||||
task_type: if input.image_paths.is_empty() {
|
||||
"text_to_image".to_string()
|
||||
} else {
|
||||
"image_edit".to_string()
|
||||
},
|
||||
prompt: input.prompt.clone(),
|
||||
model: model.clone(),
|
||||
size: input.size.clone(),
|
||||
quality: input.quality.clone(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let ai_provider = OpenAiCompatibleProvider::new(provider.base_url, api_key);
|
||||
let image_result = match if input.image_paths.is_empty() {
|
||||
ai_provider
|
||||
.generate_image(ImageGenerateRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
ai_provider
|
||||
.edit_image(ImageEditRequest {
|
||||
prompt: input.prompt,
|
||||
model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
image_paths: input.image_paths,
|
||||
})
|
||||
.await
|
||||
} {
|
||||
Ok(result) => result,
|
||||
Err(error) => {
|
||||
repository::mark_generation_task_failed(&state.db, &task.id, &error.to_string()).await?;
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
|
||||
let image_bytes = match &image_result.data {
|
||||
ImageData::Base64(data_base64) => storage::decode_base64_image(data_base64)?,
|
||||
ImageData::Url(url) => reqwest::get(url).await?.bytes().await?.to_vec(),
|
||||
};
|
||||
let stored_image = storage::save_generated_image_bytes(&app, &image_bytes, &image_result.mime_type)?;
|
||||
let file_path = stored_image.file_path.to_string_lossy().to_string();
|
||||
let asset = repository::create_image_asset(
|
||||
&state.db,
|
||||
&task.id,
|
||||
&file_path,
|
||||
&image_result.mime_type,
|
||||
stored_image.file_size,
|
||||
"generated",
|
||||
)
|
||||
.await?;
|
||||
repository::mark_generation_task_completed(&state.db, &task.id).await?;
|
||||
|
||||
Ok(GenerateImageOutput {
|
||||
task: GenerationTask {
|
||||
status: "completed".to_string(),
|
||||
..task
|
||||
},
|
||||
asset,
|
||||
})
|
||||
}
|
||||
3
src-tauri/src/commands/mod.rs
Normal file
3
src-tauri/src/commands/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod dialog;
|
||||
pub mod generation;
|
||||
pub mod provider;
|
||||
22
src-tauri/src/commands/provider.rs
Normal file
22
src-tauri/src/commands/provider.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use tauri::State;
|
||||
|
||||
use crate::{
|
||||
db::{models::UpsertProviderInput, repository},
|
||||
state::AppState,
|
||||
AppError,
|
||||
};
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_providers(state: State<'_, AppState>) -> Result<Vec<crate::db::models::ProviderConfig>, AppError> {
|
||||
repository::list_providers(&state.db).await
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn upsert_provider(state: State<'_, AppState>, input: UpsertProviderInput) -> Result<(), AppError> {
|
||||
repository::upsert_provider(&state.db, input).await
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn delete_provider(state: State<'_, AppState>, id: String) -> Result<(), AppError> {
|
||||
repository::delete_provider(&state.db, &id).await
|
||||
}
|
||||
29
src-tauri/src/db/mod.rs
Normal file
29
src-tauri/src/db/mod.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use std::fs;
|
||||
|
||||
use sqlx::{sqlite::SqliteConnectOptions, SqlitePool};
|
||||
use tauri::{AppHandle, Manager};
|
||||
|
||||
use crate::AppError;
|
||||
|
||||
pub mod models;
|
||||
pub mod repository;
|
||||
|
||||
pub async fn init(app: &AppHandle) -> Result<SqlitePool, AppError> {
|
||||
let app_data_dir = app.path().app_data_dir()?;
|
||||
fs::create_dir_all(&app_data_dir)?;
|
||||
|
||||
let database_path = app_data_dir.join("image_draw_ai.sqlite");
|
||||
let options = SqliteConnectOptions::new()
|
||||
.filename(database_path)
|
||||
.create_if_missing(true);
|
||||
let pool = SqlitePool::connect_with(options).await?;
|
||||
|
||||
for statement in include_str!("../../migrations/001_init.sql").split(';') {
|
||||
let statement = statement.trim();
|
||||
if !statement.is_empty() {
|
||||
sqlx::query(statement).execute(&pool).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
87
src-tauri/src/db/models.rs
Normal file
87
src-tauri/src/db/models.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, FromRow)]
|
||||
pub struct ProviderConfig {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub kind: String,
|
||||
pub base_url: String,
|
||||
pub api_key: Option<String>,
|
||||
pub text_model: Option<String>,
|
||||
pub image_model: Option<String>,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, FromRow)]
|
||||
pub struct ProviderSecret {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub kind: String,
|
||||
pub base_url: String,
|
||||
pub api_key_encrypted: Option<String>,
|
||||
pub image_model: Option<String>,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct UpsertProviderInput {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub kind: String,
|
||||
pub base_url: String,
|
||||
pub api_key: Option<String>,
|
||||
pub text_model: Option<String>,
|
||||
pub image_model: Option<String>,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, FromRow)]
|
||||
pub struct GenerationTask {
|
||||
pub id: String,
|
||||
pub provider_id: String,
|
||||
pub task_type: String,
|
||||
pub prompt: String,
|
||||
pub model: String,
|
||||
pub size: Option<String>,
|
||||
pub quality: Option<String>,
|
||||
pub status: String,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct CreateGenerationTaskInput {
|
||||
pub provider_id: String,
|
||||
pub task_type: String,
|
||||
pub prompt: String,
|
||||
pub model: String,
|
||||
pub size: Option<String>,
|
||||
pub quality: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, FromRow)]
|
||||
pub struct ImageAsset {
|
||||
pub id: String,
|
||||
pub task_id: Option<String>,
|
||||
pub file_path: String,
|
||||
pub mime_type: Option<String>,
|
||||
pub file_size: Option<i64>,
|
||||
pub source_type: String,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct GenerateImageInput {
|
||||
pub provider_id: String,
|
||||
pub prompt: String,
|
||||
pub model: Option<String>,
|
||||
pub size: Option<String>,
|
||||
pub quality: Option<String>,
|
||||
pub image_paths: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct GenerateImageOutput {
|
||||
pub task: GenerationTask,
|
||||
pub asset: ImageAsset,
|
||||
}
|
||||
227
src-tauri/src/db/repository.rs
Normal file
227
src-tauri/src/db/repository.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use chrono::Utc;
|
||||
use sqlx::SqlitePool;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::models::{
|
||||
CreateGenerationTaskInput, GenerationTask, ImageAsset, ProviderConfig, ProviderSecret,
|
||||
UpsertProviderInput,
|
||||
};
|
||||
use crate::AppError;
|
||||
|
||||
pub async fn list_providers(pool: &SqlitePool) -> Result<Vec<ProviderConfig>, AppError> {
|
||||
let providers = sqlx::query_as::<_, ProviderConfig>(
|
||||
r#"
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
kind,
|
||||
base_url,
|
||||
api_key_encrypted AS api_key,
|
||||
text_model,
|
||||
image_model,
|
||||
enabled != 0 AS enabled
|
||||
FROM providers
|
||||
ORDER BY updated_at DESC
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(providers)
|
||||
}
|
||||
|
||||
pub async fn upsert_provider(pool: &SqlitePool, input: UpsertProviderInput) -> Result<(), AppError> {
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
let existing_api_key: Option<String> = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT api_key_encrypted
|
||||
FROM providers
|
||||
WHERE id = ?1
|
||||
"#,
|
||||
)
|
||||
.bind(&input.id)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
.flatten();
|
||||
let api_key = input
|
||||
.api_key
|
||||
.filter(|key| !key.trim().is_empty())
|
||||
.or(existing_api_key);
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO providers (
|
||||
id, name, kind, base_url, api_key_encrypted, text_model, image_model,
|
||||
capabilities, enabled, created_at, updated_at
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?10)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
kind = excluded.kind,
|
||||
base_url = excluded.base_url,
|
||||
api_key_encrypted = excluded.api_key_encrypted,
|
||||
text_model = excluded.text_model,
|
||||
image_model = excluded.image_model,
|
||||
enabled = excluded.enabled,
|
||||
updated_at = excluded.updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(input.id)
|
||||
.bind(input.name)
|
||||
.bind(input.kind)
|
||||
.bind(input.base_url)
|
||||
.bind(api_key)
|
||||
.bind(input.text_model)
|
||||
.bind(input.image_model)
|
||||
.bind(r#"{"responses_api":true,"images_api":true,"chat_completions":true,"image_edit":true}"#)
|
||||
.bind(input.enabled)
|
||||
.bind(now)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_provider_secret(pool: &SqlitePool, id: &str) -> Result<ProviderSecret, AppError> {
|
||||
let provider = sqlx::query_as::<_, ProviderSecret>(
|
||||
r#"
|
||||
SELECT id, name, kind, base_url, api_key_encrypted, image_model, enabled != 0 AS enabled
|
||||
FROM providers
|
||||
WHERE id = ?1
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
Ok(provider)
|
||||
}
|
||||
|
||||
pub async fn delete_provider(pool: &SqlitePool, id: &str) -> Result<(), AppError> {
|
||||
sqlx::query("DELETE FROM providers WHERE id = ?1")
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_generation_task(
|
||||
pool: &SqlitePool,
|
||||
input: CreateGenerationTaskInput,
|
||||
) -> Result<GenerationTask, AppError> {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
let now = Utc::now().to_rfc3339();
|
||||
let status = "pending".to_string();
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO generation_tasks (
|
||||
id, provider_id, task_type, prompt, model, size, quality, status, created_at, updated_at
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?9)
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(&input.provider_id)
|
||||
.bind(&input.task_type)
|
||||
.bind(&input.prompt)
|
||||
.bind(&input.model)
|
||||
.bind(&input.size)
|
||||
.bind(&input.quality)
|
||||
.bind(&status)
|
||||
.bind(&now)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(GenerationTask {
|
||||
id,
|
||||
provider_id: input.provider_id,
|
||||
task_type: input.task_type,
|
||||
prompt: input.prompt,
|
||||
model: input.model,
|
||||
size: input.size,
|
||||
quality: input.quality,
|
||||
status,
|
||||
created_at: now,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn mark_generation_task_completed(pool: &SqlitePool, id: &str) -> Result<(), AppError> {
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE generation_tasks
|
||||
SET status = 'completed', updated_at = ?2, finished_at = ?2
|
||||
WHERE id = ?1
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.bind(now)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn mark_generation_task_failed(
|
||||
pool: &SqlitePool,
|
||||
id: &str,
|
||||
error_message: &str,
|
||||
) -> Result<(), AppError> {
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE generation_tasks
|
||||
SET status = 'failed', error_message = ?2, updated_at = ?3, finished_at = ?3
|
||||
WHERE id = ?1
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.bind(error_message)
|
||||
.bind(now)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_image_asset(
|
||||
pool: &SqlitePool,
|
||||
task_id: &str,
|
||||
file_path: &str,
|
||||
mime_type: &str,
|
||||
file_size: i64,
|
||||
source_type: &str,
|
||||
) -> Result<ImageAsset, AppError> {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
let now = Utc::now().to_rfc3339();
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO image_assets (
|
||||
id, task_id, file_path, mime_type, file_size, source_type, created_at
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(task_id)
|
||||
.bind(file_path)
|
||||
.bind(mime_type)
|
||||
.bind(file_size)
|
||||
.bind(source_type)
|
||||
.bind(&now)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(ImageAsset {
|
||||
id,
|
||||
task_id: Some(task_id.to_string()),
|
||||
file_path: file_path.to_string(),
|
||||
mime_type: Some(mime_type.to_string()),
|
||||
file_size: Some(file_size),
|
||||
source_type: source_type.to_string(),
|
||||
created_at: now,
|
||||
})
|
||||
}
|
||||
62
src-tauri/src/lib.rs
Normal file
62
src-tauri/src/lib.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
mod ai;
|
||||
mod commands;
|
||||
mod db;
|
||||
mod storage;
|
||||
mod state;
|
||||
|
||||
use serde::Serialize;
|
||||
use state::AppState;
|
||||
use tauri::Manager;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AppError {
|
||||
#[error("database error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
#[error("io error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("tauri path error: {0}")]
|
||||
Path(#[from] tauri::Error),
|
||||
#[error("http error: {0}")]
|
||||
Http(#[from] reqwest::Error),
|
||||
#[error("mime error: {0}")]
|
||||
Mime(#[from] reqwest::header::InvalidHeaderValue),
|
||||
#[error("base64 decode error: {0}")]
|
||||
Base64(#[from] base64::DecodeError),
|
||||
#[error("provider error: {0}")]
|
||||
Provider(String),
|
||||
}
|
||||
|
||||
impl Serialize for AppError {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(mobile, tauri::mobile_entry_point)]
|
||||
pub fn run() {
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_opener::init())
|
||||
.plugin(tauri_plugin_dialog::init())
|
||||
.setup(|app| {
|
||||
let app_handle = app.handle().clone();
|
||||
tauri::async_runtime::block_on(async move {
|
||||
let db = db::init(&app_handle).await?;
|
||||
app_handle.manage(AppState { db });
|
||||
Ok::<(), AppError>(())
|
||||
})?;
|
||||
Ok(())
|
||||
})
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
commands::provider::list_providers,
|
||||
commands::provider::upsert_provider,
|
||||
commands::provider::delete_provider,
|
||||
commands::dialog::pick_material_images,
|
||||
commands::generation::create_generation_task,
|
||||
commands::generation::generate_image,
|
||||
])
|
||||
.run(tauri::generate_context!())
|
||||
.expect("error while running tauri application");
|
||||
}
|
||||
3
src-tauri/src/main.rs
Normal file
3
src-tauri/src/main.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
image_draw_ai_lib::run();
|
||||
}
|
||||
6
src-tauri/src/state.rs
Normal file
6
src-tauri/src/state.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub db: SqlitePool,
|
||||
}
|
||||
39
src-tauri/src/storage/mod.rs
Normal file
39
src-tauri/src/storage/mod.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use std::{fs, path::PathBuf};
|
||||
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use tauri::{AppHandle, Manager};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AppError;
|
||||
|
||||
pub struct StoredImage {
|
||||
pub file_path: PathBuf,
|
||||
pub file_size: i64,
|
||||
}
|
||||
|
||||
pub fn save_generated_image_bytes(
|
||||
app: &AppHandle,
|
||||
bytes: &[u8],
|
||||
mime_type: &str,
|
||||
) -> Result<StoredImage, AppError> {
|
||||
let images_dir = app.path().app_data_dir()?.join("images").join("generated");
|
||||
fs::create_dir_all(&images_dir)?;
|
||||
|
||||
let extension = match mime_type {
|
||||
"image/jpeg" => "jpg",
|
||||
"image/webp" => "webp",
|
||||
_ => "png",
|
||||
};
|
||||
let file_path = images_dir.join(format!("{}.{}", Uuid::new_v4(), extension));
|
||||
let file_size = i64::try_from(bytes.len()).unwrap_or(i64::MAX);
|
||||
fs::write(&file_path, bytes)?;
|
||||
|
||||
Ok(StoredImage {
|
||||
file_path,
|
||||
file_size,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn decode_base64_image(data_base64: &str) -> Result<Vec<u8>, AppError> {
|
||||
Ok(general_purpose::STANDARD.decode(data_base64)?)
|
||||
}
|
||||
35
src-tauri/tauri.conf.json
Normal file
35
src-tauri/tauri.conf.json
Normal file
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"$schema": "https://schema.tauri.app/config/2",
|
||||
"productName": "Image Draw AI",
|
||||
"version": "0.1.0",
|
||||
"identifier": "com.imagedraw.ai",
|
||||
"build": {
|
||||
"beforeDevCommand": "pnpm dev",
|
||||
"devUrl": "http://localhost:1420",
|
||||
"beforeBuildCommand": "pnpm web:build",
|
||||
"frontendDist": "../dist"
|
||||
},
|
||||
"app": {
|
||||
"windows": [
|
||||
{
|
||||
"title": "Image Draw AI",
|
||||
"width": 1100,
|
||||
"height": 760,
|
||||
"minWidth": 900,
|
||||
"minHeight": 620
|
||||
}
|
||||
],
|
||||
"security": {
|
||||
"csp": null,
|
||||
"assetProtocol": {
|
||||
"enable": true,
|
||||
"scope": ["$APPDATA/**"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"bundle": {
|
||||
"active": true,
|
||||
"targets": "all",
|
||||
"icon": []
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user