Source code for PhaseEstimation.visualization

"""Plotting functions for the classes hamiltonians, vqe, qcnn, encoder.
This functions are not meant to be used directly, but are called within their respective classes"""

import pennylane as qml
from pennylane import numpy as np
import jax
import jax.numpy as jnp
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib.colors import LinearSegmentedColormap, LogNorm
import plotly.graph_objects as go
from tqdm.auto import tqdm

from PhaseEstimation import general as qmlgen
from PhaseEstimation import losses

from typing import List, Callable

from matplotlib import rc

rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"]})
## for Palatino and other serif fonts use:
rc("font", **{"family": "serif", "serif": ["Computer Modern Roman"]})
rc("text", usetex=True)

#  __           _______  _______ .__   __.  _______ .______          ___       __      
# /_ |         /  _____||   ____||  \ |  | |   ____||   _  \        /   \     |  |     
#  | |        |  |  __  |  |__   |   \|  | |  |__   |  |_)  |      /  ^  \    |  |     
#  | |        |  | |_ | |   __|  |  . `  | |   __|  |      /      /  /_\  \   |  |     
#  | |  __    |  |__| | |  |____ |  |\   | |  |____ |  |\  \----./  _____  \  |  `----.
#  |_| (__)    \______| |_______||__| \__| |_______|| _| `._____/__/     \__\ |_______|
                                                                                     
[docs]def getlines_from_Hs( Hs, func: Callable, xrange: List[float], res: int = 100, **kwargs ): """ Plot function func from xrange[0] to xrange[1] This function uses the Hamiltonians class to plot the function according to the ranges of its parameters Parameters ---------- Hs : hamiltonians.hamiltonian Custom Hamiltonian class func : function Function to plot, usually: > general.paraanti : Transition line between paramagnetic phase and antiphase; > general.paraferro : Transition line between paramagnetic phase and ferromagnetic phase; > general.b1 : Pseudo-transition line inside the antiphase subspace; > general.peshel_emery :Peshel Emery Line. """ # Get information from vqeclass for plotting # (func needs to be resized) side_x = Hs.n_kappas side_y = Hs.n_hs max_x = Hs.kappa_max yrange = [0, Hs.h_max] xs = np.linspace(xrange[0], xrange[1], res) ys = func(xs) ys[ys > yrange[1]] = yrange[1] corrected_xs = (side_x * xs / max_x - 0.5) plt.plot(corrected_xs, side_y - ys * side_y / yrange[1] - 0.5, **kwargs)
[docs]def plot_layout(Hs, pe_line, phase_lines, title, figure_already_defined = False): """ Many plotting functions here have the same layout, this function will be called inside the others to have a standard layout Parameters ---------- Hs : hamiltonians.hamiltonian Custom hamiltonian class, it is needed to set xlim and ylim and ticks pe_line : bool if True plots Peshel Emery line phase_lines : bool if True plots the phase transition lines title : str Title of the legent of the plot figure_already_defined : bool if False it calls the plt.figure function """ if not figure_already_defined: plt.figure(figsize=(8, 6), dpi=80) # Set the axes according to the Hamiltonian class plt.ylabel(r"$h$", fontsize=24) plt.xlabel(r"$\kappa$", fontsize=24) plt.tick_params(axis="x", labelsize=18) plt.tick_params(axis="y", labelsize=18) ticks_x = [-.5 , Hs.n_kappas/4 - .5, Hs.n_kappas/2 - .5 , 3*Hs.n_kappas/4 - .5, Hs.n_kappas - .5] ticks_y = [-.5 , Hs.n_hs/4 - .5, Hs.n_hs/2 - .5 , 3*Hs.n_hs/4 - .5, Hs.n_hs - .5] plt.xticks( ticks= ticks_x, labels=[np.round(k * Hs.kappa_max / 4, 2) for k in range(0, 5)], ) plt.yticks( ticks=ticks_y, labels=[np.round(k * Hs.h_max / 4, 2) for k in range(4, -1, -1)], ) if pe_line: getlines_from_Hs(Hs, qmlgen.peshel_emery, [0, 0.5], res=100, color = "blue", alpha=1, ls = '--', dashes=(4,5), label = 'Peshel-Emery line') if phase_lines: getlines_from_Hs(Hs, qmlgen.paraanti, [0.5, Hs.kappa_max], res=100, color = "red", label = 'Phase-transition\n lines') getlines_from_Hs(Hs, qmlgen.paraferro, [0, 0.5], res=100, color = "red") if len(title) > 0: leg = plt.legend( bbox_to_anchor=(1, 1), loc="upper right", fontsize=16, facecolor="white", markerscale=1, framealpha=0.9, title=title, title_fontsize=16, ) plt.tight_layout()
# ___ __ __ ___ .___ ___. __ __ .___________. ______ .__ __. __ ___ .__ __. _______. # |__ \ | | | | / \ | \/ | | | | | | | / __ \ | \ | | | | / \ | \ | | / | # ) | | |__| | / ^ \ | \ / | | | | | `---| |----`| | | | | \| | | | / ^ \ | \| | | (----` # / / | __ | / /_\ \ | |\/| | | | | | | | | | | | | . ` | | | / /_\ \ | . ` | \ \ # / /_ __ | | | | / _____ \ | | | | | | | `----. | | | `--' | | |\ | | | / _____ \ | |\ | .----) | # |____| (__) |__| |__| /__/ \__\ |__| |__| |__| |_______| |__| \______/ |__| \__| |__| /__/ \__\ |__| \__| |_______/
[docs]def HAM_mass_gap(Hs, phase_lines = False, pe_line = False): """ Shows the mass gap which is defined as the difference between the first excited leven and the ground energy level for each point in the parameter space. Parameters ---------- Hs : hamiltonians.hamiltonian Custom hamiltonian class, it is needed to call plot_layout phase_lines : bool if True plots the phase transition lines pe_line : bool if True plots Peshel Emery line """ sidex = Hs.n_kappas sidey = Hs.n_hs # Compute the massgap mass_gap = np.reshape(Hs.true_e1 - Hs.true_e0, (sidex, sidey) ) plot_layout(Hs, phase_lines=phase_lines, pe_line=pe_line, title=r"Mass Gap, $N = {0}$".format(str(Hs.N))) plt.imshow(np.rot90(mass_gap), aspect = Hs.n_kappas / Hs.n_hs) cbar = plt.colorbar(fraction=0.04) cbar.ax.tick_params(labelsize=16)
[docs]def HAM_phases_plot(Hs): """ Shows the division of phases of the parameter space according to the state-of-the-art lines Parameters ---------- Hs : hamiltonians.hamiltonian Custom hamiltonian class, it is needed to call plot_layout """ sidex = Hs.n_kappas sidey = Hs.n_hs xs = np.linspace(0,Hs.kappa_max,sidex) ys = np.linspace(0,Hs.h_max,sidey) # Mark every point of the parameter space to its corresponding phase according to the # state-of-the-art transition lines phases = [] for x in xs: for y in ys: if x == 0: if y <= 1: phases.append(0) else: phases.append(1) elif x <= .5: if y <= qmlgen.paraferro(x): phases.append(0) else: phases.append(1) else: if y <= qmlgen.paraanti(x): phases.append(2) else: phases.append(1) cmap = colors.ListedColormap(['palegreen', 'moccasin', 'lightblue']) bounds=[0,1,2,3] norm = colors.BoundaryNorm(bounds, cmap.N) plot_layout(Hs, pe_line=False, phase_lines=True, title = r"State of the art phases plot") plt.imshow(np.rot90(np.reshape(phases, (sidex,sidey))), cmap=cmap, norm = norm, aspect = Hs.n_kappas / Hs.n_hs)
# ____ ____ ____ ______ _______ # |___ \ \ \ / / / __ \ | ____| # __) | \ \/ / | | | | | |__ # |__ < \ / | | | | | __| # ___) | __ \ / | `--' '--.| |____ # |____/ (__) \__/ \_____\_____\_______| #
[docs]def VQE_show_isingchain(vqeclass): """ Shows results of a trained VQE (Nearest Neighbour Ising Model) run Parameters ---------- vqeclass : vqe.vqe Custom VQE class after being trained """ # Exit if the VQE was not trained for excited states true_e = vqeclass.true_e0 vqe_e = vqeclass.vqe_e0 title = "Ground States of Ising Hamiltonian ({0}-spins), J = {1}" lams = np.linspace(0, 2 * vqeclass.Hs.J, vqeclass.Hs.n_states) ax = plt.subplots(2, 1, figsize=(12, 6))[1] ax[0].plot(lams, true_e, "--", label="True", color="red", lw=3) ax[0].plot(lams, vqe_e, ".", label="VQE", color="green", lw=2) ax[0].plot(lams, vqe_e, color="green", lw=2, alpha=0.6) ax[0].grid(True) ax[0].set_title(title.format(vqeclass.Hs.N, vqeclass.Hs.J)) ax[0].set_xlabel(r"$\lambda$") ax[0].set_ylabel(r"$E(\lambda)$") ax[0].legend() accuracy = np.abs((true_e - vqe_e) / true_e) ax[1].fill_between(lams, 0.01, max(np.max(accuracy), 0.01), color="r", alpha=0.3) ax[1].fill_between(lams, 0.01, min(np.min(accuracy), 0), color="green", alpha=0.3) ax[1].axhline(y=0.01, color="r", linestyle="--") ax[1].scatter(lams, accuracy) ax[1].grid(True) ax[1].set_title("Accuracy of VQE") ax[1].set_xlabel(r"$\lambda$") ax[1].set_ylabel(r"$|(E_{vqe} - E_{true})/E_{true}|$") plt.tight_layout()
[docs]def VQE_show_annni(vqeclass, log_heatmap = False, plot3d=True, phase_lines = False, pe_line = False): """ Shows results of a trained VQE (ANNNI) run: Parameters ---------- vqeclass : vqe.vqe Custom VQE class after being trained log_heatmap : bool if True, the accuracy is displayed in logscale plot3d : bool if True the predicted energies and true energies will be displayed in a 3D plot phase_lines : bool if True plots the phase transition lines pe_line : bool if True plots Peshel Emery line """ sidex = vqeclass.Hs.n_kappas sidey = vqeclass.Hs.n_hs max_x = vqeclass.Hs.kappa_max max_y = vqeclass.Hs.h_max # Matrix of the true energies E_true trues = np.reshape(vqeclass.Hs.true_e0, (sidex, sidey)) # Matrix of the VQE energies E_pred preds = np.reshape(vqeclass.vqe_e0, (sidex, sidey)) # Accuracy := |E_true - E_pred|/|E_true| x = np.linspace(-max_x, 0, sidex) y = np.linspace(0, max_y, sidey) if plot3d: fig = go.Figure( data=[ # x and y needed to be swapped for it to properly show the graph go.Surface(opacity=0.2, colorscale="Reds", z=trues, x=y, y=x), go.Surface(opacity=1, colorscale="Blues", z=preds, x=y, y=x), ] ) fig.update_layout(height=500) fig.show() # Add the default layout (axes limits, names, ticks...) plot_layout(vqeclass.Hs, pe_line=pe_line, phase_lines=phase_lines, title = r"VQE, $N = {0}$".format(str(vqeclass.Hs.N))) # Accuracy := |E_true - E_pred|/|E_true| accuracy = np.rot90(np.abs(preds - trues) / np.abs(trues)) if not log_heatmap: colors_good = np.squeeze( np.dstack( ( np.dstack((np.linspace(0.3, 0, 25), np.linspace(0.8, 1, 25))), np.linspace(1, 0, 25), ) ) ) colors_bad = np.squeeze( np.dstack((np.dstack((np.linspace(1, 0, 100), [0] * 100)), [0] * 100)) ) colors = np.vstack((colors_good, colors_bad)) cmap_acc = LinearSegmentedColormap.from_list("accuracies", colors) plt.imshow(accuracy, cmap=cmap_acc, aspect = vqeclass.Hs.n_kappas / vqeclass.Hs.n_hs) plt.clim(0, 0.05) cbar = plt.colorbar(fraction=0.04) cbar.ax.tick_params(labelsize=16) else: plt.imshow(accuracy, norm=LogNorm(), aspect = vqeclass.Hs.n_kappas / vqeclass.Hs.n_hs) cbar = plt.colorbar(fraction=0.04) cbar.ax.tick_params(labelsize=16)
[docs]def VQE_psi_truepsi_fidelity(vqeclass, phase_lines = False, pe_line = False): """ For each VQE resulting state, show its fidelity compared to its true state obtained through diagonalization of the Hamiltonian: Parameters ---------- vqeclass : vqe.vqe Custom VQE class after being trained phase_lines : bool if True plots the phase transition lines pe_line : bool if True plots Peshel Emery line """ sidex = vqeclass.Hs.n_kappas sidey = vqeclass.Hs.n_hs # Prepare the quantum circuit to output the state @qml.qnode(vqeclass.device, interface="jax") def q_vqe_state(vqe_params): vqeclass.circuit(vqe_params) return qml.state() # Jit and vmapped function to compute the fidelity jv_fidelity = jax.jit(lambda true, pars: losses.vqe_fidelities(true, pars, q_vqe_state)) fidelity_map = jv_fidelity(vqeclass.Hs.true_psi0, vqeclass.vqe_params0) fidelity_map = np.reshape(fidelity_map, (sidex, sidey)) plot_layout(vqeclass.Hs, phase_lines=phase_lines, pe_line=pe_line, title=r"Fidelities, $N = {0}$".format(str(vqeclass.Hs.N))) plt.imshow(np.rot90(fidelity_map), aspect = vqeclass.Hs.n_kappas / vqeclass.Hs.n_hs) cbar = plt.colorbar(fraction=0.04) cbar.ax.tick_params(labelsize=16)
[docs]def VQE_fidelity_slice(vqeclass, slice_value, axis = 0, truestates = False): """ Shows confusion matrix of fidelities of only a 'slice' of states in the parameter space. In other words, it will be computed the fidelity of each state among every other that share the same h or kappa. Parameters ---------- vqeclass : vqe.vqe Custom VQE class after being trained slice_value : float if axis = 0, then we will pick only the states having h = slice_value and kappa whatever if axis = 1, then we will pick only the states having kappa = slice_value and h whatever axis : int Direction of where to slide, 0 is horizontal (fixed h), 1 is vertical (fixed kappa) truestates : bool if True the true states will be employed if False the VQE states will be employed """ ######################################################### # 1. Show the parameter space and the line of the slice # ######################################################### if axis == 0: plt.axhline(y = sidey - slice_value*sidey/ymax - .5, color='blue', lw=2) elif axis == 1: plt.axvline(x = slice_value*sidex/xmax - .5, color='blue', lw=2) else: raise ValueError('Invalid axis, it can only be either 0 or 1') vqeclass.Hs.show_phasesplot() sidey, ymax = vqeclass.Hs.n_hs, vqeclass.Hs.h_max sidex, xmax = vqeclass.Hs.n_kappas, vqeclass.Hs.kappa_max ticks_x = [-.5 , vqeclass.Hs.n_kappas/4 - .5, vqeclass.Hs.n_kappas/2 - .5 , 3*vqeclass.Hs.n_kappas/4 - .5, vqeclass.Hs.n_kappas - .5] ticks_y = [-.5 , vqeclass.Hs.n_hs/4 - .5, vqeclass.Hs.n_hs/2 - .5 , 3*vqeclass.Hs.n_hs/4 - .5, vqeclass.Hs.n_hs - .5] plt.show() ############################################################ # 2. Show the confusion matrix of fidelities of the states # ############################################################ def create_confusion_matrix(vectors): dimvec = len(vectors) c_matrix = np.zeros((dimvec, dimvec)) for i in range(dimvec): for j in range(dimvec): c_matrix[i,j] = np.square( np.real(vectors[i] @ np.conj(vectors[j]) ) ) plt.imshow(c_matrix, origin = 'lower') return c_matrix plt.figure(figsize=(8, 6), dpi=80) # H : 2 = index : side -> index = H * side / 2 starting_index = slice_value * vqeclass.Hs.n_hs / vqeclass.Hs.h_max if axis == 0: indexes = np.arange(starting_index,sidex*sidey,sidey).astype(int) print(indexes) print(len(indexes)) plt.xticks( ticks= ticks_x, labels=[np.round(k * vqeclass.Hs.kappa_max / 4, 2) for k in range(0, 5)], ) plt.yticks( ticks=ticks_x, labels=[np.round(k * vqeclass.Hs.kappa_max / 4, 2) for k in range(0, 5)], ) plt.ylabel(r"$\kappa$", fontsize=24) plt.xlabel(r"$\kappa$", fontsize=24) title = f'h = {slice_value}' else: indexes = np.arange(int(sidex*slice_value)*sidey,sidex*(int(sidey*slice_value))+sidex,1).astype(int) print(indexes) print(len(indexes)) plt.xticks( ticks= ticks_y, labels=[np.round(k * vqeclass.Hs.h_max / 4, 2) for k in range(4, -1, -1)], ) plt.yticks( ticks=ticks_y, labels=[np.round(k * vqeclass.Hs.h_max / 4, 2) for k in range(4, -1, -1)], ) plt.ylabel(r"$h$", fontsize=24) plt.xlabel(r"$h$", fontsize=24) title = f'k = {slice_value}' if truestates: # Check if we computed the true states, if not compute them try: vqeclass.Hs.true_psi0 except: # Compute vqeclass.Hs.true_psi0 vqeclass.Hs.add_true() confusion = create_confusion_matrix(np.array(vqeclass.Hs.true_psi0)[indexes]) else: # Quantum circuit for computing the states from the parameters @qml.qnode(vqeclass.device, interface="jax") def q_vqe_state(vqe_params): vqeclass.circuit(vqe_params) return qml.state() vqe_psi0 = jax.jit(jax.vmap(q_vqe_state, in_axes=(0)))(vqeclass.vqe_params0[indexes]) confusion = create_confusion_matrix(vqe_psi0) leg = plt.legend( bbox_to_anchor=(1, 1), loc="upper right", fontsize=16, facecolor="white", markerscale=1, framealpha=0.9, title=title, title_fontsize=16, ) cbar = plt.colorbar(fraction=0.04) cbar.ax.tick_params(labelsize=16) plt.tight_layout() plt.show()
# _ _ ____ _____ _ _ _ _ # | || | / __ \ / ____| \ | | \ | | # | || |_ | | | | | | \| | \| | # |__ _| | | | | | | . ` | . ` | # | |_ | |__| | |____| |\ | |\ | # |_(_) \___\_\\_____|_| \_|_| \_|
[docs]def QCNN_classification_ising(qcnnclass, train_index): """ Plots performance of the classifier on the whole data for a QCNN of a Nearest Neighbour Interaction Hamiltonian Parameters ---------- qcnnclass : qcnn.qcnn Custom QCNN class after being trained train_index : List[Number] List of the indexes of the training set. On displaying they will be marked with a different colour """ # Quantum Circuit to output the probabilities @qml.qnode(qcnnclass.device, interface="jax") def qcnn_circuit_prob(params_vqe, params): qcnnclass._vqe_qcnn_circuit(params_vqe, params) return qml.probs(wires=qcnnclass.N - 1) vcircuit = jax.vmap(lambda v: qcnn_circuit_prob(v, qcnnclass.params), in_axes=(0)) predictions = vcircuit(qcnnclass.vqe_params)[:, 1] # The test index is the set difference of the whole dataset and the training set test_index = np.setdiff1d(np.arange(len(qcnnclass.vqe_params)), train_index) predictions_train, colors_train = [], [] predictions_test, colors_test = [], [] for i, prediction in enumerate(predictions): # if data in training set if i in train_index: predictions_train.append(prediction) if np.round(prediction) == 0: colors_train.append("green") if qcnnclass.labels[ i ] == 0 else colors_train.append("red") else: colors_train.append("red") if qcnnclass.labels[ i ] == 0 else colors_train.append("green") else: predictions_test.append(prediction) if np.round(prediction) == 0: colors_test.append("green") if qcnnclass.labels[ i ] == 0 else colors_test.append("red") else: colors_test.append("red") if qcnnclass.labels[ i ] == 0 else colors_test.append("green") fig, ax = plt.subplots(2, 1, figsize=(16, 10)) ax[0].set_xlim(-0.1, 2.1) ax[0].set_ylim(0, 1) ax[0].grid(True) ax[0].axhline(y=0.5, color="gray", linestyle="--") ax[0].axvline(x=1, color="gray", linestyle="--") ax[0].text(0.375, 0.68, "I", fontsize=24, fontfamily="serif") ax[0].text(1.6, 0.68, "II", fontsize=24, fontfamily="serif") ax[0].set_ylabel("Prediction of label II") ax[0].set_title("Predictions of labels; J = 1") ax[0].scatter( 2 * np.sort(train_index) / len(qcnnclass.vqe_params), predictions_train, c="royalblue", label="Training samples", ) ax[0].scatter( 2 * np.sort(test_index) / len(qcnnclass.vqe_params), predictions_test, c="orange", label="Test samples", ) ax[0].legend() ax[1].set_xlim(-0.1, 2.1) ax[1].set_ylim(0, 1) ax[1].grid(True) ax[1].axhline(y=0.5, color="gray", linestyle="--") ax[1].axvline(x=1, color="gray", linestyle="--") ax[1].text(0.375, 0.68, "I", fontsize=24, fontfamily="serif") ax[1].text(1.6, 0.68, "II", fontsize=24, fontfamily="serif") ax[1].set_ylabel("Prediction of label II") ax[1].set_title("Predictions of labels; J = 1") ax[1].scatter( 2 * np.sort(train_index) / len(qcnnclass.vqe_params), predictions_train, c=colors_train, ) ax[1].scatter( 2 * np.sort(test_index) / len(qcnnclass.vqe_params), predictions_test, c=colors_test, )
[docs]def QCNN_classification_ANNNI_marginal(qcnnclass): """ Displays the probabilities of the states on the two axes. It is used more as a debug function to test if the classes are being trained correctly. Parameters ---------- qcnnclass : qcnn.qcnn Custom QCNN class after being trained """ @qml.qnode(qcnnclass.device, interface="jax") def qcnn_circuit_prob(params_vqe, params): qcnnclass._vqe_qcnn_circuit(params_vqe, params) return [qml.probs(wires=int(k)) for k in qcnnclass.final_active_wires] # Subset of the states on the two axes mask1 = jnp.array(qcnnclass.vqe.Hs.model_params)[:, 1] == 0 mask2 = jnp.array(qcnnclass.vqe.Hs.model_params)[:, 2] == 0 ising_1, label_1, x1 = ( qcnnclass.vqe_params[mask1], qcnnclass.labels[mask1, :].astype(int), np.arange( len(mask1[mask1 == True]) ) ) ising_2, label_2, x2 = ( qcnnclass.vqe_params[mask2], qcnnclass.labels[mask2, :].astype(int), np.arange( len(mask2[mask2 == True]) ) ) vcircuit = jax.vmap(lambda v: qcnn_circuit_prob(v, qcnnclass.params), in_axes=(0)) predictions1 = vcircuit(ising_1) predictions2 = vcircuit(ising_2) out1_p1, out2_p1, c1 = [], [], [] for idx, pred in enumerate(predictions1): out1_p1.append(pred[0][1]) out2_p1.append(pred[1][1]) if (np.argmax(pred[0]) == label_1[idx][0]) and ( np.argmax(pred[1]) == label_1[idx][1] ): c1.append("green") else: c1.append("red") fig, ax = plt.subplots(1, 2, figsize=(20, 6)) ax[0].grid(True) ax[0].scatter(x1, out1_p1, c=c1) ax[0].set_ylim(-0.1, 1.1) ax[1].grid(True) ax[1].scatter(x1, out2_p1, c=c1) ax[1].set_ylim(-0.1, 1.1) plt.show() out1_p2, out2_p2, c2 = [], [], [] for idx, pred in enumerate(predictions2): out1_p2.append(pred[0][1]) out2_p2.append(pred[1][1]) if (np.argmax(pred[0]) == label_2[idx][0]) and ( np.argmax(pred[1]) == label_2[idx][1] ): c2.append("green") else: c2.append("red") fig, ax = plt.subplots(1, 2, figsize=(20, 6)) ax[0].grid(True) ax[0].scatter(x2, out1_p2, c=c2) ax[0].set_ylim(-0.1, 1.1) ax[1].grid(True) ax[1].scatter(x2, out2_p2, c=c2) ax[1].set_ylim(-0.1, 1.1) plt.show()
[docs]def QCNN_classification_ANNNI( qcnnclass, hard_thr=True, predicted_line = False, label=False, info=False, ): """ Plots performance of the classifier on the whole data for a QCNN of a ANNI model Parameters ---------- qcnnclass : qcnn.qcnn Custom QCNN class after being trained hard_thr : bool if True the prediction will be displayed through an argmax instead of using color channels to entail the 3 (4 considering the trash case) probabilities predicted_line : bool if True it displays the predicted transition line label : str Label to assign to the picture, needed for the paper info : bool if True more infos will be displayed such as the names of the phases """ plt.figure(figsize=(8, 6), dpi=80) sidex = qcnnclass.vqe.Hs.n_kappas sidey = qcnnclass.vqe.Hs.n_hs predictions = qcnnclass.predict() if hard_thr: predictions = np.argmax(predictions, axis=1) phases = mpl.colors.ListedColormap( ["black", "skyblue", "yellow", "palegreen"] ) norm = mpl.colors.BoundaryNorm(np.arange(0, 5), phases.N) plt.imshow(np.rot90(np.reshape(predictions, (sidex, sidey))), cmap=phases, norm=norm, aspect = qcnnclass.vqe.Hs.n_kappas / qcnnclass.vqe.Hs.n_hs) else: mygreen = np.array([90, 255, 100]) / 255 myblue = np.array([50, 50, 200]) / 255 myyellow = np.array([300, 270, 0]) / 255 rgb_probs = np.ndarray(shape=(sidex * sidey, 3), dtype=float) for i, pred in enumerate(predictions): rgb_probs[i] = pred[3] * mygreen + pred[1] * myblue + pred[2] * myyellow rgb_probs = np.rot90(np.reshape(rgb_probs, (sidex, sidey, 3))) plt.imshow(rgb_probs, alpha=1, aspect = qcnnclass.vqe.Hs.n_kappas / qcnnclass.vqe.Hs.n_hs) if predicted_line: plt.plot(qcnnclass.predict_lines(predictions=predictions), color='magenta', label='Predicted Transition Lines') plot_layout(qcnnclass.vqe.Hs, False, True, r"QCNN, $N = {0}$".format(str(qcnnclass.N)),figure_already_defined=True) if label: plt.figtext(0.28, 0.79, "(" + label + ")", color="black", fontsize=20) if info: # Only for (x,y)=(1,2) parameter space Hamiltonians # TODO: add labels for whatever hamiltonian if (qcnnclass.vqe.Hs.h_max == 2) and (qcnnclass.vqe.Hs.kappa_max == 1): plt.text( sidex * 0.5, sidey * 0.4, "para.", color="black", fontsize=20, ha="center", va="center", ) plt.text( sidex * 0.18, sidey * 0.88, "ferro.", color="white", fontsize=20, ha="center", va="center", ) plt.text( sidex * 0.82, sidey * 0.88, "anti.", color="black", fontsize=20, ha="center", va="center", )
# _____ ______ _ # | ____| | ____| | | # | |__ | |__ _ __ ___ ___ __| | ___ _ __ # |___ \ | __| | '_ \ / __/ _ \ / _` |/ _ \ '__| # ___) | | |____| | | | (_| (_) | (_| | __/ | # |____(_) |______|_| |_|\___\___/ \__,_|\___|_|
[docs]def ENC_show_compression_ANNNI(encclass, trainingpoint=False, label=False, plot3d=False): """ Plots performance of the compression on the whole data for an encoder on the ANNI model Parameters ---------- encoder : encoder.encoder Custom encoder class after being trained trainingpoint : int Mark the single training point on the plot label : str Label to assign to the picture, needed for the paper plot3d : bool If True the 3D plot will be displayed aswell """ sidex = encclass.vqe.Hs.n_kappas sidey = encclass.vqe.Hs.n_hs max_x = encclass.vqe.Hs.kappa_max max_y = encclass.vqe.Hs.h_max x = np.linspace(-max_x, 0, sidex) y = np.linspace(0, max_y, sidey) X = jnp.array(encclass.vqe_params0) @qml.qnode(encclass.device, interface="jax") def encoder_circuit(vqe_params, params): encclass._vqe_enc_circuit(vqe_params, params) return [qml.expval(qml.PauliZ(int(k))) for k in encclass.wires_trash] v_encoder_circuit = jax.vmap(lambda p: encoder_circuit(p, encclass.params)) exps = (1 - np.sum(v_encoder_circuit(X), axis=1) / 4) / 2 exps = np.rot90(np.reshape(exps, (sidex, sidey))) if plot3d: fig = go.Figure(data=[go.Surface(z=exps, x=x, y=y)]) fig.update_layout(height=500) fig.show() plt.figure(figsize=(8, 6), dpi=80) plot_layout(encclass.vqe.Hs, pe_line=True, phase_lines=True, title='') plt.imshow(exps, aspect = encclass.vqe.Hs.n_kappas / encclass.vqe.Hs.n_hs) if type(trainingpoint) == int: train_x = trainingpoint // sidey train_y = sidey - trainingpoint % sidey if train_x == 0: train_x += 1 if train_y == sidey: train_y -= 2 plt.scatter( [train_x], [train_y], marker="+", s=300, color="orangered", label=r"Initial state $\left|\psi\right\rangle$", ) if label: plt.figtext(0.23, 0.79, "(" + label + ")", color="black", fontsize=20) plt.text( sidex * 0.5, sidey * 0.4, "para.", color="black", fontsize=22, ha="center", va="center", ) plt.text( sidex * 0.18, sidey * 0.88, "ferro.", color="white", fontsize=22, ha="center", va="center", ) plt.text( sidex * 0.82, sidey * 0.88, "anti.", color="black", fontsize=22, ha="center", va="center", ) leg = plt.legend( bbox_to_anchor=(1, 1), loc="upper right", fontsize=16, facecolor="white", markerscale=0.8, framealpha=0.9, title=r"AD, $N = {0}$".format(str(encclass.vqe.Hs.N)), title_fontsize=16, ) leg.get_frame().set_boxstyle("Square") cbar = plt.colorbar() cbar.ax.tick_params(labelsize=18)