{ "cells": [ { "cell_type": "markdown", "id": "3fdec832", "metadata": {}, "source": [ "# DeepProbLog #\n", "\n", "DeepProbLog is an extension of ProbLog that integrates Probabilistic Logic Programming with Deep Learning.\n", "\n", "The git repo of the project can be downloaded by: \n", "\n", "git clone https://bitbucket.org/problog/deepproblog.git" ] }, { "cell_type": "markdown", "id": "0621ecc2", "metadata": {}, "source": [ "## Example: MNIST Digit Addition ##\n", "\n", "In this experiment, the task is to classify the sum of two lists of MNIST digits \n", "representing multi-digit numbers. " ] }, { "cell_type": "markdown", "id": "9d23fed3", "metadata": {}, "source": [ "First, we create a ProbLog file containing the logic part of the program. The file will be saved as `tutorial/multi_digit.pl`.\n", "\n", "```prolog\n", "nn(mnist_net,[X],Y,[0,1,2,3,4,5,6,7,8,9]) :: digit(X,Y).\n", "\n", "number([],Result,Result).\n", "number([H|T],Acc,Result) :- digit(H,Nr), \n", " Acc2 is Nr+10*Acc,\n", " number(T,Acc2,Result).\n", "number(X,Y) :- number(X,0,Y).\n", "\n", "addition(X,Y,Z) :- number(X,X2), number(Y,Y2), Z is X2+Y2.\n", "```" ] }, { "cell_type": "markdown", "id": "5f8c7371", "metadata": {}, "source": [ "Then, we create the queries files for both train and test, connecting MNIST images to instances of the `addition` Prolog predicate. " ] }, { "cell_type": "code", "execution_count": 13, "id": "588057e2", "metadata": {}, "outputs": [], "source": [ "from torchvision.datasets import MNIST\n", "import random \n", "\n", "datasets = {'train': MNIST(root='data/MNIST', train=True, download=True),\n", " 'test': MNIST(root='data/MNIST', train=False, download=True)}\n", "\n", "# dataset name is train or test\n", "# op is a (lambda) function for the operation to be learned\n", "# length is the number of digits to be used\n", "# out is the output file name\n", "def generate_examples(dataset_name, op, length, out):\n", " dataset = datasets[dataset_name]\n", " indices = list(range(len(dataset)))\n", " random.shuffle(indices)\n", " i = iter(indices)\n", " examples = []\n", " while (True):\n", " try:\n", " examples.append(next_example(i, dataset, op, length))\n", " # exception is raised when all digits in dataset have been used\n", " except StopIteration:\n", " break\n", " save_examples(dataset_name, examples, out)\n", "\n", "def next_example(i, dataset, op, length):\n", " nr1, n1 = next_number(i, dataset, length)\n", " nr2, n2 = next_number(i, dataset, length)\n", " return nr1, nr2, op(n1, n2)\n", "\n", "def next_number(i, dataset, nr_digits):\n", " n = 0\n", " nr = []\n", " for _ in range(nr_digits):\n", " x = next(i)\n", " _, c = dataset[x] # c is the digit that the image represents \n", " n = n * 10 + c # the number is incrementally built from the sequence of its digits\n", " nr.append(str(x)) # nr is the list of ids of the digit images\n", " return nr, n\n", "\n", "def save_examples(dataset_name, examples, out):\n", " with open(out, 'w') as f:\n", " for example in examples:\n", " # number encoded as e.g. (test(9150),test(6809),test(1586))\n", " args1 = tuple('{}({})'.format(dataset_name, e) for e in example[0])\n", " args2 = tuple('{}({})'.format(dataset_name, e) for e in example[1])\n", " # example encoded as e.g.\n", " # addition([test(9150),test(6809),test(1586)], [test(114),test(2039),test(5872)], 1574).\n", " f.write('addition([{}], [{}], {}).\\n'.format(','.join(args1), \n", " ','.join(args2), \n", " example[2]))\n", " \n", "generate_examples('train', lambda x, y: x + y, 1, 'tutorial/train.txt')\n", "generate_examples('test', lambda x, y: x + y, 3, 'tutorial/test.txt')" ] }, { "cell_type": "markdown", "id": "183b32c1", "metadata": {}, "source": [ "Train and test queries look like this:\n", "\n", "```\n", "addition([train(2764)], [train(8527)], 9).\n", "addition([train(27012)], [train(56713)], 10).\n", "...\n", "\n", "addition([test(9271),test(5812),test(9788)], [test(4522),test(8572),test(3555)], 1575).\n", "addition([test(4052),test(7966),test(5512)], [test(4884),test(5655),test(133)], 1554).\n", "...\n", "```" ] }, { "cell_type": "markdown", "id": "eb2a1bd6", "metadata": {}, "source": [ "We can now define a python class implementing a standard CNN for MNIST images, and a neural predicate connecting the image id (as found in the query) to the corresponding image and the sending it to the neural net. " ] }, { "cell_type": "code", "execution_count": 14, "id": "1f7f67aa", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.autograd import Variable\n", "\n", "class MNIST_Net(nn.Module):\n", " def __init__(self, N=10):\n", " super(MNIST_Net, self).__init__()\n", " self.encoder = nn.Sequential(\n", " nn.Conv2d(1, 6, 5),\n", " nn.MaxPool2d(2, 2), # 6 24 24 -> 6 12 12\n", " nn.ReLU(True),\n", " nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8\n", " nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4\n", " nn.ReLU(True)\n", " )\n", " self.classifier = nn.Sequential(\n", " nn.Linear(16 * 4 * 4, 120),\n", " nn.ReLU(),\n", " nn.Linear(120, 84),\n", " nn.ReLU(),\n", " nn.Linear(84, N),\n", " nn.Softmax(1)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.encoder(x)\n", " x = x.view(-1, 16 * 4 * 4)\n", " x = self.classifier(x)\n", " return x\n", " \n", "def neural_predicate(network, i):\n", " # i is something like train(2764) or test(4052)\n", " dataset = str(i.functor)\n", " i = int(i.args[0])\n", " if dataset == 'train':\n", " d, l = mnist_train_data[i]\n", " elif dataset == 'test':\n", " d, l = mnist_test_data[i]\n", " d = Variable(d.unsqueeze(0))\n", " output = network.net(d)\n", " return output.squeeze(0)" ] }, { "cell_type": "markdown", "id": "3eec06a6", "metadata": {}, "source": [ "We can now load the mnist data, the queries and the problog file.\n", "Note that `mnist_train_data` and `mnist_test_data` are global variables used inside `neural_predicate`." ] }, { "cell_type": "code", "execution_count": 15, "id": "f7c83283", "metadata": {}, "outputs": [], "source": [ "from torchvision.datasets import MNIST\n", "import torchvision.transforms as transforms\n", "from data_loader import load\n", "\n", "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5, ))])\n", "mnist_train_data = MNIST(root='data/MNIST', train=True, download=True,transform=transform)\n", "mnist_test_data = MNIST(root='data/MNIST', train=False, download=True,transform=transform)\n", "\n", "train_queries = load('tutorial/train.txt')\n", "test_queries = load('tutorial/test.txt')[:100]\n", "\n", "with open('tutorial/multi_digit.pl') as f:\n", " problog_string = f.read()" ] }, { "cell_type": "markdown", "id": "f6676fa0", "metadata": {}, "source": [ "Finally, we can create the network and the DeepProbLog model with the network as neural predicate, and train it as a standard torch model." ] }, { "cell_type": "code", "execution_count": 16, "id": "7c4ce6f1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training for 1 epochs (30000 iterations).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/anaconda/lib/python3.6/site-packages/torch/nn/modules/module.py:795: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n", " warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Accuracy 0.0\n", "Accuracy: [('Accuracy', 0.0)]\n", "Epoch 1\n", "Iteration: 100 \tAverage Loss: 2.7413997681308944\n", "Iteration: 200 \tAverage Loss: 2.822181631299884\n", "Iteration: 300 \tAverage Loss: 2.7558953509461053\n", "Iteration: 400 \tAverage Loss: 2.641749291159689\n", "Iteration: 500 \tAverage Loss: 2.4830741851766236\n", "Iteration: 600 \tAverage Loss: 2.327804147989043\n", "Iteration: 700 \tAverage Loss: 2.1164256080170687\n", "Iteration: 800 \tAverage Loss: 2.234518181674348\n", "Iteration: 900 \tAverage Loss: 2.0504583209665026\n", "Iteration: 1000 \tAverage Loss: 1.7014140593617073\n", "Accuracy 0.15\n", "Accuracy: [('Accuracy', 0.15)]\n", "Iteration: 1100 \tAverage Loss: 1.4452080341853497\n", "Iteration: 1200 \tAverage Loss: 1.4222283941174698\n", "Iteration: 1300 \tAverage Loss: 0.8295922330291584\n", "Iteration: 1400 \tAverage Loss: 1.120169151657204\n", "Iteration: 1500 \tAverage Loss: 0.5691411603139906\n", "Iteration: 1600 \tAverage Loss: 0.7857685772432913\n", "Iteration: 1700 \tAverage Loss: 0.8248206990974919\n", "Iteration: 1800 \tAverage Loss: 0.5987613001722083\n", "Iteration: 1900 \tAverage Loss: 0.5953608247291569\n", "Iteration: 2000 \tAverage Loss: 0.5034975361035011\n", "Accuracy 0.61\n", "Accuracy: [('Accuracy', 0.61)]\n", "Iteration: 2100 \tAverage Loss: 0.5188218929677152\n", "Iteration: 2200 \tAverage Loss: 0.523973429487216\n", "Iteration: 2300 \tAverage Loss: 0.41990364064317265\n", "Iteration: 2400 \tAverage Loss: 0.4417395804486099\n", "Iteration: 2500 \tAverage Loss: 0.5128849973724388\n", "Iteration: 2600 \tAverage Loss: 0.3416890195978941\n", "Iteration: 2700 \tAverage Loss: 0.5345317742810108\n", "Iteration: 2800 \tAverage Loss: 0.5050523024390245\n", "Iteration: 2900 \tAverage Loss: 0.5693561440351348\n", "Iteration: 3000 \tAverage Loss: 0.5945739425526969\n", "Accuracy 0.78\n", "Accuracy: [('Accuracy', 0.78)]\n", "Iteration: 3100 \tAverage Loss: 0.3202775028317631\n", "Iteration: 3200 \tAverage Loss: 0.3049114087490498\n", "Iteration: 3300 \tAverage Loss: 0.29740947723341893\n", "Iteration: 3400 \tAverage Loss: 0.3677314219155365\n", "Iteration: 3500 \tAverage Loss: 0.17985040430688012\n", "Iteration: 3600 \tAverage Loss: 0.37336326862481406\n", "Iteration: 3700 \tAverage Loss: 0.3027671797362257\n", "Iteration: 3800 \tAverage Loss: 0.40385123920772215\n", "Iteration: 3900 \tAverage Loss: 0.19532375634360097\n", "Iteration: 4000 \tAverage Loss: 0.27540275159160804\n", "Accuracy 0.64\n", "Accuracy: [('Accuracy', 0.64)]\n", "Iteration: 4100 \tAverage Loss: 0.24825260026916843\n", "Iteration: 4200 \tAverage Loss: 0.2793046616772926\n", "Iteration: 4300 \tAverage Loss: 0.23741380762783032\n", "Iteration: 4400 \tAverage Loss: 0.2682812208307791\n", "Iteration: 4500 \tAverage Loss: 0.2925631595225186\n", "Iteration: 4600 \tAverage Loss: 0.19133973434923257\n", "Iteration: 4700 \tAverage Loss: 0.320692249864229\n", "Iteration: 4800 \tAverage Loss: 0.4859827643472217\n", "Iteration: 4900 \tAverage Loss: 0.32097451510841407\n", "Iteration: 5000 \tAverage Loss: 0.22904597450428293\n", "Accuracy 0.78\n", "Accuracy: [('Accuracy', 0.78)]\n", "Iteration: 5100 \tAverage Loss: 0.19584102959917815\n", "Iteration: 5200 \tAverage Loss: 0.3108532107924113\n", "Iteration: 5300 \tAverage Loss: 0.2823863028555626\n", "Iteration: 5400 \tAverage Loss: 0.3605731029965517\n", "Iteration: 5500 \tAverage Loss: 0.322385318502695\n", "Iteration: 5600 \tAverage Loss: 0.2265811622449126\n", "Iteration: 5700 \tAverage Loss: 0.24071394576269634\n", "Iteration: 5800 \tAverage Loss: 0.16813618426018462\n", "Iteration: 5900 \tAverage Loss: 0.22782210949362078\n", "Iteration: 6000 \tAverage Loss: 0.26975794411046794\n", "Accuracy 0.76\n", "Accuracy: [('Accuracy', 0.76)]\n", "Iteration: 6100 \tAverage Loss: 0.2461363996404004\n", "Iteration: 6200 \tAverage Loss: 0.12465014836239988\n", "Iteration: 6300 \tAverage Loss: 0.32069424261125934\n", "Iteration: 6400 \tAverage Loss: 0.3930008014971391\n", "Iteration: 6500 \tAverage Loss: 0.2959384059876967\n", "Iteration: 6600 \tAverage Loss: 0.4572670147461945\n", "Iteration: 6700 \tAverage Loss: 0.1794271599814699\n", "Iteration: 6800 \tAverage Loss: 0.3149823348374577\n", "Iteration: 6900 \tAverage Loss: 0.44661863399384333\n", "Iteration: 7000 \tAverage Loss: 0.3002844598254217\n", "Accuracy 0.82\n", "Accuracy: [('Accuracy', 0.82)]\n", "Iteration: 7100 \tAverage Loss: 0.24582124245217352\n", "Iteration: 7200 \tAverage Loss: 0.19918088183796878\n", "Iteration: 7300 \tAverage Loss: 0.3103837369357009\n", "Iteration: 7400 \tAverage Loss: 0.23128589339601952\n", "Iteration: 7500 \tAverage Loss: 0.18252199195827248\n", "Iteration: 7600 \tAverage Loss: 0.47914493928388213\n", "Iteration: 7700 \tAverage Loss: 0.23675691093399734\n", "Iteration: 7800 \tAverage Loss: 0.1816176635210822\n", "Iteration: 7900 \tAverage Loss: 0.1854788445087472\n", "Iteration: 8000 \tAverage Loss: 0.35170606901089374\n", "Accuracy 0.74\n", "Accuracy: [('Accuracy', 0.74)]\n", "Iteration: 8100 \tAverage Loss: 0.30293682275996453\n", "Iteration: 8200 \tAverage Loss: 0.3984703775867888\n", "Iteration: 8300 \tAverage Loss: 0.13351786149086972\n", "Iteration: 8400 \tAverage Loss: 0.364155371310647\n", "Iteration: 8500 \tAverage Loss: 0.11352509567746091\n", "Iteration: 8600 \tAverage Loss: 0.24082989771012642\n", "Iteration: 8700 \tAverage Loss: 0.24505272113038357\n", "Iteration: 8800 \tAverage Loss: 0.16560642368460626\n", "Iteration: 8900 \tAverage Loss: 0.2588549286203799\n", "Iteration: 9000 \tAverage Loss: 0.2426359481336863\n", "Accuracy 0.8\n", "Accuracy: [('Accuracy', 0.8)]\n", "Iteration: 9100 \tAverage Loss: 0.2567513272541813\n", "Iteration: 9200 \tAverage Loss: 0.22150124369625077\n", "Iteration: 9300 \tAverage Loss: 0.2253330826665129\n", "Iteration: 9400 \tAverage Loss: 0.14539374694836135\n", "Iteration: 9500 \tAverage Loss: 0.3622619653957859\n", "Iteration: 9600 \tAverage Loss: 0.19034479432067075\n", "Iteration: 9700 \tAverage Loss: 0.32536212020428124\n", "Iteration: 9800 \tAverage Loss: 0.21895144834528668\n", "Iteration: 9900 \tAverage Loss: 0.15719734055598789\n", "Writing snapshot to model_iter_10000.mdl\n", "Iteration: 10000 \tAverage Loss: 0.3079834116289314\n", "Accuracy 0.84\n", "Accuracy: [('Accuracy', 0.84)]\n", "Iteration: 10100 \tAverage Loss: 0.19060910254159258\n", "Iteration: 10200 \tAverage Loss: 0.27193299648845554\n", "Iteration: 10300 \tAverage Loss: 0.21544136691614593\n", "Iteration: 10400 \tAverage Loss: 0.0657081925431275\n", "Iteration: 10500 \tAverage Loss: 0.2453694732276121\n", "Iteration: 10600 \tAverage Loss: 0.28914190147952884\n", "Iteration: 10700 \tAverage Loss: 0.30542481182254805\n", "Iteration: 10800 \tAverage Loss: 0.20195578005447457\n", "Iteration: 10900 \tAverage Loss: 0.26921456883119416\n", "Iteration: 11000 \tAverage Loss: 0.1392901474057747\n", "Accuracy 0.8\n", "Accuracy: [('Accuracy', 0.8)]\n", "Iteration: 11100 \tAverage Loss: 0.143647419993001\n", "Iteration: 11200 \tAverage Loss: 0.2470866667844499\n", "Iteration: 11300 \tAverage Loss: 0.26351741912245047\n", "Iteration: 11400 \tAverage Loss: 0.1348753336064603\n", "Iteration: 11500 \tAverage Loss: 0.193422043725449\n", "Iteration: 11600 \tAverage Loss: 0.16633427606278675\n", "Iteration: 11700 \tAverage Loss: 0.22747680769393935\n", "Iteration: 11800 \tAverage Loss: 0.14133106481572055\n", "Iteration: 11900 \tAverage Loss: 0.2919465145036055\n", "Iteration: 12000 \tAverage Loss: 0.19366950438802508\n", "Accuracy 0.88\n", "Accuracy: [('Accuracy', 0.88)]\n", "Iteration: 12100 \tAverage Loss: 0.3169109344771518\n", "Iteration: 12200 \tAverage Loss: 0.19741893322292514\n", "Iteration: 12300 \tAverage Loss: 0.32080870167704384\n", "Iteration: 12400 \tAverage Loss: 0.3675305361183614\n", "Iteration: 12500 \tAverage Loss: 0.18325031092887906\n", "Iteration: 12600 \tAverage Loss: 0.14770074659939977\n", "Iteration: 12700 \tAverage Loss: 0.27761534127688337\n", "Iteration: 12800 \tAverage Loss: 0.2590366166859017\n", "Iteration: 12900 \tAverage Loss: 0.28034352902054094\n", "Iteration: 13000 \tAverage Loss: 0.12950466241226358\n", "Accuracy 0.85\n", "Accuracy: [('Accuracy', 0.85)]\n", "Iteration: 13100 \tAverage Loss: 0.2673317791614003\n", "Iteration: 13200 \tAverage Loss: 0.1173330592025138\n", "Iteration: 13300 \tAverage Loss: 0.1300282790669049\n", "Iteration: 13400 \tAverage Loss: 0.17167217437039287\n", "Iteration: 13500 \tAverage Loss: 0.08524931142603429\n", "Iteration: 13600 \tAverage Loss: 0.19513037404078257\n", "Iteration: 13700 \tAverage Loss: 0.06395972544878306\n", "Iteration: 13800 \tAverage Loss: 0.20439881240971136\n", "Iteration: 13900 \tAverage Loss: 0.4312087404015388\n", "Iteration: 14000 \tAverage Loss: 0.18846608651821486\n", "Accuracy 0.84\n", "Accuracy: [('Accuracy', 0.84)]\n", "Iteration: 14100 \tAverage Loss: 0.11301485910286006\n", "Iteration: 14200 \tAverage Loss: 0.14940368180403488\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 14300 \tAverage Loss: 0.1883522669075476\n", "Iteration: 14400 \tAverage Loss: 0.23984416674510384\n", "Iteration: 14500 \tAverage Loss: 0.27934753153450936\n", "Iteration: 14600 \tAverage Loss: 0.14659613511216343\n", "Iteration: 14700 \tAverage Loss: 0.09915756395264293\n", "Iteration: 14800 \tAverage Loss: 0.21752628844546773\n", "Iteration: 14900 \tAverage Loss: 0.16969644666065645\n", "Iteration: 15000 \tAverage Loss: 0.17419191548977137\n", "Accuracy 0.86\n", "Accuracy: [('Accuracy', 0.86)]\n", "Iteration: 15100 \tAverage Loss: 0.132667199302011\n", "Iteration: 15200 \tAverage Loss: 0.06084775712248787\n", "Iteration: 15300 \tAverage Loss: 0.28009641435711685\n", "Iteration: 15400 \tAverage Loss: 0.22496273861193522\n", "Iteration: 15500 \tAverage Loss: 0.17071578126496775\n", "Iteration: 15600 \tAverage Loss: 0.12306111240204767\n", "Iteration: 15700 \tAverage Loss: 0.13521139393961223\n", "Iteration: 15800 \tAverage Loss: 0.2646886639959432\n", "Iteration: 15900 \tAverage Loss: 0.06582526942718006\n", "Iteration: 16000 \tAverage Loss: 0.4805922915421103\n", "Accuracy 0.87\n", "Accuracy: [('Accuracy', 0.87)]\n", "Iteration: 16100 \tAverage Loss: 0.30833755784650857\n", "Iteration: 16200 \tAverage Loss: 0.1850719239898596\n", "Iteration: 16300 \tAverage Loss: 0.17863257396191148\n", "Iteration: 16400 \tAverage Loss: 0.1991606463181808\n", "Iteration: 16500 \tAverage Loss: 0.06643169739114059\n", "Iteration: 16600 \tAverage Loss: 0.20384018973517792\n", "Iteration: 16700 \tAverage Loss: 0.27035049948107814\n", "Iteration: 16800 \tAverage Loss: 0.12040008847753782\n", "Iteration: 16900 \tAverage Loss: 0.2622468688842014\n", "Iteration: 17000 \tAverage Loss: 0.18025100402802668\n", "Accuracy 0.85\n", "Accuracy: [('Accuracy', 0.85)]\n", "Iteration: 17100 \tAverage Loss: 0.23808840644263604\n", "Iteration: 17200 \tAverage Loss: 0.2041076692219245\n", "Iteration: 17300 \tAverage Loss: 0.30260700736093127\n", "Iteration: 17400 \tAverage Loss: 0.20777963909413558\n", "Iteration: 17500 \tAverage Loss: 0.24664657779618895\n", "Iteration: 17600 \tAverage Loss: 0.14479609043156338\n", "Iteration: 17700 \tAverage Loss: 0.21905160831023562\n", "Iteration: 17800 \tAverage Loss: 0.15346063853733716\n", "Iteration: 17900 \tAverage Loss: 0.1782339265297359\n", "Iteration: 18000 \tAverage Loss: 0.29375429019719695\n", "Accuracy 0.82\n", "Accuracy: [('Accuracy', 0.82)]\n", "Iteration: 18100 \tAverage Loss: 0.1883395353325695\n", "Iteration: 18200 \tAverage Loss: 0.2016204574636708\n", "Iteration: 18300 \tAverage Loss: 0.18355585875886465\n", "Iteration: 18400 \tAverage Loss: 0.1276961119543715\n", "Iteration: 18500 \tAverage Loss: 0.10454972295383456\n", "Iteration: 18600 \tAverage Loss: 0.12035369979167472\n", "Iteration: 18700 \tAverage Loss: 0.03511755790076845\n", "Iteration: 18800 \tAverage Loss: 0.2571861505326727\n", "Iteration: 18900 \tAverage Loss: 0.13771844387484936\n", "Iteration: 19000 \tAverage Loss: 0.1881724938888093\n", "Accuracy 0.85\n", "Accuracy: [('Accuracy', 0.85)]\n", "Iteration: 19100 \tAverage Loss: 0.11988065006854548\n", "Iteration: 19200 \tAverage Loss: 0.1566391568918511\n", "Iteration: 19300 \tAverage Loss: 0.1400116846222099\n", "Iteration: 19400 \tAverage Loss: 0.19807367180215832\n", "Iteration: 19500 \tAverage Loss: 0.14658791191240467\n", "Iteration: 19600 \tAverage Loss: 0.36836803273171315\n", "Iteration: 19700 \tAverage Loss: 0.22936162349438244\n", "Iteration: 19800 \tAverage Loss: 0.30888039175086635\n", "Iteration: 19900 \tAverage Loss: 0.12732230576683562\n", "Writing snapshot to model_iter_20000.mdl\n", "Iteration: 20000 \tAverage Loss: 0.15057461729095953\n", "Accuracy 0.89\n", "Accuracy: [('Accuracy', 0.89)]\n", "Iteration: 20100 \tAverage Loss: 0.1734634611421744\n", "Iteration: 20200 \tAverage Loss: 0.18010946408314543\n", "Iteration: 20300 \tAverage Loss: 0.1720534752912393\n", "Iteration: 20400 \tAverage Loss: 0.27533175362800066\n", "Iteration: 20500 \tAverage Loss: 0.21156122666956478\n", "Iteration: 20600 \tAverage Loss: 0.23225283561954033\n", "Iteration: 20700 \tAverage Loss: 0.0526244190425389\n", "Iteration: 20800 \tAverage Loss: 0.223016134143968\n", "Iteration: 20900 \tAverage Loss: 0.13261611045244254\n", "Iteration: 21000 \tAverage Loss: 0.15689545593867152\n", "Accuracy 0.81\n", "Accuracy: [('Accuracy', 0.81)]\n", "Iteration: 21100 \tAverage Loss: 0.1493440130899875\n", "Iteration: 21200 \tAverage Loss: 0.034721575487417645\n", "Iteration: 21300 \tAverage Loss: 0.2417561484834072\n", "Iteration: 21400 \tAverage Loss: 0.18475796524271668\n", "Iteration: 21500 \tAverage Loss: 0.031167932140879095\n", "Iteration: 21600 \tAverage Loss: 0.018396504385441655\n", "Iteration: 21700 \tAverage Loss: 0.1356240585069712\n", "Iteration: 21800 \tAverage Loss: 0.11248268850723653\n", "Iteration: 21900 \tAverage Loss: 0.26317313884128457\n", "Iteration: 22000 \tAverage Loss: 0.14646365805051625\n", "Accuracy 0.84\n", "Accuracy: [('Accuracy', 0.84)]\n", "Iteration: 22100 \tAverage Loss: 0.17701286528319024\n", "Iteration: 22200 \tAverage Loss: 0.20506601808716987\n", "Iteration: 22300 \tAverage Loss: 0.13796745352811718\n", "Iteration: 22400 \tAverage Loss: 0.17956407271188415\n", "Iteration: 22500 \tAverage Loss: 0.14661292138772486\n", "Iteration: 22600 \tAverage Loss: 0.25346146147751275\n", "Iteration: 22700 \tAverage Loss: 0.22640644539894247\n", "Iteration: 22800 \tAverage Loss: 0.21639704825103181\n", "Iteration: 22900 \tAverage Loss: 0.1255657285978953\n", "Iteration: 23000 \tAverage Loss: 0.1402354362468853\n", "Accuracy 0.85\n", "Accuracy: [('Accuracy', 0.85)]\n", "Iteration: 23100 \tAverage Loss: 0.0780190291692648\n", "Iteration: 23200 \tAverage Loss: 0.15345504421361975\n", "Iteration: 23300 \tAverage Loss: 0.11137553041931855\n", "Iteration: 23400 \tAverage Loss: 0.24618553116287875\n", "Iteration: 23500 \tAverage Loss: 0.08965773598525613\n", "Iteration: 23600 \tAverage Loss: 0.13335577233170212\n", "Iteration: 23700 \tAverage Loss: 0.11715168683586741\n", "Iteration: 23800 \tAverage Loss: 0.09496285436221114\n", "Iteration: 23900 \tAverage Loss: 0.24673271571999464\n", "Iteration: 24000 \tAverage Loss: 0.14756480740603792\n", "Accuracy 0.86\n", "Accuracy: [('Accuracy', 0.86)]\n", "Iteration: 24100 \tAverage Loss: 0.12757277554665158\n", "Iteration: 24200 \tAverage Loss: 0.17855245984663007\n", "Iteration: 24300 \tAverage Loss: 0.22687696941559882\n", "Iteration: 24400 \tAverage Loss: 0.10001853334363625\n", "Iteration: 24500 \tAverage Loss: 0.2971216761602711\n", "Iteration: 24600 \tAverage Loss: 0.25564607663523\n", "Iteration: 24700 \tAverage Loss: 0.24380333143939992\n", "Iteration: 24800 \tAverage Loss: 0.12218807585846213\n", "Iteration: 24900 \tAverage Loss: 0.22380610824532757\n", "Iteration: 25000 \tAverage Loss: 0.1492664363996664\n", "Accuracy 0.86\n", "Accuracy: [('Accuracy', 0.86)]\n", "Iteration: 25100 \tAverage Loss: 0.02358032236340111\n", "Iteration: 25200 \tAverage Loss: 0.15850524716711528\n", "Iteration: 25300 \tAverage Loss: 0.2595088638207099\n", "Iteration: 25400 \tAverage Loss: 0.1680172013918698\n", "Iteration: 25500 \tAverage Loss: 0.15741204367738199\n", "Iteration: 25600 \tAverage Loss: 0.16132882901189471\n", "Iteration: 25700 \tAverage Loss: 0.11930986991051926\n", "Iteration: 25800 \tAverage Loss: 0.10873162083489099\n", "Iteration: 25900 \tAverage Loss: 0.12927023988626904\n", "Iteration: 26000 \tAverage Loss: 0.06322001776286297\n", "Accuracy 0.85\n", "Accuracy: [('Accuracy', 0.85)]\n", "Iteration: 26100 \tAverage Loss: 0.25666175799860885\n", "Iteration: 26200 \tAverage Loss: 0.3134930123151474\n", "Iteration: 26300 \tAverage Loss: 0.16845278544195974\n", "Iteration: 26400 \tAverage Loss: 0.3404359265040135\n", "Iteration: 26500 \tAverage Loss: 0.22849257286027858\n", "Iteration: 26600 \tAverage Loss: 0.20360773151754682\n", "Iteration: 26700 \tAverage Loss: 0.2567829790227219\n", "Iteration: 26800 \tAverage Loss: 0.17790569837544923\n", "Iteration: 26900 \tAverage Loss: 0.09161017333153301\n", "Iteration: 27000 \tAverage Loss: 0.05346649254706619\n", "Accuracy 0.84\n", "Accuracy: [('Accuracy', 0.84)]\n", "Iteration: 27100 \tAverage Loss: 0.12817687724159785\n", "Iteration: 27200 \tAverage Loss: 0.1604917152484606\n", "Iteration: 27300 \tAverage Loss: 0.15668408897977776\n", "Iteration: 27400 \tAverage Loss: 0.13458660548996124\n", "Iteration: 27500 \tAverage Loss: 0.048097379389958256\n", "Iteration: 27600 \tAverage Loss: 0.15975369059084843\n", "Iteration: 27700 \tAverage Loss: 0.12483233423118115\n", "Iteration: 27800 \tAverage Loss: 0.10082192472999013\n", "Iteration: 27900 \tAverage Loss: 0.07534117738025596\n", "Iteration: 28000 \tAverage Loss: 0.062182813249212904\n", "Accuracy 0.86\n", "Accuracy: [('Accuracy', 0.86)]\n", "Iteration: 28100 \tAverage Loss: 0.09248879025218955\n", "Iteration: 28200 \tAverage Loss: 0.25376271673263767\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 28300 \tAverage Loss: 0.13782404585076444\n", "Iteration: 28400 \tAverage Loss: 0.08818393756758906\n", "Iteration: 28500 \tAverage Loss: 0.1567092426496749\n", "Iteration: 28600 \tAverage Loss: 0.11117896665074617\n", "Iteration: 28700 \tAverage Loss: 0.18585810289019405\n", "Iteration: 28800 \tAverage Loss: 0.1259237184502742\n", "Iteration: 28900 \tAverage Loss: 0.11836703749145286\n", "Iteration: 29000 \tAverage Loss: 0.031844159575524264\n", "Accuracy 0.84\n", "Accuracy: [('Accuracy', 0.84)]\n", "Iteration: 29100 \tAverage Loss: 0.1102448297319265\n", "Iteration: 29200 \tAverage Loss: 0.0699499300865066\n", "Iteration: 29300 \tAverage Loss: 0.0830309191303204\n", "Iteration: 29400 \tAverage Loss: 0.34941493655585887\n", "Iteration: 29500 \tAverage Loss: 0.06943456046559118\n", "Iteration: 29600 \tAverage Loss: 0.08473602243635145\n", "Iteration: 29700 \tAverage Loss: 0.09539665225016306\n", "Iteration: 29800 \tAverage Loss: 0.042837079973181647\n", "Iteration: 29900 \tAverage Loss: 0.18574071140404833\n", "Writing snapshot to model_iter_30000.mdl\n", "Iteration: 30000 \tAverage Loss: 0.34288293634919803\n", "Accuracy 0.79\n", "Accuracy: [('Accuracy', 0.79)]\n", "Epoch time: 5540.953398942947\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from train import train_model\n", "from network import Network\n", "from model import Model\n", "from optimizer import Optimizer\n", "\n", "def test(model):\n", " acc = model.accuracy(test_queries, test=True)\n", " print('Accuracy: ', acc)\n", " return [('accuracy', acc)]\n", "\n", "network = MNIST_Net()\n", "# Network is a DeepProbLog class that wraps a pytorch networks and interfaces with ProbLog\n", "net = Network(network, 'mnist_net', neural_predicate)\n", "net.optimizer = torch.optim.Adam(network.parameters(), lr=0.001)\n", "# Model is a DeepProbLog class that combines reasoning via the ProbLog code \n", "# and neural processing via a list of Network objects\n", "model = Model(problog_string, [net], caching=False)\n", "optimizer = Optimizer(model, 2)\n", "\n", "train_model(model, train_queries, 1, optimizer, test_iter=1000, test=test, snapshot_iter=10000)" ] }, { "cell_type": "markdown", "id": "32da590c", "metadata": {}, "source": [ "## Example: a + b - c = d \n", "\n", "Let's change the code to answer queries of the type `a + b - c = ?`" ] }, { "cell_type": "markdown", "id": "67bd5371", "metadata": {}, "source": [ "First, we create a ProbLog file containing the logic part of the program. The file will be saved as `tutorial/addition_subtraction.pl`.\n", "\n", "```prolog\n", "nn(mnist_net,[X],Y,[0,1,2,3,4,5,6,7,8,9]) :: digit(X,Y).\n", "\n", "number([],Result,Result).\n", "number([H|T],Acc,Result) :- digit(H,Nr), \n", " Acc2 is Nr+10*Acc,\n", " number(T,Acc2,Result).\n", "number(X,Y) :- number(X,0,Y).\n", "\n", "addition(X,Y,Z) :- number(X,N1), number(Y,N2), Z is N1+N2.\n", " \n", "addition_subtraction(A,B,C,D) :- addition(A,B,N1), number(C,N2), D is N1-N2. \n", "```" ] }, { "cell_type": "markdown", "id": "e13a30b4", "metadata": {}, "source": [ "Then, we create the queries files for both train and test, connecting MNIST images to instances of the `addition_subtraction` Prolog predicate." ] }, { "cell_type": "code", "execution_count": 17, "id": "ccb20d90", "metadata": {}, "outputs": [], "source": [ "def next_example(i, dataset, op, length):\n", " nr1, n1 = next_number(i, dataset, length)\n", " nr2, n2 = next_number(i, dataset, length)\n", " res = -1\n", " # make sure the result is non-negative\n", " while res < 0:\n", " nr3, n3 = next_number(i, dataset, length)\n", " res = op(n1, n2, n3)\n", " return nr1, nr2, nr3, res \n", "\n", "def save_examples(dataset_name, examples, out):\n", " with open(out, 'w') as f:\n", " for example in examples:\n", " args1 = tuple('{}({})'.format(dataset_name, e) for e in example[0])\n", " args2 = tuple('{}({})'.format(dataset_name, e) for e in example[1])\n", " args3 = tuple('{}({})'.format(dataset_name, e) for e in example[2])\n", " f.write('addition_subtraction([{}], [{}], [{}], {}).\\n'.format(','.join(args1), \n", " ','.join(args2), \n", " ','.join(args3), \n", " example[3]))\n", " \n", "generate_examples('train', lambda x, y, z: x + y - z, 1, 'tutorial/train.txt')\n", "generate_examples('test', lambda x, y, z: x + y - z, 3, 'tutorial/test.txt')" ] }, { "cell_type": "markdown", "id": "dbc4e45f", "metadata": {}, "source": [ "Reload queries and problog file, instantiate network and model and retrain." ] }, { "cell_type": "code", "execution_count": 19, "id": "98c6a192", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training for 1 epochs (17416 iterations).\n", "Accuracy 0.0\n", "Accuracy: [('Accuracy', 0.0)]\n", "Epoch 1\n", "Iteration: 100 \tAverage Loss: 2.7642198722784004\n", "Iteration: 200 \tAverage Loss: 2.801872244313671\n", "Iteration: 300 \tAverage Loss: 2.6762651530991595\n", "Iteration: 400 \tAverage Loss: 2.6976396559065567\n", "Iteration: 500 \tAverage Loss: 2.4950326703163945\n", "Iteration: 600 \tAverage Loss: 2.6426878719611553\n", "Iteration: 700 \tAverage Loss: 2.5003180111536354\n", "Iteration: 800 \tAverage Loss: 2.470023140425346\n", "Iteration: 900 \tAverage Loss: 2.3809421806027244\n", "Iteration: 1000 \tAverage Loss: 2.3171711316796855\n", "Accuracy 0.0\n", "Accuracy: [('Accuracy', 0.0)]\n", "Iteration: 1100 \tAverage Loss: 2.240560181079019\n", "Iteration: 1200 \tAverage Loss: 2.290888851917434\n", "Iteration: 1300 \tAverage Loss: 2.305876613045047\n", "Iteration: 1400 \tAverage Loss: 2.1712287488969992\n", "Iteration: 1500 \tAverage Loss: 2.2892492135890303\n", "Iteration: 1600 \tAverage Loss: 2.0477930040540913\n", "Iteration: 1700 \tAverage Loss: 2.1463180474731125\n", "Iteration: 1800 \tAverage Loss: 2.0614224536890107\n", "Iteration: 1900 \tAverage Loss: 1.8095064850261915\n", "Iteration: 2000 \tAverage Loss: 1.8911919073127195\n", "Accuracy 0.06\n", "Accuracy: [('Accuracy', 0.06)]\n", "Iteration: 2100 \tAverage Loss: 1.4773763290400703\n", "Iteration: 2200 \tAverage Loss: 1.4500693256767463\n", "Iteration: 2300 \tAverage Loss: 1.1228666945155312\n", "Iteration: 2400 \tAverage Loss: 1.2100257274026083\n", "Iteration: 2500 \tAverage Loss: 1.0694808288004232\n", "Iteration: 2600 \tAverage Loss: 0.7912047818138435\n", "Iteration: 2700 \tAverage Loss: 0.661085098942301\n", "Iteration: 2800 \tAverage Loss: 0.501763028022165\n", "Iteration: 2900 \tAverage Loss: 0.7016890604679269\n", "Iteration: 3000 \tAverage Loss: 0.4881873789130234\n", "Accuracy 0.56\n", "Accuracy: [('Accuracy', 0.56)]\n", "Iteration: 3100 \tAverage Loss: 0.7186009653241326\n", "Iteration: 3200 \tAverage Loss: 0.4578437913477562\n", "Iteration: 3300 \tAverage Loss: 0.4626975750480106\n", "Iteration: 3400 \tAverage Loss: 0.4547800428223713\n", "Iteration: 3500 \tAverage Loss: 0.5841922046960463\n", "Iteration: 3600 \tAverage Loss: 0.4948286437944828\n", "Iteration: 3700 \tAverage Loss: 0.5403618149431492\n", "Iteration: 3800 \tAverage Loss: 0.4935502692606493\n", "Iteration: 3900 \tAverage Loss: 0.3868899525287208\n", "Iteration: 4000 \tAverage Loss: 0.49382182856048007\n", "Accuracy 0.64\n", "Accuracy: [('Accuracy', 0.64)]\n", "Iteration: 4100 \tAverage Loss: 0.562454299655663\n", "Iteration: 4200 \tAverage Loss: 0.3669372467437896\n", "Iteration: 4300 \tAverage Loss: 0.49008016652484643\n", "Iteration: 4400 \tAverage Loss: 0.47610636170733317\n", "Iteration: 4500 \tAverage Loss: 0.4178939992562782\n", "Iteration: 4600 \tAverage Loss: 0.44782084012233475\n", "Iteration: 4700 \tAverage Loss: 0.385732653663784\n", "Iteration: 4800 \tAverage Loss: 0.3747322734059361\n", "Iteration: 4900 \tAverage Loss: 0.353573361042046\n", "Iteration: 5000 \tAverage Loss: 0.49929911591782494\n", "Accuracy 0.66\n", "Accuracy: [('Accuracy', 0.66)]\n", "Iteration: 5100 \tAverage Loss: 0.25476817850091965\n", "Interrupted!\n", "Epoch time: 3875.5561311244965\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_queries = load('tutorial/train.txt')\n", "test_queries = load('tutorial/test.txt')[:100]\n", "\n", "with open('tutorial/addition_subtraction.pl') as f:\n", " problog_string = f.read()\n", "\n", "network = MNIST_Net()\n", "# Network is a DeepProbLog class that wraps a pytorch networks and interfaces with ProbLog\n", "net = Network(network, 'mnist_net', neural_predicate)\n", "net.optimizer = torch.optim.Adam(network.parameters(), lr=0.001)\n", "# Model is a DeepProbLog class that combines reasoning via the ProbLog code \n", "# and neural processing via a list of Network objects\n", "model = Model(problog_string, [net], caching=False)\n", "optimizer = Optimizer(model, 2)\n", "\n", "train_model(model, train_queries, 1, optimizer, test_iter=1000, test=test, snapshot_iter=10000)" ] }, { "cell_type": "code", "execution_count": null, "id": "f174708c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 5 }