""" This module implements loss functions and regularizers for VQE, QCNN and Encoder"""
import jax
import jax.numpy as jnp
from typing import List, Callable
from numbers import Number
# VQE LOSSES
[docs]def vqe_fidelities(Y: List[Number], params: List[Number], q_circuit: Callable) -> float:
"""
LOSS: Compute Fidelity between VQE PSI (output of q_circuit(params)) and TRUE PSI computed by diagonalizing the Hamiltonian
Parameters
----------
Y : np.ndarray
Array of true wavefunction obtained by diagonalizing the Hamiltonians
params : np.ndarray
Array of parameters of the VQE circuits
q_circuit : fun
Quantum function of the VQE circuit
Returns
-------
float
Mean fidelities between VQE PSI and TRUE PSI
"""
# Core function, not vectorized
def vqe_fidelity(y, p, q_circuit):
psi_out = q_circuit(p)
return jnp.square(jnp.abs(jnp.conj(psi_out) @ y))
# Vectorize the fidelity function
v_fidelty = jax.vmap(lambda y, p: vqe_fidelity(y, p, q_circuit), in_axes=(0, 0))
return v_fidelty(Y, params)
# QCNN LOSSES
[docs]def hinge(X, Y, params, q_circuit):
"""
LOSS: (Experimental) Compute Hinge loss for a binary classification task
N.B: MAX is not applied because output is a probability [0,1] that will be mapped
to [-1,1], hence the 1 - Prediction(X)*Y can be at minimum 0
Parameters
----------
X : np.ndarray
Array of VQE parameters (input of VQE)
Y : np.ndarray
Array of labels
params : np.ndarray
Array of parameters of the QCNN circuit
q_circuit : fun
Quantum function of the VQE circuit
Returns
-------
float
Mean Hinge Loss <Circuit(X)|Y>
"""
v_qcnn_prob = jax.vmap(lambda v: q_circuit(v, params))
predictions = 2 * v_qcnn_prob(X) - 1
Y_hinge = 2 * Y - 1
hinge_loss = jnp.mean(1 - predictions[:, 1] * Y_hinge)
return hinge_loss
[docs]def cross_entropy1D(X, Y, params, q_circuit):
"""
LOSS: Compute Cross Entropy for a binary classification task
Parameters
----------
X : np.ndarray
Array of VQE parameters (input of VQE)
Y : np.ndarray
Array of labels
params : np.ndarray
Array of parameters of the QCNN circuit
q_circuit : fun
Quantum function of the VQE circuit
Returns
-------
float
Cross entropy <Circuit(X)|Y>
"""
v_qcnn_prob = jax.vmap(lambda v: q_circuit(v, params))
predictions = v_qcnn_prob(X)
logprobs = jnp.log(predictions)
nll = jnp.take_along_axis(logprobs, jnp.expand_dims(Y, axis=1), axis=1)
ce = -jnp.mean(nll)
return ce
[docs]def cross_entropy(X, Y, params, q_circuit):
"""
LOSS: Compute Cross Entropy for a binary classification task
Parameters
----------
X : np.ndarray
Array of VQE parameters (input of VQE)
Y : np.ndarray
Array of labels
params : np.ndarray
Array of parameters of the QCNN circuit
q_circuit : function
Quantum function of the VQE circuit
Returns
-------
float
Cross entropy <Circuit(X)|Y>
"""
v_qcnn_prob = jax.vmap(lambda v: q_circuit(v, params))
predictions = v_qcnn_prob(X)
logprobs1 = jnp.log(predictions).flatten()
logprobs2 = jnp.log(1 - predictions).flatten()
logprobs1 = logprobs1
logprobs2 = logprobs2
Y = Y.flatten()
return +jnp.mean(Y * logprobs1 + (1 - Y) * logprobs2)
[docs]def cross_entropy_power4(X, Y, params, q_circuit):
"""
LOSS: Compute Cross Entropy for a binary classification task
Apply ^4 to punish the model on uncertain classifications
Parameters
----------
X : np.ndarray
Array of VQE parameters (input of VQE)
Y : np.ndarray
Array of labels
params : np.ndarray
Array of parameters of the QCNN circuit
q_circuit : function
Quantum function of the VQE circuit
Returns
-------
float
Cross entropy <Circuit(X)|Y>
"""
v_qcnn_prob = jax.vmap(lambda v: q_circuit(v, params))
predictions = v_qcnn_prob(X)
logprobs1 = jnp.log(predictions).flatten()
logprobs2 = jnp.log(1 - predictions).flatten()
logprobs1 = jnp.square(jnp.square(logprobs1))
logprobs2 = jnp.square(jnp.square(logprobs2))
Y = Y.flatten()
return +jnp.mean(Y * logprobs1 + (1 - Y) * logprobs2)