未验证 提交 9fd4fc63 编写于 作者: Y yangguohao 提交者: GitHub

Update utils.py

上级 27ad07b1
......@@ -1310,14 +1310,7 @@ def plot_n_qubit_state_in_bloch_sphere(
for i in range(state_len):
assert type(state[i]) == paddle.Tensor or type(state[i]) == np.ndarray, \
'the type of "state[i]" should be "paddle.Tensor" or "numpy.ndarray".'
if show_qubits is None:
show_qubits = [None]*state_len
else:
assert len(show_qubits)==state_len,'show_qubits大小需要和state相同'
for i in range(state_len):
assert type(show_qubits[i])==list,'the type of show_qubits should be None or list'
# Convert Tensor to numpy
for i in range(state_len):
if type(state[i]) == paddle.Tensor:
......@@ -1328,23 +1321,33 @@ def plot_n_qubit_state_in_bloch_sphere(
if state[i].size == 2:
state_vector = state[i]
state[i] = np.outer(state_vector, np.conj(state_vector))
if show_qubits is None:
show_qubits = [None]*state_len
else:
assert len(show_qubits)==state_len,'show_qubits大小需要和state相同'
for i in range(state_len):
assert type(show_qubits[i])==list,'the type of show_qubits should be None or list'
assert 0<len(show_qubits[i])<=state[i].shape[0], '要展示的量子比特数目错误'
for i in range(state_len):
#若为多量子态
if state[i].shape[0]>2:
s = []
if show_qubits[i] is None:
qubits_list = [*range(int(np.log2(state[i].shape[0])))]
qubits_list = list(range(int(np.log2(state[i].shape[0]))))
else:
qubits_list = show_qubits[i]
rho = paddle.to_tensor(state[i])
for q in qubits_list:
s.append(partial_trace_discontiguous(rho,[q]))
plot_state_in_bloch_sphere(s,**args)
#多量子态的子图的箭头向量改为蓝色
plot_state_in_bloch_sphere(s,n_qubit=len(qubits_list),set_color='#0000FF',**args)
else:
plot_state_in_bloch_sphere(state[i],**args)
def plot_state_in_bloch_sphere(
state,
qubits_list=None,
show_arrow=False,
save_gif=False,
filename=None,
......@@ -1356,6 +1359,7 @@ def plot_state_in_bloch_sphere(
Args:
state (list(numpy.ndarray or paddle.Tensor)): 输入的量子态列表,可以支持态矢量和密度矩阵
n_qubit (int): 若为多量子态,则为需要展示的量子比特的数目
show_arrow (bool): 是否展示向量的箭头,默认为 ``False``
save_gif (bool): 是否存储 gif 动图,默认为 ``False``
filename (str): 存储的 gif 动图的名字
......@@ -1433,13 +1437,23 @@ def plot_state_in_bloch_sphere(
# Helper function to plot vectors on a sphere.
fig = plt.figure(figsize=(8, 8), dpi=100)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
ax = fig.add_subplot(111, projection='3d')
__plot_bloch_sphere(
if qubits_list is None:#若为单量子态
ax = fig.add_subplot(111, projection='3d')
__plot_bloch_sphere(
ax, bloch_vectors, show_arrow, clear_plt=True,
view_angle=view_angle, view_dist=view_dist, set_color=set_color
)
)
else: #若为多量子态
dim = np.ceil(sqrt(n_qubit))
for i in range(1,n_qubit+1):
ax = fig.add_subplot(dim,dim,i,projection='3d')
bloch_vector=np.array([bloch_vectors[i-1]])
__plot_bloch_sphere(
ax, bloch_vector, show_arrow, clear_plt=True,
view_angle=view_angle, view_dist=view_dist, set_color=set_color
)
plt.show()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册