Automatic Differentiation (JAX)
This example demonstrates how to use JAX for automatic differentiation in Python solving the 5D Levy test function using the L-BFGS local solver. JAX computes the gradients of the Levy function efficiently, enabling optimization algorithms to find the minimum more easily.
JAX is a library for high-performance numerical computing and machine learning in Python. It provides automatic differentiation capabilities, allowing users to compute gradients and optimize functions easily 1.
The Levy test function is a multimodal function commonly used to evaluate optimization algorithms 2. It is defined as:
where is defined as:
In this case, we evaluate the 5 dimensional problem (). The bounds of the function are the hypercube
. It has a global minimum at .
The file for this example can be found at python/examples/levy_5d_jax.py
Code
import pyglobalsearch as gs
# Importing JAXimport jax.numpy as jnpfrom jax import grad, jit
# We begin by defining the 5D Levy function@jit # JIT compile for maximum performancedef levy_jax(x): """5D Levy function implemented in pure JAX using explicit computation""" # Transform variables: w_i = 1 + (x_i - 1) / 4 w1 = 1.0 + (x[0] - 1.0) / 4.0 w2 = 1.0 + (x[1] - 1.0) / 4.0 w3 = 1.0 + (x[2] - 1.0) / 4.0 w4 = 1.0 + (x[3] - 1.0) / 4.0 w5 = 1.0 + (x[4] - 1.0) / 4.0
# First term: sin**2(π * w_1) term1 = jnp.sin(jnp.pi * w1) ** 2
# Middle terms: (w_i - 1)² * [1 + 10 * sin**2(π * w_i + 1)] for i = 1 to d-1 term_mid1 = (w1 - 1.0) ** 2 * (1.0 + 10.0 * jnp.sin(jnp.pi * w1 + 1.0) ** 2) term_mid2 = (w2 - 1.0) ** 2 * (1.0 + 10.0 * jnp.sin(jnp.pi * w2 + 1.0) ** 2) term_mid3 = (w3 - 1.0) ** 2 * (1.0 + 10.0 * jnp.sin(jnp.pi * w3 + 1.0) ** 2) term_mid4 = (w4 - 1.0) ** 2 * (1.0 + 10.0 * jnp.sin(jnp.pi * w4 + 1.0) ** 2)
# Last term: (w_5 - 1)**2 * [1 + sin**2(2π * w_5)] term_last = (w5 - 1.0) ** 2 * (1.0 + jnp.sin(2.0 * jnp.pi * w5) ** 2)
return term1 + term_mid1 + term_mid2 + term_mid3 + term_mid4 + term_last
# We can automatically compute the gradient using JAX and JIT compilationgradient_jax = jit(grad(levy_jax))
def obj(x) -> float: """Objective function wrapper""" result = levy_jax(x) # We return a float for PyGlobalSearch return float(result)
def grad_func(x): """Gradient function using JAX automatic differentiation""" grad_result = gradient_jax(x) return grad_result
# x_i ∈ [-10, 10] for all i = 1, ..., 5bounds_jax = jnp.array( [[-10.0, 10.0], [-10.0, 10.0], [-10.0, 10.0], [-10.0, 10.0], [-10.0, 10.0]])
def variable_bounds(): """Variable bounds for the 5D Levy function""" return bounds_jax
# Create optimization parametersparams = gs.PyOQNLPParams( iterations=1000, population_size=5000, wait_cycle=10, threshold_factor=0.1, distance_factor=0.5,)
print("Optimization Parameters:")print(f" Iterations: {params.iterations}")print(f" Population size: {params.population_size}")print(f" Wait cycle: {params.wait_cycle}")print(f" Threshold factor: {params.threshold_factor}")print(f" Distance factor: {params.distance_factor}")print()
# Create the problem with JAX-computed gradientproblem = gs.PyProblem(obj, variable_bounds, grad_func) # type: ignore
print("Starting optimization...")print()
# Run optimization with L-BFGSsol_set = gs.optimize(problem, params, local_solver="LBFGS", seed=0)
# Display resultsif sol_set is not None and len(sol_set) > 0: print(f"Optimization completed! Found {len(sol_set)} solution(s):") print("=" * 50)
for i, sol in enumerate(sol_set, 1): x_opt = sol["x"] f_opt = sol["fun"]
print(f"Solution #{i}:") print(f" Parameters: {x_opt}") print(f" Objective: {f_opt:12.8f}")
# Verify gradient is near zero at optimum grad_at_opt = grad_func(jnp.array(x_opt)) grad_norm = jnp.linalg.norm(grad_at_opt) print(f" Gradient norm: {grad_norm:12.8f}")
# Check if this is close to the known global minimum [1, 1, 1, 1, 1] known_minimum = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0]) distance_to_optimum = float(jnp.linalg.norm(jnp.array(x_opt) - known_minimum)) error_sq = float(jnp.square(distance_to_optimum)) print(f" Error (squared): {error_sq:.2e}")
if distance_to_optimum < 0.1: print( f" Close to known global minimum (distance: {distance_to_optimum:.6f})" ) else: print(f" Distance to known global minimum: {distance_to_optimum:.6f}")
print()
else: print("No solution found!")
With this code, PyGlobalSearch is able to find the optimum of the objective function.
References
Footnotes
-
Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., Necula, G., Paszke, A., VanderPlas, J., Wanderman-Milne, S., & Zhang, Q. (2018). JAX: Composable transformations of Python+NumPy programs (Version 0.3.13) [Software]. Retrieved July 2025, from http://github.com/jax-ml/jax ↩
-
Surjanovic, S., & Bingham, D. (2013). Virtual Library of Simulation Experiments: Test Functions and Datasets. Retrieved from https://www.sfu.ca/~ssurjano/levy.html ↩