Gemma is a state-of-the-art deep learning library for natural language processing (NLP) tasks. It offers a range of pretrained models, tools, and utilities to facilitate NLP research and development.
Gemma is freely available and open-source, allowing researchers, developers, and practitioners to leverage its capabilities for various NLP applications. It works with frameworks used for deep learning, such as TensorFlow, PyTorch, and Keras, offering flexibility and interoperability.
Gemma seamlessly integrates with Kaggle, a renowned platform for data science competitions and collaborative projects. Leveraging the Kaggle API, users can effortlessly incorporate Gemma into their Kaggle kernels and notebooks. This integration empowers users to directly harness Gemma's pretrained models, training workflows, and evaluation methodologies.
Low-Rank Adaptation (LoRA) is a technique used in machine learning and optimization to efficiently adapt a pretrained model to a new task or dataset with fewer computational resources.
Imagine we have a large pretrained model that performs well on a specific task or dataset. However, retraining the entire model from scratch for a new task can be computationally expensive and time-consuming. LoRA addresses this challenge by decomposing the pretrained model into a low-rank approximation, which captures the most important information or patterns in the original model while reducing its complexity.
In summary, LoRA is a technique that leverages low-rank approximation to efficiently adapt pretrained models to new tasks or datasets with reduced computational complexity, enabling faster training and resource-efficient machine learning.
Let's fine-tune the Gemma model using LoRA!
To leverage the model from Kaggle, let's use its API to authenticate our account. Go to Kaggle, and then click "Settings". There will be an option to create an API token. Click on "Create New Token" and this will download a .json
file, which will contain your credentials.
After obtaining your credentials, go to your Google Colab notebook and change the runtime type to T4 GPU. Using a T4 GPU for running machine learning models, such as the Gemma model, offers a balanced solution regarding performance, cost-effectiveness, availability, and compatibility. T4 GPUs are designed for high-performance computing tasks and provide significant computational power with support for NVIDIA Tensor Cores, enabling faster training and inference.
After this, let's connect our Kaggle account with our Google Colab using the following steps:
Create two environment variables in the Secrets column in Google Colab (it is the tab with a little key logo). Click on the column and make the following environment variables. The value of the environment variables will be the values of "key" and "username" in your .json
file, respectively.
After configuring this, write the following commands in the notebook:
import osfrom google.colab import userdataos.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
These commands are used to configure environment variables KAGGLE_USERNAME
and KAGGLE_KEY
with the credentials required to access the Kaggle API. This allows the user to authenticate with the Kaggle API and perform various tasks leveraging Kaggle's functionalities.
To load and fine-tune Gemma with LoRA, we need to install some dependencies first.
We need tensorflow >= 2.15
and keras >= 3
. Use the following command to install the necessary dependencies:
!pip install -q -U --upgrade tensorflow!pip install -q -U keras >=3!pip install -q -U keras-nlp
We can check if the correct version was installed using:
!pip show tensorflow!pip show keras
Lastly, let's import the necessary libraries:
import tensorflow as tfimport kerasimport keras_nlp
There are multiple backends that keras
uses such as jax
, tensorflow
and torch
. In our code, and according to the documentation of Gemma, we will be using jax
as our backend. jax
leverages just-in-time (JIT) compilation and XLA (Accelerated Linear Algebra) to efficiently execute computations, resulting in faster training and inference times compared to other backends.
To configure jax
, we will use the following command:
os.environ["KERAS_BACKEND"] = "jax"os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
Furthermore, we have also configured another environment variable XLA_PYTHON_CLIENT_MEM_FRACTION
to 1.00
. Setting this environment variable ensures that jax
uses a certain fraction (in this case, 100%) of available memory for its Python client, which can help avoid memory fragmentation and improve performance when using the jax
backend.
Now our dependencies are installed and our configurations are complete. It's time to load the model to see how it works. To load the model, write the following commands:
gemma_model = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")gemma_model.summary()
Note: You can gain access to the Gemma dataset by accepting the license agreement through Kaggle.
This should give us a result like this:
The usage of the code is explained as follows:
Line 1: We are creating an instance of the GemmaCausalLM
model using the from_preset
method from the Keras NLP library. The "gemma_2b_en"
argument specifies the preset configuration of the GemmaCausalLM
model, indicating a specific pretrained configuration or architecture to use.
Line 2: Next, we print a summary of the GemmaCausalLM
model’s architecture. The summary()
method is a utility function provided by Keras models, which prints a concise summary of the model’s layers, their output shapes, and the number of parameters in each layer.
The Gemma model can be used to generate various types of outputs and to understand its functionality, let's generate a couple of outputs.
Simple text generation: We will ask the model a simple question and evaluate its answer. We will use the generate()
function to generate an answer to our query. Furthermore, we will provide the maximum length we want our answer to be.
gemma_model.generate("What is the meaning of life?", max_length=64)
This command gave us the following result:
As a result, we got an answer to our question within the length of time we had provided. The answer is coherent and cohesive and an adequate answer to the question.
Using prompt engineering: We will ask the model to generate an answer based on a template and a prompt. Let's write a template and a prompt for our answer:
template = "Instruction: {instruction}\nResponse: {response}"prompt = template.format(instruction="What should I do on a trip to Europe?",response="", )
The explanation of the code above is as follows:
Line 1: We define a template string with placeholders for an instruction and a response. The {instruction}
and {response}
placeholders will be filled in later with specific values.
Lines 2–4: We create a prompt string by formatting the template with specific values. The instruction
placeholder is filled with the instruction "What should I do on a trip to Europe?", and the response
placeholder is left empty for now.
Next, let's compile our model with the prompt given above:
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)gemma_model.compile(sampler=sampler)print(gemma_model.generate(prompt, max_length=256))
Let's break the code step-by-step:
Line 1: We are creating a TopKSampler object with a k
value of 5 (indicating that during text generation, the model will consider the top 5 most likely tokens for each step) and a seed
value of 2 (for reproducibility).
Line 2: Next, we are compiling the Gemma model for text generation using the specified sampler. The Gemma model is configured to use the TopKSampler during text generation.
Line 3: Using the Gemma model, we generate a text sequence based on the provided prompt. The generate
method takes the prompt string, the maximum length of the generated sequence (256 tokens in this case), and other generation parameters as input. The generated text sequence is then printed.
The result of the above code is as follows:
The result is now generated according to our template, and the model provides a comprehensive answer to our question.
Now, it's time to fine-tune Gemma using LoRA, and our backend configuration will stay the same. First, let's download a dataset to fine-tune our model.
We will be using the databricks/dolly-v2-12b dataset on Hugging Face. The dataset contains approximately 15 thousand instruction/response records. However, for simplicity of training, we will use only 1000 examples.
Let's download the dataset using the following command:
!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
After downloading the dataset, we need to preprocess it. We will filter out examples that contain context so it is easier to fine-tune our model. Next, we will format the entire example into one string.
import jsondata = []with open("databricks-dolly-15k.jsonl") as file:for line in file:features = json.loads(line)if features["context"]:continuetemplate = "Instruction:\n{instruction}\n\nResponse:\n{response}"data.append(template.format(**features))data = data[:1000]
This code reads data from a file in JSON Lines format (.jsonl
), filters out examples with context, and formats the remaining examples into a specific template. Let’s break it down step-by-step:
Lines 1–2: We are importing the Python json
module, which provides functions for encoding and decoding JSON data. Next, we are initializing an empty list called data
to store formatted examples.
Lines 3–5: First, we open the specified file named databricks-dolly-15k.jsonl
in read mode using a context manager. The file is expected to be in JSON Lines format, where each line represents a single JSON object. Next, we are iterating over each line in the opened file. Lastly, we load the JSON data from the current line of the file and parse it into a Python dictionary named features
using the json.loads()
function.
Lines 6–7: We are checking if the "context"
key exists in the features
dictionary. If it does (i.e., if the context is not empty), the code continues to the next iteration of the loop, skipping this example, thereby removing the examples with the context key.
Lines 8–9: We are defining a template string with placeholders for the instruction and response parts of each example. Next, we format the current example using the template string and the features
dictionary, and append the formatted string to the data
list. The **features
syntax unpacks the dictionary and passes its key-value pairs as keyword arguments to the format()
method.
Line 11: Lastly, we are limiting the data
list to contain only the first 1000 training examples. This is done to reduce the size of the dataset and speed up processing.
Our dataset is now ready and we can begin fine-tuning our model.
First, let's enable LoRA on our model using the following commands:
gemma_model.backbone.enable_lora(rank=4)gemma_model.summary()
The output of the above command is given below:
The trainable parameters have considerably decreased after enabling LoRA, as they went from 9.34 GB to 5.20 MB. The advantage of LoRA is already visible.
Let's start our training:
gemma_model.preprocessor.sequence_length = 64optimizer = keras.optimizers.AdamW(learning_rate=5e-5,weight_decay=0.01,)optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])gemma_model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer=optimizer,weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],)gemma_model.fit(data, epochs=1, batch_size=1)
Let's break the implementation down step-by-step:
Line 1: We set the maximum input sequence length to 64 tokens. By limiting the input sequence length, the code aims to control memory usage during training. This can be important when dealing with large datasets or limited computational resources.
Lines 3–5: We first define an AdamW
optimizer with a learning rate of 5e-5 and a weight decay
of 0.01. AdamW
is a variant of the Adam optimizer that includes weight decay regularization to prevent overfitting in transformer models.
Line 7: We specify a rule that the weight decay regularization should not be applied to the parameters named bias
and scale
. In transformer models, these parameters typically correspond to the bias terms and layer normalization scale parameters, respectively. Excluding them from weight decay can help stabilize training and prevent certain parameters from becoming overly penalized during optimization.
Lines 9–12: Here, we compile the Gemma model for training. It specifies the loss function, optimizer, and evaluation metrics for training. In this case, the loss function is SparseCategoricalCrossentropy
, the optimizer is the AdamW
optimizer defined earlier, and the evaluation metric is SparseCategoricalAccuracy
.
Line 14: We are training the Gemma model on the provided data for one epoch
using a batch_size
of 1. Training the model involves iteratively updating its parameters (weights) based on the input data and the specified loss function and optimizer. The number of epochs
determines how many times the entire dataset is processed during training. Using a batch_size
of 1 means that each training example is processed individually, which can be inefficient but is sometimes necessary for models with memory constraints or specific training requirements.
The result of the above code is as follows, which shows us the loss and the accuracy of the model:
Since our dataset was small and we trained our model for just one epoch
, our results were expected to be poor. However, a 56% accuracy is not bad for the first epoch
.
Let's make a prediction using the same example as above ("What should I do on a trip to Europe?") and see the difference in the results:
Our result is concise and to the point now. The answer is conversational, whereas before, it was very robotic. Since we reduced the max_length
to 64, our answer is shortened, but it is still very to the point.
Now you have an idea of how Gemma works and the impact that fine-tuning has on the results of our prompts. You can implement it on your own now, and change up the prompts to better understand it!
Free Resources