{ "cells": [ { "cell_type": "markdown", "id": "538f25ce", "metadata": {}, "source": [ "\n", "이 노트북은 OpenAI의 **gpt-oss (open‑weight)** 모델을 **한국 뉴스 문체 + 최신 대화체**로 세밀 튜닝하는 방법을\n", "한국어/영어 **이중 언어**로 제공합니다. \n", "This notebook shows how to fine‑tune OpenAI's **gpt-oss (open‑weight)** models for **Korean news style + modern chat tone**, in **Korean & English**.\n", "\n", "---\n", "\n", "### MXFP4 workflow clarifications · MXFP4 워크플로 정리\n", "\n", "**EN:** \n", "- Training or fine-tuning **directly in MXFP4 is not supported** by public frameworks today. \n", "- Recommended path: train in **BF16** (or **QLoRA 4‑bit nf4**) → **merge LoRA** → **post‑training quantize to MXFP4** → `save_pretrained()` for deployment. \n", "- If you need an MXFP4 artifact, you must **re‑quantize from BF16** after merging adapters. (Export utilities are evolving; if your toolchain already supports MXFP4 serialization, that’s ideal.)\n", "\n", "**KR:** \n", "- 현재 공개 프레임워크에서는 **MXFP4로 직접 학습/파인튜닝**이 지원되지 않습니다. \n", "- 권장 경로: **BF16**(또는 **QLoRA 4‑bit nf4**)로 학습 → **LoRA 병합** → **사후(MXFP4) 양자화** → 배포용으로 `save_pretrained()` 저장. \n", "- MXFP4 아티팩트가 필요하면, 어댑터 병합 후 **BF16 → MXFP4 재양자화**가 필요합니다. (직렬화 유틸은 진화 중이며, 툴체인에서 MXFP4 저장을 지원하면 가장 좋습니다.)\n", "\n", "---\n", "\n", "### LoRA targets (MoE) · LoRA 타깃(MoE 포함)\n", "\n", "**EN:** \n", "- Minimal config (fast, low VRAM): target attention only, e.g. `[\"q_proj\",\"v_proj\"]`. \n", "- MoE‑aware config (better domain adaptation, more VRAM/time): include **expert projection layers** in addition to attention. \n", "\n", "```python\n", "from peft import LoraConfig\n", "\n", "TARGET_MODULES = [\"q_proj\", \"v_proj\"] # baseline\n", "MOE_TARGET_PARAMETERS = [\n", " # example expert layers; adjust indices to your model depth\n", " \"mlp.experts.gate_up_proj\",\n", " \"mlp.experts.down_proj\",\n", "]\n", "\n", "lora_cfg = LoraConfig(\n", " r=16, lora_alpha=32, lora_dropout=0.05,\n", " target_modules=\"all-linear\", # cover all linear layers\n", " target_parameters=MOE_TARGET_PARAMETERS, # add expert projections\n", " bias=\"none\", task_type=\"CAUSAL_LM\",\n", ")\n", "```\n", "\n", "- Start with attention‑only; if KR domain fit is insufficient, enable MoE targets and re‑eval.\n", "\n", "**KR:** \n", "- 최소 구성(빠르고 VRAM 절약): `[\"q_proj\",\"v_proj\"]` 등 **어텐션만** 적용. \n", "- **MoE 인지 구성**(도메인 적합성↑, 자원 소모↑): 어텐션에 **전문가(Expert) 투영 레이어**를 추가로 포함. \n", "- 먼저 어텐션만으로 시도한 뒤, 한국어 도메인 적합성이 부족하면 MoE 타깃을 켜고 재평가하세요." ] }, { "cell_type": "markdown", "id": "bd7c12ff", "metadata": {}, "source": [ "## Contents · 목차\n", "0) Goals & Scope · 목표 & 범위 \n", "1) Environment check · 환경 점검 \n", "2) 설정값 · Config \n", "3) 패키지 설치 · Install Deps \n", "4) 데이터 소싱(한국형) · KR‑Context Data Sourcing \n", "5) 샘플 데이터 생성 · Create Sample Data \n", "6) 전처리(PIPA) & 스타일 라벨 · PII Scrubbing & Style Tags \n", "7) 데이터 로딩/포맷팅 · Load & Format \n", "8) 모델/토크나이저 로드 · Load Model & Tokenizer \n", "9) Fine‑Tuning (LoRA/QLoRA) · 세밀 튜닝 \n", " 9a) Data curation & splits \n", " 9b) Hyperparameters (r/alpha/dropout) \n", " 9c) Merge adapters (BF16) \n", " 9d) Save merged BF16 (`save_pretrained`) \n", " 9e) Export & Quantize (BF16 → MXFP4) · 내보내기 & 양자화 \n", "10) 평가(뉴스/대화) · Evaluation (News/Chat) \n", "11) Inference Prompt Templates · 추론 프롬프트 템플릿 \n", "12) 최신성 유지 · Freshness Strategy \n", "13) 안전/컴플라이언스 · Safety & Compliance \n", "14) 문제해결 & 다음 단계 · Troubleshooting & Next Steps\n" ] }, { "cell_type": "markdown", "id": "bb8655d2", "metadata": {}, "source": [ "### ⚙️ Training vs Quantization — What’s supported\n", "- **Do:** Train with BF16/FP16 or QLoRA; export merged weights.\n", "- **Then:** Quantize to **MXFP4** for inference using provided conversion scripts/utilities.\n", "- **Don’t:** Attempt to run an end‑to‑end “train in MXFP4” pipeline — not supported today." ] }, { "cell_type": "markdown", "id": "bb24a3d9", "metadata": {}, "source": [ "> **PII & Compliance Reminder:** For KR data, follow your enterprise policy (mask RRN/phone/account IDs, remove emails) **before** training & logging. Keep train/val/test splits stratified by source and style tags." ] }, { "cell_type": "markdown", "id": "e1e883f5", "metadata": {}, "source": [ "### 🧪 MoE adapters (optional)\n", "You can target MoE layers with adapters, but treat this as **advanced/experimental**. Start with attention projections first and validate KR benchmarks before expanding scope." ] }, { "cell_type": "markdown", "id": "179543e6", "metadata": {}, "source": [ "> **Note:** Keep `transformers`, `peft`, `accelerate`, and `trl` at versions known to support BF16/4‑bit LoRA. \n", "If you pin `safetensors`, remember that **native MXFP4 serialization is not yet standardized**; loaders may upcast internally." ] }, { "cell_type": "markdown", "id": "f8e743f0", "metadata": {}, "source": [ "### 🔎 Support Matrix — At a glance\n", "- **Fine‑tuning precision:** BF16/FP16 ✅ · QLoRA 4‑bit ✅ · **MXFP4 FT ❌**\n", "- **Quantization target:** MXFP4 ✅ (post‑training)\n", "- **API FT (hosted) for OSS models:** ❌\n", "- **Open‑source FT (Transformers/TRL/PEFT):** ✅\n", "- **LoRA targets:** `q_proj`, `k_proj`, `v_proj`, `o_proj` ✅; MoE expert adapters **experimental** ⚠️" ] }, { "cell_type": "markdown", "id": "f4dec1f6", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "e3d489c2", "metadata": {}, "source": [ "## 0) Goals & Scope · 목표 & 범위\n", "- **KR**: 한국어 일반 뉴스 + 일상/상담 대화체에 최적화. `style=news_headline|news_lead|news_body|kakao_casual|kakao_formal` 제어.\n", "- **EN**: Optimize for Korean news writing and modern chat tone; control output via style tags above.\n", "- **Stack**: `transformers`, `trl(SFTTrainer)`, `peft(LoRA/QLoRA)`, `datasets`.\n", "- **Hardware**: Single/few GPUs (BF16 preferred). CPU/Mac for lightweight tests." ] }, { "cell_type": "markdown", "id": "db97218d", "metadata": {}, "source": [ "## 1) Environment check · 환경 점검" ] }, { "cell_type": "code", "execution_count": 9, "id": "5babb2c3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python: 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0]\n", "OS/Platform: Linux-6.8.0-60-generic-x86_64-with-glibc2.35\n", "CUDA_VISIBLE_DEVICES: \n", "Torch: 2.7.1+cu126 CUDA: True\n", "GPU: NVIDIA H100 80GB HBM3\n" ] } ], "source": [ "import os, sys, platform\n", "print(\"Python:\", sys.version)\n", "print(\"OS/Platform:\", platform.platform())\n", "print(\"CUDA_VISIBLE_DEVICES:\", os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"\"))\n", "\n", "try:\n", " import torch\n", " print(\"Torch:\", torch.__version__, \"CUDA:\", torch.cuda.is_available())\n", " if torch.cuda.is_available():\n", " print(\"GPU:\", torch.cuda.get_device_name(0))\n", "except Exception as e:\n", " print(\"Torch not installed or GPU not detected:\", e)" ] }, { "cell_type": "markdown", "id": "25688688", "metadata": {}, "source": [ "## 2) 설정값 · Config" ] }, { "cell_type": "code", "execution_count": 10, "id": "c15817f7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Config ready.\n" ] } ], "source": [ "from pathlib import Path\n", "import os\n", "\n", "# === Model & Training Params ===\n", "BASE_URL = \"http://localhost:8000/v1\" # vLLM OpenAI-compatible endpoint\n", "API_KEY = \"dummy-key\" # vLLM ignores; SDK requires a value\n", "MODEL = \"openai/gpt-oss-120b\" # must match the model vLLM loaded\n", "OUTPUT_DIR = \"ft-oss-kr-news-chat-bilingual\"\n", "\n", "# Data mix (news : chat)\n", "MIX_NEWS = 0.6\n", "MIX_CHAT = 0.4\n", "\n", "# LoRA\n", "LORA_R = 8\n", "LORA_ALPHA = 16\n", "LORA_DROPOUT = 0.05\n", "TARGET_MODULES = [\"q_proj\", \"v_proj\"] # adjust per model\n", "\n", "# Training\n", "EPOCHS = 1\n", "PER_DEVICE_BS = 2\n", "GRAD_ACCUM = 8\n", "LEARNING_RATE = 2e-4\n", "BF16 = True\n", "LOG_STEPS = 20\n", "SAVE_STEPS = 200\n", "SAVE_TOTAL_LIMIT = 2\n", "\n", "print(\"Config ready.\")" ] }, { "cell_type": "markdown", "id": "85f258eb", "metadata": {}, "source": [ "## 3) 패키지 설치 · Install Deps" ] }, { "cell_type": "code", "execution_count": 11, "id": "b1b75968", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "transformers: 4.55.3\n", "accelerate: 1.10.0\n", "datasets: 4.0.0\n", "peft: not installed\n", "trl: 0.21.0\n", "bitsandbytes: not installed\n", "sentencepiece: 0.2.1\n", "vllm: 0.10.1\n", "llama_cpp: 0.3.16\n", "pip: 25.2\n", "Install cells are commented. Un-comment in your environment.\n" ] } ], "source": [ "# %pip install --upgrade pip\n", "# %pip install transformers accelerate datasets peft trl bitsandbytes sentencepiece\n", "# (optional) serving/runtimes\n", "# %pip install vllm\n", "# %pip install llama-cpp-python\n", "\n", "import importlib, pip\n", "\n", "for dep in [\"transformers\",\"accelerate\",\"datasets\",\"peft\",\"trl\",\n", " \"bitsandbytes\",\"sentencepiece\",\"vllm\",\"llama_cpp\"]:\n", " try:\n", " print(f\"{dep}: {importlib.import_module(dep).__version__}\")\n", " except Exception:\n", " print(f\"{dep}: not installed\")\n", "\n", "print(f\"pip: {pip.__version__}\")\n", "\n", "print(\"Install cells are commented. Un-comment in your environment.\")" ] }, { "cell_type": "markdown", "id": "de8647fd", "metadata": {}, "source": [ "## 4) 데이터 소싱(한국형) · KR‑Context Data Sourcing" ] }, { "cell_type": "markdown", "id": "da22cbd6", "metadata": {}, "source": [ "**KR** \n", "- 공개 벤치마크(주제 분류/요약/QA) + **허용된 뉴스 API의 메타데이터(제목/요약/섹션)** 중심으로 스타일 보정.\n", "- 기사 **원문 대량 재학습은 저작권/약관 이슈** → 메타데이터·공개 코퍼스 위주.\n", "- 대화체는 합법 공개 코퍼스(반말/존댓말/이모티콘/축약어 라벨 포함) 우선.\n", "- PIPA: 주민번호/연락처/이메일/계좌 등 개인정보는 **훈련 전/로그 전** 스크러빙.\n", "\n", "**EN** \n", "- Prefer public KR benchmarks (topic classification / summarization / QA) and **allowed news API metadata** for style calibration.\n", "- Avoid mass training on news full texts due to license/ToS constraints; use metadata + open corpora.\n", "- For chat, use lawful open corpora with tone/emoji/informal‑formal annotations.\n", "- Scrub PII (phone, RRNs, emails, accounts) before training/logging." ] }, { "cell_type": "markdown", "id": "9b918411", "metadata": {}, "source": [ "## 5) 샘플 데이터 생성 · Create Sample Data" ] }, { "cell_type": "code", "execution_count": 12, "id": "18db10a6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Created: data/news.jsonl, data/chat.jsonl\n" ] } ], "source": [ "import json, pathlib\n", "pathlib.Path(\"data\").mkdir(exist_ok=True)\n", "\n", "news_samples = [\n", " {\"style\":\"news_lead\",\"topic\":\"경제\",\"title\":\"반도체 수출 호조… 7월 수출액 20% 증가\",\"summary\":\"수출 개선세가 이어지며 경기 회복 기대가 커졌다.\"},\n", " {\"style\":\"news_headline\",\"topic\":\"정치\",\"title\":\"국회, 데이터 산업 육성법 본회의 통과\",\"summary\":\"데이터 활용 촉진과 개인정보 보호를 강화하는 내용.\"},\n", " {\n", " \"style\": \"news_lead\",\n", " \"topic\": \"경제\",\n", " \"title\": \"카카오페이 보안 점검… 고객문의: help+vip@corp.co.kr\",\n", " \"summary\": \"고객센터 010-1234-5678로 문의 폭주. 계좌 110-123-456789 관련 결제 오류 논란.\"\n", " },\n", " {\n", " \"style\": \"news_headline\",\n", " \"topic\": \"사회\",\n", " \"title\": \"개인정보 유출 의혹… 주민번호 901010-1234567 유통 주장\",\n", " \"summary\": \"서울특별시 강남구 테헤란로 123에서 자료 확보… 담당자 john.doe+news@example.com\"\n", " }\n", "]\n", "\n", "chat_samples = [\n", " {\"style\":\"kakao_casual\",\"dialog\":[\"주말에 비 온대?\",\"응 일요일에 꽤 온다더라 ☔\",\"헐 우산 챙겨야겠다\"]},\n", " {\"style\":\"kakao_formal\",\"dialog\":[\"안녕하세요. 배송 일정 확인 부탁드립니다.\",\"내일 중 도착 예정입니다.\",\"안내 감사합니다.\"]},\n", " {\n", " \"style\": \"kakao_formal\",\n", " \"dialog\": [\n", " \"배송 확인 부탁드립니다. 주문번호 ORD-2025-0001 입니다.\",\n", " \"연락처는 010-2222-3333 입니다. (유니코드 하이픈)\",\n", " \"주민등록번호는 제공할 수 없습니다.\"\n", " ]\n", " }\n", "]\n", "\n", "with open(\"data/news.jsonl\",\"w\",encoding=\"utf-8\") as f:\n", " for ex in news_samples: f.write(json.dumps(ex, ensure_ascii=False)+\"\\n\")\n", "with open(\"data/chat.jsonl\",\"w\",encoding=\"utf-8\") as f:\n", " for ex in chat_samples: f.write(json.dumps(ex, ensure_ascii=False)+\"\\n\")\n", "\n", "print(\"Created: data/news.jsonl, data/chat.jsonl\")" ] }, { "cell_type": "markdown", "id": "4f1eaa27", "metadata": {}, "source": [ "## 6) 전처리(PIPA) & 스타일 라벨 · PII Scrubbing & Style Tags" ] }, { "cell_type": "code", "execution_count": 13, "id": "430c1b68", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data/news.jsonl -> data/news_clean.jsonl | rows: 4, redacted_rows: 2, hits: {'[EMAIL]': 2, '[ACCOUNT]': 1, '[RRN]': 1, '[CITY]': 1}\n", "data/chat.jsonl -> data/chat_clean.jsonl | rows: 3, redacted_rows: 1, hits: {'[PHONE]': 1}\n" ] } ], "source": [ "# Step 6 — PII scrubbing + style tags (no Harmony here)\n", "import json, re, unicodedata\n", "from pathlib import Path\n", "\n", "# --- Normalization helpers ---\n", "HYPHENS = dict.fromkeys(map(ord, \"‐-‒–—―﹘﹣-\"), ord(\"-\")) # map unicode hyphens → ASCII\n", "def normalize(s: str) -> str:\n", " if not isinstance(s, str): return s\n", " s = unicodedata.normalize(\"NFKC\", s)\n", " s = s.translate(HYPHENS)\n", " return s\n", "\n", "# --- PII patterns (illustrative; tune for production) ---\n", "RE_EMAIL = re.compile(r\"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}\")\n", "# KR mobile numbers with spaces/hyphens: 010-1234-5678, 010 1234 5678, etc.\n", "RE_PHONE = re.compile(r\"\\b01[016789][-\\s]?\\d{3,4}[-\\s]?\\d{4}\\b\")\n", "# Korean RRN (주민등록번호) basic pattern\n", "RE_RRN = re.compile(r\"\\b\\d{6}-\\d{7}\\b\")\n", "# Bank-ish account numbers: strictly digits in groups (avoid codes with letters)\n", "RE_ACCOUNT = re.compile(r\"\\b\\d{2,3}-\\d{2,4}-\\d{3,6}\\b\")\n", "# Very simple postal address cue (city names) – conservative, just redact the token (optional)\n", "RE_CITY = re.compile(r\"(서울특별시|부산광역시|대구광역시|인천광역시|광주광역시|대전광역시|울산광역시|세종특별자치시|경기도|강원도|충청북도|충청남도|전라북도|전라남도|경상북도|경상남도|제주특별자치도)\")\n", "\n", "# Allowlist: things that look like PII but aren’t (e.g., bill/order codes w/ letters)\n", "def looks_like_code(s: str) -> bool:\n", " return bool(re.search(r\"[A-Za-z]\", s)) # if letters present, treat as code, not account/phone\n", "\n", "# Order of application matters (longest/most specific first sometimes helps)\n", "SCRUBBERS = [\n", " (\"[RRN]\", RE_RRN),\n", " (\"[EMAIL]\", RE_EMAIL),\n", " (\"[PHONE]\", RE_PHONE),\n", " (\"[ACCOUNT]\", RE_ACCOUNT),\n", " (\"[CITY]\", RE_CITY), # optional; comment out if you don't want to redact city tokens\n", "]\n", "\n", "def scrub_text(text: str) -> tuple[str, dict]:\n", " \"\"\"Return (scrubbed_text, hits_dict). Avoid false positives with basic allowlisting.\"\"\"\n", " if not isinstance(text, str) or not text:\n", " return text, {}\n", " orig = text\n", " text = normalize(text)\n", " hits = {}\n", "\n", " # Guard account-like and phone-like strings that contain letters (likely codes)\n", " guarded = set()\n", " for m in RE_ACCOUNT.finditer(text):\n", " if looks_like_code(m.group(0)):\n", " guarded.add(m.span())\n", " for m in RE_PHONE.finditer(text):\n", " if looks_like_code(m.group(0)):\n", " guarded.add(m.span())\n", "\n", " # Apply scrubs\n", " for label, pattern in SCRUBBERS:\n", " out = []\n", " last = 0\n", " count = 0\n", " for m in pattern.finditer(text):\n", " span = m.span()\n", " if pattern in (RE_ACCOUNT, RE_PHONE) and span in guarded:\n", " continue\n", " out.append(text[last:span[0]])\n", " out.append(label)\n", " last = span[1]\n", " count += 1\n", " out.append(text[last:])\n", " text = \"\".join(out)\n", " if count:\n", " hits[label] = hits.get(label, 0) + count\n", "\n", " return text, hits if text != orig else {}\n", "\n", "def scrub_record(rec: dict, kind: str) -> tuple[dict, dict]:\n", " \"\"\"Scrub fields in a news/chat record; return (new_rec, hits).\"\"\"\n", " rec = dict(rec) # shallow copy\n", " total_hits = {}\n", "\n", " def scrub_field(key):\n", " val = rec.get(key)\n", " new, hits = scrub_text(val) if isinstance(val, str) else (val, {})\n", " rec[key] = new\n", " for k, v in hits.items():\n", " total_hits[k] = total_hits.get(k, 0) + v\n", "\n", " if kind == \"news\":\n", " for key in (\"title\", \"summary\", \"topic\"):\n", " scrub_field(key)\n", " elif kind == \"chat\":\n", " scrub_field(\"style\")\n", " if isinstance(rec.get(\"dialog\"), list):\n", " cleaned_dialog = []\n", " for turn in rec[\"dialog\"]:\n", " new, hits = scrub_text(turn) if isinstance(turn, str) else (turn, {})\n", " cleaned_dialog.append(new)\n", " for k, v in hits.items():\n", " total_hits[k] = total_hits.get(k, 0) + v\n", " rec[\"dialog\"] = cleaned_dialog\n", "\n", " return rec, total_hits\n", "\n", "# --- Style tagger (lightweight labels for later routing/metrics) ---\n", "def build_style_tags(rec: dict, kind: str) -> list[str]:\n", " tags = []\n", " if kind == \"news\":\n", " tags.append(\"domain:\" + (rec.get(\"topic\") or \"unknown\"))\n", " tags.append(\"style:\" + (rec.get(\"style\") or \"news\"))\n", " tags.append(\"tone:formal\")\n", " tags.append(\"medium:news\")\n", " elif kind == \"chat\":\n", " style = (rec.get(\"style\") or \"\").lower()\n", " tags.append(\"style:\" + (style or \"chat\"))\n", " tags.append(\"tone:\" + (\"formal\" if \"formal\" in style else \"casual\"))\n", " tags.append(\"medium:kakao\")\n", " return [t.replace(\" \", \"_\") for t in tags]\n", "\n", "# --- Process files ---\n", "def process_file(src: str, dst: str, kind: str):\n", " total = 0\n", " redacted = 0\n", " counters = {}\n", " with open(src, encoding=\"utf-8\") as fin, open(dst, \"w\", encoding=\"utf-8\") as fout:\n", " for line in fin:\n", " if not line.strip(): continue\n", " rec = json.loads(line)\n", " total += 1\n", " cleaned, hits = scrub_record(rec, kind)\n", " cleaned[\"style_tags\"] = build_style_tags(cleaned, kind)\n", " cleaned[\"_pii_hits\"] = hits # keep for inspection; drop later if you want\n", " if hits: redacted += 1\n", " for k, v in hits.items():\n", " counters[k] = counters.get(k, 0) + v\n", " fout.write(json.dumps(cleaned, ensure_ascii=False) + \"\\n\")\n", " print(f\"{src} -> {dst} | rows: {total}, redacted_rows: {redacted}, hits: {counters}\")\n", "\n", "process_file(\"data/news.jsonl\", \"data/news_clean.jsonl\", kind=\"news\")\n", "process_file(\"data/chat.jsonl\", \"data/chat_clean.jsonl\", kind=\"chat\")" ] }, { "cell_type": "markdown", "id": "6ac01dca", "metadata": {}, "source": [ "## 7) 데이터 로딩/포맷팅 · Load & Format" ] }, { "cell_type": "code", "execution_count": 15, "id": "9cd825e3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Created: data/news_harmony.jsonl data/chat_harmony.jsonl\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6f769d524f424ed5a11781a157cfa796", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating news split: 0 examples [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "af2e4dc971884747a719d500caf52722", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating chat split: 0 examples [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'train': 3, 'validation': 4}\n" ] } ], "source": [ "# Step 7 — Harmony conversion + dataset loading & tokenization\n", "import json, math\n", "from pathlib import Path\n", "from datasets import load_dataset, Dataset, concatenate_datasets\n", "from transformers import AutoTokenizer\n", "\n", "DATA = Path(\"data\")\n", "assert (DATA / \"news_clean.jsonl\").exists(), \"Run Step 6 first\"\n", "assert (DATA / \"chat_clean.jsonl\").exists(), \"Run Step 6 first\"\n", "\n", "# ---------- 7A) Convert cleaned → Harmony messages ----------\n", "\n", "def news_to_messages(rec):\n", " # system style from Step 6 tags; default to KR news tone\n", " system = \"한국 뉴스 문체로 간결하고 사실 위주로 작성.\"\n", " # user asks for a headline+lead from topic; assistant is the expected formatted answer\n", " user = f\"주제: {rec.get('topic','알수없음')}. 기사 제목과 요약을 생성해줘.\"\n", " assistant = f\"{rec.get('title','')} — {rec.get('summary','')}\"\n", " return [{\"role\":\"system\",\"content\":system},\n", " {\"role\":\"user\",\"content\":user},\n", " {\"role\":\"assistant\",\"content\":assistant}]\n", "\n", "def chat_to_messages(rec):\n", " # Keep style hint (casual/formal) in system\n", " style = (rec.get(\"style\") or \"\").lower()\n", " system = f\"카카오톡 대화 스타일. style={style or 'chat'}\"\n", " dialog = rec.get(\"dialog\") or []\n", " msgs = [{\"role\":\"system\",\"content\":system}]\n", " # Alternate user/assistant turns; if odd length, last user stays without assistant label\n", " roles = [\"user\",\"assistant\"]\n", " for i, turn in enumerate(dialog[:6]): # cap tiny demos to avoid runaway\n", " msgs.append({\"role\": roles[i % 2], \"content\": str(turn)})\n", " # Ensure there is at least one assistant turn for SFT\n", " if not any(m[\"role\"]==\"assistant\" for m in msgs):\n", " msgs.append({\"role\":\"assistant\",\"content\":\"네, 확인했습니다.\"})\n", " return msgs\n", "\n", "def write_harmony(src, dst, kind):\n", " convert = news_to_messages if kind==\"news\" else chat_to_messages\n", " with open(src, encoding=\"utf-8\") as fin, open(dst, \"w\", encoding=\"utf-8\") as fout:\n", " for line in fin:\n", " if not line.strip(): continue\n", " rec = json.loads(line)\n", " msgs = convert(rec)\n", " fout.write(json.dumps({\"messages\": msgs}, ensure_ascii=False) + \"\\n\")\n", "\n", "write_harmony(DATA/\"news_clean.jsonl\", DATA/\"news_harmony.jsonl\", \"news\")\n", "write_harmony(DATA/\"chat_clean.jsonl\", DATA/\"chat_harmony.jsonl\", \"chat\")\n", "print(\"Created:\", DATA/\"news_harmony.jsonl\", DATA/\"chat_harmony.jsonl\")\n", "\n", "# ---------- 7B) Load Harmony JSONL with 🤗 Datasets ----------\n", "raw = load_dataset(\n", " \"json\",\n", " data_files={\"news\": str(DATA/\"news_harmony.jsonl\"),\n", " \"chat\": str(DATA/\"chat_harmony.jsonl\")}\n", ")\n", "\n", "# Mix train split using your Step-2 mix ratios\n", "news = raw[\"news\"]\n", "chat = raw[\"chat\"]\n", "\n", "def take_portion(ds, frac):\n", " n = max(1, int(round(len(ds) * frac)))\n", " return ds.select(range(n)) if n < len(ds) else ds\n", "\n", "news_part = take_portion(news, MIX_NEWS if 'MIX_NEWS' in globals() else 0.5)\n", "chat_part = take_portion(chat, MIX_CHAT if 'MIX_CHAT' in globals() else 0.5)\n", "train_ds = concatenate_datasets([news_part, chat_part]).shuffle(seed=42)\n", "\n", "# Tiny validation built from remaining examples (if any)\n", "remaining_news = news.select(range(len(news_part), len(news))) if len(news) > len(news_part) else news_part\n", "remaining_chat = chat.select(range(len(chat_part), len(chat))) if len(chat) > len(chat_part) else chat_part\n", "val_candidates = concatenate_datasets([remaining_news, remaining_chat])\n", "val_ds = val_candidates.shuffle(seed=43).select(range(min(64, len(val_candidates)))) if len(val_candidates) else train_ds.select(range(min(32, len(train_ds))))\n", "\n", "dataset = {\"train\": train_ds, \"validation\": val_ds}\n", "print({k: len(v) for k, v in dataset.items()})\n" ] }, { "cell_type": "markdown", "id": "c95c9122", "metadata": {}, "source": [ "## 8) 모델/토크나이저 로드 · Load Model & Tokenizer" ] }, { "cell_type": "code", "execution_count": 16, "id": "db67b6b3", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1cfc411479e145e4b5b161df311d4b13", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ebea3ddd62e340cc83e2a484a04e3e89", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "330fd60c5e1248998f0f5bc8c394b2ce", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer.json: 0%| | 0.00/27.9M [00:00\n", "{{ m['content'] }}<|end|>\n", "{%- elif m['role'] == 'user' -%}<|user|>\n", "{{ m['content'] }}<|end|>\n", "{%- elif m['role'] == 'assistant' -%}<|assistant|>\n", "{{ m['content'] }}<|end|>\n", "{%- endif -%}\n", "{%- endfor -%}\"\"\"\n", "\n", "# Ensure pad/eos are sane\n", "tokenizer.pad_token = tokenizer.eos_token or tokenizer.pad_token\n", "\n", "# ---------- 7D) Tokenize with assistant-only labels ----------\n", "ASST_TOKEN = None\n", "END_TOKEN = None\n", "try:\n", " ASST_TOKEN = tokenizer.convert_tokens_to_ids(\"<|assistant|>\")\n", " END_TOKEN = tokenizer.convert_tokens_to_ids(\"<|end|>\")\n", "except Exception:\n", " # If the base vocab lacks these tokens, it's okay; masking fallback below will still work heuristically\n", " pass\n", "\n", "MAX_LEN = 2048 # you can raise this if you have room\n", "\n", "def tokenize_with_labels(example):\n", " # 1) Render with chat template (includes assistant answer)\n", " text = tokenizer.apply_chat_template(example[\"messages\"], tokenize=False, add_generation_prompt=False)\n", " # 2) Tokenize\n", " enc = tokenizer(text, truncation=True, max_length=MAX_LEN)\n", " input_ids = enc[\"input_ids\"]\n", " labels = [-100] * len(input_ids)\n", "\n", " # 3) Label only assistant content\n", " if ASST_TOKEN is not None and END_TOKEN is not None:\n", " start = None\n", " for i, tid in enumerate(input_ids):\n", " if tid == ASST_TOKEN:\n", " start = i + 1 # learn after the tag\n", " elif start is not None and tid == END_TOKEN:\n", " start = None\n", " elif start is not None:\n", " labels[i] = input_ids[i]\n", " else:\n", " # Heuristic fallback: learn on the last third of tokens (crude but avoids total silence)\n", " start = int(len(input_ids) * 0.66)\n", " for i in range(start, len(input_ids)):\n", " labels[i] = input_ids[i]\n", "\n", " return {\"input_ids\": input_ids, \"attention_mask\": enc[\"attention_mask\"], \"labels\": labels}\n", "\n", "tokenized_train = dataset[\"train\"].map(tokenize_with_labels, remove_columns=[\"messages\"])\n", "tokenized_val = dataset[\"validation\"].map(tokenize_with_labels, remove_columns=[\"messages\"])\n", "\n", "print(\"Tokenization done.\",\n", " \"train:\", len(tokenized_train),\n", " \"val:\", len(tokenized_val),\n", " \"example lens:\", tokenized_train[0][\"input_ids\"][:12], \"...\")" ] }, { "cell_type": "markdown", "id": "f67dd4ef", "metadata": {}, "source": [ "## 9) Fine‑Tuning (LoRA/QLoRA) · 세밀 튜닝\n", "### 9a) Data curation & splits\n", "_(See Section 7/8 for dataset prep; move relevant snippets here if needed.)_\n", "### 9b) Hyperparameters (r/alpha/dropout)\n", "```python\n", "# Example LoRA hyperparameters\n", "LORA_R = 8\n", "LORA_ALPHA = 16\n", "LORA_DROPOUT = 0.05\n", "```\n", "\n", "### 9c) Merge adapters (BF16)\n", "```python\n", "# Example merge step (after training)\n", "# model = PeftModel.from_pretrained(base_model, adapter_path)\n", "# merged_model = model.merge_and_unload()\n", "```\n", "\n", "### 9d) Save merged BF16 (`save_pretrained`)\n", "```python\n", "# merged_model.save_pretrained(OUTPUT_DIR)\n", "```\n" ] }, { "cell_type": "markdown", "id": "c9157315", "metadata": {}, "source": [ "### 9e) Export & Quantize (BF16 → MXFP4) · 내보내기 & 양자화\n", "\n", "**EN (neutral, framework-agnostic):** \n", "Public libraries currently do **not** support training/fine‑tuning *directly* in MXFP4. The common pipeline is:\n", "1) **Train/SFT** in **BF16** (or **QLoRA 4‑bit nf4**). \n", "2) **Merge LoRA adapters** into the base model (BF16). \n", "3) **Save** the merged BF16 checkpoint with `save_pretrained()`. \n", "4) **Post‑training quantize** the merged BF16 tensors to **MXFP4** using a **vendor/toolchain‑provided packer**. \n", "5) **Save/export** the MXFP4 artifact (same shape as Hugging Face `save_pretrained()` output) for deployment/serving.\n", "\n", "> Notes: \n", "> - If your serving stack supports **LoRA at inference**, you may skip merging and quantization and ship: **base (MXFP4 or BF16) + LoRA adapters**. \n", "> - If your runtime requires **merged MXFP4**, you must run a **BF16 → MXFP4** quantization step after merging adapters. \n", "> - Keep **tokenizer/config** files aligned across BF16 and MXFP4 exports.\n", "\n", "**KR (중립적, 도구 비의존):** \n", "현재 공개 라이브러리는 MXFP4에서 **직접 학습/파인튜닝을 지원하지 않습니다**. 일반적인 파이프라인은 다음과 같습니다: \n", "1) **BF16**(또는 **QLoRA 4‑bit nf4**)로 **학습/파인튜닝** \n", "2) **LoRA 어댑터 병합**(BF16 기준) \n", "3) `save_pretrained()`로 **병합된 BF16 체크포인트 저장** \n", "4) 벤더/툴체인에서 제공하는 **양자화 도구**로 **BF16 → MXFP4 사후 양자화** \n", "5) 배포/서빙용 **MXFP4 아티팩트 저장/내보내기** (Hugging Face `save_pretrained()` 구조와 동일)\n", "\n", "> 참고: \n", "> - **서빙에서 LoRA를 지원**한다면, 병합·양자화를 생략하고 **기저( MXFP4 또는 BF16 ) + LoRA 어댑터**로 제공할 수 있습니다. \n", "> - **병합된 MXFP4**가 필요한 런타임의 경우, 어댑터 병합 후 **BF16 → MXFP4 재양자화** 단계가 필요합니다. \n", "> - **tokenizer/config** 파일은 BF16과 MXFP4 아티팩트 간에 일관되게 유지하세요.\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "48a5cbc9", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fine‑tuning skeleton ready. Un‑comment on your machine.\n" ] } ], "source": [ "from trl import SFTTrainer, SFTConfig\n", "from peft import LoraConfig, get_peft_model\n", "\n", "lora_cfg = LoraConfig(\n", " task_type=\"CAUSAL_LM\",\n", " r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,\n", " target_modules=TARGET_MODULES\n", ")\n", "\n", "# base_model = get_peft_model(base_model, lora_cfg)\n", "\n", "sft_args = SFTConfig(\n", " output_dir=OUTPUT_DIR,\n", " num_train_epochs=EPOCHS,\n", " per_device_train_batch_size=PER_DEVICE_BS,\n", " gradient_accumulation_steps=GRAD_ACCUM,\n", " learning_rate=LEARNING_RATE,\n", " lr_scheduler_type=\"cosine\",\n", " bf16=BF16,\n", " logging_steps=LOG_STEPS,\n", " save_steps=SAVE_STEPS,\n", " save_total_limit=SAVE_TOTAL_LIMIT\n", ")\n", "\n", "# trainer = SFTTrainer(model=base_model, args=sft_args, train_dataset=combined, tokenizer=tokenizer)\n", "# trainer.train()\n", "# trainer.save_model(OUTPUT_DIR)\n", "print(\"Fine‑tuning skeleton ready. Un‑comment on your machine.\")" ] }, { "cell_type": "markdown", "id": "490798f2", "metadata": {}, "source": [ "## 10) 평가(뉴스/대화) · Evaluation (News/Chat)" ] }, { "cell_type": "markdown", "id": "d1bdafe4", "metadata": {}, "source": [ "**KR 지표 · KR Metrics** \n", "- 뉴스성: 주제 분류 적합도(F1), 요약 품질(ROUGE‑1/2/L), 독해 QA(EM/F1). \n", "- 대화성: 자연성/맥락 유지, 경어/반말 전환 정확도, 이모티콘/축약어 적절성.\n", "\n", "**EN Notes** \n", "- Use public KR benchmarks (e.g., topic classification, KorQuAD‑like QA) where licenses permit.\n", "- Mix automatic metrics (F1/ROUGE) with human eval for tone & politeness." ] }, { "cell_type": "code", "execution_count": 20, "id": "971b8dbd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Eval stubs ready.\n" ] } ], "source": [ "# Example helpers (stub)\n", "def simple_accuracy(preds, labels):\n", " return sum(int(p==g) for p,g in zip(preds, labels)) / max(1, len(labels))\n", "\n", "# For ROUGE:\n", "# import evaluate\n", "# rouge = evaluate.load(\"rouge\")\n", "# result = rouge.compute(predictions=pred_texts, references=ref_texts)\n", "# print(result)\n", "\n", "print(\"Eval stubs ready.\")" ] }, { "cell_type": "markdown", "id": "e0b5594e", "metadata": {}, "source": [ "## 11) Inference Prompt Templates · 추론 프롬프트 템플릿" ] }, { "cell_type": "code", "execution_count": 25, "id": "1f690452", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\n", "Knowledge cutoff: 2024-06\n", "Current date: 2025-08-21\n", "\n", "Reasoning: medium\n", "\n", "# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions\n", "\n", "너는 한국 고객을 돕는 유능한 AI 어시스턴트다.\n", "\n", "<|end|><|start|>user<|message|>국내 PIPA 규정을 준수하면서 사내 문서 요약기를 구성하려면 어떤 아키텍처가 좋을까?<|end|><|start|>assistant\n" ] } ], "source": [ "from openai_harmony import Message, ChatFormatter\n", "\n", "# Example prompt construction using Harmony\n", "messages = [\n", " Message(role=\"system\", content=\"너는 한국 고객을 돕는 유능한 AI 어시스턴트다.\"),\n", " Message(role=\"user\", content=\"국내 PIPA 규정을 준수하면서 사내 문서 요약기를 구성하려면 어떤 아키텍처가 좋을까?\")\n", "]\n", "\n", "prompt = ChatFormatter.to_chat_prompt(messages)\n", "print(prompt) # For preview; pass to tokenizer when running inference\n" ] }, { "cell_type": "markdown", "id": "5216d049", "metadata": {}, "source": [ "## 12) 최신성 유지 · Freshness Strategy" ] }, { "cell_type": "markdown", "id": "452decd1", "metadata": {}, "source": [ "- **주간 보정 SFT**: 허용된 뉴스 API **메타데이터(제목/요약/섹션)** 샘플링 → 스타일 보정. \n", "- **대화체 업데이트**: 최신 축약어/신조어/이모티콘 사전 반영(예: ㄱㄱ, ㅇㅋ, ㅋㅋ, ㄹㅇ). \n", "- **회귀 평가**: 동일 지표로 before/after 비교 → 혼합비/온도/패널티 튜닝.\n", "\n", "- Weekly calibration SFT using **allowed news API metadata** for style; \n", "- Update slang/emoji lexicons; \n", "- Regression evals to track drift and adjust data mix/decoding." ] }, { "cell_type": "markdown", "id": "718b9f2a", "metadata": {}, "source": [ "## 13) 안전/컴플라이언스 · Safety & Compliance" ] }, { "cell_type": "markdown", "id": "61ad24ef", "metadata": {}, "source": [ "- 데이터 출처/라이선스 확인(벤치마크, API, 내부 데이터) · Verify dataset/API licenses.\n", "- 개인정보 스크러빙(훈련/로그/평가 전) · Scrub PII before training/logging/eval.\n", "- 저작권/약관 준수(기사 **원문 대량 재학습 금지**) · Avoid mass training on full news articles.\n", "- 출력 검증(스키마/금칙어/민감도 규칙) · Output validation & forbidden‑term filters.\n", "- 버전/평가 리포트 관리 · Version datasets/models and keep eval reports." ] }, { "cell_type": "markdown", "id": "5cb8464b", "metadata": {}, "source": [ "## 14) 문제해결 & 다음 단계 · Troubleshooting & Next Steps" ] }, { "cell_type": "markdown", "id": "8ee17077", "metadata": {}, "source": [ "- 혼합 비율 튜닝: (뉴스:대화) 6:4 → 7:3 또는 5:5로 조정 \n", "- LoRA 하이퍼파라미터: r=8~16, α=16~32, dropout=0.05~0.1 \n", "- 서비스화: vLLM/llama.cpp 서빙 + 토픽/스타일 라우팅 \n", "- RAG 결합: 최신 사실성 보강을 위해 뉴스/문서 인덱스 결합 \n", "- A/B 테스트: 톤/길이/이모티콘 사용량 등 사용자 만족도 측정\n", "\n", "- Tune mix ratios, run A/B tests, consider vLLM serving, and pair with RAG for factuality." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }