Contenu du cours
Essentiels de Pytorch
Essentiels de Pytorch
Formes et Dimensions dans PyTorch
De même que pour les tableaux NumPy, la forme d'un tenseur détermine ses dimensions. Vous pouvez inspecter la forme d'un tenseur en utilisant l'attribut .shape
:
import torch tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) print(f"Tensor shape: {tensor.shape}")
Remodelage des Tenseurs avec view
La méthode .view()
crée une nouvelle vue du tenseur avec la forme spécifiée sans modifier le tenseur original. Le nombre total d'éléments doit rester le même.
import torch tensor = torch.arange(12) # Reshape a tensor to 4x3 reshaped_tensor = tensor.view(4, 3) print(f"Reshaped tensor: {reshaped_tensor}") # Original tensor remains unchanged print(f"Original tensor: {tensor}")
Remodelage des tenseurs avec reshape
La méthode .reshape()
est similaire à .view()
mais peut gérer les cas où le tenseur n'est pas stocké de manière contiguë en mémoire. Elle ne modifie pas non plus le tenseur original.
import torch tensor = torch.arange(12) # Reshape a tensor to 6x2 reshaped_tensor = tensor.reshape(6, 2) print(f"Reshaped tensor: {reshaped_tensor}")
Utilisation des dimensions négatives
Vous pouvez utiliser -1
dans la forme pour laisser PyTorch déduire la taille d'une dimension en fonction du nombre total d'éléments.
import torch tensor = torch.arange(12) # Automatically infer the second dimension inferred_tensor = tensor.view(2, -1) print("Inferred Tensor:", inferred_tensor)
Comprendre les Vues de Tenseur
Une vue d'un tenseur partage les mêmes données avec le tenseur original. Les modifications apportées à la vue affectent le tenseur original et vice versa.
import torch tensor = torch.arange(12) view_tensor = tensor.view(2, 6) view_tensor[0, 0] = 999 # Changes in the view are reflected in the original tensor print("View Tensor:", view_tensor) print("Original Tensor:", tensor)
Changer les Dimensions
Les deux méthodes suivantes vous permettent d'ajouter ou de supprimer des dimensions :
unsqueeze(dim)
ajoute une nouvelle dimension à la position spécifiée ;squeeze(dim)
supprime les dimensions de taille 1.
import torch tensor = torch.arange(12) # Add a new dimension unsqueezed_tensor = tensor.unsqueeze(0) # Add a batch dimension print(f"Unsqueezed tensor: {unsqueezed_tensor.shape}") # Remove a dimension of size 1 squeezed_tensor = unsqueezed_tensor.squeeze(0) print(f"Squeezed Tensor: {squeezed_tensor.shape}")
Merci pour vos commentaires !