Source code for PhaseEstimation.qcnn

""" This module implements the base functions to implement a Quantum Convolutional Neural Network (QCNN) for the (ANNNI) Ising Model. """
import pennylane as qml
from pennylane import numpy as np
import jax
import jax.numpy as jnp
from jax.example_libraries import optimizers

from matplotlib import pyplot as plt

import copy, tqdm, pickle

from PhaseEstimation import circuits, vqe, general as qmlgen, ising_chain as ising, annni_model as annni, visualization as qplt

from typing import Tuple, List, Callable
from numbers import Number

##############


[docs]def qcnn_circuit(params: List[Number], N: int, n_outputs: int) -> Tuple[int, List[int]]: """ Building function for the QCNN circuit: Parameters ---------- params : np.ndarray Array of QCNN parameters N : int Number of qubits n_outputs : int Output vector dimension Returns ------- int Total number of parameters needed to build this circuit np.ndarray Array of indexes of not-measured wires (due to pooling) """ # Wires that are not measured (through pooling) active_wires = np.arange(N) # Visual Separation VQE||QCNN qml.Barrier() qml.Barrier() # Index of the parameter vector index = 0 # Iterate Convolution+Pooling until we only have a single wires index = circuits.wall_gate(active_wires, qml.RY, params, index) circuits.wall_cgate_serial(active_wires, qml.CNOT) while len(active_wires) > n_outputs: # Repeat until the number of active wires # (non measured) is equal to n_outputs # Convolute index = circuits.convolution(active_wires, params, index) # Measure wires and apply rotations based on the measurement index, active_wires = circuits.pooling(active_wires, qml.RX, params, index) qml.Barrier() circuits.wall_cgate_serial(active_wires, qml.CNOT) index = circuits.wall_gate(active_wires, qml.RY, params, index) # Return the number of parameters return index + 1, active_wires
[docs]class qcnn: def __init__(self, vqe: vqe.vqe, qcnn_circuit: Callable, n_outputs: int = 1): """ Class for the QCNN algorithm Parameters ---------- vqe : class VQE class qcnn_circuit : Function of the QCNN circuit n_outputs : int Output vector dimension """ self.vqe = vqe self.N = vqe.Hs.N self.n_states = vqe.Hs.n_states self.n_outputs = n_outputs self.qcnn_circuit_fun = lambda p: qcnn_circuit(p, self.N, n_outputs) self.n_params, self.final_active_wires = self.qcnn_circuit_fun([0] * 10000) self.params = np.array(np.random.rand(self.n_params)) self.device = vqe.device self.vqe_params = np.array(vqe.vqe_params0) self.labels = np.array(vqe.Hs.labels) self.loss_train: List[float] = [] self.loss_test: List[float] = [] def __repr__(self): @qml.qnode(self.device, interface="jax") def circuit_drawer(self): _ = self.qcnn_circuit_fun(np.arange(self.n_params)) if self.n_outputs == 1: return qml.probs(wires=self.N - 1) else: return qml.probs([int(k) for k in self.final_active_wires]) return qml.draw(circuit_drawer)(self) def _vqe_qcnn_circuit(self, vqe_p, qcnn_p): """ Circuit: VQE + QCNN """ self.vqe.circuit(vqe_p) self.qcnn_circuit_fun(qcnn_p) # Training function
[docs] def train( self, lr: float, n_epochs: int, train_index: List[Number], loss_fn: Callable, circuit: bool = False, plot: bool = False, ): """ Training function for the QCNN. Parameters ---------- lr : float Learning rate for the ADAM optimizer n_epochs : int Total number of epochs for each learning train_index : np.ndarray Index of training points loss_fn : function Loss function circuit : bool if True -> Prints the circuit plot : bool if True -> It displays loss curve """ # -1 could be in the labels as [-1, -1] when training # ANNNI model which non-trivial cases have no solution if (-1 not in self.labels) and (None not in self.labels): X_train, Y_train = ( jnp.array(self.vqe_params[train_index]), jnp.array(self.labels[train_index]), ) test_index = np.setdiff1d(np.arange(len(self.vqe_params)), train_index) X_test, Y_test = ( jnp.array(self.vqe_params[test_index]), jnp.array(self.labels[test_index]), ) else: # If we are traing an ANNNI model, we have to first restrict on the trivial cases: # L = 0, K = whatever # K = 0, L = whatever mask = jnp.array( jnp.logical_or( jnp.array(self.vqe.Hs.model_params)[:, 1] == 0, jnp.array(self.vqe.Hs.model_params)[:, 2] == 0, ) ) self.vqe_params = jnp.array(self.vqe_params) X, Y = self.vqe_params[mask], self.labels[mask, :].astype(int) # The labels stored in the Hamiltonian class are: # > [ 1, 1] for paramagnetic states # > [ 0, 1] for ferromagnetic states # > [ 1, 0] for antiphase states # > [ 0, 0] not used # > [-1,-1] for states with no analytical solutions # qml.probs(wires = active_wires) will output the following probabilities: # (example for a two qbits output) # p(00), p(01), p(10), p(11) # The labels need to be transformed accordingly # [0,0] -> [1,0,0,0] trash case # [0,1] -> [0,1,0,0] for ferromagnetic # [1,0] -> [0,0,1,0] for antiphase # [1,1] -> [0,0,0,1] for paramagnetic Ymix = [] for label in Y: if (label == [0, 0]).all(): Ymix.append([1, 0, 0, 0]) # Trash elif (label == [0, 1]).all(): Ymix.append([0, 1, 0, 0]) # Ferromagnetic elif (label == [1, 0]).all(): Ymix.append([0, 0, 1, 0]) # Antiphase elif (label == [1, 1]).all(): Ymix.append([0, 0, 0, 1]) # Paramagnetic Y = jnp.array(Ymix) # The indexes of test are # All indexes (only analitical) \ train_index test_index = np.setdiff1d(np.arange(len(Y)), train_index) X_train, Y_train = X[train_index], Y[train_index] X_test, Y_test = X[test_index], Y[test_index] if circuit: # Display the circuit print("+--- CIRCUIT ---+") print(self) # QCircuit: Circuit(VQE, QCNNparams) -> probs @qml.qnode(self.device, interface="jax") def qcnn_circuit_prob(vqe_p, qcnn_p): self._vqe_qcnn_circuit(vqe_p, qcnn_p) return qml.probs([int(k) for k in self.final_active_wires]) params = copy.copy(self.params) # Gradient of the Loss function jd_loss_fn = jax.jit( jax.grad(lambda p: loss_fn(X_train, Y_train, p, qcnn_circuit_prob)) ) # Update function # Returns updated parameters, updated state of the optimizer def update(params, opt_state): grads = jd_loss_fn(params) opt_state = opt_update(0, grads, opt_state) return get_params(opt_state), opt_state # Definying following function: # jitted loss function for training set loss(params) train_loss_fn = jax.jit( lambda p: loss_fn(X_train, Y_train, p, qcnn_circuit_prob) ) # jitted loss function for test set loss(params) test_loss_fn = jax.jit(lambda p: loss_fn(X_test, Y_test, p, qcnn_circuit_prob)) # Initialize tqdm progress bar progress = tqdm.tqdm(range(n_epochs), position=0, leave=True) # Defining an optimizer in Jax opt_init, opt_update, get_params = optimizers.adam(lr) opt_state = opt_init(params) loss_history, loss_history_test = [], [] # Training loop: for epoch in range(n_epochs): params, opt_state = update(params, opt_state) # Every 100 iterations append the updated training (and testing) loss if epoch % 100 == 0: loss_history.append(train_loss_fn(params)) if len(Y_test) > 0: loss_history_test.append(test_loss_fn(params)) # Update progress bar progress.update(1) progress.set_description("Cost: {0}".format(loss_history[-1])) # Update qcnn class after training self.loss_train = loss_history self.loss_test = loss_history_test self.params = params if plot: plt.figure(figsize=(15, 5)) plt.plot( np.arange(len(loss_history)) * 100, np.asarray(loss_history), label="Training Loss", ) if len(X_test) > 0: plt.plot( np.arange(len(loss_history_test)) * 100, np.asarray(loss_history_test), label="Test Loss", ) plt.axhline(y=0, color="r", linestyle="--") plt.title("Loss history") plt.ylabel("Average Cross entropy") plt.xlabel("Epoch") plt.grid(True) plt.legend()
[docs] def predict(self): """ Get the phases probabilities for each VQE state Returns ------- List[List[Number]] List of probabilities """ @qml.qnode(self.device, interface="jax") def qcnn_circuit_prob(params_vqe, params): self._vqe_qcnn_circuit(params_vqe, params) return qml.probs([int(k) for k in self.final_active_wires]) vcircuit = jax.vmap( lambda v: qcnn_circuit_prob(v, self.params), in_axes=(0) ) predictions = np.array(vcircuit(self.vqe_params)) return predictions
[docs] def predict_lines(self, predictions = []): """ Get the prdicted phase-transition line Parameters ---------- predictions : List[List[Number]] This is the output of self.predict(), if it is not passed, the predictions will be computed asnew Returns ------- List[Number] y-coordinate of the transition point for each kappa value """ sidex, sidey = self.vqe.Hs.n_kappas, self.vqe.Hs.n_hs print(sidex,sidey) if len(predictions) == 0: predictions = self.predict() predictions = np.reshape(np.argmax(predictions,axis=1), (sidex,sidey)) line_trans = [] for col in range(sidex): y_cord_trans = 0 for row in range(sidey-1,-1,-1): prediction = predictions[col,row] if prediction != 3: break y_cord_trans += 1 line_trans.append(y_cord_trans) return np.array(line_trans)
[docs] def save(self, filename: str): """ Saves QCNN parameters to file Parameters ---------- filename : str File where to save the parameters """ if isinstance(filename, str): things_to_save = [self.params, self.qcnn_circuit_fun] with open(filename, "wb") as f: pickle.dump(things_to_save, f) else: raise TypeError("Invalid name for file")
[docs] def show(self, train_index = [], marginal = False, **kwargs): if self.vqe.Hs.func == ising.build_Hs: qplt.QCNN_classification_ising(self, train_index) elif self.vqe.Hs.func == annni.build_Hs: if marginal: qplt.QCNN_classification_ANNNI_marginal(self) qplt.QCNN_classification_ANNNI(self, **kwargs)
[docs]def load(filename_vqe: str, filename_qcnn: str) -> qcnn: """ Load QCNN from VQE file and QCNN file Parameters ---------- filename_vqe : str Name of the file from where to load the VQE class filename_qcnn : str Name of the file from where to load the main parameters of the QCNN class Returns ------- class QCNN class """ if isinstance(filename_vqe, str) and isinstance(filename_qcnn, str): loaded_vqe = vqe.load_vqe(filename_vqe) with open(filename_qcnn, "rb") as f: params, qcnn_circuit_fun = pickle.load(f) loaded_qcnn = qcnn(loaded_vqe, qcnn_circuit_fun) loaded_qcnn.params = params return loaded_qcnn raise TypeError("Invalid name for file")
[docs]def get_trainset_gaussian(vqeclass: vqe.vqe, nS: int, sigma: float = 1) -> List[int]: """ Draw randomly samples from the training for each axis according to the gaussian distribution centered around the phase transition on the axis and std sigma Parameters ---------- vqeclass : vqe.vqe VQE class to get the side size of the system nS : int Number of samples to draw in total sigma : float Standard deviation of the two distributions Returns ------- np.ndarray List of the indexes of the subset of the training set """ side = vqeclass.Hs.side if nS > 2 * side - 1: raise ValueError("Subset size too large!") nS = nS // 2 # Size of the subset -> Number of samples to draw among each axis mu = side // 2 # Mean of the distributions training_set: List[int] = [] # Get Y training set: while len(training_set) < nS: sample = int( np.random.normal(mu, sigma) ) # Draw randomly according to the gaussian distribution if sample not in training_set: # No duplicates allowed if sample >= 0 and sample < side: # Check if the drawn sample is in range training_set.append(sample) # Get X training set: while len(training_set) < 2 * nS: sample = ( int(np.random.normal(mu, sigma)) + side ) # Draw randomly according to the gaussian distribution (and shift to the X axis) if sample not in training_set: # No duplicates allowed if ( sample >= side and sample < 2 * side ): # Check if the drawn sample is in range training_set.append(sample) return np.array(training_set)
[docs]def ANNNI_accuracy(qcnnclass: qcnn, plot: bool = False) -> float: """ Compute accuracy of the QCNN of the whole ANNNI state space Parameters ---------- qcnnclass : qcnn QCNN class plot : bool if True -> displays the plot of the accuracy: if green: sample correctly classified if red : sample wrongly classified Returns ------- float Accuracy : (# samples correctly classified)/(# samples) (0,1) """ circuit = qcnnclass._vqe_qcnn_circuit side = qcnnclass.vqe.Hs.side @qml.qnode(qcnnclass.device, interface="jax") def qcnn_circuit_prob(params_vqe, params): circuit(params_vqe, params) return [qml.probs(wires=int(k)) for k in qcnnclass.final_active_wires] vcircuit = jax.vmap(lambda v: qcnn_circuit_prob(v, qcnnclass.params), in_axes=(0)) # Get the predictions of the QCNN among all states of the VQE predictions = np.array(np.argmax(vcircuit(qcnnclass.vqe_params), axis=2)) # Compare predictions to actual states # applying inequalities to theoretical curves labels = [] for idx in range(qcnnclass.vqe.Hs.n_states): # compute coordinates and normalize for x in [0,1] # and y in [0,2] x = (idx // side) / side y = 2 * (idx % side) / side # If x==0 we get into 0/0 on the theoretical curve if x == 0: if 1 <= y: labels.append([1, 1]) else: labels.append([0, 1]) elif x <= 0.5: if qmlgen.paraferro(x) <= y: labels.append([1, 1]) else: labels.append([0, 1]) else: if (qmlgen.paraanti(x)) <= y: labels.append([1, 1]) else: labels.append([1, 0]) correct = np.sum(np.array(labels) == predictions, axis=1).astype(int) == 2 accuracy = np.sum(correct) / (side * side) if plot: plt.imshow(np.rot90(np.reshape(correct, (side, side))), cmap="RdYlGn") plt.show() return accuracy