Home/Blog/LLM Fine-Tuning: A Practical Guide to Customizing ...
AI⭐ Featured

LLM Fine-Tuning: A Practical Guide to Customizing AI Models for Your Needs

Learn how to fine-tune Large Language Models for specific tasks. Covers data preparation, training strategies, evaluation, and deployment of custom AI models.

Sani Mridha

Sani Mridha

Senior Mobile Developer

📅 2024-01-05⏱️ 18 min read
🧠

LLM Fine-Tuning: A Practical Guide to Customizing AI Models for Your Needs

Fine-tuning Large Language Models (LLMs) allows you to adapt powerful base models to your specific use case. Let's explore how to do it effectively.

Why Fine-Tune?

When to Fine-Tune

Fine-tune when you need:

  • Consistent output format or style
  • Domain-specific knowledge
  • Improved performance on specific tasks
  • Cost reduction (smaller fine-tuned models can replace larger ones)
  • Data privacy (keep sensitive data in training, not prompts)
  • Don't fine-tune when:

  • Prompt engineering works well enough
  • You have insufficient training data (<100 examples)
  • The task is too general
  • You need real-time information (use RAG instead)
  • Fine-Tuning Methods

    1. Full Fine-Tuning

    Update all model parameters. Most effective but resource-intensive.

    from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
    
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
    
    training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-5,
        fp16=True,
        logging_steps=10,
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )
    
    trainer.train()

    2. LoRA (Low-Rank Adaptation)

    Efficient fine-tuning by training small adapter layers.

    from peft import LoraConfig, get_peft_model
    
    lora_config = LoraConfig(
        r=16,  # Rank
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )
    
    model = get_peft_model(model, lora_config)
    
    # Only 0.1% of parameters are trainable!
    model.print_trainable_parameters()
    # Output: trainable params: 4,194,304 || all params: 6,738,415,616 || trainable%: 0.062

    3. QLoRA (Quantized LoRA)

    LoRA with 4-bit quantization for even more efficiency.

    from transformers import BitsAndBytesConfig
    
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-hf",
        quantization_config=bnb_config,
        device_map="auto"
    )
    
    model = get_peft_model(model, lora_config)

    Data Preparation

    Quality Over Quantity

    Good training data is crucial:

    # Example: Customer support fine-tuning data
    training_examples = [
        {
            "instruction": "Respond to customer inquiry about shipping",
            "input": "When will my order arrive?",
            "output": "I'd be happy to check your order status. Could you please provide your order number? Once I have that, I can give you an accurate delivery estimate."
        },
        {
            "instruction": "Respond to customer inquiry about returns",
            "input": "I want to return this product",
            "output": "I understand you'd like to return your purchase. Our return policy allows returns within 30 days. To start the process, I'll need your order number and the reason for the return."
        },
        # ... more examples
    ]

    Data Format

    Convert to model-specific format:

    def format_instruction(example):
        prompt = f"""### Instruction:
    {example['instruction']}
    
    ### Input:
    {example['input']}
    
    ### Response:
    {example['output']}"""
        return {"text": prompt}
    
    formatted_dataset = dataset.map(format_instruction)

    Data Quality Checks

    def validate_dataset(dataset):
        issues = []
        
        for idx, example in enumerate(dataset):
            # Check length
            if len(example['text']) < 50:
                issues.append(f"Example {idx}: Too short")
            
            # Check for duplicates
            if example['text'] in seen_texts:
                issues.append(f"Example {idx}: Duplicate")
            seen_texts.add(example['text'])
            
            # Check format
            if "### Response:" not in example['text']:
                issues.append(f"Example {idx}: Missing response section")
        
        return issues

    Training Process

    Step 1: Setup

    from transformers import AutoTokenizer, AutoModelForCausalLM
    from datasets import load_dataset
    
    # Load model and tokenizer
    model_name = "meta-llama/Llama-2-7b-hf"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        load_in_8bit=True,
        device_map="auto"
    )
    
    # Load and prepare dataset
    dataset = load_dataset("json", data_files="training_data.jsonl")
    dataset = dataset.map(lambda x: tokenizer(x["text"], truncation=True, max_length=512))

    Step 2: Configure Training

    from transformers import TrainingArguments
    
    training_args = TrainingArguments(
        output_dir="./llama-2-7b-finetuned",
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        fp16=True,
        save_total_limit=3,
        logging_steps=10,
        evaluation_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        save_steps=100,
        load_best_model_at_end=True,
        report_to="wandb",  # For tracking
    )

    Step 3: Train

    from transformers import Trainer
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
    )
    
    # Start training
    trainer.train()
    
    # Save the model
    trainer.save_model("./final_model")

    Hyperparameter Tuning

    Learning Rate

    Critical parameter - too high causes instability, too low is inefficient:

    # Learning rate finder
    from torch.optim.lr_scheduler import OneCycleLR
    
    def find_lr(model, train_loader, optimizer):
        lr_finder = LRFinder(model, optimizer, criterion)
        lr_finder.range_test(train_loader, end_lr=1, num_iter=100)
        lr_finder.plot()
        lr_finder.reset()

    Batch Size and Gradient Accumulation

    Balance memory and training speed:

    # Effective batch size = per_device_batch_size * gradient_accumulation_steps * num_gpus
    # Example: 4 * 4 * 1 = 16 effective batch size
    
    training_args = TrainingArguments(
        per_device_train_batch_size=4,  # Fits in memory
        gradient_accumulation_steps=4,   # Accumulate gradients
        # Effective batch size: 16
    )

    LoRA Parameters

    # Experiment with different configurations
    lora_configs = [
        {"r": 8, "lora_alpha": 16},   # Smaller, faster
        {"r": 16, "lora_alpha": 32},  # Balanced
        {"r": 32, "lora_alpha": 64},  # Larger, more capacity
    ]
    
    for config in lora_configs:
        model = train_with_config(config)
        evaluate(model)

    Evaluation

    Quantitative Metrics

    from evaluate import load
    
    # Perplexity
    perplexity = load("perplexity")
    results = perplexity.compute(predictions=predictions, model_id=model_name)
    
    # BLEU score (for translation/generation)
    bleu = load("bleu")
    results = bleu.compute(predictions=predictions, references=references)
    
    # ROUGE score (for summarization)
    rouge = load("rouge")
    results = rouge.compute(predictions=predictions, references=references)

    Qualitative Evaluation

    def evaluate_model_responses(model, test_cases):
        results = []
        
        for case in test_cases:
            prompt = format_prompt(case["input"])
            response = generate_response(model, prompt)
            
            evaluation = {
                "input": case["input"],
                "expected": case["expected"],
                "actual": response,
                "scores": {
                    "relevance": rate_relevance(response, case["expected"]),
                    "accuracy": rate_accuracy(response, case["expected"]),
                    "style": rate_style(response),
                }
            }
            results.append(evaluation)
        
        return results

    A/B Testing

    def ab_test(base_model, finetuned_model, test_set):
        base_scores = []
        finetuned_scores = []
        
        for example in test_set:
            base_response = base_model.generate(example["prompt"])
            finetuned_response = finetuned_model.generate(example["prompt"])
            
            # Human evaluation or automated scoring
            base_score = evaluate_response(base_response, example["expected"])
            finetuned_score = evaluate_response(finetuned_response, example["expected"])
            
            base_scores.append(base_score)
            finetuned_scores.append(finetuned_score)
        
        print(f"Base model avg: {np.mean(base_scores)}")
        print(f"Fine-tuned model avg: {np.mean(finetuned_scores)}")
        print(f"Improvement: {np.mean(finetuned_scores) - np.mean(base_scores)}")

    Production Deployment

    Model Optimization

    # Merge LoRA weights for faster inference
    from peft import PeftModel
    
    base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
    finetuned_model = PeftModel.from_pretrained(base_model, "./lora_weights")
    
    # Merge and save
    merged_model = finetuned_model.merge_and_unload()
    merged_model.save_pretrained("./merged_model")

    Inference Optimization

    # Use vLLM for fast inference
    from vllm import LLM, SamplingParams
    
    llm = LLM(model="./merged_model", tensor_parallel_size=1)
    
    sampling_params = SamplingParams(
        temperature=0.7,
        top_p=0.95,
        max_tokens=256
    )
    
    def generate_batch(prompts):
        outputs = llm.generate(prompts, sampling_params)
        return [output.outputs[0].text for output in outputs]

    API Deployment

    from fastapi import FastAPI
    from pydantic import BaseModel
    
    app = FastAPI()
    
    class GenerationRequest(BaseModel):
        prompt: str
        max_tokens: int = 256
        temperature: float = 0.7
    
    @app.post("/generate")
    async def generate(request: GenerationRequest):
        response = llm.generate(
            request.prompt,
            SamplingParams(
                max_tokens=request.max_tokens,
                temperature=request.temperature
            )
        )
        return {"response": response[0].outputs[0].text}

    Cost Optimization

    Training Costs

    | Method | GPU Memory | Training Time | Cost (AWS p3.2xlarge) |

    |--------|-----------|---------------|----------------------|

    | Full Fine-tuning | 32GB+ | 24 hours | ~$75 |

    | LoRA | 16GB | 8 hours | ~$25 |

    | QLoRA | 8GB | 10 hours | ~$15 |

    Inference Costs

    # Use quantization for cheaper inference
    from transformers import AutoModelForCausalLM
    
    model = AutoModelForCausalLM.from_pretrained(
        "./merged_model",
        load_in_8bit=True,  # 4x memory reduction
        device_map="auto"
    )
    
    # Or use ONNX for even faster inference
    from optimum.onnxruntime import ORTModelForCausalLM
    
    ort_model = ORTModelForCausalLM.from_pretrained(
        "./merged_model",
        export=True,
        provider="CUDAExecutionProvider"
    )

    Real-World Examples

    Example 1: Code Generation

    # Training data format
    {
        "instruction": "Write a Python function",
        "input": "Create a function that calculates fibonacci numbers",
        "output": """def fibonacci(n):
        if n <= 1:
            return n
        return fibonacci(n-1) + fibonacci(n-2)"""
    }

    Example 2: Customer Service

    # Training data format
    {
        "instruction": "Respond professionally to customer",
        "input": "This product is broken!",
        "output": "I sincerely apologize for the inconvenience. I'd like to help resolve this immediately. Could you please describe the issue in detail so I can assist you better?"
    }

    Example 3: Medical Documentation

    # Training data format
    {
        "instruction": "Generate medical note",
        "input": "Patient: fever, cough, 3 days",
        "output": "Chief Complaint: Fever and cough for 3 days. Patient reports onset of symptoms 3 days ago with persistent fever and productive cough. Recommend further evaluation and symptomatic treatment."
    }

    Common Pitfalls

    1. Overfitting

    Problem: Model memorizes training data

    Solution: Use validation set, early stopping, regularization

    training_args = TrainingArguments(
        # ... other args
        evaluation_strategy="steps",
        eval_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
    )

    2. Catastrophic Forgetting

    Problem: Model forgets general knowledge

    Solution: Mix general data with specific data

    # 80% specific data, 20% general data
    combined_dataset = concatenate_datasets([
        specific_dataset.select(range(8000)),
        general_dataset.select(range(2000))
    ])

    3. Poor Data Quality

    Problem: Inconsistent or incorrect training data

    Solution: Rigorous data validation and cleaning

    Conclusion

    Fine-tuning LLMs is powerful when done right:

    1. Start with clear objectives

    2. Prepare high-quality data

    3. Choose the right method (LoRA for most cases)

    4. Evaluate thoroughly

    5. Optimize for production

    Remember: Fine-tuning is not always the answer. Try prompt engineering and RAG first!

    ---

    *Have questions about fine-tuning? Let's discuss your use case!*

    Tags

    #AI#LLM#Fine-tuning#Machine Learning#Deep Learning

    Share this article

    Let's Work Together

    Need help with your mobile app or have a project in mind?

    Sani Mridha - Senior React Native Developer | iOS & Android Expert