Instructions to use AiArtLab/sdxs with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use AiArtLab/sdxs with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("AiArtLab/sdxs", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps
- Draw Things
- DiffusionBee
| # pip install flash-attn --no-build-isolation | |
| import torch | |
| import os | |
| import gc | |
| import numpy as np | |
| import random | |
| import json | |
| import shutil | |
| import time | |
| from datasets import Dataset, load_from_disk, concatenate_datasets | |
| from diffusers import AutoencoderKL,AutoencoderKLWan | |
| from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda | |
| from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM | |
| from typing import Dict, List, Tuple, Optional, Any | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from datetime import timedelta | |
| # ---------------- 1️⃣ Настройки ---------------- | |
| dtype = torch.float32 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| batch_size = 10 | |
| min_size = 192 #384 #320 #192 #256 #192 | |
| max_size = 320 #768 #640 #384 #256 #384 | |
| step = 64 #64 | |
| empty_share = 0.0 | |
| limit = 0 | |
| # Основная процедура обработки | |
| folder_path = "/workspace/mjnj" #alchemist" | |
| save_path = "/workspace/sdxs/datasets/mjnj" #"alchemist" | |
| os.makedirs(save_path, exist_ok=True) | |
| # Функция для очистки CUDA памяти | |
| def clear_cuda_memory(): | |
| if torch.cuda.is_available(): | |
| used_gb = torch.cuda.max_memory_allocated() / 1024**3 | |
| print(f"used_gb: {used_gb:.2f} GB") | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # ---------------- 2️⃣ Загрузка моделей ---------------- | |
| def load_models(): | |
| print("Загрузка моделей...") | |
| vae = AutoencoderKL.from_pretrained("AiArtLab/sdxs",subfolder="vae1x",torch_dtype=dtype).to(device).eval() | |
| #model_name = "Qwen/Qwen3-0.6B" | |
| #tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| #model = AutoModelForCausalLM.from_pretrained( | |
| # model_name, | |
| # torch_dtype=dtype, | |
| # device_map=device | |
| #).eval() | |
| #tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left') | |
| #model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B').to("cuda") | |
| return vae#, model, tokenizer | |
| #vae, model, tokenizer = load_models() | |
| vae = load_models() | |
| shift_factor = getattr(vae.config, "shift_factor", 0.0) | |
| if shift_factor is None: | |
| shift_factor = 0.0 | |
| scaling_factor = getattr(vae.config, "scaling_factor", 1.0) | |
| if scaling_factor is None: | |
| scaling_factor = 1.0 | |
| latents_mean = getattr(vae.config, "latents_mean", None) | |
| latents_std = getattr(vae.config, "latents_std", None) | |
| # ---------------- 3️⃣ Трансформации ---------------- | |
| def get_image_transform(min_size=256, max_size=512, step=64): | |
| def transform(img, dry_run=False): | |
| # Сохраняем исходные размеры изображения | |
| original_width, original_height = img.size | |
| # 0. Ресайз: масштабируем изображение, чтобы максимальная сторона была равна max_size | |
| if original_width >= original_height: | |
| new_width = max_size | |
| new_height = int(max_size * original_height / original_width) | |
| else: | |
| new_height = max_size | |
| new_width = int(max_size * original_width / original_height) | |
| if new_height < min_size or new_width < min_size: | |
| # 1. Ресайз: масштабируем изображение, чтобы минимальная сторона была равна min_size | |
| if original_width <= original_height: | |
| new_width = min_size | |
| new_height = int(min_size * original_height / original_width) | |
| else: | |
| new_height = min_size | |
| new_width = int(min_size * original_width / original_height) | |
| # 2. Проверка: если одна из сторон превышает max_size, готовимся к обрезке | |
| crop_width = min(max_size, (new_width // step) * step) | |
| crop_height = min(max_size, (new_height // step) * step) | |
| # Убеждаемся, что размеры обрезки не меньше min_size | |
| crop_width = max(min_size, crop_width) | |
| crop_height = max(min_size, crop_height) | |
| # Если запрошен только предварительный расчёт размеров | |
| if dry_run: | |
| return crop_width, crop_height | |
| # Конвертация в RGB и ресайз | |
| img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS) | |
| # Определение координат обрезки (обрезаем с учетом вотермарок - треть сверху) | |
| top = (new_height - crop_height) // 3 | |
| left = 0 | |
| # Обрезка изображения | |
| img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height)) | |
| # Сохраняем итоговые размеры после всех преобразований | |
| final_width, final_height = img_cropped.size | |
| # тензор | |
| img_tensor = ToTensor()(img_cropped) | |
| img_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img_tensor) | |
| return img_tensor, img_cropped, final_width, final_height | |
| return transform | |
| # ---------------- 4️⃣ Функции обработки ---------------- | |
| def last_token_pool(last_hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor) -> torch.Tensor: | |
| # Определяем, есть ли left padding | |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) | |
| if left_padding: | |
| return last_hidden_states[:, -1] | |
| else: | |
| sequence_lengths = attention_mask.sum(dim=1) - 1 | |
| batch_size = last_hidden_states.shape[0] | |
| return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] | |
| def encode_texts_batch(texts, tokenizer, model, device="cuda", max_length=150, normalize=False): | |
| with torch.inference_mode(): | |
| # Токенизация | |
| batch = tokenizer( | |
| texts, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=max_length | |
| ).to(device) | |
| # Прогон через модель | |
| #outputs = model(**batch) | |
| # Пулинг по last token | |
| #embeddings = last_token_pool(outputs.last_hidden_state, batch["attention_mask"]) | |
| # L2-нормализация (опционально, обычно нужна для семантического поиска) | |
| #if normalize: | |
| # embeddings = F.normalize(embeddings, p=2, dim=1) | |
| # Прогон через базовую модель (внутри CausalLM) | |
| outputs = model.model(**batch, output_hidden_states=True) | |
| # Берем последний слой (эмбеддинги всех токенов) | |
| hidden_states = outputs.hidden_states[-1] # [B, L, D] | |
| # Можно применить нормализацию по каждому токену (как в CLIP) | |
| if normalize: | |
| hidden_states = F.normalize(hidden_states, p=2, dim=-1) | |
| return hidden_states.cpu().numpy() # embeddings.unsqueeze(1).cpu().numpy() | |
| def clean_label(label): | |
| label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "").replace("The image depicts ","").replace("The image presents ","").replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip() | |
| if label.startswith("."): | |
| label = label[1:].lstrip() | |
| return label | |
| def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01): | |
| """ | |
| Обрабатывает список меток для classifier-free guidance. | |
| С вероятностью prob_to_make_empty: | |
| - Метка в первом списке заменяется на пустую строку. | |
| - К метке во втором списке добавляется префикс "zero:". | |
| В противном случае метки в обоих списках остаются оригинальными. | |
| """ | |
| labels_for_model = [] | |
| labels_for_logging = [] | |
| for label in original_labels: | |
| if random.random() < prob_to_make_empty: | |
| labels_for_model.append("") # Заменяем на пустую строку для модели | |
| labels_for_logging.append(f"zero: {label}") # Добавляем префикс для логгирования | |
| else: | |
| labels_for_model.append(label) # Оставляем оригинальную метку для модели | |
| labels_for_logging.append(label) # Оставляем оригинальную метку для логгирования | |
| return labels_for_model, labels_for_logging | |
| def encode_to_latents(images, texts): | |
| transform = get_image_transform(min_size, max_size, step) | |
| try: | |
| # Обработка изображений (все одинакового размера) | |
| transformed_tensors = [] | |
| pil_images = [] | |
| widths, heights = [], [] | |
| # Применяем трансформацию ко всем изображениям | |
| for img in images: | |
| try: | |
| t_img, pil_img, w, h = transform(img) | |
| transformed_tensors.append(t_img) | |
| pil_images.append(pil_img) | |
| widths.append(w) | |
| heights.append(h) | |
| except Exception as e: | |
| print(f"Ошибка трансформации: {e}") | |
| continue | |
| if not transformed_tensors: | |
| return None | |
| # Создаём батч | |
| batch_tensor = torch.stack(transformed_tensors).to(device, dtype) | |
| if batch_tensor.ndim==5: | |
| batch_tensor = batch_tensor.unsqueeze(2) # [B, C, 1, H, W] | |
| # Кодируем батч | |
| with torch.no_grad(): | |
| posteriors = vae.encode(batch_tensor).latent_dist.mode() | |
| latents = (posteriors - shift_factor) / scaling_factor | |
| latents_np = latents.to(dtype).cpu().numpy() | |
| # Обрабатываем тексты | |
| text_labels = [clean_label(text) for text in texts] | |
| model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share) | |
| #embeddings = encode_texts_batch(model_prompts, tokenizer, model) | |
| return { | |
| "vae": latents_np, | |
| #"embeddings": embeddings, | |
| "text": text_labels, | |
| "width": widths, | |
| "height": heights | |
| } | |
| except Exception as e: | |
| print(f"Критическая ошибка в encode_to_latents: {e}") | |
| raise | |
| # ---------------- 5️⃣ Обработка папки с изображениями и текстами ---------------- | |
| def process_folder(folder_path, limit=None): | |
| """ | |
| Рекурсивно обходит указанную директорию и все вложенные директории, | |
| собирая пути к изображениям и соответствующим текстовым файлам. | |
| """ | |
| image_paths = [] | |
| text_paths = [] | |
| width = [] | |
| height = [] | |
| transform = get_image_transform(min_size, max_size, step) | |
| # Используем os.walk для рекурсивного обхода директорий | |
| for root, dirs, files in os.walk(folder_path): | |
| for filename in files: | |
| # Проверяем, является ли файл изображением | |
| if filename.lower().endswith((".jpg", ".jpeg", ".png")): | |
| image_path = os.path.join(root, filename) | |
| try: | |
| img = Image.open(image_path) | |
| except Exception as e: | |
| print(f"Ошибка при открытии {image_path}: {e}") | |
| os.remove(image_path) | |
| text_path = os.path.splitext(image_path)[0] + ".txt" | |
| if os.path.exists(text_path): | |
| os.remove(text_path) | |
| continue | |
| # Применяем трансформацию только для получения размеров | |
| w, h = transform(img, dry_run=True) | |
| # Формируем путь к текстовому файлу | |
| text_path = os.path.splitext(image_path)[0] + ".txt" | |
| # Добавляем пути, если текстовый файл существует | |
| if os.path.exists(text_path) and min(w, h)>0: | |
| image_paths.append(image_path) | |
| text_paths.append(text_path) | |
| width.append(w) # Добавляем в список | |
| height.append(h) # Добавляем в список | |
| # Проверяем ограничение на количество | |
| if limit and limit>0 and len(image_paths) >= limit: | |
| print(f"Достигнут лимит в {limit} изображений") | |
| return image_paths, text_paths, width, height | |
| print(f"Найдено {len(image_paths)} изображений с текстовыми описаниями") | |
| return image_paths, text_paths, width, height | |
| def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=1): | |
| total_files = len(image_paths) | |
| start_time = time.time() | |
| chunks = range(0, total_files, chunk_size) | |
| for chunk_idx, start in enumerate(chunks, 1): | |
| end = min(start + chunk_size, total_files) | |
| chunk_image_paths = image_paths[start:end] | |
| chunk_text_paths = text_paths[start:end] | |
| chunk_widths = width[start:end] if isinstance(width, list) else [width] * len(chunk_image_paths) | |
| chunk_heights = height[start:end] if isinstance(height, list) else [height] * len(chunk_image_paths) | |
| # Чтение текстов | |
| chunk_texts = [] | |
| for text_path in chunk_text_paths: | |
| try: | |
| with open(text_path, 'r', encoding='utf-8') as f: | |
| text = f.read().strip() | |
| chunk_texts.append(text) | |
| except Exception as e: | |
| print(f"Ошибка чтения {text_path}: {e}") | |
| chunk_texts.append("") | |
| # Группируем изображения по размерам | |
| size_groups = {} | |
| for i in range(len(chunk_image_paths)): | |
| size_key = (chunk_widths[i], chunk_heights[i]) | |
| if size_key not in size_groups: | |
| size_groups[size_key] = {"image_paths": [], "texts": []} | |
| size_groups[size_key]["image_paths"].append(chunk_image_paths[i]) | |
| size_groups[size_key]["texts"].append(chunk_texts[i]) | |
| # Обрабатываем каждую группу размеров отдельно | |
| for size_key, group_data in size_groups.items(): | |
| print(f"Обработка группы с размером {size_key[0]}x{size_key[1]} - {len(group_data['image_paths'])} изображений") | |
| group_dataset = Dataset.from_dict({ | |
| "image_path": group_data["image_paths"], | |
| "text": group_data["texts"] | |
| }) | |
| # Теперь можно использовать указанный batch_size, т.к. все изображения одного размера | |
| processed_group = group_dataset.map( | |
| lambda examples: encode_to_latents( | |
| [Image.open(path) for path in examples["image_path"]], | |
| examples["text"] | |
| ), | |
| batched=True, | |
| batch_size=batch_size, | |
| #remove_columns=["image_path"], | |
| desc=f"Обработка группы размера {size_key[0]}x{size_key[1]}" | |
| ) | |
| # Сохраняем результаты группы | |
| group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}" | |
| processed_group.save_to_disk(group_save_path) | |
| clear_cuda_memory() | |
| elapsed = time.time() - start_time | |
| processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]]) | |
| if processed > 0: | |
| remaining = (elapsed / processed) * (total_files - processed) | |
| elapsed_str = str(timedelta(seconds=int(elapsed))) | |
| remaining_str = str(timedelta(seconds=int(remaining))) | |
| print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})") | |
| # ---------------- 7️⃣ Объединение чанков ---------------- | |
| def combine_chunks(temp_path, final_path): | |
| """Объединение обработанных чанков в финальный датасет""" | |
| chunks = sorted([ | |
| os.path.join(temp_path, d) | |
| for d in os.listdir(temp_path) | |
| if d.startswith("chunk_") | |
| ]) | |
| datasets = [load_from_disk(chunk) for chunk in chunks] | |
| combined = concatenate_datasets(datasets) | |
| combined.save_to_disk(final_path) | |
| print(f"✅ Датасет успешно сохранен в: {final_path}") | |
| # Создаем временную папку для чанков | |
| temp_path = f"{save_path}_temp" | |
| os.makedirs(temp_path, exist_ok=True) | |
| # Получаем список файлов | |
| image_paths, text_paths, width, height = process_folder(folder_path,limit) | |
| print(f"Всего найдено {len(image_paths)} изображений") | |
| # Обработка с чанкованием | |
| process_in_chunks(image_paths, text_paths, width, height, chunk_size=20000, batch_size=batch_size) | |
| # Удаление папки | |
| try: | |
| shutil.rmtree(folder_path) | |
| print(f"✅ Папка {folder_path} успешно удалена") | |
| except Exception as e: | |
| print(f"⚠️ Ошибка при удалении папки: {e}") | |
| # Объединение чанков в финальный датасет | |
| combine_chunks(temp_path, save_path) | |
| # Удаление временной папки | |
| try: | |
| shutil.rmtree(temp_path) | |
| print(f"✅ Временная папка {temp_path} успешно удалена") | |
| except Exception as e: | |
| print(f"⚠️ Ошибка при удалении временной папки: {e}") |