GAN in a Layman Term:

  • Generative modeling is an unsupervised learning task in machine learning that involves automatically discovering and learning the regularities or patterns in input data in such a way that the model can be used to generate or output new examples that plausibly could have been drawn from the original dataset.
  • It consist two part that is Generator and Discriminator.
  • Discriminator :

    • In this network, the input is the real data and the output is the probability of the image is real or fake.
    • Also output of the Generator is given as a input to the Discriminator and it predict whether the data that are given by a Generator is real or fake.
    • With many iteration and back propagation the Generator start to create a data that is same as the real data and the Discriminator start to predict the data that are created by the Generator is real.
  • Generator :

    • In this network, the input is the noise and the output is the real data.
    • Output of this model is same as the input of the Discriminator since it needs to be fed to the Discriminator.

Importing Required Libraries

import torch
from torch import nn
import math
import matplotlib.pyplot as plt
torch.manual_seed(111)
<torch._C.Generator at 0x1e2413ee2b0>

Generating Sinusoidal Data

  • Here we are generating sinusoidal data, with 1024 rows and 2 columns.
train_data_length = 1024
train_data = torch.zeros((train_data_length, 2))
train_data[:, 0] = 2 * math.pi * torch.rand(train_data_length)
train_data[:, 1] = torch.sin(train_data[:, 0])
train_labels = torch.zeros(train_data_length)
train_set = [
    (train_data[i], train_labels[i]) for i in range(train_data_length)
]
  • Our train Labels are all 0 here.
train_labels
tensor([0., 0., 0.,  ..., 0., 0., 0.])
  • Printing training data which is 1024x2.
train_data
tensor([[ 4.4960, -0.9767],
        [ 5.7428, -0.5145],
        [ 1.7710,  0.9800],
        ...,
        [ 4.4772, -0.9725],
        [ 3.2305, -0.0887],
        [ 4.0663, -0.7984]])
  • Printing our train set, and we see it consist of traing data which is 1024x2 and training labels all 0.
train_set
[(tensor([ 4.4960, -0.9767]), tensor(0.)),
 (tensor([ 5.7428, -0.5145]), tensor(0.)),
 (tensor([1.7710, 0.9800]), tensor(0.)),
 (tensor([1.6217, 0.9987]), tensor(0.)),
 (tensor([ 3.9654, -0.7337]), tensor(0.)),
 (tensor([ 3.7702, -0.5881]), tensor(0.)),
 (tensor([ 5.8509, -0.4190]), tensor(0.)),
 (tensor([1.3527, 0.9763]), tensor(0.)),
 (tensor([ 3.7905, -0.6043]), tensor(0.)),
 (tensor([ 4.6042, -0.9942]), tensor(0.)),
 (tensor([1.1669, 0.9195]), tensor(0.)),
 (tensor([ 3.2048, -0.0632]), tensor(0.)),
 (tensor([ 4.7404, -0.9996]), tensor(0.)),
 (tensor([1.8120, 0.9710]), tensor(0.)),
 (tensor([ 3.6283, -0.4678]), tensor(0.)),
 (tensor([0.2252, 0.2233]), tensor(0.)),
 (tensor([0.6448, 0.6011]), tensor(0.)),
 (tensor([2.1483, 0.8378]), tensor(0.)),
 (tensor([2.7645, 0.3682]), tensor(0.)),
 (tensor([ 3.9842, -0.7464]), tensor(0.)),
 (tensor([ 3.9106, -0.6954]), tensor(0.)),
 (tensor([ 3.9801, -0.7436]), tensor(0.)),
 (tensor([ 5.3908, -0.7786]), tensor(0.)),
 (tensor([0.9877, 0.8348]), tensor(0.)),
 (tensor([ 4.9343, -0.9755]), tensor(0.)),
 (tensor([0.9158, 0.7931]), tensor(0.)),
 (tensor([2.6242, 0.4946]), tensor(0.)),
 (tensor([1.9143, 0.9416]), tensor(0.)),
 (tensor([0.2399, 0.2377]), tensor(0.)),
 (tensor([ 3.6471, -0.4843]), tensor(0.)),
 (tensor([1.3126, 0.9668]), tensor(0.)),
 (tensor([2.4904, 0.6061]), tensor(0.)),
 (tensor([2.2160, 0.7990]), tensor(0.)),
 (tensor([ 3.4643, -0.3172]), tensor(0.)),
 (tensor([1.8984, 0.9468]), tensor(0.)),
 (tensor([ 5.0552, -0.9418]), tensor(0.)),
 (tensor([ 5.8631, -0.4079]), tensor(0.)),
 (tensor([1.8540, 0.9602]), tensor(0.)),
 (tensor([ 3.7585, -0.5785]), tensor(0.)),
 (tensor([ 3.8393, -0.6425]), tensor(0.)),
 (tensor([1.5985, 0.9996]), tensor(0.)),
 (tensor([0.4236, 0.4110]), tensor(0.)),
 (tensor([1.1676, 0.9198]), tensor(0.)),
 (tensor([1.5580, 0.9999]), tensor(0.)),
 (tensor([ 4.3549, -0.9368]), tensor(0.)),
 (tensor([ 4.5587, -0.9882]), tensor(0.)),
 (tensor([ 4.3005, -0.9164]), tensor(0.)),
 (tensor([1.3398, 0.9734]), tensor(0.)),
 (tensor([ 4.9141, -0.9797]), tensor(0.)),
 (tensor([0.0834, 0.0833]), tensor(0.)),
 (tensor([2.0378, 0.8929]), tensor(0.)),
 (tensor([ 5.7031, -0.5481]), tensor(0.)),
 (tensor([ 5.3871, -0.7809]), tensor(0.)),
 (tensor([1.2457, 0.9476]), tensor(0.)),
 (tensor([0.1676, 0.1668]), tensor(0.)),
 (tensor([ 5.7513, -0.5072]), tensor(0.)),
 (tensor([1.6272, 0.9984]), tensor(0.)),
 (tensor([1.5507, 0.9998]), tensor(0.)),
 (tensor([ 6.0537, -0.2275]), tensor(0.)),
 (tensor([ 4.7265, -0.9999]), tensor(0.)),
 (tensor([1.3090, 0.9659]), tensor(0.)),
 (tensor([0.6934, 0.6392]), tensor(0.)),
 (tensor([ 3.8849, -0.6767]), tensor(0.)),
 (tensor([ 5.0472, -0.9445]), tensor(0.)),
 (tensor([ 4.4803, -0.9732]), tensor(0.)),
 (tensor([1.7774, 0.9787]), tensor(0.)),
 (tensor([1.2704, 0.9552]), tensor(0.)),
 (tensor([ 6.1268, -0.1557]), tensor(0.)),
 (tensor([ 3.8282, -0.6339]), tensor(0.)),
 (tensor([3.1029, 0.0387]), tensor(0.)),
 (tensor([2.1285, 0.8485]), tensor(0.)),
 (tensor([ 4.9231, -0.9779]), tensor(0.)),
 (tensor([ 3.4332, -0.2875]), tensor(0.)),
 (tensor([ 6.0792, -0.2025]), tensor(0.)),
 (tensor([1.0637, 0.8742]), tensor(0.)),
 (tensor([ 4.9705, -0.9669]), tensor(0.)),
 (tensor([ 4.3160, -0.9225]), tensor(0.)),
 (tensor([1.6372, 0.9978]), tensor(0.)),
 (tensor([2.0695, 0.8782]), tensor(0.)),
 (tensor([1.1154, 0.8981]), tensor(0.)),
 (tensor([ 4.0235, -0.7720]), tensor(0.)),
 (tensor([1.1981, 0.9314]), tensor(0.)),
 (tensor([ 3.3908, -0.2466]), tensor(0.)),
 (tensor([ 3.8165, -0.6248]), tensor(0.)),
 (tensor([ 4.1180, -0.8285]), tensor(0.)),
 (tensor([ 6.0550, -0.2262]), tensor(0.)),
 (tensor([3.0711, 0.0704]), tensor(0.)),
 (tensor([2.4257, 0.6563]), tensor(0.)),
 (tensor([1.1897, 0.9282]), tensor(0.)),
 (tensor([ 5.6532, -0.5892]), tensor(0.)),
 (tensor([2.1624, 0.8300]), tensor(0.)),
 (tensor([ 3.8246, -0.6311]), tensor(0.)),
 (tensor([ 4.5474, -0.9864]), tensor(0.)),
 (tensor([ 6.2220, -0.0611]), tensor(0.)),
 (tensor([ 3.1633, -0.0217]), tensor(0.)),
 (tensor([1.8252, 0.9678]), tensor(0.)),
 (tensor([1.9630, 0.9241]), tensor(0.)),
 (tensor([1.2379, 0.9451]), tensor(0.)),
 (tensor([3.0619, 0.0797]), tensor(0.)),
 (tensor([2.4039, 0.6726]), tensor(0.)),
 (tensor([1.3017, 0.9640]), tensor(0.)),
 (tensor([2.2538, 0.7757]), tensor(0.)),
 (tensor([3.0810, 0.0605]), tensor(0.)),
 (tensor([ 3.3202, -0.1776]), tensor(0.)),
 (tensor([ 4.6913, -0.9998]), tensor(0.)),
 (tensor([ 5.0812, -0.9328]), tensor(0.)),
 (tensor([ 3.9149, -0.6985]), tensor(0.)),
 (tensor([ 5.4772, -0.7215]), tensor(0.)),
 (tensor([1.3155, 0.9676]), tensor(0.)),
 (tensor([ 5.7844, -0.4784]), tensor(0.)),
 (tensor([ 4.8668, -0.9881]), tensor(0.)),
 (tensor([ 5.4700, -0.7265]), tensor(0.)),
 (tensor([ 3.7131, -0.5409]), tensor(0.)),
 (tensor([2.1264, 0.8496]), tensor(0.)),
 (tensor([ 4.7625, -0.9987]), tensor(0.)),
 (tensor([1.8958, 0.9476]), tensor(0.)),
 (tensor([ 3.4312, -0.2855]), tensor(0.)),
 (tensor([ 5.3313, -0.8145]), tensor(0.)),
 (tensor([ 6.0557, -0.2255]), tensor(0.)),
 (tensor([ 5.3079, -0.8279]), tensor(0.)),
 (tensor([1.5122, 0.9983]), tensor(0.)),
 (tensor([ 4.1322, -0.8364]), tensor(0.)),
 (tensor([ 4.5755, -0.9906]), tensor(0.)),
 (tensor([ 5.8580, -0.4125]), tensor(0.)),
 (tensor([ 5.7319, -0.5238]), tensor(0.)),
 (tensor([1.9002, 0.9462]), tensor(0.)),
 (tensor([1.8911, 0.9491]), tensor(0.)),
 (tensor([0.5134, 0.4912]), tensor(0.)),
 (tensor([ 5.3492, -0.8040]), tensor(0.)),
 (tensor([2.6993, 0.4280]), tensor(0.)),
 (tensor([0.4159, 0.4040]), tensor(0.)),
 (tensor([ 4.1940, -0.8686]), tensor(0.)),
 (tensor([ 4.2190, -0.8807]), tensor(0.)),
 (tensor([ 3.7647, -0.5835]), tensor(0.)),
 (tensor([1.3597, 0.9778]), tensor(0.)),
 (tensor([1.2901, 0.9609]), tensor(0.)),
 (tensor([ 6.0611, -0.2203]), tensor(0.)),
 (tensor([ 4.1504, -0.8462]), tensor(0.)),
 (tensor([2.3733, 0.6949]), tensor(0.)),
 (tensor([0.8172, 0.7292]), tensor(0.)),
 (tensor([1.1567, 0.9155]), tensor(0.)),
 (tensor([ 4.3609, -0.9389]), tensor(0.)),
 (tensor([ 3.7468, -0.5689]), tensor(0.)),
 (tensor([ 6.0551, -0.2261]), tensor(0.)),
 (tensor([ 5.2761, -0.8453]), tensor(0.)),
 (tensor([2.7248, 0.4048]), tensor(0.)),
 (tensor([0.9058, 0.7869]), tensor(0.)),
 (tensor([1.5087, 0.9981]), tensor(0.)),
 (tensor([2.9453, 0.1950]), tensor(0.)),
 (tensor([2.0185, 0.9014]), tensor(0.)),
 (tensor([ 4.2701, -0.9038]), tensor(0.)),
 (tensor([1.0664, 0.8755]), tensor(0.)),
 (tensor([1.7118, 0.9901]), tensor(0.)),
 (tensor([0.1772, 0.1763]), tensor(0.)),
 (tensor([2.6852, 0.4407]), tensor(0.)),
 (tensor([ 4.2353, -0.8883]), tensor(0.)),
 (tensor([2.2587, 0.7726]), tensor(0.)),
 (tensor([ 3.3514, -0.2083]), tensor(0.)),
 (tensor([2.6058, 0.5105]), tensor(0.)),
 (tensor([2.9358, 0.2043]), tensor(0.)),
 (tensor([ 4.7218, -1.0000]), tensor(0.)),
 (tensor([ 4.9948, -0.9604]), tensor(0.)),
 (tensor([2.3325, 0.7237]), tensor(0.)),
 (tensor([0.7530, 0.6838]), tensor(0.)),
 (tensor([ 3.4248, -0.2794]), tensor(0.)),
 (tensor([3.0730, 0.0686]), tensor(0.)),
 (tensor([ 5.3298, -0.8154]), tensor(0.)),
 (tensor([ 5.3059, -0.8290]), tensor(0.)),
 (tensor([ 5.9572, -0.3202]), tensor(0.)),
 (tensor([1.6113, 0.9992]), tensor(0.)),
 (tensor([ 5.4767, -0.7219]), tensor(0.)),
 (tensor([ 5.7086, -0.5434]), tensor(0.)),
 (tensor([ 6.1398, -0.1429]), tensor(0.)),
 (tensor([ 3.2114, -0.0698]), tensor(0.)),
 (tensor([ 5.3701, -0.7914]), tensor(0.)),
 (tensor([1.9663, 0.9228]), tensor(0.)),
 (tensor([ 3.5114, -0.3614]), tensor(0.)),
 (tensor([0.0989, 0.0987]), tensor(0.)),
 (tensor([ 3.2244, -0.0827]), tensor(0.)),
 (tensor([0.9899, 0.8360]), tensor(0.)),
 (tensor([3.1331, 0.0084]), tensor(0.)),
 (tensor([ 5.9770, -0.3014]), tensor(0.)),
 (tensor([ 3.2086, -0.0670]), tensor(0.)),
 (tensor([2.6843, 0.4415]), tensor(0.)),
 (tensor([0.8984, 0.7823]), tensor(0.)),
 (tensor([ 4.5545, -0.9876]), tensor(0.)),
 (tensor([2.2646, 0.7688]), tensor(0.)),
 (tensor([ 4.9544, -0.9709]), tensor(0.)),
 (tensor([ 3.1737, -0.0321]), tensor(0.)),
 (tensor([2.9629, 0.1778]), tensor(0.)),
 (tensor([0.6394, 0.5967]), tensor(0.)),
 (tensor([ 4.9967, -0.9598]), tensor(0.)),
 (tensor([0.0612, 0.0612]), tensor(0.)),
 (tensor([ 3.8193, -0.6270]), tensor(0.)),
 (tensor([1.8542, 0.9601]), tensor(0.)),
 (tensor([ 4.3075, -0.9191]), tensor(0.)),
 (tensor([ 5.5783, -0.6480]), tensor(0.)),
 (tensor([ 5.1486, -0.9064]), tensor(0.)),
 (tensor([2.0642, 0.8807]), tensor(0.)),
 (tensor([2.3026, 0.7440]), tensor(0.)),
 (tensor([0.8257, 0.7350]), tensor(0.)),
 (tensor([ 4.3272, -0.9267]), tensor(0.)),
 (tensor([ 5.6873, -0.5612]), tensor(0.)),
 (tensor([ 4.7988, -0.9963]), tensor(0.)),
 (tensor([0.5408, 0.5149]), tensor(0.)),
 (tensor([ 4.9613, -0.9692]), tensor(0.)),
 (tensor([ 5.5346, -0.6806]), tensor(0.)),
 (tensor([2.0656, 0.8801]), tensor(0.)),
 (tensor([ 5.0966, -0.9271]), tensor(0.)),
 (tensor([ 5.1646, -0.8995]), tensor(0.)),
 (tensor([1.0824, 0.8831]), tensor(0.)),
 (tensor([ 3.6604, -0.4958]), tensor(0.)),
 (tensor([3.0639, 0.0776]), tensor(0.)),
 (tensor([ 4.3111, -0.9206]), tensor(0.)),
 (tensor([1.5272, 0.9990]), tensor(0.)),
 (tensor([ 3.5548, -0.4015]), tensor(0.)),
 (tensor([3.0072, 0.1340]), tensor(0.)),
 (tensor([2.3228, 0.7303]), tensor(0.)),
 (tensor([1.4319, 0.9904]), tensor(0.)),
 (tensor([ 3.3767, -0.2330]), tensor(0.)),
 (tensor([0.8956, 0.7806]), tensor(0.)),
 (tensor([2.3856, 0.6860]), tensor(0.)),
 (tensor([ 6.2084, -0.0747]), tensor(0.)),
 (tensor([ 4.1441, -0.8428]), tensor(0.)),
 (tensor([0.7408, 0.6749]), tensor(0.)),
 (tensor([2.1661, 0.8280]), tensor(0.)),
 (tensor([ 5.2815, -0.8424]), tensor(0.)),
 (tensor([ 3.2504, -0.1086]), tensor(0.)),
 (tensor([ 4.6482, -0.9979]), tensor(0.)),
 (tensor([1.8120, 0.9711]), tensor(0.)),
 (tensor([2.9483, 0.1921]), tensor(0.)),
 (tensor([2.1229, 0.8514]), tensor(0.)),
 (tensor([2.6274, 0.4918]), tensor(0.)),
 (tensor([2.9681, 0.1726]), tensor(0.)),
 (tensor([ 4.8104, -0.9952]), tensor(0.)),
 (tensor([ 4.0185, -0.7688]), tensor(0.)),
 (tensor([ 4.2184, -0.8804]), tensor(0.)),
 (tensor([2.5683, 0.5424]), tensor(0.)),
 (tensor([ 6.1541, -0.1287]), tensor(0.)),
 (tensor([ 4.3056, -0.9184]), tensor(0.)),
 (tensor([ 4.8706, -0.9875]), tensor(0.)),
 (tensor([1.6433, 0.9974]), tensor(0.)),
 (tensor([ 3.2805, -0.1385]), tensor(0.)),
 (tensor([ 3.8895, -0.6801]), tensor(0.)),
 (tensor([1.5895, 0.9998]), tensor(0.)),
 (tensor([ 4.6308, -0.9967]), tensor(0.)),
 (tensor([1.1737, 0.9222]), tensor(0.)),
 (tensor([0.7496, 0.6813]), tensor(0.)),
 (tensor([ 5.8098, -0.4559]), tensor(0.)),
 (tensor([ 4.0724, -0.8021]), tensor(0.)),
 (tensor([ 6.2203, -0.0628]), tensor(0.)),
 (tensor([ 4.2312, -0.8865]), tensor(0.)),
 (tensor([2.3644, 0.7013]), tensor(0.)),
 (tensor([ 5.9393, -0.3371]), tensor(0.)),
 (tensor([ 5.9575, -0.3200]), tensor(0.)),
 (tensor([0.0604, 0.0604]), tensor(0.)),
 (tensor([ 3.7253, -0.5511]), tensor(0.)),
 (tensor([2.8656, 0.2725]), tensor(0.)),
 (tensor([ 5.7334, -0.5225]), tensor(0.)),
 (tensor([ 5.0301, -0.9500]), tensor(0.)),
 (tensor([ 4.8108, -0.9952]), tensor(0.)),
 (tensor([1.2983, 0.9631]), tensor(0.)),
 (tensor([2.7635, 0.3692]), tensor(0.)),
 (tensor([ 5.8366, -0.4319]), tensor(0.)),
 (tensor([1.1852, 0.9266]), tensor(0.)),
 (tensor([ 3.6803, -0.5130]), tensor(0.)),
 (tensor([ 5.0755, -0.9348]), tensor(0.)),
 (tensor([1.6325, 0.9981]), tensor(0.)),
 (tensor([0.9008, 0.7838]), tensor(0.)),
 (tensor([1.0656, 0.8751]), tensor(0.)),
 (tensor([ 5.3912, -0.7783]), tensor(0.)),
 (tensor([ 5.0971, -0.9269]), tensor(0.)),
 (tensor([ 3.9144, -0.6982]), tensor(0.)),
 (tensor([ 6.1215, -0.1609]), tensor(0.)),
 (tensor([1.7694, 0.9803]), tensor(0.)),
 (tensor([ 3.4351, -0.2893]), tensor(0.)),
 (tensor([ 3.1727, -0.0311]), tensor(0.)),
 (tensor([0.7493, 0.6811]), tensor(0.)),
 (tensor([ 4.0239, -0.7722]), tensor(0.)),
 (tensor([ 4.2071, -0.8750]), tensor(0.)),
 (tensor([0.7660, 0.6933]), tensor(0.)),
 (tensor([ 3.6005, -0.4429]), tensor(0.)),
 (tensor([ 5.6719, -0.5739]), tensor(0.)),
 (tensor([ 4.5499, -0.9868]), tensor(0.)),
 (tensor([1.1890, 0.9280]), tensor(0.)),
 (tensor([0.3831, 0.3738]), tensor(0.)),
 (tensor([ 6.2432, -0.0400]), tensor(0.)),
 (tensor([ 4.6326, -0.9968]), tensor(0.)),
 (tensor([1.2476, 0.9482]), tensor(0.)),
 (tensor([0.1237, 0.1234]), tensor(0.)),
 (tensor([0.9087, 0.7887]), tensor(0.)),
 (tensor([2.9864, 0.1545]), tensor(0.)),
 (tensor([1.1802, 0.9247]), tensor(0.)),
 (tensor([ 6.0684, -0.2131]), tensor(0.)),
 (tensor([1.6670, 0.9954]), tensor(0.)),
 (tensor([1.4016, 0.9857]), tensor(0.)),
 (tensor([ 3.6633, -0.4983]), tensor(0.)),
 (tensor([ 5.8788, -0.3935]), tensor(0.)),
 (tensor([ 4.5433, -0.9857]), tensor(0.)),
 (tensor([1.0380, 0.8614]), tensor(0.)),
 (tensor([ 4.2811, -0.9084]), tensor(0.)),
 (tensor([ 3.5639, -0.4098]), tensor(0.)),
 (tensor([1.1835, 0.9259]), tensor(0.)),
 (tensor([2.9760, 0.1649]), tensor(0.)),
 (tensor([ 5.6908, -0.5584]), tensor(0.)),
 (tensor([0.0114, 0.0114]), tensor(0.)),
 (tensor([ 3.1716, -0.0300]), tensor(0.)),
 (tensor([ 4.2059, -0.8744]), tensor(0.)),
 (tensor([ 5.7705, -0.4905]), tensor(0.)),
 (tensor([1.7864, 0.9768]), tensor(0.)),
 (tensor([ 5.4703, -0.7263]), tensor(0.)),
 (tensor([1.8463, 0.9623]), tensor(0.)),
 (tensor([2.9880, 0.1530]), tensor(0.)),
 (tensor([ 4.0271, -0.7742]), tensor(0.)),
 (tensor([1.2799, 0.9580]), tensor(0.)),
 (tensor([ 3.2618, -0.1199]), tensor(0.)),
 (tensor([ 6.2396, -0.0435]), tensor(0.)),
 (tensor([ 4.4761, -0.9722]), tensor(0.)),
 (tensor([ 4.0254, -0.7732]), tensor(0.)),
 (tensor([ 5.4091, -0.7669]), tensor(0.)),
 (tensor([ 4.8155, -0.9947]), tensor(0.)),
 (tensor([ 3.8865, -0.6779]), tensor(0.)),
 (tensor([ 3.7294, -0.5546]), tensor(0.)),
 (tensor([1.3621, 0.9783]), tensor(0.)),
 (tensor([0.0846, 0.0845]), tensor(0.)),
 (tensor([0.2490, 0.2464]), tensor(0.)),
 (tensor([1.8686, 0.9560]), tensor(0.)),
 (tensor([2.3872, 0.6848]), tensor(0.)),
 (tensor([1.6562, 0.9964]), tensor(0.)),
 (tensor([1.8508, 0.9611]), tensor(0.)),
 (tensor([ 5.6421, -0.5981]), tensor(0.)),
 (tensor([ 5.1563, -0.9031]), tensor(0.)),
 (tensor([ 6.0798, -0.2020]), tensor(0.)),
 (tensor([1.0753, 0.8797]), tensor(0.)),
 (tensor([ 4.8161, -0.9946]), tensor(0.)),
 (tensor([ 4.8086, -0.9954]), tensor(0.)),
 (tensor([1.0917, 0.8874]), tensor(0.)),
 (tensor([ 3.8939, -0.6833]), tensor(0.)),
 (tensor([ 5.2089, -0.8793]), tensor(0.)),
 (tensor([2.5255, 0.5778]), tensor(0.)),
 (tensor([1.7640, 0.9814]), tensor(0.)),
 (tensor([ 4.2794, -0.9077]), tensor(0.)),
 (tensor([ 4.2551, -0.8972]), tensor(0.)),
 (tensor([1.5493, 0.9998]), tensor(0.)),
 (tensor([ 5.8459, -0.4235]), tensor(0.)),
 (tensor([ 3.4368, -0.2909]), tensor(0.)),
 (tensor([ 4.3139, -0.9217]), tensor(0.)),
 (tensor([1.3443, 0.9745]), tensor(0.)),
 (tensor([0.0501, 0.0501]), tensor(0.)),
 (tensor([0.6171, 0.5787]), tensor(0.)),
 (tensor([ 4.6662, -0.9989]), tensor(0.)),
 (tensor([1.5281, 0.9991]), tensor(0.)),
 (tensor([ 5.3672, -0.7932]), tensor(0.)),
 (tensor([ 3.9473, -0.7213]), tensor(0.)),
 (tensor([0.0370, 0.0370]), tensor(0.)),
 (tensor([ 3.2603, -0.1184]), tensor(0.)),
 (tensor([ 5.0617, -0.9396]), tensor(0.)),
 (tensor([1.2277, 0.9417]), tensor(0.)),
 (tensor([ 3.6298, -0.4690]), tensor(0.)),
 (tensor([2.1296, 0.8479]), tensor(0.)),
 (tensor([0.6327, 0.5913]), tensor(0.)),
 (tensor([1.6097, 0.9992]), tensor(0.)),
 (tensor([1.3860, 0.9830]), tensor(0.)),
 (tensor([0.1552, 0.1546]), tensor(0.)),
 (tensor([ 3.4685, -0.3212]), tensor(0.)),
 (tensor([ 3.5042, -0.3548]), tensor(0.)),
 (tensor([ 5.0711, -0.9364]), tensor(0.)),
 (tensor([ 5.8639, -0.4071]), tensor(0.)),
 (tensor([ 3.9113, -0.6960]), tensor(0.)),
 (tensor([ 3.7078, -0.5364]), tensor(0.)),
 (tensor([1.7260, 0.9880]), tensor(0.)),
 (tensor([2.8570, 0.2808]), tensor(0.)),
 (tensor([2.0331, 0.8950]), tensor(0.)),
 (tensor([2.4961, 0.6016]), tensor(0.)),
 (tensor([0.5782, 0.5465]), tensor(0.)),
 (tensor([ 4.7121, -1.0000]), tensor(0.)),
 (tensor([ 4.2118, -0.8773]), tensor(0.)),
 (tensor([2.7959, 0.3388]), tensor(0.)),
 (tensor([1.4472, 0.9924]), tensor(0.)),
 (tensor([1.1715, 0.9214]), tensor(0.)),
 (tensor([ 3.8511, -0.6515]), tensor(0.)),
 (tensor([0.5546, 0.5266]), tensor(0.)),
 (tensor([1.8861, 0.9507]), tensor(0.)),
 (tensor([ 4.1345, -0.8376]), tensor(0.)),
 (tensor([ 4.6497, -0.9980]), tensor(0.)),
 (tensor([2.1666, 0.8277]), tensor(0.)),
 (tensor([ 5.3958, -0.7754]), tensor(0.)),
 (tensor([2.8034, 0.3318]), tensor(0.)),
 (tensor([ 3.4794, -0.3314]), tensor(0.)),
 (tensor([ 4.7491, -0.9993]), tensor(0.)),
 (tensor([ 5.1849, -0.8904]), tensor(0.)),
 (tensor([1.6354, 0.9979]), tensor(0.)),
 (tensor([1.7974, 0.9744]), tensor(0.)),
 (tensor([0.8882, 0.7760]), tensor(0.)),
 (tensor([ 4.9007, -0.9823]), tensor(0.)),
 (tensor([2.3240, 0.7295]), tensor(0.)),
 (tensor([2.0125, 0.9040]), tensor(0.)),
 (tensor([2.3328, 0.7234]), tensor(0.)),
 (tensor([1.3685, 0.9796]), tensor(0.)),
 (tensor([ 3.8691, -0.6650]), tensor(0.)),
 (tensor([2.2420, 0.7831]), tensor(0.)),
 (tensor([0.5640, 0.5345]), tensor(0.)),
 (tensor([ 4.5734, -0.9904]), tensor(0.)),
 (tensor([ 6.0663, -0.2152]), tensor(0.)),
 (tensor([2.3751, 0.6936]), tensor(0.)),
 (tensor([ 5.1425, -0.9089]), tensor(0.)),
 (tensor([0.6433, 0.5998]), tensor(0.)),
 (tensor([ 5.5834, -0.6441]), tensor(0.)),
 (tensor([0.0079, 0.0079]), tensor(0.)),
 (tensor([ 3.4145, -0.2696]), tensor(0.)),
 (tensor([0.4163, 0.4044]), tensor(0.)),
 (tensor([1.7434, 0.9851]), tensor(0.)),
 (tensor([ 5.8330, -0.4351]), tensor(0.)),
 (tensor([2.2920, 0.7510]), tensor(0.)),
 (tensor([ 3.9540, -0.7259]), tensor(0.)),
 (tensor([2.9616, 0.1790]), tensor(0.)),
 (tensor([ 3.5914, -0.4348]), tensor(0.)),
 (tensor([ 5.7615, -0.4984]), tensor(0.)),
 (tensor([ 5.3818, -0.7842]), tensor(0.)),
 (tensor([ 3.4328, -0.2871]), tensor(0.)),
 (tensor([1.4471, 0.9924]), tensor(0.)),
 (tensor([2.9038, 0.2355]), tensor(0.)),
 (tensor([ 4.5984, -0.9935]), tensor(0.)),
 (tensor([ 3.9657, -0.7339]), tensor(0.)),
 (tensor([2.4683, 0.6236]), tensor(0.)),
 (tensor([1.3490, 0.9755]), tensor(0.)),
 (tensor([0.6413, 0.5983]), tensor(0.)),
 (tensor([ 5.5527, -0.6672]), tensor(0.)),
 (tensor([2.8104, 0.3252]), tensor(0.)),
 (tensor([ 3.6722, -0.5061]), tensor(0.)),
 (tensor([ 5.7284, -0.5267]), tensor(0.)),
 (tensor([ 4.9600, -0.9695]), tensor(0.)),
 (tensor([ 3.6493, -0.4862]), tensor(0.)),
 (tensor([ 5.7366, -0.5198]), tensor(0.)),
 (tensor([ 3.2509, -0.1091]), tensor(0.)),
 (tensor([0.1817, 0.1807]), tensor(0.)),
 (tensor([2.4693, 0.6228]), tensor(0.)),
 (tensor([ 4.1434, -0.8424]), tensor(0.)),
 (tensor([3.0793, 0.0622]), tensor(0.)),
 (tensor([ 4.5297, -0.9834]), tensor(0.)),
 (tensor([ 4.8078, -0.9955]), tensor(0.)),
 (tensor([ 3.6052, -0.4471]), tensor(0.)),
 (tensor([1.3082, 0.9657]), tensor(0.)),
 (tensor([ 3.7913, -0.6049]), tensor(0.)),
 (tensor([ 5.3843, -0.7826]), tensor(0.)),
 (tensor([2.3309, 0.7247]), tensor(0.)),
 (tensor([ 4.7226, -0.9999]), tensor(0.)),
 (tensor([ 3.3408, -0.1979]), tensor(0.)),
 (tensor([2.2615, 0.7708]), tensor(0.)),
 (tensor([2.8908, 0.2482]), tensor(0.)),
 (tensor([ 5.2612, -0.8532]), tensor(0.)),
 (tensor([ 5.2395, -0.8643]), tensor(0.)),
 (tensor([0.5774, 0.5458]), tensor(0.)),
 (tensor([ 4.2068, -0.8749]), tensor(0.)),
 (tensor([0.4586, 0.4427]), tensor(0.)),
 (tensor([ 6.1544, -0.1284]), tensor(0.)),
 (tensor([2.9935, 0.1475]), tensor(0.)),
 (tensor([2.6370, 0.4835]), tensor(0.)),
 (tensor([0.4595, 0.4435]), tensor(0.)),
 (tensor([ 4.0552, -0.7917]), tensor(0.)),
 (tensor([ 4.7049, -1.0000]), tensor(0.)),
 (tensor([2.2177, 0.7980]), tensor(0.)),
 (tensor([0.0569, 0.0568]), tensor(0.)),
 (tensor([ 5.6798, -0.5674]), tensor(0.)),
 (tensor([ 4.7860, -0.9973]), tensor(0.)),
 (tensor([2.6815, 0.4441]), tensor(0.)),
 (tensor([0.2921, 0.2880]), tensor(0.)),
 (tensor([0.0428, 0.0428]), tensor(0.)),
 (tensor([0.1474, 0.1469]), tensor(0.)),
 (tensor([1.9557, 0.9268]), tensor(0.)),
 (tensor([0.4849, 0.4661]), tensor(0.)),
 (tensor([0.0786, 0.0785]), tensor(0.)),
 (tensor([ 3.9971, -0.7549]), tensor(0.)),
 (tensor([1.6543, 0.9965]), tensor(0.)),
 (tensor([ 3.2234, -0.0818]), tensor(0.)),
 (tensor([ 3.3435, -0.2005]), tensor(0.)),
 (tensor([ 5.4374, -0.7485]), tensor(0.)),
 (tensor([1.3314, 0.9715]), tensor(0.)),
 (tensor([ 5.5147, -0.6950]), tensor(0.)),
 (tensor([0.7815, 0.7043]), tensor(0.)),
 (tensor([ 4.5533, -0.9874]), tensor(0.)),
 (tensor([1.6857, 0.9934]), tensor(0.)),
 (tensor([1.9255, 0.9378]), tensor(0.)),
 (tensor([ 4.5951, -0.9931]), tensor(0.)),
 (tensor([ 4.3117, -0.9208]), tensor(0.)),
 (tensor([2.8729, 0.2655]), tensor(0.)),
 (tensor([ 5.6020, -0.6297]), tensor(0.)),
 (tensor([ 6.1388, -0.1438]), tensor(0.)),
 (tensor([ 3.2185, -0.0768]), tensor(0.)),
 (tensor([ 4.4606, -0.9685]), tensor(0.)),
 (tensor([1.4724, 0.9952]), tensor(0.)),
 (tensor([0.9102, 0.7896]), tensor(0.)),
 (tensor([ 3.5504, -0.3975]), tensor(0.)),
 (tensor([0.2406, 0.2383]), tensor(0.)),
 (tensor([0.5696, 0.5393]), tensor(0.)),
 (tensor([1.4154, 0.9880]), tensor(0.)),
 (tensor([ 5.9948, -0.2844]), tensor(0.)),
 (tensor([2.6513, 0.4708]), tensor(0.)),
 (tensor([ 4.0855, -0.8098]), tensor(0.)),
 (tensor([1.0309, 0.8577]), tensor(0.)),
 (tensor([ 4.2469, -0.8936]), tensor(0.)),
 (tensor([ 4.4621, -0.9689]), tensor(0.)),
 (tensor([ 4.3699, -0.9419]), tensor(0.)),
 (tensor([1.3189, 0.9684]), tensor(0.)),
 (tensor([0.4671, 0.4503]), tensor(0.)),
 (tensor([1.2694, 0.9549]), tensor(0.)),
 (tensor([ 5.3103, -0.8265]), tensor(0.)),
 (tensor([2.4152, 0.6642]), tensor(0.)),
 (tensor([3.0755, 0.0660]), tensor(0.)),
 (tensor([ 5.9463, -0.3305]), tensor(0.)),
 (tensor([1.0863, 0.8849]), tensor(0.)),
 (tensor([ 5.6225, -0.6136]), tensor(0.)),
 (tensor([2.0244, 0.8989]), tensor(0.)),
 (tensor([ 5.9406, -0.3360]), tensor(0.)),
 (tensor([2.4685, 0.6234]), tensor(0.)),
 (tensor([1.9825, 0.9165]), tensor(0.)),
 (tensor([1.9544, 0.9273]), tensor(0.)),
 (tensor([ 4.8315, -0.9929]), tensor(0.)),
 (tensor([ 3.2010, -0.0594]), tensor(0.)),
 (tensor([ 4.2250, -0.8836]), tensor(0.)),
 (tensor([2.6211, 0.4973]), tensor(0.)),
 (tensor([ 4.4830, -0.9738]), tensor(0.)),
 (tensor([ 5.9381, -0.3382]), tensor(0.)),
 (tensor([ 4.9272, -0.9770]), tensor(0.)),
 (tensor([ 4.2841, -0.9097]), tensor(0.)),
 (tensor([0.7651, 0.6926]), tensor(0.)),
 (tensor([ 4.1089, -0.8234]), tensor(0.)),
 (tensor([1.0808, 0.8823]), tensor(0.)),
 (tensor([ 4.5785, -0.9911]), tensor(0.)),
 (tensor([2.0333, 0.8949]), tensor(0.)),
 (tensor([ 4.5827, -0.9916]), tensor(0.)),
 (tensor([2.4951, 0.6024]), tensor(0.)),
 (tensor([2.3753, 0.6935]), tensor(0.)),
 (tensor([ 6.0443, -0.2367]), tensor(0.)),
 (tensor([0.8269, 0.7359]), tensor(0.)),
 (tensor([ 5.4904, -0.7123]), tensor(0.)),
 (tensor([ 3.3035, -0.1612]), tensor(0.)),
 (tensor([ 4.0998, -0.8182]), tensor(0.)),
 (tensor([2.8110, 0.3246]), tensor(0.)),
 (tensor([3.1162, 0.0254]), tensor(0.)),
 (tensor([ 3.7506, -0.5721]), tensor(0.)),
 (tensor([ 4.3719, -0.9426]), tensor(0.)),
 (tensor([ 4.2924, -0.9131]), tensor(0.)),
 (tensor([0.9541, 0.8158]), tensor(0.)),
 (tensor([ 4.6378, -0.9972]), tensor(0.)),
 (tensor([2.5449, 0.5619]), tensor(0.)),
 (tensor([ 4.4945, -0.9763]), tensor(0.)),
 (tensor([ 4.2070, -0.8750]), tensor(0.)),
 (tensor([ 4.7014, -0.9999]), tensor(0.)),
 (tensor([3.0537, 0.0877]), tensor(0.)),
 (tensor([1.0754, 0.8798]), tensor(0.)),
 (tensor([ 5.1578, -0.9024]), tensor(0.)),
 (tensor([0.2662, 0.2631]), tensor(0.)),
 (tensor([ 4.7845, -0.9974]), tensor(0.)),
 (tensor([0.4495, 0.4345]), tensor(0.)),
 (tensor([0.5967, 0.5620]), tensor(0.)),
 (tensor([ 4.7512, -0.9992]), tensor(0.)),
 (tensor([0.7988, 0.7165]), tensor(0.)),
 (tensor([ 4.2782, -0.9072]), tensor(0.)),
 (tensor([0.4943, 0.4744]), tensor(0.)),
 (tensor([ 3.1880, -0.0464]), tensor(0.)),
 (tensor([0.9560, 0.8169]), tensor(0.)),
 (tensor([2.1072, 0.8595]), tensor(0.)),
 (tensor([ 6.1698, -0.1131]), tensor(0.)),
 (tensor([ 4.4128, -0.9554]), tensor(0.)),
 (tensor([2.5633, 0.5466]), tensor(0.)),
 (tensor([ 5.0184, -0.9535]), tensor(0.)),
 (tensor([ 6.0350, -0.2456]), tensor(0.)),
 (tensor([ 5.4535, -0.7377]), tensor(0.)),
 (tensor([ 3.9428, -0.7182]), tensor(0.)),
 (tensor([0.5213, 0.4980]), tensor(0.)),
 (tensor([3.0541, 0.0874]), tensor(0.)),
 (tensor([ 4.4641, -0.9693]), tensor(0.)),
 (tensor([3.0974, 0.0442]), tensor(0.)),
 (tensor([0.7741, 0.6991]), tensor(0.)),
 (tensor([1.0126, 0.8482]), tensor(0.)),
 (tensor([ 6.1346, -0.1481]), tensor(0.)),
 (tensor([0.5118, 0.4898]), tensor(0.)),
 (tensor([ 5.4476, -0.7417]), tensor(0.)),
 (tensor([1.3581, 0.9775]), tensor(0.)),
 (tensor([0.8903, 0.7773]), tensor(0.)),
 (tensor([2.0001, 0.9093]), tensor(0.)),
 (tensor([ 4.6815, -0.9995]), tensor(0.)),
 (tensor([0.0399, 0.0399]), tensor(0.)),
 (tensor([ 4.0253, -0.7731]), tensor(0.)),
 (tensor([1.4562, 0.9934]), tensor(0.)),
 (tensor([1.8687, 0.9559]), tensor(0.)),
 (tensor([ 3.6773, -0.5104]), tensor(0.)),
 (tensor([1.1451, 0.9108]), tensor(0.)),
 (tensor([2.9922, 0.1488]), tensor(0.)),
 (tensor([0.9300, 0.8016]), tensor(0.)),
 (tensor([0.5205, 0.4973]), tensor(0.)),
 (tensor([ 3.6819, -0.5144]), tensor(0.)),
 (tensor([0.2328, 0.2307]), tensor(0.)),
 (tensor([ 4.3974, -0.9508]), tensor(0.)),
 (tensor([ 4.8054, -0.9957]), tensor(0.)),
 (tensor([ 4.3467, -0.9339]), tensor(0.)),
 (tensor([ 3.5541, -0.4009]), tensor(0.)),
 (tensor([1.4626, 0.9942]), tensor(0.)),
 (tensor([ 5.2079, -0.8797]), tensor(0.)),
 (tensor([1.3104, 0.9663]), tensor(0.)),
 (tensor([0.8758, 0.7680]), tensor(0.)),
 (tensor([0.0659, 0.0659]), tensor(0.)),
 (tensor([ 3.7710, -0.5887]), tensor(0.)),
 (tensor([2.5870, 0.5266]), tensor(0.)),
 (tensor([1.8738, 0.9544]), tensor(0.)),
 (tensor([ 5.3051, -0.8294]), tensor(0.)),
 (tensor([0.0905, 0.0904]), tensor(0.)),
 (tensor([ 5.8635, -0.4075]), tensor(0.)),
 (tensor([0.1847, 0.1837]), tensor(0.)),
 (tensor([ 3.1608, -0.0192]), tensor(0.)),
 (tensor([1.9415, 0.9321]), tensor(0.)),
 (tensor([ 3.7541, -0.5749]), tensor(0.)),
 (tensor([1.9478, 0.9298]), tensor(0.)),
 (tensor([2.4650, 0.6262]), tensor(0.)),
 (tensor([1.0588, 0.8718]), tensor(0.)),
 (tensor([ 4.2258, -0.8839]), tensor(0.)),
 (tensor([0.5235, 0.4999]), tensor(0.)),
 (tensor([1.6455, 0.9972]), tensor(0.)),
 (tensor([1.1990, 0.9317]), tensor(0.)),
 (tensor([ 5.3983, -0.7739]), tensor(0.)),
 (tensor([ 4.7861, -0.9973]), tensor(0.)),
 (tensor([1.2879, 0.9602]), tensor(0.)),
 (tensor([ 3.6439, -0.4815]), tensor(0.)),
 (tensor([ 6.1615, -0.1214]), tensor(0.)),
 (tensor([ 5.3196, -0.8213]), tensor(0.)),
 (tensor([ 3.2833, -0.1412]), tensor(0.)),
 (tensor([ 5.2264, -0.8708]), tensor(0.)),
 (tensor([ 4.7765, -0.9979]), tensor(0.)),
 (tensor([0.8159, 0.7284]), tensor(0.)),
 (tensor([1.3673, 0.9794]), tensor(0.)),
 (tensor([ 3.8826, -0.6750]), tensor(0.)),
 (tensor([0.4692, 0.4522]), tensor(0.)),
 (tensor([ 4.8578, -0.9895]), tensor(0.)),
 (tensor([3.1142, 0.0274]), tensor(0.)),
 (tensor([2.1392, 0.8428]), tensor(0.)),
 (tensor([ 4.3334, -0.9291]), tensor(0.)),
 (tensor([2.0447, 0.8898]), tensor(0.)),
 (tensor([2.9910, 0.1501]), tensor(0.)),
 (tensor([0.3053, 0.3006]), tensor(0.)),
 (tensor([ 6.0279, -0.2526]), tensor(0.)),
 (tensor([1.1734, 0.9221]), tensor(0.)),
 (tensor([0.3656, 0.3575]), tensor(0.)),
 (tensor([ 5.3004, -0.8321]), tensor(0.)),
 (tensor([ 5.8471, -0.4224]), tensor(0.)),
 (tensor([ 5.1829, -0.8913]), tensor(0.)),
 (tensor([0.7993, 0.7169]), tensor(0.)),
 (tensor([ 4.7090, -1.0000]), tensor(0.)),
 (tensor([3.0869, 0.0546]), tensor(0.)),
 (tensor([2.4706, 0.6218]), tensor(0.)),
 (tensor([ 3.4798, -0.3318]), tensor(0.)),
 (tensor([ 5.4015, -0.7718]), tensor(0.)),
 (tensor([2.0488, 0.8879]), tensor(0.)),
 (tensor([ 3.4608, -0.3138]), tensor(0.)),
 (tensor([2.5578, 0.5512]), tensor(0.)),
 (tensor([2.8398, 0.2972]), tensor(0.)),
 (tensor([ 5.0049, -0.9575]), tensor(0.)),
 (tensor([ 4.6587, -0.9986]), tensor(0.)),
 (tensor([0.2395, 0.2372]), tensor(0.)),
 (tensor([0.9551, 0.8164]), tensor(0.)),
 (tensor([2.0732, 0.8765]), tensor(0.)),
 (tensor([ 5.3903, -0.7789]), tensor(0.)),
 (tensor([ 5.5121, -0.6969]), tensor(0.)),
 (tensor([ 5.9264, -0.3492]), tensor(0.)),
 (tensor([1.6830, 0.9937]), tensor(0.)),
 (tensor([0.1936, 0.1923]), tensor(0.)),
 (tensor([1.5491, 0.9998]), tensor(0.)),
 (tensor([1.5677, 1.0000]), tensor(0.)),
 (tensor([2.0873, 0.8695]), tensor(0.)),
 (tensor([2.2793, 0.7594]), tensor(0.)),
 (tensor([0.6442, 0.6005]), tensor(0.)),
 (tensor([ 4.2670, -0.9025]), tensor(0.)),
 (tensor([0.1144, 0.1142]), tensor(0.)),
 (tensor([ 6.0698, -0.2118]), tensor(0.)),
 (tensor([ 3.8948, -0.6840]), tensor(0.)),
 (tensor([2.3036, 0.7433]), tensor(0.)),
 (tensor([ 5.2646, -0.8514]), tensor(0.)),
 (tensor([ 3.4683, -0.3209]), tensor(0.)),
 (tensor([0.0868, 0.0867]), tensor(0.)),
 (tensor([2.0910, 0.8677]), tensor(0.)),
 (tensor([2.7948, 0.3399]), tensor(0.)),
 (tensor([ 4.6581, -0.9985]), tensor(0.)),
 (tensor([ 3.2077, -0.0660]), tensor(0.)),
 (tensor([ 4.4714, -0.9711]), tensor(0.)),
 (tensor([ 5.8339, -0.4343]), tensor(0.)),
 (tensor([ 4.8528, -0.9902]), tensor(0.)),
 (tensor([ 3.7258, -0.5516]), tensor(0.)),
 (tensor([0.7492, 0.6811]), tensor(0.)),
 (tensor([ 3.9040, -0.6907]), tensor(0.)),
 (tensor([ 5.5941, -0.6358]), tensor(0.)),
 (tensor([ 3.4325, -0.2868]), tensor(0.)),
 (tensor([2.9682, 0.1725]), tensor(0.)),
 (tensor([2.7154, 0.4134]), tensor(0.)),
 (tensor([ 6.0634, -0.2180]), tensor(0.)),
 (tensor([ 3.1780, -0.0364]), tensor(0.)),
 (tensor([2.3260, 0.7281]), tensor(0.)),
 (tensor([1.1459, 0.9111]), tensor(0.)),
 (tensor([0.9801, 0.8306]), tensor(0.)),
 (tensor([1.4355, 0.9909]), tensor(0.)),
 (tensor([1.0477, 0.8663]), tensor(0.)),
 (tensor([ 5.9351, -0.3411]), tensor(0.)),
 (tensor([1.2005, 0.9322]), tensor(0.)),
 (tensor([0.1700, 0.1691]), tensor(0.)),
 (tensor([ 3.9441, -0.7191]), tensor(0.)),
 (tensor([0.8833, 0.7728]), tensor(0.)),
 (tensor([ 5.1372, -0.9111]), tensor(0.)),
 (tensor([ 3.4812, -0.3331]), tensor(0.)),
 (tensor([1.9319, 0.9355]), tensor(0.)),
 (tensor([ 4.0576, -0.7932]), tensor(0.)),
 (tensor([ 3.7380, -0.5617]), tensor(0.)),
 (tensor([ 4.1361, -0.8385]), tensor(0.)),
 (tensor([0.9478, 0.8122]), tensor(0.)),
 (tensor([ 5.0302, -0.9499]), tensor(0.)),
 (tensor([2.3782, 0.6914]), tensor(0.)),
 (tensor([ 3.4307, -0.2851]), tensor(0.)),
 (tensor([ 5.8555, -0.4148]), tensor(0.)),
 (tensor([ 6.0637, -0.2178]), tensor(0.)),
 (tensor([ 5.1444, -0.9081]), tensor(0.)),
 (tensor([0.1241, 0.1238]), tensor(0.)),
 (tensor([1.1988, 0.9316]), tensor(0.)),
 (tensor([1.6933, 0.9925]), tensor(0.)),
 (tensor([1.0857, 0.8846]), tensor(0.)),
 (tensor([ 4.5021, -0.9780]), tensor(0.)),
 (tensor([ 3.9390, -0.7155]), tensor(0.)),
 (tensor([2.5363, 0.5690]), tensor(0.)),
 (tensor([0.0487, 0.0486]), tensor(0.)),
 (tensor([0.4303, 0.4172]), tensor(0.)),
 (tensor([ 5.0658, -0.9382]), tensor(0.)),
 (tensor([ 4.1687, -0.8558]), tensor(0.)),
 (tensor([ 4.6531, -0.9982]), tensor(0.)),
 (tensor([ 4.4801, -0.9731]), tensor(0.)),
 (tensor([0.5528, 0.5251]), tensor(0.)),
 (tensor([ 4.8205, -0.9942]), tensor(0.)),
 (tensor([ 5.5095, -0.6988]), tensor(0.)),
 (tensor([2.8697, 0.2685]), tensor(0.)),
 (tensor([0.5071, 0.4857]), tensor(0.)),
 (tensor([ 4.2560, -0.8976]), tensor(0.)),
 (tensor([ 3.6649, -0.4998]), tensor(0.)),
 (tensor([ 4.1473, -0.8445]), tensor(0.)),
 (tensor([ 6.0297, -0.2508]), tensor(0.)),
 (tensor([ 4.9048, -0.9816]), tensor(0.)),
 (tensor([0.0916, 0.0915]), tensor(0.)),
 (tensor([ 5.1534, -0.9043]), tensor(0.)),
 (tensor([ 3.5128, -0.3627]), tensor(0.)),
 (tensor([ 5.3969, -0.7747]), tensor(0.)),
 (tensor([0.9016, 0.7843]), tensor(0.)),
 (tensor([1.0582, 0.8715]), tensor(0.)),
 (tensor([0.9264, 0.7994]), tensor(0.)),
 (tensor([2.8810, 0.2577]), tensor(0.)),
 (tensor([ 4.2827, -0.9091]), tensor(0.)),
 (tensor([ 4.5512, -0.9870]), tensor(0.)),
 (tensor([1.2489, 0.9486]), tensor(0.)),
 (tensor([1.7094, 0.9904]), tensor(0.)),
 (tensor([ 5.5082, -0.6997]), tensor(0.)),
 (tensor([ 5.1028, -0.9248]), tensor(0.)),
 (tensor([ 5.3189, -0.8217]), tensor(0.)),
 (tensor([ 5.6493, -0.5923]), tensor(0.)),
 (tensor([ 5.5562, -0.6646]), tensor(0.)),
 (tensor([ 3.2767, -0.1347]), tensor(0.)),
 (tensor([ 3.3848, -0.2408]), tensor(0.)),
 (tensor([1.0828, 0.8833]), tensor(0.)),
 (tensor([ 4.8920, -0.9839]), tensor(0.)),
 (tensor([0.1346, 0.1342]), tensor(0.)),
 (tensor([ 3.1569, -0.0153]), tensor(0.)),
 (tensor([2.9746, 0.1662]), tensor(0.)),
 (tensor([1.2680, 0.9545]), tensor(0.)),
 (tensor([ 5.8038, -0.4612]), tensor(0.)),
 (tensor([ 5.0203, -0.9530]), tensor(0.)),
 (tensor([ 4.9115, -0.9802]), tensor(0.)),
 (tensor([2.9725, 0.1683]), tensor(0.)),
 (tensor([ 5.4785, -0.7206]), tensor(0.)),
 (tensor([1.8340, 0.9656]), tensor(0.)),
 (tensor([ 3.7401, -0.5634]), tensor(0.)),
 (tensor([0.2845, 0.2807]), tensor(0.)),
 (tensor([1.9567, 0.9264]), tensor(0.)),
 (tensor([ 5.5911, -0.6381]), tensor(0.)),
 (tensor([ 4.0472, -0.7868]), tensor(0.)),
 (tensor([ 5.1319, -0.9133]), tensor(0.)),
 (tensor([ 4.9405, -0.9741]), tensor(0.)),
 (tensor([ 4.4401, -0.9632]), tensor(0.)),
 (tensor([ 6.2682, -0.0150]), tensor(0.)),
 (tensor([ 5.9384, -0.3380]), tensor(0.)),
 (tensor([0.9305, 0.8019]), tensor(0.)),
 (tensor([ 5.8741, -0.3978]), tensor(0.)),
 (tensor([1.2851, 0.9595]), tensor(0.)),
 (tensor([0.0339, 0.0339]), tensor(0.)),
 (tensor([ 4.2149, -0.8788]), tensor(0.)),
 (tensor([ 3.8732, -0.6680]), tensor(0.)),
 (tensor([2.1098, 0.8582]), tensor(0.)),
 (tensor([1.4294, 0.9900]), tensor(0.)),
 (tensor([1.0216, 0.8529]), tensor(0.)),
 (tensor([ 4.5838, -0.9917]), tensor(0.)),
 (tensor([0.2827, 0.2790]), tensor(0.)),
 (tensor([1.9099, 0.9431]), tensor(0.)),
 (tensor([0.9088, 0.7888]), tensor(0.)),
 (tensor([2.2396, 0.7846]), tensor(0.)),
 (tensor([1.0738, 0.8790]), tensor(0.)),
 (tensor([ 5.1953, -0.8856]), tensor(0.)),
 (tensor([ 4.3575, -0.9377]), tensor(0.)),
 (tensor([0.0283, 0.0283]), tensor(0.)),
 (tensor([ 4.4066, -0.9536]), tensor(0.)),
 (tensor([ 3.4834, -0.3352]), tensor(0.)),
 (tensor([ 5.4841, -0.7167]), tensor(0.)),
 (tensor([2.7946, 0.3401]), tensor(0.)),
 (tensor([ 5.3369, -0.8113]), tensor(0.)),
 (tensor([0.1005, 0.1003]), tensor(0.)),
 (tensor([ 5.2127, -0.8774]), tensor(0.)),
 (tensor([2.1430, 0.8407]), tensor(0.)),
 (tensor([ 6.1539, -0.1290]), tensor(0.)),
 (tensor([1.4938, 0.9970]), tensor(0.)),
 (tensor([2.9692, 0.1716]), tensor(0.)),
 (tensor([0.7420, 0.6757]), tensor(0.)),
 (tensor([ 3.3734, -0.2297]), tensor(0.)),
 (tensor([2.1106, 0.8578]), tensor(0.)),
 (tensor([ 5.1580, -0.9023]), tensor(0.)),
 (tensor([ 3.8282, -0.6339]), tensor(0.)),
 (tensor([ 5.0904, -0.9294]), tensor(0.)),
 (tensor([ 3.3224, -0.1798]), tensor(0.)),
 (tensor([ 3.9617, -0.7312]), tensor(0.)),
 (tensor([ 5.9070, -0.3674]), tensor(0.)),
 (tensor([0.9227, 0.7972]), tensor(0.)),
 (tensor([ 6.1245, -0.1580]), tensor(0.)),
 (tensor([ 4.7130, -1.0000]), tensor(0.)),
 (tensor([ 4.0833, -0.8086]), tensor(0.)),
 (tensor([ 3.6703, -0.5044]), tensor(0.)),
 (tensor([ 4.1021, -0.8195]), tensor(0.)),
 (tensor([1.3091, 0.9659]), tensor(0.)),
 (tensor([ 6.0991, -0.1830]), tensor(0.)),
 (tensor([1.5629, 1.0000]), tensor(0.)),
 (tensor([2.2417, 0.7833]), tensor(0.)),
 (tensor([ 5.0395, -0.9470]), tensor(0.)),
 (tensor([ 5.0313, -0.9496]), tensor(0.)),
 (tensor([2.1831, 0.8183]), tensor(0.)),
 (tensor([1.6298, 0.9983]), tensor(0.)),
 (tensor([0.2868, 0.2829]), tensor(0.)),
 (tensor([1.4968, 0.9973]), tensor(0.)),
 (tensor([1.9265, 0.9374]), tensor(0.)),
 (tensor([2.9532, 0.1873]), tensor(0.)),
 (tensor([ 5.3610, -0.7969]), tensor(0.)),
 (tensor([ 4.4399, -0.9631]), tensor(0.)),
 (tensor([ 4.0058, -0.7606]), tensor(0.)),
 (tensor([2.9268, 0.2132]), tensor(0.)),
 (tensor([ 3.3334, -0.1906]), tensor(0.)),
 (tensor([0.1760, 0.1751]), tensor(0.)),
 (tensor([2.5487, 0.5588]), tensor(0.)),
 (tensor([ 4.4146, -0.9560]), tensor(0.)),
 (tensor([1.1504, 0.9129]), tensor(0.)),
 (tensor([ 3.6752, -0.5087]), tensor(0.)),
 (tensor([2.2518, 0.7769]), tensor(0.)),
 (tensor([ 5.1212, -0.9176]), tensor(0.)),
 (tensor([0.0181, 0.0181]), tensor(0.)),
 (tensor([ 3.9029, -0.6899]), tensor(0.)),
 (tensor([1.4767, 0.9956]), tensor(0.)),
 (tensor([ 5.9932, -0.2860]), tensor(0.)),
 (tensor([2.9458, 0.1946]), tensor(0.)),
 (tensor([ 4.2004, -0.8718]), tensor(0.)),
 (tensor([ 5.2323, -0.8679]), tensor(0.)),
 (tensor([1.0487, 0.8668]), tensor(0.)),
 (tensor([1.2561, 0.9509]), tensor(0.)),
 (tensor([ 3.8265, -0.6326]), tensor(0.)),
 (tensor([ 4.3443, -0.9330]), tensor(0.)),
 (tensor([1.0926, 0.8878]), tensor(0.)),
 (tensor([ 4.7414, -0.9996]), tensor(0.)),
 (tensor([ 4.0089, -0.7626]), tensor(0.)),
 (tensor([ 4.3077, -0.9192]), tensor(0.)),
 (tensor([ 4.4741, -0.9717]), tensor(0.)),
 (tensor([ 4.7400, -0.9996]), tensor(0.)),
 (tensor([0.5700, 0.5396]), tensor(0.)),
 (tensor([ 3.3737, -0.2300]), tensor(0.)),
 (tensor([ 3.4370, -0.2912]), tensor(0.)),
 (tensor([ 5.5838, -0.6437]), tensor(0.)),
 (tensor([ 3.9137, -0.6976]), tensor(0.)),
 (tensor([2.5837, 0.5294]), tensor(0.)),
 (tensor([ 3.9547, -0.7264]), tensor(0.)),
 (tensor([ 4.0337, -0.7784]), tensor(0.)),
 (tensor([ 6.0901, -0.1919]), tensor(0.)),
 (tensor([ 4.1978, -0.8705]), tensor(0.)),
 (tensor([ 5.9771, -0.3013]), tensor(0.)),
 (tensor([ 5.3158, -0.8234]), tensor(0.)),
 (tensor([0.5856, 0.5527]), tensor(0.)),
 (tensor([0.3833, 0.3739]), tensor(0.)),
 (tensor([ 5.2765, -0.8451]), tensor(0.)),
 (tensor([0.0659, 0.0658]), tensor(0.)),
 (tensor([0.4476, 0.4328]), tensor(0.)),
 (tensor([0.2343, 0.2322]), tensor(0.)),
 (tensor([ 6.1582, -0.1247]), tensor(0.)),
 (tensor([ 3.6949, -0.5255]), tensor(0.)),
 (tensor([ 4.4463, -0.9648]), tensor(0.)),
 (tensor([1.1467, 0.9114]), tensor(0.)),
 (tensor([2.7389, 0.3919]), tensor(0.)),
 (tensor([ 3.5551, -0.4019]), tensor(0.)),
 (tensor([ 4.4508, -0.9660]), tensor(0.)),
 (tensor([1.4399, 0.9914]), tensor(0.)),
 (tensor([1.7649, 0.9812]), tensor(0.)),
 (tensor([1.8027, 0.9732]), tensor(0.)),
 (tensor([1.8332, 0.9658]), tensor(0.)),
 (tensor([ 4.7793, -0.9978]), tensor(0.)),
 (tensor([ 5.1078, -0.9229]), tensor(0.)),
 (tensor([ 3.2019, -0.0603]), tensor(0.)),
 (tensor([1.4660, 0.9945]), tensor(0.)),
 (tensor([ 5.8771, -0.3950]), tensor(0.)),
 (tensor([0.3589, 0.3512]), tensor(0.)),
 (tensor([ 3.4510, -0.3045]), tensor(0.)),
 (tensor([ 5.6290, -0.6085]), tensor(0.)),
 (tensor([1.3247, 0.9699]), tensor(0.)),
 (tensor([ 3.6500, -0.4868]), tensor(0.)),
 (tensor([ 5.2455, -0.8612]), tensor(0.)),
 (tensor([ 5.7855, -0.4774]), tensor(0.)),
 (tensor([ 4.2265, -0.8843]), tensor(0.)),
 (tensor([ 5.6922, -0.5571]), tensor(0.)),
 (tensor([2.9878, 0.1532]), tensor(0.)),
 (tensor([ 4.7333, -0.9998]), tensor(0.)),
 (tensor([ 3.7000, -0.5299]), tensor(0.)),
 (tensor([1.3499, 0.9757]), tensor(0.)),
 (tensor([ 3.7723, -0.5897]), tensor(0.)),
 (tensor([ 4.0646, -0.7974]), tensor(0.)),
 (tensor([2.5221, 0.5806]), tensor(0.)),
 (tensor([ 4.7050, -1.0000]), tensor(0.)),
 (tensor([ 5.0367, -0.9479]), tensor(0.)),
 (tensor([ 5.8355, -0.4329]), tensor(0.)),
 (tensor([ 6.0669, -0.2146]), tensor(0.)),
 (tensor([ 5.2800, -0.8432]), tensor(0.)),
 (tensor([ 4.8634, -0.9886]), tensor(0.)),
 (tensor([ 4.8306, -0.9930]), tensor(0.)),
 (tensor([ 4.3002, -0.9163]), tensor(0.)),
 (tensor([0.8006, 0.7178]), tensor(0.)),
 (tensor([1.6978, 0.9920]), tensor(0.)),
 (tensor([0.7781, 0.7019]), tensor(0.)),
 (tensor([1.1624, 0.9178]), tensor(0.)),
 (tensor([2.5093, 0.5910]), tensor(0.)),
 (tensor([ 3.4711, -0.3236]), tensor(0.)),
 (tensor([ 4.4852, -0.9743]), tensor(0.)),
 (tensor([1.3860, 0.9830]), tensor(0.)),
 (tensor([0.6134, 0.5757]), tensor(0.)),
 (tensor([0.6026, 0.5668]), tensor(0.)),
 (tensor([ 3.5354, -0.3837]), tensor(0.)),
 (tensor([ 3.2597, -0.1178]), tensor(0.)),
 (tensor([2.5126, 0.5884]), tensor(0.)),
 (tensor([1.7951, 0.9750]), tensor(0.)),
 (tensor([2.1229, 0.8514]), tensor(0.)),
 (tensor([ 3.3707, -0.2271]), tensor(0.)),
 (tensor([2.5801, 0.5324]), tensor(0.)),
 (tensor([ 3.5122, -0.3622]), tensor(0.)),
 (tensor([2.5804, 0.5322]), tensor(0.)),
 (tensor([1.0894, 0.8864]), tensor(0.)),
 (tensor([2.6662, 0.4577]), tensor(0.)),
 (tensor([1.0120, 0.8479]), tensor(0.)),
 (tensor([2.6264, 0.4927]), tensor(0.)),
 (tensor([ 5.2403, -0.8639]), tensor(0.)),
 (tensor([1.1123, 0.8967]), tensor(0.)),
 (tensor([ 5.4166, -0.7621]), tensor(0.)),
 (tensor([0.6482, 0.6038]), tensor(0.)),
 (tensor([ 3.5278, -0.3767]), tensor(0.)),
 (tensor([ 4.3000, -0.9162]), tensor(0.)),
 (tensor([ 5.5955, -0.6348]), tensor(0.)),
 (tensor([ 5.7388, -0.5179]), tensor(0.)),
 (tensor([ 3.7021, -0.5316]), tensor(0.)),
 (tensor([1.8810, 0.9523]), tensor(0.)),
 (tensor([ 6.1621, -0.1208]), tensor(0.)),
 (tensor([3.0462, 0.0953]), tensor(0.)),
 (tensor([1.7542, 0.9832]), tensor(0.)),
 (tensor([ 3.5722, -0.4174]), tensor(0.)),
 (tensor([1.6588, 0.9961]), tensor(0.)),
 (tensor([ 4.5668, -0.9894]), tensor(0.)),
 (tensor([ 3.7898, -0.6037]), tensor(0.)),
 (tensor([ 5.3887, -0.7799]), tensor(0.)),
 (tensor([3.0185, 0.1227]), tensor(0.)),
 (tensor([2.2590, 0.7724]), tensor(0.)),
 (tensor([ 5.2324, -0.8678]), tensor(0.)),
 (tensor([ 4.8812, -0.9858]), tensor(0.)),
 (tensor([0.4356, 0.4220]), tensor(0.)),
 (tensor([3.0159, 0.1253]), tensor(0.)),
 (tensor([ 5.4195, -0.7603]), tensor(0.)),
 (tensor([ 6.1751, -0.1079]), tensor(0.)),
 (tensor([2.3043, 0.7428]), tensor(0.)),
 (tensor([ 3.2476, -0.1058]), tensor(0.)),
 (tensor([2.0519, 0.8865]), tensor(0.)),
 (tensor([0.0879, 0.0878]), tensor(0.)),
 (tensor([ 5.3038, -0.8302]), tensor(0.)),
 (tensor([1.2740, 0.9563]), tensor(0.)),
 (tensor([1.4431, 0.9919]), tensor(0.)),
 (tensor([ 4.7805, -0.9977]), tensor(0.)),
 (tensor([ 4.6791, -0.9994]), tensor(0.)),
 (tensor([2.4727, 0.6201]), tensor(0.)),
 (tensor([ 5.5794, -0.6471]), tensor(0.)),
 (tensor([0.0529, 0.0529]), tensor(0.)),
 (tensor([0.4800, 0.4618]), tensor(0.)),
 (tensor([2.9136, 0.2260]), tensor(0.)),
 (tensor([ 5.2374, -0.8653]), tensor(0.)),
 (tensor([0.4742, 0.4566]), tensor(0.)),
 (tensor([ 4.9509, -0.9717]), tensor(0.)),
 (tensor([2.1109, 0.8577]), tensor(0.)),
 (tensor([2.3936, 0.6802]), tensor(0.)),
 (tensor([1.9197, 0.9398]), tensor(0.)),
 (tensor([2.6277, 0.4916]), tensor(0.)),
 (tensor([ 4.9070, -0.9811]), tensor(0.)),
 (tensor([ 5.2294, -0.8693]), tensor(0.)),
 (tensor([ 3.6529, -0.4893]), tensor(0.)),
 (tensor([1.7373, 0.9862]), tensor(0.)),
 (tensor([ 3.2200, -0.0783]), tensor(0.)),
 ...]
  • Now Plotting the training data and we can see its a sinusoidal.
plt.plot(train_data[:, 0], train_data[:, 1], ".")
[<matplotlib.lines.Line2D at 0x1e2407e4580>]
  • Here we set the batch size to be fed into our model and load our created data into train_loader.
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)

Creating a Discriminator Model Class

  • Here we are creating simple neural network model for our discriminator. In reality descriminator is a neural network which is used to classify the data.
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 556),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(556, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        output = self.model(x)
        return output
  • Instanciating Discriminator class object.
discriminator = Discriminator()

Creating a Generator Model Class

  • Here we are creating simple neural network model for our generator. In reality generator is a neural network which is used to generate the fake data that we will fed into the discriminator.
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, 2),
        )

    def forward(self, x):
        output = self.model(x)
        return output

generator = Generator()
  • Setting learning rate, number of epoch and loss function.
lr = 0.001
num_epochs = 300
loss_function = nn.BCELoss()
  • Now we will initialize a optimizer for our model which is both Adam optimizer.
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

Creating GAN Model.

  • Here we are going to create a GAN model.
  • For this we need to create a training data, and then we need to create a discriminator and generator model.
  • Our Discriminator model take two inputs that is all samples which consist of real samples which is real data and generated samples which is generated by our generative models and all_samples_labels which consist of real_samples_labels which is all 1 and generated samples label which is all 0 and of size epoch x 1.

    • real samples : These are the sinosoidal data that we have created previously.
    • generated samples : these are the data that we get from our generator.
    • all_samples_labels : This is the label that we have created with real_samples_label and generated_samples_labels .
      • real_samples_label : This is the label that we have created with all 1 of size batch x 1.
      • generated_samples_label : This is the label that we have created with all 0 of size batch x 1.
  • Training Discriminator Model

    • In this step we are training our discriminator model, and getting output as output_descriminator.
    • This output is feed into loss function along with all_samples_labels by which we calculate the loss.
    • Then we back propagate the loss and update the weights.
  • Training Generator Model

    • In this step we are training our generator model, and getting output as generated_samples.
    • Input to the generator model is latent_space_samples which is a random vector of size batch x 2.
    • Here we feed the output of generator model to the discriminator model and claculate the loss and back propagate to update the weights.
    • Here we calculate a loss by using the discriminator model output and real_samples_labels so that we discriminator detect whether a generator feeded output is real data or fake data.
for epoch in range(num_epochs):
    for n, (real_samples, _) in enumerate(train_loader):
        # Data for training the discriminator
        real_samples_labels = torch.ones((batch_size, 1))
        latent_space_samples = torch.randn((batch_size, 2))
        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((batch_size, 1))
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat(
            (real_samples_labels, generated_samples_labels)
        )

        # Training the discriminator
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples)
        loss_discriminator = loss_function(
            output_discriminator, all_samples_labels)
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Data for training the generator
        latent_space_samples = torch.randn((batch_size, 2))

        # Training the generator
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_discriminator_generated = discriminator(generated_samples)
        loss_generator = loss_function(
            output_discriminator_generated, real_samples_labels
        )
        loss_generator.backward()
        optimizer_generator.step()

        # Show loss
        if epoch % 10 == 0 and n == batch_size - 1:
            print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
            print(f"Epoch: {epoch} Loss G.: {loss_generator}")
Epoch: 0 Loss D.: 0.22545303404331207
Epoch: 0 Loss G.: 3.307260274887085
Epoch: 10 Loss D.: 0.4244134724140167
Epoch: 10 Loss G.: 1.4460361003875732
Epoch: 20 Loss D.: 0.6815066933631897
Epoch: 20 Loss G.: 0.6763569712638855
Epoch: 30 Loss D.: 0.7113456726074219
Epoch: 30 Loss G.: 0.6826677322387695
Epoch: 40 Loss D.: 0.6589072942733765
Epoch: 40 Loss G.: 0.7110073566436768
Epoch: 50 Loss D.: 0.6537125110626221
Epoch: 50 Loss G.: 0.6886794567108154
Epoch: 60 Loss D.: 0.6310231685638428
Epoch: 60 Loss G.: 0.6411446332931519
Epoch: 70 Loss D.: 0.7475214004516602
Epoch: 70 Loss G.: 0.9075831174850464
Epoch: 80 Loss D.: 0.6742514371871948
Epoch: 80 Loss G.: 0.6912871599197388
Epoch: 90 Loss D.: 0.7131425142288208
Epoch: 90 Loss G.: 0.7293562293052673
Epoch: 100 Loss D.: 0.6102262735366821
Epoch: 100 Loss G.: 0.8444703817367554
Epoch: 110 Loss D.: 0.6022018194198608
Epoch: 110 Loss G.: 0.9215141534805298
Epoch: 120 Loss D.: 0.6511936187744141
Epoch: 120 Loss G.: 0.8186532258987427
Epoch: 130 Loss D.: 0.6581251621246338
Epoch: 130 Loss G.: 0.8208433985710144
Epoch: 140 Loss D.: 0.6667916774749756
Epoch: 140 Loss G.: 0.7491060495376587
Epoch: 150 Loss D.: 0.6819881796836853
Epoch: 150 Loss G.: 0.726937472820282
Epoch: 160 Loss D.: 0.7139056921005249
Epoch: 160 Loss G.: 0.7458503246307373
Epoch: 170 Loss D.: 0.6834815740585327
Epoch: 170 Loss G.: 0.7246146202087402
Epoch: 180 Loss D.: 0.6735205054283142
Epoch: 180 Loss G.: 0.7756674885749817
Epoch: 190 Loss D.: 0.6795198917388916
Epoch: 190 Loss G.: 0.8316391706466675
Epoch: 200 Loss D.: 0.6492403149604797
Epoch: 200 Loss G.: 0.9561660289764404
Epoch: 210 Loss D.: 0.7166296243667603
Epoch: 210 Loss G.: 0.7127377390861511
Epoch: 220 Loss D.: 0.6955048441886902
Epoch: 220 Loss G.: 0.6957291960716248
Epoch: 230 Loss D.: 0.6621055006980896
Epoch: 230 Loss G.: 0.8285194039344788
Epoch: 240 Loss D.: 0.6331629157066345
Epoch: 240 Loss G.: 0.7604714632034302
Epoch: 250 Loss D.: 0.6775591373443604
Epoch: 250 Loss G.: 0.6923383474349976
Epoch: 260 Loss D.: 0.6518263816833496
Epoch: 260 Loss G.: 0.7644563913345337
Epoch: 270 Loss D.: 0.6683076024055481
Epoch: 270 Loss G.: 0.8083503246307373
Epoch: 280 Loss D.: 0.7086775302886963
Epoch: 280 Loss G.: 0.6793789267539978
Epoch: 290 Loss D.: 0.7268669009208679
Epoch: 290 Loss G.: 0.6526477336883545
  • Now we can see the output of generator model is near perfect.
latent_space_samples = torch.randn(500, 2)
generated_samples = generator(latent_space_samples)
generated_samples = generated_samples.detach()
plt.plot(generated_samples[:, 0], generated_samples[:, 1], ".")
[<matplotlib.lines.Line2D at 0x1e249754ee0>]
  • This is the real data plot which is sinosoidal.
plt.plot(train_data[:, 0], train_data[:, 1], ".")
[<matplotlib.lines.Line2D at 0x7f38c44bf890>]
  • From above result we can see that our generator model is working well to create a fake data as a sinosoidal data.

Summary

  • In this blog post we have seen how simple GAN can be implemented. First creating a sinusoidal data, and then creating a discriminator and generator model and at the end merging them to create a GAN model.