The Shape of Learning: A Deep Dive into Keras Fundamentals
Understanding Tensors, Backends, and the High-Level Grammar Beneath Every Model
Why I Went Back to Keras Fundamentals
We often talk about models. We always tweet about architectures. We constantly argue about “the right” optimizer like it’s a sports rivalry.
In the last couple of years, AI went from niche to dinner-table topic, and everyone started wiring generative models into products that barely needed them.
But beneath the launch posts and leaderboard screenshots, there’s the quieter substrate that actually moves numbers from “random noise” to “working system.”
In PyTorch, the conversation often starts with tensors and autograd. In Keras, it starts with something just as important; in its own way, it actually is more opinionated:
a high-level grammar for assembling differentiable programs, married to a simple mental model: define a model, compile it, fit it. Keras earns its keep by making the common path painless, while still letting you drop down to raw ops when the road gets weird.
This post aims to the Keras fundamentals guide I wish I had when I first bounced between notebooks, tutorials, and production code. It obviously isn’t a step-by-step “click here” tutorial.
It’s a map of the ground floor: what Keras actually is, what lives under model.fit, how tensors, dtypes, devices, and shapes flow through the system, and how to take control when abstractions get in your way.
If you want to understand why your training loop behaves the way it does, and how to change that behavior deliberately, then start here.
What Is Keras, Really?
At a first glance, Keras looks like a friendly API for TensorFlow. That was historically true. Today, it’s way bigger: Keras 3 is multi-backend.
The same Keras code can run atop TensorFlow, JAX, or PyTorch: you pick a backend and Keras routes ops, autograd, and device placement to the engine you chose. The north-star idea is: write once, run on your favorite accelerator stack.
Keras 3 makes this possible with a backend-agnostic ops layer (keras.ops) that mirrors the bread-and-butter of array programming (think about NumPy): reshape, matmul, einsum, broadcast_to, reductions, activations: you name it.
The goal isn’t to hide details forever; it’s to keep your model code portable and clean until you actually need to drop to backend-specific APIs.
How do you choose a backend? Typically by environment variable (KERAS_BACKEND=tensorflow|jax|torch) or configuration at import time.
The same .keras model file you save on one backend can be reloaded on another, as long as any custom components use backend-agnostic APIs.
That’s a powerful portability guarantee when your infra (or curiosity) changes.
The Hidden Terrain Beneath Every Model
Keras models transform tensors. Depending on your backend, a “tensor” might be a tf.Tensor, a JAX device array, or a Torch tensor: Keras wraps and normalizes behavior through keras.ops.
That means:
Shape (
ops.shape(x)) is the first truth.Rank (number of dimensions) is your contract with layers.
Dtype (e.g.,
float32,float16) determines speed and numerical stability.Device (CPU/GPU/TPU) determines where math happens.
You don’t have to think about all this on every line. But you do need to know it’s there, and reach for it deliberately when you debug or scale.
The Grammar of Shapes (And Why Your Model Fails at 2am)
Deep learning is ruthless about shapes. A dense layer expects [batch, features]. A 2D convolution expects [batch, height, width, channels] by default in Keras (channels-last, though channels-first is supported and configurable).
Get one axis wrong and you’ll either crash loudly or, worse, learn nonsense quietly.
You’ll constantly use a few shape shifters:
import keras as K
from keras import ops
x = ops.arange(0, 12) # [12]
x2 = ops.reshape(x, (3, 4)) # [3, 4]
x3 = ops.expand_dims(x2, 0) # [1, 3, 4]
x4 = ops.squeeze(x3, 0) # back to [3, 4]
x5 = ops.broadcast_to(x2[0], (3, 4)) # [3, 4]
If this code looks like NumPy, that’s by design. Under the hood, Keras routes each call to the backend you’ve previously selected.
Rule of thumb: before blaming your optimizer, print shapes at every boundary: inputs, after each major block, before the loss.
Making Tensors Out of Thin Air
As we said a second ago, Keras gives you backend-agnostic ways to craft data:
x = ops.array([[1., 2.], [3., 4.]])
r = K.random.uniform((2, 3)) # random in [0,1)
z = ops.zeros((2, 2))
seq = ops.arange(0, 10, 2) # 0, 2, 4, 6, 8
And Keras provides a RNG API for reproducibility: keras.utils.set_random_seed(42) seeds Python, NumPy, and the backend so you can tell real improvement from lucky initialization.
Where Learning Actually Happens
Yes, we all know that models are “layers on layers,” but at runtime you’re essentially doing vectorized math:
y = ops.matmul(x, w) + b
y = ops.maximum(y, 0.0) # ReLU
That’s it. That’s the whole plot. And it scales because Keras leans on your backend’s autograd and kernel libraries.
For day-to-day modeling, you won’t write low-level ops. You’ll use Layers that wrap those ops with clean shape contracts and parameter/gradient bookkeeping.
Three Ways to Build a Model
Keras gives you three equally valid mental models:
Sequential for straight stacks:
from keras import layers, models
model = models.Sequential([
layers.Input((28, 28, 1)),
layers.Conv2D(32, 3, activation="relu"),
layers.Flatten(),
layers.Dense(10)
])
Functional API for DAGs with branches/skip connections:
inputs = layers.Input((32,))
x = layers.Dense(64, activation="relu")(inputs)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(1)(x)
model = models.Model(inputs, outputs)
Subclassing for full control (you define
call, optionally overridetrain_step):
class MyMLP(models.Model):
def __init__(self):
super().__init__()
self.d1 = layers.Dense(64, activation="relu")
self.d2 = layers.Dense(1)
def call(self, x, training=False):
x = self.d1(x)
return self.d2(x)
When you need exotic training behavior (contrastive pairs, multi-optimizers, gradient clipping schedules, etc.), override train_step and keep the convenience of fit. That pattern is officially supported and portable.
The “compile / fit / evaluate” Contract
Keras distills the training loop to a crisp interface:
model.compile(
optimizer="adam",
loss=K.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[K.metrics.SparseCategoricalAccuracy()]
)
history = model.fit(train_ds, epochs=10, validation_data=val_ds)
model.evaluate(test_ds)
Behind the scenes, Keras builds a backend-specific train function that runs the forward pass, computes the loss, backpropagates gradients, and applies updates; all of that without you writing unnecessary boilerplate.
If that’s all you need, stop here. If not, override train_step for custom behavior while preserving callbacks, distribution, and logs.
Data Pipelines That Scale
Keras plays well with plain NumPy arrays for small problems. For real workloads, reach for dataset utilities:
Image folders → datasets:
keras.utils.image_dataset_from_directory(...)Timeseries windows:
keras.utils.timeseries_dataset_from_array(...)Built-in toy datasets:
keras.datasets.mnist/cifar10/imdb/...for quick experiments.
If you’re on the TensorFlow backend, tf.data is battle-tested for performance: prefetch, cache, interleave, shard. On JAX or PyTorch backends, you can still use Keras datasets or the ecosystem’s native options: Keras doesn’t get in your way.
CPU, GPU, TPU (and How Keras Plays Nice)
Accelerators change the shape of your day. Keras relies on the backend for device semantics:
TensorFlow backend: use GPUs/TPUs via
tf.configandtf.distribute(MirroredStrategy, MultiWorkerMirroredStrategy). Keras integrates all of that natively: all you have to do is to callstrategy.scope()and build/compile inside for synchronous multi-device training.JAX backend: works with multiple accelerators via JAX’s device model; Keras routes ops through the compiler XLA.
PyTorch backend: uses Torch device semantics under the hood while preserving the Keras public API.
What you need to remember: put the whole training step on the same device: model, inputs, targets, and loss. Keras handles the common paths; be explicit when you step off them.
Mixed Precision: Speed, with Care
Half-precision (float16 / bfloat16) can cut memory in half and dramatically speed up training on modern GPUs/TPUs. In Keras you enable it via a dtype policy:
from keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16") # or "bfloat16"
That flips layer compute dtypes (and often keeps variables in float32 for stability). It’s just one line, you might think, but it’s not free: keep an eye on loss scaling and numerics, especially on older hardware. Start with validation after each change.
Precision Isn’t a Moral Quality (It’s a Tradeoff)
float32: the dependable default.float16/bfloat16: memory/perf wins with potential stability tradeoffs.float64: rarely needed in DL; slower and heavier.
Choose based on hardware + model behavior, not vibes. Measure throughput and convergence.
Randomness and Reproducibility
Stochastic training means runs vary. When you’re tuning, you need to know if a change helped or if you just rolled a natural 20 on initialization. Use:
import keras as K
K.utils.set_random_seed(7)
That seeds Python/NumPy/backend in one go. Combine with deterministic dataloading where possible.
The Control Panel You Actually Use
Callbacks are how you steer training without touching the loop:
EarlyStopping: stop when val loss stalls.
ModelCheckpoint: save weights/artifacts during training.
ReduceLROnPlateau / LearningRateScheduler: adapt the LR.
TensorBoard: logs that tell you what actually happened.
These are stable, boring, and indispensable. Glue them in early.
Saving and Loading: The .keras Format
Keras 3 standardizes on a portable .keras format (a zip archive) that saves architecture, weights, and optimizer state, and can be reloaded across backends if you wrote custom bits with backend-agnostic APIs. This is the default you should reach for:
model.save("model.keras")
reloaded = keras.models.load_model("model.keras")
(You can still save just weights or use backend-specific exports when appropriate, but .keras is the clean portable path.)
When You Need to Go Off-Road: Custom Train Steps
Sometimes you need gradient accumulation, multiple optimizers, or nonstandard objectives. Override train_step:
class ContrastiveModel(keras.Model):
def train_step(self, data):
(xa, xb), y = data
with K.backend.StatelessScope(): # optional, backend-specific contexts
with K.backend.GradientTape() as tape:
za = self(xa, training=True)
zb = self(xb, training=True)
loss = self.compiled_loss(y, (za, zb), regularization_losses=self.losses)
grads = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply(grads, self.trainable_variables)
self.compiled_metrics.update_state(y, (za, zb))
return {m.name: m.result() for m in self.metrics}
You still call fit. You still get callbacks and logs. You just own the per-batch logic. This is the right level of control for 90% of “weird” training.
Debugging: What to Look At (In Order)
Shapes at API boundaries (
model.summary()is your friend).Dtypes after you change policies (did your final layer output switch to
float16unintentionally?).Device placement when you mix data sources.
Loss & metrics definitions (wrong reduction? logits vs. probabilities?).
Callbacks & LR schedule (plateaus often just need schedule adjustments).
If your curves look flat: test on a toy subset, overfit a tiny batch, and verify the model can drive training loss close to zero. If it can’t, the bug’s in plumbing, not scale.
A Quick, Concrete Mini-Project
Task: image classification on a small folder dataset.
import os
os.environ["KERAS_BACKEND"] = "tensorflow" # or "jax" or "torch"
import keras as K
from keras import layers
train = K.utils.image_dataset_from_directory(
"data/train", image_size=(224, 224), batch_size=64
)
val = K.utils.image_dataset_from_directory(
"data/val", image_size=(224, 224), batch_size=64
)
model = K.Sequential([
layers.Input((224, 224, 3)),
layers.Rescaling(1./255),
layers.Conv2D(32, 3, activation="relu"),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation="relu"),
layers.GlobalAveragePooling2D(),
layers.Dense(1)
])
model.compile(
optimizer=K.optimizers.Adam(1e-3),
loss=K.losses.BinaryCrossentropy(from_logits=True),
metrics=[K.metrics.BinaryAccuracy(threshold=0.0)]
)
cb = [
K.callbacks.ModelCheckpoint("chkpt.keras", save_best_only=True),
K.callbacks.EarlyStopping(patience=3, restore_best_weights=True),
]
history = model.fit(train, validation_data=val, epochs=12, callbacks=cb)
model.save("final.keras")
Swap KERAS_BACKEND and rerun to experiment with very different sets of engines and accelerators: your model code stays the same.
Mental Models That Scale
Keras is a language. Its primitives—Layers, Models, Losses, Metrics, Optimizers, Datasets—compose into programs with clear semantics.
Abstractions are levers.
compile/fitis the fast path;train_stepis the escape hatch.Portability matters. Backend choice is a constraint now and a lever later. Write backend-agnostic code unless you need backend-specific APIs.
Performance is a pipeline. Dataloaders, device placement, mixed precision, and distribution strategies add up. Change one at a time and measure.
Saving is part of training. Decide early how you’ll serialize and where you’ll reload. Default to
.keras.
The Keras Project Journey
Going back to Keras fundamentals isn’t nostalgia: it’s clarity.
When you peel away the hype cycles, paper drops, and leaderboard screenshots, what matters is understanding the levers you can actually pull: tensors, shapes, devices, dtypes, and the training loop itself.
Keras gives you a grammar for deep learning: clean defaults for the common path, and structured escape hatches when you need control.
Its portability across TensorFlow, JAX, and PyTorch means you can focus on the logic of your model instead of rewriting code for every backend du jour.
The point isn’t just to “get something working.” It’s to understand why it works, and how to change the system deliberately when it doesn’t.
Once you grasp the ground floor, so the contract between tensors, layers, losses, and devices, you stop treating deep learning as trial-and-error and start shaping it with intent.
In the end, Keras isn’t just a library. It’s a way of thinking about models as composable programs that remain flexible, portable, and understandable: so you can build systems that don’t just train, but endure.



