return type fixes (#9)
Reviewed-on: #9 Co-authored-by: Jeffrey Smith <jasafpro@gmail.com> Co-committed-by: Jeffrey Smith <jasafpro@gmail.com>
This commit is contained in:
149
venice/image.py
149
venice/image.py
@@ -15,12 +15,11 @@ description: |
|
||||
and attach images to chat via event emitter for inline display.
|
||||
|
||||
Re-entrant safe: Multiple concurrent calls accumulate images correctly.
|
||||
|
||||
v1.7.0: Added VeniceImage namespace class for helper functions to avoid
|
||||
method collisions with Open WebUI framework introspection.
|
||||
v1.6.0: Added UserValves for SAFE_MODE and HIDE_WATERMARK with proper
|
||||
admin/user override logic.
|
||||
changelog:
|
||||
1.7.1:
|
||||
- changed return type from string to dictionary, mirroring the default tools behavior
|
||||
- fixed issues with user valve overrides - Watermake and Safe Mode
|
||||
- status message will display either [SFW] or [NSFW] depending on flag not content
|
||||
1.7.0:
|
||||
- Added VeniceImage namespace class for helper functions
|
||||
- Moved get_api_key, parse_venice_image_response to VeniceImage namespace
|
||||
@@ -42,7 +41,7 @@ import httpx
|
||||
class VeniceImage:
|
||||
"""
|
||||
Namespaced helpers for Venice image operations.
|
||||
|
||||
|
||||
Using a separate class prevents Open WebUI framework introspection
|
||||
from colliding with tool methods that have generic names like _get_api_key.
|
||||
"""
|
||||
@@ -96,7 +95,7 @@ class Tools:
|
||||
|
||||
class UserValves(BaseModel):
|
||||
VENICE_API_KEY: str = Field(default="", description="Your Venice.ai API key (overrides admin)")
|
||||
SAFE_MODE: bool = Field(default=False, description="Enable SFW content filtering")
|
||||
SAFE_MODE: bool = Field(default=True, description="Enable SFW content filtering")
|
||||
HIDE_WATERMARK: bool = Field(default=False, description="Hide Venice.ai watermark")
|
||||
DEFAULT_MODEL: str = Field(default="", description="Your preferred image model")
|
||||
DEFAULT_NEGATIVE_PROMPT: str = Field(default="", description="Default negative prompt")
|
||||
@@ -111,11 +110,21 @@ class Tools:
|
||||
self._lock_init = threading.Lock()
|
||||
self._last_cleanup: float = 0.0
|
||||
|
||||
def _is_safe_mode_enabled(self) -> bool:
|
||||
return self.valves.SAFE_MODE or self.user_valves.SAFE_MODE
|
||||
def _is_safe_mode_enabled(self, __user__: dict = None) -> bool:
|
||||
user_safe_mode = self.user_valves.SAFE_MODE
|
||||
|
||||
def _is_watermark_hidden(self) -> bool:
|
||||
return self.valves.HIDE_WATERMARK or self.user_valves.HIDE_WATERMARK
|
||||
if __user__ and "valves" in __user__:
|
||||
user_safe_mode = __user__["valves"].SAFE_MODE
|
||||
|
||||
return self.valves.SAFE_MODE or user_safe_mode
|
||||
|
||||
def _is_watermark_hidden(self, __user__: dict = None) -> bool:
|
||||
user_hide_watermark = self.user_valves.HIDE_WATERMARK
|
||||
|
||||
if __user__ and "valves" in __user__:
|
||||
user_hide_watermark = __user__["valves"].HIDE_WATERMARK
|
||||
|
||||
return self.valves.HIDE_WATERMARK or user_hide_watermark
|
||||
|
||||
def _get_default_model(self) -> str:
|
||||
return self.user_valves.DEFAULT_MODEL or self.valves.DEFAULT_MODEL
|
||||
@@ -152,20 +161,29 @@ class Tools:
|
||||
|
||||
async def _accumulate_files(self, key: str, new_files: List[dict], __event_emitter__: Callable[[dict], Any] = None):
|
||||
all_files = []
|
||||
|
||||
async with self._get_lock():
|
||||
if key not in self._message_files:
|
||||
self._message_files[key] = {"files": [], "timestamp": time.time()}
|
||||
|
||||
for f in new_files:
|
||||
self._message_files[key]["files"].append(dict(f))
|
||||
|
||||
self._message_files[key]["timestamp"] = time.time()
|
||||
|
||||
all_files = list(self._message_files[key]["files"])
|
||||
|
||||
now = time.time()
|
||||
|
||||
if now - self._last_cleanup > 60:
|
||||
self._last_cleanup = now
|
||||
ttl = self.valves.ACCUMULATOR_TTL
|
||||
|
||||
expired = [k for k, v in self._message_files.items() if now - v.get("timestamp", 0) > ttl]
|
||||
|
||||
for k in expired:
|
||||
del self._message_files[k]
|
||||
|
||||
if all_files and __event_emitter__:
|
||||
await __event_emitter__({"type": "files", "data": {"files": all_files}})
|
||||
|
||||
@@ -290,38 +308,73 @@ class Tools:
|
||||
except Exception as e:
|
||||
return None, f"Fetch error: {type(e).__name__}: {e}"
|
||||
|
||||
async def generate_image(self, prompt: str, model: Optional[str] = None, width: int = 1024, height: int = 1024, negative_prompt: Optional[str] = None, style_preset: Optional[str] = None, variants: int = 1, __request__=None, __user__: dict = None, __metadata__: dict = None, __event_emitter__: Callable[[dict], Any] = None) -> str:
|
||||
async def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
model: Optional[str] = None,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
negative_prompt: Optional[str] = None,
|
||||
style_preset: Optional[str] = None,
|
||||
variants: int = 1,
|
||||
__request__=None,
|
||||
__user__: dict = None,
|
||||
__metadata__: dict = None,
|
||||
__event_emitter__: Callable[[dict], Any] = None
|
||||
) -> dict:
|
||||
retVal = {
|
||||
"status": "failed",
|
||||
"message": "",
|
||||
"settings": {},
|
||||
"images": [],
|
||||
}
|
||||
venice_key = VeniceImage.get_api_key(self.valves, self.user_valves, __user__)
|
||||
|
||||
if not venice_key:
|
||||
return "Generate Image\nStatus: 0\nError: Venice.ai API key not configured."
|
||||
retVal["message"] = "Error: Venice.ai API key not configured",
|
||||
return retVal
|
||||
|
||||
if not prompt or not prompt.strip():
|
||||
return "Generate Image\nStatus: 0\nError: Prompt is required"
|
||||
retVal["message"] = "Error: Prompt is required",
|
||||
return retVal
|
||||
|
||||
msg_key = self._get_message_key(__metadata__)
|
||||
user_id = __user__.get("id", "default") if __user__ else "default"
|
||||
cooldown = self.valves.COOLDOWN_SECONDS
|
||||
|
||||
if cooldown > 0:
|
||||
now = time.time()
|
||||
last_gen = self._cooldowns.get(user_id, 0)
|
||||
is_reentrant = self._get_accumulated_count(msg_key) > 0
|
||||
|
||||
if not is_reentrant and now - last_gen < cooldown:
|
||||
remaining = cooldown - (now - last_gen)
|
||||
return f"Generate Image\nStatus: 429\nError: Rate limited. Wait {remaining:.1f}s."
|
||||
retVal["message"] = "Error: Rate limited. Wait {remaining:.1f}s.",
|
||||
return retVal
|
||||
|
||||
self._cooldowns[user_id] = now
|
||||
|
||||
model = model or self._get_default_model()
|
||||
safe_mode = self._is_safe_mode_enabled()
|
||||
hide_watermark = self._is_watermark_hidden()
|
||||
safe_mode = self._is_safe_mode_enabled(__user__)
|
||||
hide_watermark = self._is_watermark_hidden(__user__)
|
||||
effective_negative_prompt = negative_prompt or self._get_default_negative_prompt()
|
||||
variants = max(1, min(4, variants))
|
||||
width = max(512, min(1280, width))
|
||||
height = max(512, min(1280, height))
|
||||
existing_count = self._get_accumulated_count(msg_key)
|
||||
|
||||
if __event_emitter__:
|
||||
status_msg = f"Generating {variants} image{'s' if variants > 1 else ''} with {model}"
|
||||
if existing_count > 0:
|
||||
status_msg += f" (adding to {existing_count} existing)"
|
||||
|
||||
if safe_mode:
|
||||
status_msg += " [SFW]"
|
||||
else:
|
||||
status_msg += " [NSFW]"
|
||||
|
||||
await __event_emitter__({"type": "status", "data": {"description": f"{status_msg}...", "done": False}})
|
||||
|
||||
payload = {"model": model, "prompt": prompt, "width": width, "height": height, "safe_mode": safe_mode, "hide_watermark": hide_watermark, "return_binary": False, "variants": variants}
|
||||
if effective_negative_prompt:
|
||||
payload["negative_prompt"] = effective_negative_prompt
|
||||
@@ -329,6 +382,7 @@ class Tools:
|
||||
payload["style_preset"] = style_preset
|
||||
retried = False
|
||||
dropped_params = []
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=float(self.valves.GENERATION_TIMEOUT)) as client:
|
||||
response = await client.post("https://api.venice.ai/api/v1/image/generate", headers={"Authorization": f"Bearer {venice_key}", "Content-Type": "application/json"}, json=payload)
|
||||
@@ -349,65 +403,86 @@ class Tools:
|
||||
except httpx.HTTPStatusError as e:
|
||||
if __event_emitter__:
|
||||
await __event_emitter__({"type": "status", "data": {"done": True}})
|
||||
return f"Generate Image\nStatus: {e.response.status_code}\nError: {e.response.text[:200]}"
|
||||
|
||||
retVal["message"] = f"Status: {e.response.status_code} Error: {e.response.text[:200]}",
|
||||
return retVal["message"]
|
||||
except httpx.TimeoutException:
|
||||
if __event_emitter__:
|
||||
await __event_emitter__({"type": "status", "data": {"done": True}})
|
||||
return f"Generate Image\nStatus: 408\nError: Timed out after {self.valves.GENERATION_TIMEOUT}s"
|
||||
|
||||
retVal["message"] = f"Status: 408\nError: Timed out after {self.valves.GENERATION_TIMEOUT}s",
|
||||
return retVal
|
||||
except Exception as e:
|
||||
if __event_emitter__:
|
||||
await __event_emitter__({"type": "status", "data": {"done": True}})
|
||||
return f"Generate Image\nStatus: 0\nError: {type(e).__name__}: {e}"
|
||||
|
||||
retVal["message"] = f"Status: 0\nError: {type(e).__name__}: {e}",
|
||||
return retVal
|
||||
|
||||
images = result.get("images", [])
|
||||
|
||||
if not images:
|
||||
if __event_emitter__:
|
||||
await __event_emitter__({"type": "status", "data": {"done": True}})
|
||||
return "Generate Image\nStatus: 0\nError: No images returned"
|
||||
|
||||
retVal["message"] = f"Status: 0\nError: No images returned",
|
||||
return retVal
|
||||
|
||||
if __event_emitter__:
|
||||
await __event_emitter__({"type": "status", "data": {"description": f"Uploading {len(images)} images...", "done": False}})
|
||||
|
||||
chat_id = __metadata__.get("chat_id") if __metadata__ else None
|
||||
message_id = __metadata__.get("message_id") if __metadata__ else None
|
||||
uploaded_files = []
|
||||
errors = []
|
||||
|
||||
for i, image_b64 in enumerate(images):
|
||||
timestamp = int(time.time() * 1000)
|
||||
filename = f"venice_{model}_{timestamp}_{i+1}.webp"
|
||||
file_metadata = {"name": filename, "content_type": "image/webp", "data": {"model": model, "prompt": prompt, "negative_prompt": effective_negative_prompt, "style_preset": style_preset, "width": width, "height": height, "variant": i+1, "total_variants": len(images), "safe_mode": safe_mode, "hide_watermark": hide_watermark}}
|
||||
|
||||
if chat_id:
|
||||
file_metadata["chat_id"] = chat_id
|
||||
|
||||
if message_id:
|
||||
file_metadata["message_id"] = message_id
|
||||
|
||||
file_id, error = await self._upload_image(image_b64, filename, file_metadata, "image/webp", __request__)
|
||||
|
||||
if file_id:
|
||||
uploaded_files.append({"type": "image", "url": f"/api/v1/files/{file_id}/content"})
|
||||
else:
|
||||
errors.append(f"Variant {i+1}: {error}")
|
||||
|
||||
if uploaded_files:
|
||||
await self._accumulate_files(msg_key, uploaded_files, __event_emitter__)
|
||||
|
||||
final_count = self._get_accumulated_count(msg_key)
|
||||
|
||||
if __event_emitter__:
|
||||
await __event_emitter__({"type": "status", "data": {"description": f"Done ({final_count} images total)", "done": True}})
|
||||
parts = ["Generate Image", "Status: 200", "", f"Generated {len(uploaded_files)} image(s) for: {prompt[:100]}{'...' if len(prompt) > 100 else ''}", f"Model: {model} | Size: {width}x{height}"]
|
||||
settings_parts = []
|
||||
|
||||
retVal["status"] = "success"
|
||||
retVal["message"] = "The image has been successfully generated and is already visible to the user in the chat. You do not need to display or embed the image again - just acknowledge that it has been created.",
|
||||
|
||||
if safe_mode:
|
||||
settings_parts.append("SFW")
|
||||
retVal["settings"]["safe_mode"]: "SFW"
|
||||
else:
|
||||
retVal["settings"]["safe_mode"]: "NSFW"
|
||||
|
||||
if hide_watermark:
|
||||
settings_parts.append("No watermark")
|
||||
if settings_parts:
|
||||
parts.append(f"Settings: {', '.join(settings_parts)}")
|
||||
retVal["settings"]["hide_watermark"]: "hide_watermark"
|
||||
|
||||
if uploaded_files:
|
||||
parts.append("", "Files:")
|
||||
for i, f in enumerate(uploaded_files):
|
||||
parts.append(f" [{i+1}] {f['url']}")
|
||||
retVal["images"] = uploaded_files
|
||||
|
||||
if dropped_params:
|
||||
parts.append(f"Note: {model} doesn't support: {', '.join(dropped_params)} (ignored)")
|
||||
if final_count > len(uploaded_files):
|
||||
parts.append(f"Total images in message: {final_count}")
|
||||
retVal["note"] = f" {model} doesn't support: {', '.join(dropped_params)} (ignored)"
|
||||
|
||||
if errors:
|
||||
parts.append("", "Warnings:")
|
||||
for e in errors:
|
||||
parts.append(f" - {e}")
|
||||
return "\n".join(parts)
|
||||
retVal["warnings"] = errors
|
||||
|
||||
return retVal
|
||||
|
||||
async def upscale_image(self, image: str, scale: int = 2, enhance: bool = False, enhance_creativity: float = 0.5, enhance_prompt: Optional[str] = None, __request__=None, __user__: dict = None, __metadata__: dict = None, __files__: list = None, __event_emitter__: Callable[[dict], Any] = None) -> str:
|
||||
venice_key = VeniceImage.get_api_key(self.valves, self.user_valves, __user__)
|
||||
|
||||
Reference in New Issue
Block a user