How to fine-tune the Gemma model with LoRA

Gemma is a state-of-the-art deep learning library for natural language processing (NLP) tasks. It offers a range of pretrained models, tools, and utilities to facilitate NLP research and development.

Gemma is freely available and open-source, allowing researchers, developers, and practitioners to leverage its capabilities for various NLP applications. It works with frameworks used for deep learning, such as TensorFlow, PyTorch, and Keras, offering flexibility and interoperability.

Gemma seamlessly integrates with Kaggle, a renowned platform for data science competitions and collaborative projects. Leveraging the Kaggle API, users can effortlessly incorporate Gemma into their Kaggle kernels and notebooks. This integration empowers users to directly harness Gemma's pretrained models, training workflows, and evaluation methodologies.

Low-Rank Adaptation (LoRA)

Low-Rank Adaptation (LoRA) is a technique used in machine learning and optimization to efficiently adapt a pretrained model to a new task or dataset with fewer computational resources.

Imagine we have a large pretrained model that performs well on a specific task or dataset. However, retraining the entire model from scratch for a new task can be computationally expensive and time-consuming. LoRA addresses this challenge by decomposing the pretrained model into a low-rank approximation, which captures the most important information or patterns in the original model while reducing its complexity.

In summary, LoRA is a technique that leverages low-rank approximation to efficiently adapt pretrained models to new tasks or datasets with reduced computational complexity, enabling faster training and resource-efficient machine learning.

Let's fine-tune the Gemma model using LoRA!

Use your Kaggle API

To leverage the model from Kaggle, let's use its API to authenticate our account. Go to Kaggle, and then click "Settings". There will be an option to create an API token. Click on "Create New Token" and this will download a .json file, which will contain your credentials.

Create new token on Kaggle
Create new token on Kaggle

After obtaining your credentials, go to your Google Colab notebook and change the runtime type to T4 GPU. Using a T4 GPU for running machine learning models, such as the Gemma model, offers a balanced solution regarding performance, cost-effectiveness, availability, and compatibility. T4 GPUs are designed for high-performance computing tasks and provide significant computational power with support for NVIDIA Tensor Cores, enabling faster training and inference.

After this, let's connect our Kaggle account with our Google Colab using the following steps:

  1. Create two environment variables in the Secrets column in Google Colab (it is the tab with a little key logo). Click on the column and make the following environment variables. The value of the environment variables will be the values of "key" and "username" in your .json file, respectively.

Environment variables required to connect colab with Kaggle account.
Environment variables required to connect colab with Kaggle account.

After configuring this, write the following commands in the notebook:

import os
from google.colab import userdata
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Python code to connect our notebook with our Kaggle API

These commands are used to configure environment variables KAGGLE_USERNAME and KAGGLE_KEY with the credentials required to access the Kaggle API. This allows the user to authenticate with the Kaggle API and perform various tasks leveraging Kaggle's functionalities.

Install the dependencies

To load and fine-tune Gemma with LoRA, we need to install some dependencies first.

We need tensorflow >= 2.15 and keras >= 3. Use the following command to install the necessary dependencies:

!pip install -q -U --upgrade tensorflow
!pip install -q -U keras >=3
!pip install -q -U keras-nlp

We can check if the correct version was installed using:

!pip show tensorflow
!pip show keras

Lastly, let's import the necessary libraries:

import tensorflow as tf
import keras
import keras_nlp

Configure the backend

There are multiple backends that keras uses such as jax, tensorflow and torch. In our code, and according to the documentation of Gemma, we will be using jax as our backend. jax leverages just-in-time (JIT) compilation and XLA (Accelerated Linear Algebra) to efficiently execute computations, resulting in faster training and inference times compared to other backends.

To configure jax , we will use the following command:

os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Furthermore, we have also configured another environment variable XLA_PYTHON_CLIENT_MEM_FRACTION to 1.00. Setting this environment variable ensures that jax uses a certain fraction (in this case, 100%) of available memory for its Python client, which can help avoid memory fragmentation and improve performance when using the jax backend.

Load the model

Now our dependencies are installed and our configurations are complete. It's time to load the model to see how it works. To load the model, write the following commands:

gemma_model = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_model.summary()

Note: You can gain access to the Gemma dataset by accepting the license agreement through Kaggle.

This should give us a result like this:

The usage of the code is explained as follows:

  • Line 1: We are creating an instance of the GemmaCausalLM model using the from_preset method from the Keras NLP library. The "gemma_2b_en" argument specifies the preset configuration of the GemmaCausalLM model, indicating a specific pretrained configuration or architecture to use.

  • Line 2: Next, we print a summary of the GemmaCausalLM model’s architecture. The summary() method is a utility function provided by Keras models, which prints a concise summary of the model’s layers, their output shapes, and the number of parameters in each layer.

Generate outputs from the model

The Gemma model can be used to generate various types of outputs and to understand its functionality, let's generate a couple of outputs.

  1. Simple text generation: We will ask the model a simple question and evaluate its answer. We will use the generate() function to generate an answer to our query. Furthermore, we will provide the maximum length we want our answer to be.

gemma_model.generate("What is the meaning of life?", max_length=64)

This command gave us the following result:

Answer generated by our model
Answer generated by our model

As a result, we got an answer to our question within the length of time we had provided. The answer is coherent and cohesive and an adequate answer to the question.

  1. Using prompt engineering: We will ask the model to generate an answer based on a template and a prompt. Let's write a template and a prompt for our answer:

template = "Instruction: {instruction}\nResponse: {response}"
prompt = template.format(
instruction="What should I do on a trip to Europe?",
response="", )

Code explanation

The explanation of the code above is as follows:

  • Line 1: We define a template string with placeholders for an instruction and a response. The {instruction} and {response} placeholders will be filled in later with specific values.

  • Lines 2–4: We create a prompt string by formatting the template with specific values. The instruction placeholder is filled with the instruction "What should I do on a trip to Europe?", and the response placeholder is left empty for now.

Next, let's compile our model with the prompt given above:

sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_model.compile(sampler=sampler)
print(gemma_model.generate(prompt, max_length=256))

Let's break the code step-by-step:

  • Line 1: We are creating a TopKSampler object with a k value of 5 (indicating that during text generation, the model will consider the top 5 most likely tokens for each step) and a seed value of 2 (for reproducibility).

  • Line 2: Next, we are compiling the Gemma model for text generation using the specified sampler. The Gemma model is configured to use the TopKSampler during text generation.

  • Line 3: Using the Gemma model, we generate a text sequence based on the provided prompt. The generate method takes the prompt string, the maximum length of the generated sequence (256 tokens in this case), and other generation parameters as input. The generated text sequence is then printed.

The result of the above code is as follows:

Answer generated by our model
Answer generated by our model

The result is now generated according to our template, and the model provides a comprehensive answer to our question.

Fine-tune Gemma using LoRA

Now, it's time to fine-tune Gemma using LoRA, and our backend configuration will stay the same. First, let's download a dataset to fine-tune our model.

Load and preprocess the dataset

We will be using the databricks/dolly-v2-12b dataset on Hugging Face. The dataset contains approximately 15 thousand instruction/response records. However, for simplicity of training, we will use only 1000 examples.

Let's download the dataset using the following command:

!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

After downloading the dataset, we need to preprocess it. We will filter out examples that contain context so it is easier to fine-tune our model. Next, we will format the entire example into one string.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
for line in file:
features = json.loads(line)
if features["context"]:
continue
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
data.append(template.format(**features))
data = data[:1000]

This code reads data from a file in JSON Lines format (.jsonl), filters out examples with context, and formats the remaining examples into a specific template. Let’s break it down step-by-step:

  • Lines 1–2: We are importing the Python json module, which provides functions for encoding and decoding JSON data. Next, we are initializing an empty list called data to store formatted examples.

  • Lines 3–5: First, we open the specified file named databricks-dolly-15k.jsonl in read mode using a context manager. The file is expected to be in JSON Lines format, where each line represents a single JSON object. Next, we are iterating over each line in the opened file. Lastly, we load the JSON data from the current line of the file and parse it into a Python dictionary named features using the json.loads() function.

  • Lines 6–7: We are checking if the "context" key exists in the features dictionary. If it does (i.e., if the context is not empty), the code continues to the next iteration of the loop, skipping this example, thereby removing the examples with the context key.

  • Lines 8–9: We are defining a template string with placeholders for the instruction and response parts of each example. Next, we format the current example using the template string and the features dictionary, and append the formatted string to the data list. The **features syntax unpacks the dictionary and passes its key-value pairs as keyword arguments to the format() method.

  • Line 11: Lastly, we are limiting the data list to contain only the first 1000 training examples. This is done to reduce the size of the dataset and speed up processing.

Our dataset is now ready and we can begin fine-tuning our model.

Training Gemma model on the dataset

First, let's enable LoRA on our model using the following commands:

gemma_model.backbone.enable_lora(rank=4)
gemma_model.summary()

The output of the above command is given below:

Summary of the model with LoRA enabled
Summary of the model with LoRA enabled

The trainable parameters have considerably decreased after enabling LoRA, as they went from 9.34 GB to 5.20 MB. The advantage of LoRA is already visible.

Let's start our training:

gemma_model.preprocessor.sequence_length = 64
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],)
gemma_model.fit(data, epochs=1, batch_size=1)

Let's break the implementation down step-by-step:

  • Line 1: We set the maximum input sequence length to 64 tokens. By limiting the input sequence length, the code aims to control memory usage during training. This can be important when dealing with large datasets or limited computational resources.

  • Lines 3–5: We first define an AdamW optimizer with a learning rate of 5e-5 and a weight decay of 0.01. AdamW is a variant of the Adam optimizer that includes weight decay regularization to prevent overfitting in transformer models.

  • Line 7: We specify a rule that the weight decay regularization should not be applied to the parameters named bias and scale. In transformer models, these parameters typically correspond to the bias terms and layer normalization scale parameters, respectively. Excluding them from weight decay can help stabilize training and prevent certain parameters from becoming overly penalized during optimization.

  • Lines 9–12: Here, we compile the Gemma model for training. It specifies the loss function, optimizer, and evaluation metrics for training. In this case, the loss function is SparseCategoricalCrossentropy, the optimizer is the AdamW optimizer defined earlier, and the evaluation metric is SparseCategoricalAccuracy.

  • Line 14: We are training the Gemma model on the provided data for one epoch using a batch_size of 1. Training the model involves iteratively updating its parameters (weights) based on the input data and the specified loss function and optimizer. The number of epochs determines how many times the entire dataset is processed during training. Using a batch_size of 1 means that each training example is processed individually, which can be inefficient but is sometimes necessary for models with memory constraints or specific training requirements.

The result of the above code is as follows, which shows us the loss and the accuracy of the model:

Result of training our model on the dataset
Result of training our model on the dataset

Since our dataset was small and we trained our model for just one epoch, our results were expected to be poor. However, a 56% accuracy is not bad for the first epoch.

Let's make a prediction using the same example as above ("What should I do on a trip to Europe?") and see the difference in the results:

Result of our query on a fine-tuned model
Result of our query on a fine-tuned model

Our result is concise and to the point now. The answer is conversational, whereas before, it was very robotic. Since we reduced the max_length to 64, our answer is shortened, but it is still very to the point.

Now you have an idea of how Gemma works and the impact that fine-tuning has on the results of our prompts. You can implement it on your own now, and change up the prompts to better understand it!

Free Resources

Copyright ©2025 Educative, Inc. All rights reserved