5.2. Linear Transformations#
5.2.1. What is a Linear Transformation?#
A linear transformation takes an input vector and produces an output vector by applying a matrix multiplication followed by adding a bias vector. In machine learning, we use weights (\(W\)) and biases (\(\mathbf{b}\)) to transform data from one space to another.
For an input vector \(\mathbf{x}\) in \(\mathbb{R}^n\), a linear transformation produces an output vector \(\mathbf{y}\) in \(\mathbb{R}^m\) using the equation
where \(W\) is an \(m \times n\) weight matrix, \(\mathbf{b}\) is an \(m\)-dimensional bias vector, \(\mathbf{x}\) is an \(n\)-dimensional input vector, and \(\mathbf{y}\) is an \(m\)-dimensional output vector.
Example 1: 2D to 1D Transformation: Consider transforming a 2D input to a 1D output. If we have input \(\mathbf{x} = [x_1, x_2]^T\) and want output \(y\), we use
Exercise
If a 4-element array is linearly transformed to a 2-element array, what are the shapes of the matrix \(W\) and vector \(\mathbf{b}\)?
Solution
For a linear transformation from a 4-element array to a 2-element array, the input vector \(\mathbf{x}\) has shape \((4,)\) or \(4 \times 1\), and the output vector \(\mathbf{y}\) has shape \((2,)\) or \(2 \times 1\).
Using the equation \(\mathbf{y} = W\mathbf{x} + \mathbf{b}\), the weight matrix \(W\) has shape \(2 \times 4\), and the bias vector \(\mathbf{b}\) has shape \((2,)\) or \(2 \times 1\).
Explanation: The weight matrix \(W\) must have dimensions that allow matrix multiplication: \((m \times n) \times (n \times 1) = (m \times 1)\). Since we want \((2 \times 1)\) output from \((4 \times 1)\) input, we need \(W\) to be \(2 \times 4\). The bias vector \(\mathbf{b}\) must have the same dimensions as the output vector, so it is \(2 \times 1\).
Example in code:
import numpy as np
# Weight matrix (2x4) and bias vector (2,)
W = np.array([[1, 2, -1, 3], # weights for y1
[0, 1, 2, -1]]) # weights for y2
b = np.array([1, -2]) # bias vector
# Input vector (4,)
x = np.array([1, 2, 3, 4])
# Apply transformation
y = W @ x + b
print(f"W shape: {W.shape}") # (2, 4)
print(f"b shape: {b.shape}") # (2,)
print(f"x shape: {x.shape}") # (4,)
print(f"y shape: {y.shape}") # (2,)
print(f"Output: {y}")