# DeepProbLog #

DeepProbLog is an extension of ProbLog that integrates Probabilistic Logic Programming with Deep Learning.

The git repo of the project can be downloaded by: 

git clone https://bitbucket.org/problog/deepproblog.git

## Example: MNIST Digit Addition ##

In this experiment, the task is to classify the sum of two lists of MNIST digits 
representing multi-digit numbers. 

First, we create a ProbLog file containing the logic part of the program. The file will be saved as `tutorial/multi_digit.pl`.

```prolog
nn(mnist_net,[X],Y,[0,1,2,3,4,5,6,7,8,9]) :: digit(X,Y).

number([],Result,Result).
number([H|T],Acc,Result) :- digit(H,Nr), 
 Acc2 is Nr+10*Acc,
 number(T,Acc2,Result).
number(X,Y) :- number(X,0,Y).

addition(X,Y,Z) :- number(X,X2), number(Y,Y2), Z is X2+Y2.
```

Then, we create the queries files for both train and test, connecting MNIST images to instances of the `addition` Prolog predicate. 

In [13]:
from torchvision.datasets import MNIST
import random 

datasets = {'train': MNIST(root='data/MNIST', train=True, download=True),
 'test': MNIST(root='data/MNIST', train=False, download=True)}

# dataset name is train or test
# op is a (lambda) function for the operation to be learned
# length is the number of digits to be used
# out is the output file name
def generate_examples(dataset_name, op, length, out):
 dataset = datasets[dataset_name]
 indices = list(range(len(dataset)))
 random.shuffle(indices)
 i = iter(indices)
 examples = []
 while (True):
 try:
 examples.append(next_example(i, dataset, op, length))
 # exception is raised when all digits in dataset have been used
 except StopIteration:
 break
 save_examples(dataset_name, examples, out)

def next_example(i, dataset, op, length):
 nr1, n1 = next_number(i, dataset, length)
 nr2, n2 = next_number(i, dataset, length)
 return nr1, nr2, op(n1, n2)

def next_number(i, dataset, nr_digits):
 n = 0
 nr = []
 for _ in range(nr_digits):
 x = next(i)
 _, c = dataset[x] # c is the digit that the image represents 
 n = n * 10 + c # the number is incrementally built from the sequence of its digits
 nr.append(str(x)) # nr is the list of ids of the digit images
 return nr, n

def save_examples(dataset_name, examples, out):
 with open(out, 'w') as f:
 for example in examples:
 # number encoded as e.g. (test(9150),test(6809),test(1586))
 args1 = tuple('{}({})'.format(dataset_name, e) for e in example[0])
 args2 = tuple('{}({})'.format(dataset_name, e) for e in example[1])
 # example encoded as e.g.
 # addition([test(9150),test(6809),test(1586)], [test(114),test(2039),test(5872)], 1574).
 f.write('addition([{}], [{}], {}).\n'.format(','.join(args1), 
 ','.join(args2), 
 example[2]))
 
generate_examples('train', lambda x, y: x + y, 1, 'tutorial/train.txt')
generate_examples('test', lambda x, y: x + y, 3, 'tutorial/test.txt')

Train and test queries look like this:

```
addition([train(2764)], [train(8527)], 9).
addition([train(27012)], [train(56713)], 10).
...

addition([test(9271),test(5812),test(9788)], [test(4522),test(8572),test(3555)], 1575).
addition([test(4052),test(7966),test(5512)], [test(4884),test(5655),test(133)], 1554).
...
```

We can now define a python class implementing a standard CNN for MNIST images, and a neural predicate connecting the image id (as found in the query) to the corresponding image and the sending it to the neural net. 

In [14]:
import torch
import torch.nn as nn
from torch.autograd import Variable

class MNIST_Net(nn.Module):
 def __init__(self, N=10):
 super(MNIST_Net, self).__init__()
 self.encoder = nn.Sequential(
 nn.Conv2d(1, 6, 5),
 nn.MaxPool2d(2, 2), # 6 24 24 -> 6 12 12
 nn.ReLU(True),
 nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8
 nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4
 nn.ReLU(True)
 )
 self.classifier = nn.Sequential(
 nn.Linear(16 * 4 * 4, 120),
 nn.ReLU(),
 nn.Linear(120, 84),
 nn.ReLU(),
 nn.Linear(84, N),
 nn.Softmax(1)
 )

 def forward(self, x):
 x = self.encoder(x)
 x = x.view(-1, 16 * 4 * 4)
 x = self.classifier(x)
 return x
 
def neural_predicate(network, i):
 # i is something like train(2764) or test(4052)
 dataset = str(i.functor)
 i = int(i.args[0])
 if dataset == 'train':
 d, l = mnist_train_data[i]
 elif dataset == 'test':
 d, l = mnist_test_data[i]
 d = Variable(d.unsqueeze(0))
 output = network.net(d)
 return output.squeeze(0)

We can now load the mnist data, the queries and the problog file.
Note that `mnist_train_data` and `mnist_test_data` are global variables used inside `neural_predicate`.

In [15]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from data_loader import load

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5, ))])
mnist_train_data = MNIST(root='data/MNIST', train=True, download=True,transform=transform)
mnist_test_data = MNIST(root='data/MNIST', train=False, download=True,transform=transform)

train_queries = load('tutorial/train.txt')
test_queries = load('tutorial/test.txt')[:100]

with open('tutorial/multi_digit.pl') as f:
 problog_string = f.read()

Finally, we can create the network and the DeepProbLog model with the network as neural predicate, and train it as a standard torch model.

In [16]:
from train import train_model
from network import Network
from model import Model
from optimizer import Optimizer

def test(model):
 acc = model.accuracy(test_queries, test=True)
 print('Accuracy: ', acc)
 return [('accuracy', acc)]

network = MNIST_Net()
# Network is a DeepProbLog class that wraps a pytorch networks and interfaces with ProbLog
net = Network(network, 'mnist_net', neural_predicate)
net.optimizer = torch.optim.Adam(network.parameters(), lr=0.001)
# Model is a DeepProbLog class that combines reasoning via the ProbLog code 
# and neural processing via a list of Network objects
model = Model(problog_string, [net], caching=False)
optimizer = Optimizer(model, 2)

train_model(model, train_queries, 1, optimizer, test_iter=1000, test=test, snapshot_iter=10000)

Training for 1 epochs (30000 iterations).




Accuracy 0.0
Accuracy: [('Accuracy', 0.0)]
Epoch 1
Iteration: 100 	Average Loss: 2.7413997681308944
Iteration: 200 	Average Loss: 2.822181631299884
Iteration: 300 	Average Loss: 2.7558953509461053
Iteration: 400 	Average Loss: 2.641749291159689
Iteration: 500 	Average Loss: 2.4830741851766236
Iteration: 600 	Average Loss: 2.327804147989043
Iteration: 700 	Average Loss: 2.1164256080170687
Iteration: 800 	Average Loss: 2.234518181674348
Iteration: 900 	Average Loss: 2.0504583209665026
Iteration: 1000 	Average Loss: 1.7014140593617073
Accuracy 0.15
Accuracy: [('Accuracy', 0.15)]
Iteration: 1100 	Average Loss: 1.4452080341853497
Iteration: 1200 	Average Loss: 1.4222283941174698
Iteration: 1300 	Average Loss: 0.8295922330291584
Iteration: 1400 	Average Loss: 1.120169151657204
Iteration: 1500 	Average Loss: 0.5691411603139906
Iteration: 1600 	Average Loss: 0.7857685772432913
Iteration: 1700 	Average Loss: 0.8248206990974919
Iteration: 1800 	Average Loss: 0.5987613001722083
Iteration: 1900 	A

Iteration: 14300 	Average Loss: 0.1883522669075476
Iteration: 14400 	Average Loss: 0.23984416674510384
Iteration: 14500 	Average Loss: 0.27934753153450936
Iteration: 14600 	Average Loss: 0.14659613511216343
Iteration: 14700 	Average Loss: 0.09915756395264293
Iteration: 14800 	Average Loss: 0.21752628844546773
Iteration: 14900 	Average Loss: 0.16969644666065645
Iteration: 15000 	Average Loss: 0.17419191548977137
Accuracy 0.86
Accuracy: [('Accuracy', 0.86)]
Iteration: 15100 	Average Loss: 0.132667199302011
Iteration: 15200 	Average Loss: 0.06084775712248787
Iteration: 15300 	Average Loss: 0.28009641435711685
Iteration: 15400 	Average Loss: 0.22496273861193522
Iteration: 15500 	Average Loss: 0.17071578126496775
Iteration: 15600 	Average Loss: 0.12306111240204767
Iteration: 15700 	Average Loss: 0.13521139393961223
Iteration: 15800 	Average Loss: 0.2646886639959432
Iteration: 15900 	Average Loss: 0.06582526942718006
Iteration: 16000 	Average Loss: 0.4805922915421103
Accuracy 0.87
Accuracy: 

Iteration: 28300 	Average Loss: 0.13782404585076444
Iteration: 28400 	Average Loss: 0.08818393756758906
Iteration: 28500 	Average Loss: 0.1567092426496749
Iteration: 28600 	Average Loss: 0.11117896665074617
Iteration: 28700 	Average Loss: 0.18585810289019405
Iteration: 28800 	Average Loss: 0.1259237184502742
Iteration: 28900 	Average Loss: 0.11836703749145286
Iteration: 29000 	Average Loss: 0.031844159575524264
Accuracy 0.84
Accuracy: [('Accuracy', 0.84)]
Iteration: 29100 	Average Loss: 0.1102448297319265
Iteration: 29200 	Average Loss: 0.0699499300865066
Iteration: 29300 	Average Loss: 0.0830309191303204
Iteration: 29400 	Average Loss: 0.34941493655585887
Iteration: 29500 	Average Loss: 0.06943456046559118
Iteration: 29600 	Average Loss: 0.08473602243635145
Iteration: 29700 	Average Loss: 0.09539665225016306
Iteration: 29800 	Average Loss: 0.042837079973181647
Iteration: 29900 	Average Loss: 0.18574071140404833
Writing snapshot to model_iter_30000.mdl
Iteration: 30000 	Average Loss: 0



## Example: a + b - c = d 

Let's change the code to answer queries of the type `a + b - c = ?`

First, we create a ProbLog file containing the logic part of the program. The file will be saved as `tutorial/addition_subtraction.pl`.

```prolog
nn(mnist_net,[X],Y,[0,1,2,3,4,5,6,7,8,9]) :: digit(X,Y).

number([],Result,Result).
number([H|T],Acc,Result) :- digit(H,Nr), 
 Acc2 is Nr+10*Acc,
 number(T,Acc2,Result).
number(X,Y) :- number(X,0,Y).

addition(X,Y,Z) :- number(X,N1), number(Y,N2), Z is N1+N2.
 
addition_subtraction(A,B,C,D) :- addition(A,B,N1), number(C,N2), D is N1-N2. 
```

Then, we create the queries files for both train and test, connecting MNIST images to instances of the `addition_subtraction` Prolog predicate.

In [17]:
def next_example(i, dataset, op, length):
 nr1, n1 = next_number(i, dataset, length)
 nr2, n2 = next_number(i, dataset, length)
 res = -1
 # make sure the result is non-negative
 while res < 0:
 nr3, n3 = next_number(i, dataset, length)
 res = op(n1, n2, n3)
 return nr1, nr2, nr3, res 

def save_examples(dataset_name, examples, out):
 with open(out, 'w') as f:
 for example in examples:
 args1 = tuple('{}({})'.format(dataset_name, e) for e in example[0])
 args2 = tuple('{}({})'.format(dataset_name, e) for e in example[1])
 args3 = tuple('{}({})'.format(dataset_name, e) for e in example[2])
 f.write('addition_subtraction([{}], [{}], [{}], {}).\n'.format(','.join(args1), 
 ','.join(args2), 
 ','.join(args3), 
 example[3]))
 
generate_examples('train', lambda x, y, z: x + y - z, 1, 'tutorial/train.txt')
generate_examples('test', lambda x, y, z: x + y - z, 3, 'tutorial/test.txt')

Reload queries and problog file, instantiate network and model and retrain.

In [19]:
train_queries = load('tutorial/train.txt')
test_queries = load('tutorial/test.txt')[:100]

with open('tutorial/addition_subtraction.pl') as f:
 problog_string = f.read()

network = MNIST_Net()
# Network is a DeepProbLog class that wraps a pytorch networks and interfaces with ProbLog
net = Network(network, 'mnist_net', neural_predicate)
net.optimizer = torch.optim.Adam(network.parameters(), lr=0.001)
# Model is a DeepProbLog class that combines reasoning via the ProbLog code 
# and neural processing via a list of Network objects
model = Model(problog_string, [net], caching=False)
optimizer = Optimizer(model, 2)

train_model(model, train_queries, 1, optimizer, test_iter=1000, test=test, snapshot_iter=10000)

Training for 1 epochs (17416 iterations).
Accuracy 0.0
Accuracy: [('Accuracy', 0.0)]
Epoch 1
Iteration: 100 	Average Loss: 2.7642198722784004
Iteration: 200 	Average Loss: 2.801872244313671
Iteration: 300 	Average Loss: 2.6762651530991595
Iteration: 400 	Average Loss: 2.6976396559065567
Iteration: 500 	Average Loss: 2.4950326703163945
Iteration: 600 	Average Loss: 2.6426878719611553
Iteration: 700 	Average Loss: 2.5003180111536354
Iteration: 800 	Average Loss: 2.470023140425346
Iteration: 900 	Average Loss: 2.3809421806027244
Iteration: 1000 	Average Loss: 2.3171711316796855
Accuracy 0.0
Accuracy: [('Accuracy', 0.0)]
Iteration: 1100 	Average Loss: 2.240560181079019
Iteration: 1200 	Average Loss: 2.290888851917434
Iteration: 1300 	Average Loss: 2.305876613045047
Iteration: 1400 	Average Loss: 2.1712287488969992
Iteration: 1500 	Average Loss: 2.2892492135890303
Iteration: 1600 	Average Loss: 2.0477930040540913
Iteration: 1700 	Average Loss: 2.1463180474731125
Iteration: 1800 	Average Los

