File size: 3,905 Bytes
dabebe4
 
 
 
 
 
 
 
 
 
 
 
 
 
ad3f97c
dabebe4
 
 
 
 
 
 
 
 
ad3f97c
 
 
 
dabebe4
 
 
 
 
 
 
ad3f97c
dabebe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de608a4
 
 
dabebe4
 
de608a4
 
 
 
92e3646
 
 
de608a4
 
 
9628cab
dabebe4
 
 
 
 
 
de608a4
dabebe4
 
 
 
 
 
 
 
 
 
9628cab
 
 
dabebe4
9628cab
dabebe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# /// script
# requires-python = ">=3.11"
# dependencies = [
#     "trl>=0.12.0",
#     "peft>=0.7.0",
#     "transformers>=4.45",
#     "datasets>=2.20",
#     "accelerate>=0.34",
#     "trackio",
#     "unsloth",
# ]
# ///
"""Phase-A LoRA SFT for the raunch page-mode model — runs inside HF Jobs.

Base: Sao10K/Llama-3.1-8B-Stheno-v3.4
Dataset: 4moha/raunch-page-mode-v0 (private)
Output: pushed to 4moha/raunch-stheno-v3.4-lora-v0

NSFW-only: training data is raunch's NSFW Claude-generated prose. The resulting
LoRA is deployed to the raunch server instance, NOT the SFW lili server.

This script is submitted as the body of the HF Job; it expects the env vars
HF_TOKEN, HF_DATASET_REPO, HF_MODEL_REPO to be set in the job environment.
"""
# Unsloth MUST be imported before transformers/trl/peft — its module-init patches
# don't apply otherwise and you get the "imported late" warning + degraded perf.
from unsloth import FastLanguageModel

import os

from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig


BASE_MODEL = "Sao10K/Llama-3.1-8B-Stheno-v3.4"
DATASET_REPO = os.environ.get("HF_DATASET_REPO", "4moha/raunch-page-mode-v0")
MODEL_REPO = os.environ.get("HF_MODEL_REPO", "4moha/raunch-stheno-v3.4-lora-v0")


def main() -> None:
    # Load model + tokenizer via Unsloth (faster + leaner than vanilla transformers)
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=BASE_MODEL,
        max_seq_length=4096,
        dtype=None,         # auto
        load_in_4bit=True,  # QLoRA — fits more comfortably on A10G
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r=16,
        lora_alpha=32,
        lora_dropout=0,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj"],
        use_gradient_checkpointing="unsloth",
        random_state=42,
    )

    # Load dataset, pre-render the chat template into a single "text" column,
    # then split. Avoids version-skew on TRL/Unsloth's formatting_func contracts —
    # SFTTrainer reads dataset_text_field="text" and tokenizes directly.
    full = load_dataset(DATASET_REPO, data_files="train.jsonl", split="train")

    def render_chat(example: dict) -> dict:
        return {
            "text": tokenizer.apply_chat_template(
                example["messages"],
                tokenize=False,
                add_generation_prompt=False,
            )
        }
    full = full.map(render_chat, remove_columns=["messages"])
    split = full.train_test_split(test_size=0.05, seed=42)

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=split["train"],
        eval_dataset=split["test"],
        args=SFTConfig(
            dataset_text_field="text",
            output_dir="raunch-stheno-v3.4-lora-v0",
            push_to_hub=True,
            hub_model_id=MODEL_REPO,
            hub_private_repo=True,
            hub_strategy="every_save",
            num_train_epochs=3,
            per_device_train_batch_size=1,
            gradient_accumulation_steps=8,
            learning_rate=5e-5,
            lr_scheduler_type="cosine",
            # warmup_ratio is deprecated in TRL 5.x — express as concrete steps instead.
            # ~22 steps/epoch × 3 epochs = ~65 steps; 5% warmup = ~3 steps.
            warmup_steps=3,
            max_length=4096,
            logging_steps=5,
            save_strategy="steps",
            save_steps=200,
            eval_strategy="steps",
            eval_steps=50,
            seed=42,
            report_to="trackio",
            run_name="raunch-stheno-v3.4-lora-v0",
            project="raunch-page-mode",
        ),
    )

    trainer.train()
    trainer.push_to_hub()
    print("Training complete. LoRA pushed to:", MODEL_REPO)


if __name__ == "__main__":
    main()