diff --git a/paddle_quantum/utils.py b/paddle_quantum/utils.py index 26c289ec651e8e8bcd8a42a5461fc8d566a6d54a..593ec71a84bb97892e0d97822847e5dd507143e4 100644 --- a/paddle_quantum/utils.py +++ b/paddle_quantum/utils.py @@ -1284,60 +1284,114 @@ def __plot_bloch_sphere( def plot_n_qubit_state_in_bloch_sphere( state, - show_qubits=None, - **args + which_qubits=None, + save_gif=False, + filename=None, + view_angle=None, + view_dist=None, + set_color='#0000FF' ): r"""将输入的多量子比特的量子态展示在 Bloch 球面上 Args: - state (list(numpy.ndarray or paddle.Tensor)): 输入的量子态列表,可以支持态矢量和密度矩阵, + state (numpy.ndarray or paddle.Tensor): 输入的量子态,可以支持态矢量和密度矩阵, 该函数下,列表内每一个量子态对应一张单独的图片 - show_qubits(list(list)):若为多量子比特,则给出要展示的量子比特,默认为 None,表示全展示 + which_qubits(list or None):若为多量子比特,则给出要展示的量子比特,默认为 None,表示全展示 show_arrow (bool): 是否展示向量的箭头,默认为 ``False`` save_gif (bool): 是否存储 gif 动图,默认为 ``False`` filename (str): 存储的 gif 动图的名字 view_angle (list or tuple): 视图的角度, 第一个元素为关于 xy 平面的夹角 [0-360],第二个元素为关于 xz 平面的夹角 [0-360], 默认为 ``(30, 45)`` view_dist (int): 视图的距离,默认为 7 - set_color (str): 若要设置指定的颜色,请查阅 ``cmap`` 表。默认为红色到黑色的渐变颜色 + set_color (str): 若要设置指定的颜色,请查阅 ``cmap`` 表。默认为蓝色 """ - assert type(state) == list or type(state) == paddle.Tensor or type(state) == np.ndarray, \ - 'the type of "state" must be "list" or "paddle.Tensor" or "np.ndarray".' - if type(state) == paddle.Tensor or type(state) == np.ndarray: - state = [state] - state_len = len(state) - assert state_len >= 1, '"state" is NULL.' - for i in range(state_len): - assert type(state[i]) == paddle.Tensor or type(state[i]) == np.ndarray, \ - 'the type of "state[i]" should be "paddle.Tensor" or "numpy.ndarray".' + # Check input data + __input_args_dtype_check(show_arrow, save_gif, filename, view_angle, view_dist) + + assert type(state) == paddle.Tensor or type(state) == np.ndarray, \ + 'the type of "state" must be "paddle.Tensor" or "np.ndarray".' + assert type(set_color) == str, \ + 'the type of "set_color" should be "str".' + + n_qubits = np.log2(state.shape[0]) - if show_qubits is None: - show_qubits = [None]*state_len + if which_qubits is None: + which_qubits = list(range(int(n_qubits))) else: - assert len(show_qubits)==state_len,'show_qubits大小需要和state相同' - for i in range(state_len): - assert type(show_qubits[i])==list,'the type of show_qubits should be None or list' - assert 0=2 and state.size==2*state.shape[0]: + state_vector = state + state = np.outer(state_vector, np.conj(state_vector)) + rho = paddle.to_tensor(state) + tmp_s = [] + for q in which_qubits: + tmp_s.append(partial_trace_discontiguous(rho,[q])) + state = tmp_s + state_len = len(state) + # Calc the bloch_vectors + bloch_vector_list = [] for i in range(state_len): - #若为多量子态 - if state[i].shape[0]>2: - s = [] - if show_qubits[i] is None: - qubits_list = list(range(int(np.log2(state[i].shape[0])))) - else: - qubits_list = show_qubits[i] - rho = paddle.to_tensor(state[i]) - for q in qubits_list: - s.append(partial_trace_discontiguous(rho,[q])) - #多量子态的子图的箭头向量改为蓝色 - plot_state_in_bloch_sphere(s,n_qubit=len(qubits_list),set_color='#0000FF',**args) - else: - plot_state_in_bloch_sphere(state[i],**args) + bloch_vector_tmp = __density_matrix_convert_to_bloch_vector(state[i]) + bloch_vector_list.append(bloch_vector_tmp) + + # List must be converted to array for slicing. + bloch_vectors = np.array(bloch_vector_list) + + # A update function for animation class + def update(frame): + view_rotating_angle = 5 + new_view_angle = [view_angle[0], view_angle[1] + view_rotating_angle * frame] + __plot_bloch_sphere( + ax, bloch_vectors, show_arrow, clear_plt=True, + view_angle=new_view_angle, view_dist=view_dist, set_color=set_color + ) + + # Dynamic update and save + if save_gif: + # Helper function to plot vectors on a sphere. + fig = plt.figure(figsize=(8, 8), dpi=100) + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + ax = fig.add_subplot(111, projection='3d') + + frames_num = 7 + anim = animation.FuncAnimation(fig, update, frames=frames_num, interval=600, repeat=False) + anim.save(filename, dpi=100, writer='pillow') + # close the plt + plt.close(fig) + + # Helper function to plot vectors on a sphere. + fig = plt.figure(figsize=(8, 8), dpi=100) + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + dim = np.ceil(sqrt(n_qubit)) + for i in range(1,n_qubit+1): + ax = fig.add_subplot(dim,dim,i,projection='3d') + bloch_vector=np.array([bloch_vectors[i-1]]) + __plot_bloch_sphere( + ax, bloch_vector, show_arrow, clear_plt=True, + view_angle=view_angle, view_dist=view_dist, set_color=set_color + ) + plt.show() def plot_state_in_bloch_sphere( state, - qubits_list=None, show_arrow=False, save_gif=False, filename=None, @@ -1349,7 +1403,6 @@ def plot_state_in_bloch_sphere( Args: state (list(numpy.ndarray or paddle.Tensor)): 输入的量子态列表,可以支持态矢量和密度矩阵 - n_qubit (int): 若为多量子态,则为需要展示的量子比特的数目 show_arrow (bool): 是否展示向量的箭头,默认为 ``False`` save_gif (bool): 是否存储 gif 动图,默认为 ``False`` filename (str): 存储的 gif 动图的名字 @@ -1428,22 +1481,13 @@ def plot_state_in_bloch_sphere( fig = plt.figure(figsize=(8, 8), dpi=100) fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - if qubits_list is None:#若为单量子态 - ax = fig.add_subplot(111, projection='3d') - __plot_bloch_sphere( - ax, bloch_vectors, show_arrow, clear_plt=True, - view_angle=view_angle, view_dist=view_dist, set_color=set_color - ) - else: #若为多量子态 - dim = np.ceil(sqrt(n_qubit)) - for i in range(1,n_qubit+1): - ax = fig.add_subplot(dim,dim,i,projection='3d') - bloch_vector=np.array([bloch_vectors[i-1]]) - __plot_bloch_sphere( - ax, bloch_vector, show_arrow, clear_plt=True, - view_angle=view_angle, view_dist=view_dist, set_color=set_color - ) - + + ax = fig.add_subplot(111, projection='3d') + __plot_bloch_sphere( + ax, bloch_vectors, show_arrow, clear_plt=True, + view_angle=view_angle, view_dist=view_dist, set_color=set_color + ) + plt.show()