百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

Siamese Networks:算法,应用程序和PyTorch实现

toyiye 2024-08-24 00:25 7 浏览 0 评论

由于暹罗网络在深度学习研究和应用程序中越来越受欢迎,我将解释什么是暹罗网络,并以PyTorch中一个简单的暹罗CNN网络为例进行总结。

什么是暹罗网络?

暹罗网络是包含两个或多个相同子网组件的神经网络。暹罗网络可能如下所示:

重要的是,不仅子网络的体系结构是相同的,而且必须在它们之间共享权重,使网络被称为“siamese”。siamese网络背后的主要思想是,他们可以学习有用的数据描述符,这些描述符可以进一步用于比较各个子网络的输入。因此,输入可以是数字数据(在这种情况下,子网络通常由完全连接的层组成)、图像数据(以CNN作为子网络),甚至是连续的数据,如句子或时间信号(以RNN为子网络)。

通常,暹罗网络在输出端执行二元分类。因此,在训练期间可以使用不同的损失函数。最流行的损失函数之一是二元交叉熵损失。这种损失可以计算为

,其中L是损失函数,y是类标签(0或1),p是预测。为了训练网络区分相似和不同的对象,我们可以一次给它一个正的和一个负的例子,并把损失加起来:

另一种使用triplet loss:

d是距离函数(例如L2损失),a是数据集的样本,p是随机正样本,n是负样本。m是任意边界,用于进一步分析正分数和负分数。

暹罗网络的应用

暹罗网络具有广泛的应用。这里有几个:

  • One-shot learning。在这个学习场景中,一个新的训练数据集被提供给训练过的(分类)网络,每个类只有一个样本。然后,在一个单独的测试数据集上测试这个新数据集的分类性能。当暹罗网络首先学习大型特定数据集的判别特征时,它们也可用于将这些知识推广到全新的类和分布。在(Koch,Gregory,Richard Zemel和Ruslan Salakhutdinov。“用于一次性图像识别的连体神经网络。”ICML Deep Learning Workshop.Vol.2。2015.)中,作者使用此功能进行一次性学习MNIST数据集使用在Omniglot数据集上训练的网络(完全不同的图像数据集)。
  • 用于视频监控的行人跟踪。在这项工作中,一个暹罗CNN网络与图像块的大小和位置特征相结合,通过检测它们在每个视频帧中的位置,学习多个帧之间的关联和计算,来跟踪摄像机视野中的多个人。轨迹。
  • Cosegmentation(Mukherjee,Prerana,Brejesh Lall和Snehith Lattupally。“使用深暹罗网络的对象分配。”arXiv preprint arXiv:1803.02555(2018)。)。
  • 匹配简历到工作。在这个应用程序中,该网络试图为应聘者找到匹配的工作岗位。为了做到这一点,一个训练有素的暹罗CNN网络从帖子和简历中提取深层上下文信息,并计算它们的语义相似性。假设匹配的简历-张贴配对在相似度上比不匹配的排序更高。

示例:在PyTorch中使用Siamese网络对MNIST图像进行分类

在解释了连体网络的基本原理之后,我们现在将在PyTorch中构建一个网络,以分类一对MNIST图像是否具有相同的数字。我们将使用二元交叉熵损失作为我们的训练损失函数,我们将使用精度测量来评估测试数据集上的网络。以下是这篇文章的完整Python代码:

import codecs
import errno
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import random
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torchvision.datasets.mnist
from torchvision import transforms
from tqdm import tqdm
 
do_learn = True
save_frequency = 2
batch_size = 16
lr = 0.001
num_epochs = 10
weight_decay = 0.0001
 
class BalancedMNISTPair(torch.utils.data.Dataset):
 """Dataset that on each iteration provides two random pairs of
 MNIST images. One pair is of the same number (positive sample), one
 is of two different numbers (negative sample).
 """
 urls = [
 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
 ]
 raw_folder = 'raw'
 processed_folder = 'processed'
 training_file = 'training.pt'
 test_file = 'test.pt'
 
 def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
 self.root = os.path.expanduser(root)
 self.transform = transform
 self.target_transform = target_transform
 self.train = train # training set or test set
 
 if download:
 self.download()
 
 if not self._check_exists():
 raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
 
 if self.train:
 self.train_data, self.train_labels = torch.load(
 os.path.join(self.root, self.processed_folder, self.training_file))
 
 train_labels_class = []
 train_data_class = []
 for i in range(10):
 indices = torch.squeeze((self.train_labels == i).nonzero())
 train_labels_class.append(torch.index_select(self.train_labels, 0, indices))
 train_data_class.append(torch.index_select(self.train_data, 0, indices))
 
 # generate balanced pairs
 self.train_data = []
 self.train_labels = []
 lengths = [x.shape[0] for x in train_labels_class]
 for i in range(10):
 for j in range(500): # create 500 pairs
 rnd_cls = random.randint(0,8) # choose random class that is not the same class
 if rnd_cls >= i:
 rnd_cls = rnd_cls + 1
 
 rnd_dist = random.randint(0, 100)
 
 self.train_data.append(torch.stack([train_data_class[i][j], train_data_class[i][j+rnd_dist], train_data_class[rnd_cls][j]]))
 self.train_labels.append([1,0])
 
 self.train_data = torch.stack(self.train_data)
 self.train_labels = torch.tensor(self.train_labels)
 
 else:
 self.test_data, self.test_labels = torch.load(
 os.path.join(self.root, self.processed_folder, self.test_file))
 
 test_labels_class = []
 test_data_class = []
 for i in range(10):
 indices = torch.squeeze((self.test_labels == i).nonzero())
 test_labels_class.append(torch.index_select(self.test_labels, 0, indices))
 test_data_class.append(torch.index_select(self.test_data, 0, indices))
 
 # generate balanced pairs
 self.test_data = []
 self.test_labels = []
 lengths = [x.shape[0] for x in test_labels_class]
 for i in range(10):
 for j in range(500): # create 500 pairs
 rnd_cls = random.randint(0,8) # choose random class that is not the same class
 if rnd_cls >= i:
 rnd_cls = rnd_cls + 1
 
 rnd_dist = random.randint(0, 100)
 
 self.test_data.append(torch.stack([test_data_class[i][j], test_data_class[i][j+rnd_dist], test_data_class[rnd_cls][j]]))
 self.test_labels.append([1,0])
 
 self.test_data = torch.stack(self.test_data)
 self.test_labels = torch.tensor(self.test_labels)
 
 def __getitem__(self, index):
 if self.train:
 imgs, target = self.train_data[index], self.train_labels[index]
 else:
 imgs, target = self.test_data[index], self.test_labels[index]
 
 img_ar = []
 for i in range(len(imgs)):
 img = Image.fromarray(imgs[i].numpy(), mode='L')
 if self.transform is not None:
 img = self.transform(img)
 img_ar.append(img)
 
 if self.target_transform is not None:
 target = self.target_transform(target)
 
 return img_ar, target
 
 def __len__(self):
 if self.train:
 return len(self.train_data)
 else:
 return len(self.test_data)
 
 def _check_exists(self):
 return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
 os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
 
 def download(self):
 """Download the MNIST data if it doesn't exist in processed_folder already."""
 from six.moves import urllib
 import gzip
 
 if self._check_exists():
 return
 
 # download files
 try:
 os.makedirs(os.path.join(self.root, self.raw_folder))
 os.makedirs(os.path.join(self.root, self.processed_folder))
 except OSError as e:
 if e.errno == errno.EEXIST:
 pass
 else:
 raise
 
 for url in self.urls:
 print('Downloading ' + url)
 data = urllib.request.urlopen(url)
 filename = url.rpartition('/')[2]
 file_path = os.path.join(self.root, self.raw_folder, filename)
 with open(file_path, 'wb') as f:
 f.write(data.read())
 with open(file_path.replace('.gz', ''), 'wb') as out_f, \
 gzip.GzipFile(file_path) as zip_f:
 out_f.write(zip_f.read())
 os.unlink(file_path)
 
 # process and save as torch files
 print('Processing...')
 
 training_set = (
 read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
 read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
 )
 test_set = (
 read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
 read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
 )
 with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
 torch.save(training_set, f)
 with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
 torch.save(test_set, f)
 
 print('Done!')
 
 def __repr__(self):
 fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
 fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
 tmp = 'train' if self.train is True else 'test'
 fmt_str += ' Split: {}\n'.format(tmp)
 fmt_str += ' Root Location: {}\n'.format(self.root)
 tmp = ' Transforms (if any): '
 fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
 tmp = ' Target Transforms (if any): '
 fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
 return fmt_str
 
class Net(nn.Module):
 def __init__(self):
 super().__init__()
 
 self.conv1 = nn.Conv2d(1, 64, 7)
 self.pool1 = nn.MaxPool2d(2)
 self.conv2 = nn.Conv2d(64, 128, 5)
 self.conv3 = nn.Conv2d(128, 256, 5)
 self.linear1 = nn.Linear(2304, 512)
 
 self.linear2 = nn.Linear(512, 2)
 
 def forward(self, data):
 res = []
 for i in range(2): # Siamese nets; sharing weights
 x = data[i]
 x = self.conv1(x)
 x = F.relu(x)
 x = self.pool1(x)
 x = self.conv2(x)
 x = F.relu(x)
 x = self.conv3(x)
 x = F.relu(x)
 
 x = x.view(x.shape[0], -1)
 x = self.linear1(x)
 res.append(F.relu(x))
 
 res = torch.abs(res[1] - res[0])
 res = self.linear2(res)
 return res
 
def train(model, device, train_loader, epoch, optimizer):
 model.train()
 
 for batch_idx, (data, target) in enumerate(train_loader):
 for i in range(len(data)):
 data[i] = data[i].to(device)
 
 optimizer.zero_grad()
 output_positive = model(data[:2])
 output_negative = model(data[0:3:2])
 
 target = target.type(torch.LongTensor).to(device)
 target_positive = torch.squeeze(target[:,0])
 target_negative = torch.squeeze(target[:,1])
 
 loss_positive = F.cross_entropy(output_positive, target_positive)
 loss_negative = F.cross_entropy(output_negative, target_negative)
 
 loss = loss_positive + loss_negative
 loss.backward()
 
 optimizer.step()
 if batch_idx % 10 == 0:
 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
 epoch, batch_idx*batch_size, len(train_loader.dataset), 100. * batch_idx*batch_size / len(train_loader.dataset),
 loss.item()))
 
def test(model, device, test_loader):
 model.eval()
 
 with torch.no_grad():
 accurate_labels = 0
 all_labels = 0
 loss = 0
 for batch_idx, (data, target) in enumerate(test_loader):
 for i in range(len(data)):
 data[i] = data[i].to(device)
 
 output_positive = model(data[:2])
 output_negative = model(data[0:3:2])
 
 target = target.type(torch.LongTensor).to(device)
 target_positive = torch.squeeze(target[:,0])
 target_negative = torch.squeeze(target[:,1])
 
 loss_positive = F.cross_entropy(output_positive, target_positive)
 loss_negative = F.cross_entropy(output_negative, target_negative)
 
 loss = loss + loss_positive + loss_negative
 
 accurate_labels_positive = torch.sum(torch.argmax(output_positive, dim=1) == target_positive).cpu()
 accurate_labels_negative = torch.sum(torch.argmax(output_negative, dim=1) == target_negative).cpu()
 
 accurate_labels = accurate_labels + accurate_labels_positive + accurate_labels_negative
 all_labels = all_labels + len(target_positive) + len(target_negative)
 
 accuracy = 100. * accurate_labels / all_labels
 print('Test accuracy: {}/{} ({:.3f}%)\tLoss: {:.6f}'.format(accurate_labels, all_labels, accuracy, loss))
 
def oneshot(model, device, data):
 model.eval()
 
 with torch.no_grad():
 for i in range(len(data)):
 data[i] = data[i].to(device)
 
 output = model(data)
 return torch.squeeze(torch.argmax(output, dim=1)).cpu().item()
 
def main():
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
 
 model = Net().to(device)
 
 if do_learn: # training mode
 train_loader = torch.utils.data.DataLoader(BalancedMNISTPair('../data', train=True, download=True, transform=trans), batch_size=batch_size, shuffle=True)
 test_loader = torch.utils.data.DataLoader(BalancedMNISTPair('../data', train=False, download=True, transform=trans), batch_size=batch_size, shuffle=False)
 
 optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
 for epoch in range(num_epochs):
 train(model, device, train_loader, epoch, optimizer)
 test(model, device, test_loader)
 if epoch & save_frequency == 0:
 torch.save(model, 'siamese_{:03}.pt'.format(epoch))
 else: # prediction
 prediction_loader = torch.utils.data.DataLoader(BalancedMNISTPair('../data', train=False, download=True, transform=trans), batch_size=1, shuffle=True)
 model.load_state_dict(torch.load(load_model_path))
 data = []
 data.extend(next(iter(prediction_loader))[0][:3:2])
 same = oneshot(model, device, data)
 if same > 0:
 print('These two images are of the same number')
 else:
 print('These two images are not of the same number')
 
if __name__ == '__main__':
 main()

如您所见,大部分代码都包括构建一个适当的Dataset类,它为我们提供随机的图像样本。为了训练网络,重要的是我们要得到一个平衡的数据集,有正的和负的样本。因此,在每次迭代中,我们同时提供这两种方法。数据集的代码很长,但最终很简单:对于每个数字(类)0-9,我们必须提供一个正对(另一个相同数字的图像)和一个负对(随机不同数字的图像)。

网络本身,在Net类中定义,是一个siamese tional neural network,由2个相同的子网络组成,每个子网络包含3个tional layer,内核大小分别为7、5和5,中间还有一个pooling层。在经过卷积层之后,我们让网络构建每个输入的一维描述符,方法是将特征扁平化,并将它们通过带有512个输出特征的线性层传递。注意,两个子网络中的层共享相同的权重。这允许网络为每个输入学习有意义的描述符,并使输出对称(输入的顺序应该与我们的目标无关)。

整个过程的关键步骤是下一个步骤:计算特征向量的平方距离。原则上,为了训练网络,我们可以使用三重损失和这个平方差的输出。但是,我使用二元交叉熵损失得到了更好的结果(收敛速度更快)。因此,我们在网络上附加一个带有两个输出特征的线性层(数量相同,数量不同)来获得逻辑。

代码中有三个主要的相关函数:训练函数、测试函数和预测函数。

在train函数中,我们向网络提供一个正样本和一个负样本(两对图像)。我们计算每个损失,并将它们相加(正样本的目标是1,负样本的目标是0)。

测试函数用于测量测试数据集中网络的准确性。我们在每个训练阶段结束后进行测试,观察训练进度,防止过拟合。

给定一对MNIST图像,该预测函数仅预测它们是否属于同一类。通过将全局变量do_learn设置为False,可以在培训结束后使用predict。

使用上面的实现,我能够在测试MNIST数据集上达到96%的准确率。

相关推荐

# Python 3 # Python 3字典Dictionary(1)

Python3字典字典是另一种可变容器模型,且可存储任意类型对象。字典的每个键值(key=>value)对用冒号(:)分割,每个对之间用逗号(,)分割,整个字典包括在花括号({})中,格式如...

Python第八课:数据类型中的字典及其函数与方法

Python3字典字典是另一种可变容器模型,且可存储任意类型对象。字典的每个键值...

Python中字典详解(python 中字典)

字典是Python中使用键进行索引的重要数据结构。它们是无序的项序列(键值对),这意味着顺序不被保留。键是不可变的。与列表一样,字典的值可以保存异构数据,即整数、浮点、字符串、NaN、布尔值、列表、数...

Python3.9又更新了:dict内置新功能,正式版十月见面

机器之心报道参与:一鸣、JaminPython3.8的热乎劲还没过去,Python就又双叒叕要更新了。近日,3.9版本的第四个alpha版已经开源。从文档中,我们可以看到官方透露的对dic...

Python3 基本数据类型详解(python三种基本数据类型)

文章来源:加米谷大数据Python中的变量不需要声明。每个变量在使用前都必须赋值,变量赋值以后该变量才会被创建。在Python中,变量就是变量,它没有类型,我们所说的"类型"是变...

一文掌握Python的字典(python字典用法大全)

字典是Python中最强大、最灵活的内置数据结构之一。它们允许存储键值对,从而实现高效的数据检索、操作和组织。本文深入探讨了字典,涵盖了它们的创建、操作和高级用法,以帮助中级Python开发...

超级完整|Python字典详解(python字典的方法或操作)

一、字典概述01字典的格式Python字典是一种可变容器模型,且可存储任意类型对象,如字符串、数字、元组等其他容器模型。字典的每个键值key=>value对用冒号:分割,每个对之间用逗号,...

Python3.9版本新特性:字典合并操作的详细解读

处于测试阶段的Python3.9版本中有一个新特性:我们在使用Python字典时,将能够编写出更可读、更紧凑的代码啦!Python版本你现在使用哪种版本的Python?3.7分?3.5分?还是2.7...

python 自学,字典3(一些例子)(python字典有哪些基本操作)

例子11;如何批量复制字典里的内容2;如何批量修改字典的内容3;如何批量修改字典里某些指定的内容...

Python3.9中的字典合并和更新,几乎影响了所有Python程序员

全文共2837字,预计学习时长9分钟Python3.9正在积极开发,并计划于今年10月发布。2月26日,开发团队发布了alpha4版本。该版本引入了新的合并(|)和更新(|=)运算符,这个新特性几乎...

Python3大字典:《Python3自学速查手册.pdf》限时下载中

最近有人会想了,2022了,想学Python晚不晚,学习python有前途吗?IT行业行业薪资高,发展前景好,是很多求职群里严重的香饽饽,而要进入这个高薪行业,也不是那么轻而易举的,拿信工专业的大学生...

python学习——字典(python字典基本操作)

字典Python的字典数据类型是基于hash散列算法实现的,采用键值对(key:value)的形式,根据key的值计算value的地址,具有非常快的查取和插入速度。但它是无序的,包含的元素个数不限,值...

324页清华教授撰写【Python 3 菜鸟查询手册】火了,小白入门字典

如何入门学习python...

Python3.9中的字典合并和更新,了解一下

全文共2837字,预计学习时长9分钟Python3.9正在积极开发,并计划于今年10月发布。2月26日,开发团队发布了alpha4版本。该版本引入了新的合并(|)和更新(|=)运算符,这个新特性几乎...

python3基础之字典(python中字典的基本操作)

字典和列表一样,也是python内置的一种数据结构。字典的结构如下图:列表用中括号[]把元素包起来,而字典是用大括号{}把元素包起来,只不过字典的每一个元素都包含键和值两部分。键和值是一一对应的...

取消回复欢迎 发表评论:

请填写验证码