InteractE: Improving Convolution-based Knowledge Graph Embeddings by Increasing Feature Interactions

Thesis and code reading

Research questions

Based on ConvE, interact is proposed to further enhance the interaction between relationship embedding and entity embedding to improve the effect of link prediction task

Background motivation

  • ConvE enhances interactivity by converting one-dimensional entity and relationship embedding into two-dimensional and splicing as the input of convolution operation, which has a certain effect on improving performance, but the interactivity of this simple method is not enough.

Symbol definition

  • It is assumed that entity embedding and relationship embedding are expressed as e s = ( a 1 , ... , a d ) , e r = ( b 1 , ... , b d ) \boldsymbol{e}_{s}=\left(a_{1}, \ldots, a_{d}\right), \boldsymbol{e}_{r}=\left(b_{1}, \ldots, b_{d}\right) es = (a1,..., ad), er = (b1,..., bd), and the convolution kernel is expressed as w ∈ R k × k \boldsymbol{w} \in \mathbb{R}^{k \times k} w∈Rk×k
  • For matrix M k ∈ R k × k M_{k} \in \mathbb{R}^{k \times k} Mk​∈Rk×k, N ∈ R m × n N \in \mathbb{R}^{m \times n} N∈Rm × n. If both are satisfied M k = N i : i + k , j : j + k M_{k}=N_{i: i+k, j: j+k} Mk = Ni:i+k,j:j+k, then the former is called the k-submatrix of the latter M k ⊆ N M_{k} \subseteq N Mk​⊆N
  • Shaping function ϕ : R d × R d → R m × n \phi: \mathbb{R}^{d} \times \mathbb{R}^{d} \rightarrow \mathbb{R}^{m \times n} ϕ: Rd × Rd→Rm × n convert the input entity and relationship embedding into a matrix ϕ ( e s , e r ) \phi\left(\boldsymbol{e}_{\boldsymbol{s}}, \boldsymbol{e}_{r}\right) ϕ (es, er) and meet m × n = 2 d m \times n=2 d m × n=2d, the paper defines the following three shaping functions

  • Stack: simple stack, remember to do ϕ s t k \phi_{s t k} ϕstk​
  • Alternate: interleave by line, remember ϕ a l t τ \phi_{a l t}^{\tau} ϕaltτ​
  • Chequer: staggered by elements, remember ϕ c h k \phi_{c h k} ϕchk​
  • Interaction and interaction number: an interaction is defined as a triple ( x , y , M k ) \left(x, y, M_{k}\right) (x,y,Mk), where M k ⊆ ϕ ( e s , e r ) M_{k} \subseteq \phi\left(e_{s}, e_{r}\right) Mk​⊆ϕ(es​,er​), x , y ∈ M k x, y \in M_{k} x. Y ∈ Mk, number of interactions N ( ϕ , k ) \mathcal{N}(\phi, k) N( ϕ, k) Defined as the number of all possible triples if x and y are from e s , e r \boldsymbol{e}_{\boldsymbol{s}}, \boldsymbol{e}_{r} es, er, the interaction is called heterogeneous interaction, otherwise it is called homogeneous interaction, and the number of heterogeneous and homogeneous interactions is defined as N h e t ( ϕ , k ) \mathcal{N}_{h e t}(\phi, k) Nhet​( ϕ, k) And N homo ⁡ ( ϕ , k ) \mathcal{N}_{\operatorname{homo}}(\phi, k) Nhomo​( ϕ, k) , both meet N het  ( ϕ , k ) + N homo  ( ϕ , k ) = 2 ( k 2 2 ) \mathcal{N}_{\text {het }}(\phi, k)+\mathcal{N}_{\text {homo }}(\phi, k)=2\left(\begin{array}{c}k^{2} \\2\end{array}\right) Nhet   ​( ϕ, k)+Nhomo   ​( ϕ, k)=2(k22​). For example, for a 3 × 3 3 \times 3 three × Matrix of 3 M 3 M_{3} M3, if there are 5 elements from e s \boldsymbol{e}_{\boldsymbol{s}} es, 4 elements from e r \boldsymbol{e}_{r} er, then N het  = 2 ( 5 × 4 ) = 40 \mathcal{N}_{\text {het }}=2(5 \times 4)=40 Nhet ​=2(5×4)=40, N homo  = 2 [ ( 5 2 ) + ( 4 2 ) ] = 32 .  \mathcal{N}_{\text {homo }}=2\left[\left(\begin{array}{l}5 \\2\end{array}\right)+\left(\begin{array}{l}4 \\2\end{array}\right)\right]=32 \text {. } Nhomo ​=2[(52​)+(42​)]=32. 

model structure

  • Overall framework

For the input entity and relationship embedding, interact first generates a variety of possible random permutations, then converts them into a two-dimensional matrix using the Chequer method mentioned above, and then extracts the features using cyclic convolution and finally projects them into the embedding space.

  • Embedded rearrangement

Interact embeds and rearranges entities and relationships t times, and records the results P t = [ ( e s 1 , e r 1 ) ; ... ; ( e s t , e r t ) ] \mathcal{P}_{t}=\left[\left(\boldsymbol{e}_{s}^{1}, \boldsymbol{e}_{r}^{1}\right) ; \ldots ;\left(\boldsymbol{e}_{s}^{t}, \boldsymbol{e}_{r}^{t}\right)\right] Pt = [(es1, er1);...; (est, ert)], considering the great possibility of random rearrangement, it can be considered that these sequences will not intersect with each other, that is, the number of possible interactions after rearrangement will become t times of the original.

  • Chequer operation

The shaping function is used in ConvE ϕ s t k \phi_{s t k} ϕ stk replace with ϕ c h k \phi_{c h k} ϕ chk to maximize heterogeneous interactions.

  • Cyclic convolution

Compared with standard convolution, cyclic convolution no longer padds the input with 0, but uses edge elements

The input matrices are as follows, and the order needs to be reversed

The formula is expressed as [ I ⋆ w ] p , q = ∑ i = − ⌊ k / 2 ⌋ ∑ j = − ⌊ k / 2 ⌋ ⌊ k / 2 ⌋ I [ p − i ] m , [ q − j ] n w i , j [\boldsymbol{I} \star \boldsymbol{w}]_{p, q}=\sum_{i=-\lfloor k / 2\rfloor} \sum_{j=-\lfloor k / 2\rfloor}^{\lfloor k / 2\rfloor} \boldsymbol{I}_{[p-i]_{m},[q-j]_{n}} \boldsymbol{w}_{i, j} [I⋆w]p,q​=i=−⌊k/2⌋∑​j=−⌊k/2⌋∑⌊k/2⌋​I[p−i]m​,[q−j]n​​wi,j​

It is mentioned in the paper that different channel s can get better results by sharing convolution kernels

  • Score function

ψ ( s , r , o ) = g ( vec ⁡ ( f ( ϕ ( P k ) ⊗ w ) ) W ) e o \psi(s, r, o)=g\left(\operatorname{vec}\left(f\left(\phi\left(\mathcal{P}_{k}\right) \otimes \boldsymbol{w}\right)\right) \boldsymbol{W}\right) \boldsymbol{e}_{o} ψ(s,r,o)=g(vec(f(ϕ(Pk​)⊗w))W)eo​

f and g represent relu and sigmod functions respectively. The standard binary cross entropy is used as the loss function and label smoothing is used in training

Experimental part

  • Comparison of link prediction results

  • Comparison of different convolution and shaping functions

  • Comparison of characteristic rearrangement number

The number of rearrangements cannot be too many. The effect begins to decline after more than two times

  • Comparison of different relationship types

It can be seen that the model has improved significantly in the complex relationship of N-N

Code reading

Sampling strategy

1_to_n VS 1_to_x

interact.py

if self.p.train_strategy == 'one_to_n':
    for (sub, rel), obj in self.sr2o.items():
        self.triples['train'].append({
            'triple': (sub, rel, -1),
            'label': self.sr2o[(sub, rel)],
            'sub_samp': 1
        })
else:
    for sub, rel, obj in self.data['train']:
        rel_inv = rel + self.p.num_rel
        sub_samp = len(self.sr2o[(sub, rel)]) + len(self.sr2o[(obj, rel_inv)])
        sub_samp = np.sqrt(1 / sub_samp)

        self.triples['train'].append({
            'triple': (sub, rel, obj),
            'label': self.sr2o[(sub, rel)],
            'sub_samp': sub_samp
        })
        self.triples['train'].append({
            'triple': (obj, rel_inv, sub),
            'label': self.sr2o[(obj, rel_inv)],
            'sub_samp': sub_samp
        })

data_loader.py

class TrainDataset(Dataset):
	"""
	Training Dataset class.

	Parameters
	----------
	triples:	The triples used for training the model
	params:		Parameters for the experiments
	
	Returns
	-------
	A training Dataset class instance used by DataLoader
	"""
	def __init__(self, triples, params):
		self.triples	= triples
		self.p 		= params
		self.strategy	= self.p.train_strategy
		self.entities	= np.arange(self.p.num_ent, dtype=np.int32)

	def __len__(self):
		return len(self.triples)

	def __getitem__(self, idx):
		ele			= self.triples[idx]
		triple, label, sub_samp	= torch.LongTensor(ele['triple']), np.int32(ele['label']), np.float32(ele['sub_samp'])
		trp_label		= self.get_label(label)

		if self.p.lbl_smooth != 0.0:
			trp_label = (1.0 - self.p.lbl_smooth)*trp_label + (1.0/self.p.num_ent)

		if self.strategy == 'one_to_n':
			return triple, trp_label, None, None

		elif self.strategy == 'one_to_x':
			sub_samp		= torch.FloatTensor([sub_samp])
			neg_ent			= torch.LongTensor(self.get_neg_ent(triple, label))
			return triple, trp_label, neg_ent, sub_samp
		else: 
			raise NotImplementedError


	@staticmethod
	def collate_fn(data):
		triple		= torch.stack([_[0] 	for _ in data], dim=0)
		trp_label	= torch.stack([_[1] 	for _ in data], dim=0)

		if not data[0][2] is None:							# one_to_x
			neg_ent		= torch.stack([_[2] 	for _ in data], dim=0)
			sub_samp	= torch.cat([_[3] 	for _ in data], dim=0)
			return triple, trp_label, neg_ent, sub_samp
		else:
			return triple, trp_label
	
	def get_neg_ent(self, triple, label):
		def get(triple, label):
			if self.strategy == 'one_to_x':
				pos_obj		= triple[2]
				mask		= np.ones([self.p.num_ent], dtype=np.bool)
				mask[label]	= 0
				neg_ent		= np.int32(np.random.choice(self.entities[mask], self.p.neg_num, replace=False)).reshape([-1])
				neg_ent		= np.concatenate((pos_obj.reshape([-1]), neg_ent))
			else:
				pos_obj		= label
				mask		= np.ones([self.p.num_ent], dtype=np.bool)
				mask[label]	= 0
				neg_ent		= np.int32(np.random.choice(self.entities[mask], self.p.neg_num - len(label), replace=False)).reshape([-1])
				neg_ent		= np.concatenate((pos_obj.reshape([-1]), neg_ent))

				if len(neg_ent) > self.p.neg_num:
					import pdb; pdb.set_trace()
					
			return neg_ent

		neg_ent = get(triple, label)
		return neg_ent

	def get_label(self, label):
		if self.strategy == 'one_to_n':
			y = np.zeros([self.p.num_ent], dtype=np.float32)
			for e2 in label: y[e2] = 1.0
		elif self.strategy == 'one_to_x':
			y = [1] + [0] * self.p.neg_num
		else: 
			raise NotImplementedError
		return torch.FloatTensor(y)

Embedded reorganization

There are two main steps here. The first step is to rearrange the embedding t times. The second step is to reorganize the arranged results into a chessboard. What is obtained here is not the element, but the index of the corresponding element

interact.py

# The embedded dimension is equal to matrix width * matrix height
self.p.embed_dim = self.p.k_w * self.p.k_h if self.p.embed_dim is None else self.p.embed_dim
def get_chequer_perm(self):
	# Transform the one-dimensional perm arrangement of T entities and relationships into t k_h × (2 × Two dimensional matrix of k_w)
	# For each perm, whether it becomes a chessboard from entity or relationship also changes alternately, that is, the judgment of k in the code
    """
        Function to generate the chequer permutation required for InteractE model

        Parameters
        ----------
        
        Returns
        -------
        
        """
    ent_perm = np.int32(
        [np.random.permutation(self.p.embed_dim) for _ in range(self.p.perm)])
    rel_perm = np.int32(
        [np.random.permutation(self.p.embed_dim) for _ in range(self.p.perm)])

    comb_idx = []
    for k in range(self.p.perm):
        temp = []
        ent_idx, rel_idx = 0, 0
        for i in range(self.p.k_h):
            for j in range(self.p.k_w):
                if k % 2 == 0:
                # The first line: entity relationship entity..., the second line relationship entity relationship
                    if i % 2 == 0:
                        temp.append(ent_perm[k, ent_idx])
                        ent_idx += 1
                        temp.append(rel_perm[k, rel_idx] + self.p.embed_dim)
                        rel_idx += 1
                    else:
                        temp.append(rel_perm[k, rel_idx] + self.p.embed_dim)
                        rel_idx += 1
                        temp.append(ent_perm[k, ent_idx])
                        ent_idx += 1
                else:
                # The first line: relationship entity relationship..., the second line entity relationship entity
                    if i % 2 == 0:
                        temp.append(rel_perm[k, rel_idx] + self.p.embed_dim)
                        rel_idx += 1
                        temp.append(ent_perm[k, ent_idx])
                        ent_idx += 1
                    else:
                        temp.append(ent_perm[k, ent_idx])
                        ent_idx += 1
                        temp.append(rel_perm[k, rel_idx] + self.p.embed_dim)
                        rel_idx += 1

        comb_idx.append(temp)

    chequer_perm = torch.LongTensor(np.int32(comb_idx)).to(self.device)
    return chequer_perm
# Model definition passed in
model = InteractE(self.p, self.chequer_perm)

Model definition

Note that the loss function is binary cross entropy loss, and the implementation of cyclic convolution is to define a cyclic padding

# PyTorch related imports
import torch
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_normal_, xavier_uniform_
from torch.nn import Parameter as Param
from torch.utils.data import DataLoader

class InteractE(torch.nn.Module):
    """
	Proposed method in the paper. Refer Section 6 of the paper for mode details 

	Parameters
	----------
	params:        	Hyperparameters of the model
	chequer_perm:   Reshaping to be used by the model
	
	Returns
	-------
	The InteractE model instance
		
	"""
    def __init__(self, params, chequer_perm):
        super(InteractE, self).__init__()

        self.p = params
        self.ent_embed = torch.nn.Embedding(self.p.num_ent,
                                            self.p.embed_dim,
                                            padding_idx=None)
        xavier_normal_(self.ent_embed.weight)
        self.rel_embed = torch.nn.Embedding(self.p.num_rel * 2,
                                            self.p.embed_dim,
                                            padding_idx=None)
        xavier_normal_(self.rel_embed.weight)
        self.bceloss = torch.nn.BCELoss()

        self.inp_drop = torch.nn.Dropout(self.p.inp_drop)
        self.hidden_drop = torch.nn.Dropout(self.p.hid_drop)
        self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop)
        self.bn0 = torch.nn.BatchNorm2d(self.p.perm)

        flat_sz_h = self.p.k_h
        flat_sz_w = 2 * self.p.k_w
        self.padding = 0

        self.bn1 = torch.nn.BatchNorm2d(self.p.num_filt * self.p.perm)
        self.flat_sz = flat_sz_h * flat_sz_w * self.p.num_filt * self.p.perm

        self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim)
        self.fc = torch.nn.Linear(self.flat_sz, self.p.embed_dim)
        self.chequer_perm = chequer_perm

        self.register_parameter('bias', Parameter(torch.zeros(self.p.num_ent)))
        # Note that the convolution kernel here is manually defined and shared across perm
        self.register_parameter(
            'conv_filt',
            Parameter(
            # kernel_size is 9
                torch.zeros(self.p.num_filt, 1, self.p.ker_sz, self.p.ker_sz)))
        xavier_normal_(self.conv_filt)

    def loss(self, pred, true_label=None, sub_samp=None):
        label_pos = true_label[0]
        label_neg = true_label[1:]
        loss = self.bceloss(pred, true_label)
        return loss

    def circular_padding_chw(self, batch, padding):
        upper_pad = batch[..., -padding:, :]
        lower_pad = batch[..., :padding, :]
        temp = torch.cat([upper_pad, batch, lower_pad], dim=2)

        left_pad = temp[..., -padding:]
        right_pad = temp[..., :padding]
        padded = torch.cat([left_pad, temp, right_pad], dim=3)
        return padded

    def forward(self, sub, rel, neg_ents, strategy='one_to_x'):
    	# Suppose perm is 3, batch_ The size is 128
    	# torch.Size([128, 200])
        sub_emb = self.ent_embed(sub)
        rel_emb = self.rel_embed(rel)
        # torch.Size([128, 400])
        comb_emb = torch.cat([sub_emb, rel_emb], dim=1)
        # torch.Size([128, 3, 400])
        chequer_perm = comb_emb[:, self.chequer_perm]
        # torch.Size([128, 3, 20, 20])
        stack_inp = chequer_perm.reshape(
            (-1, self.p.perm, 2 * self.p.k_w, self.p.k_h))
        stack_inp = self.bn0(stack_inp)
        x = self.inp_drop(stack_inp)
        x = self.circular_padding_chw(x, self.p.ker_sz // 2)
        x = F.conv2d(x,
                     self.conv_filt.repeat(self.p.perm, 1, 1, 1),
                     padding=self.padding,
                     groups=self.p.perm)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(-1, self.flat_sz)
        x = self.fc(x)
        x = self.hidden_drop(x)
        x = self.bn2(x)
        x = F.relu(x)

        if strategy == 'one_to_n':
            x = torch.mm(x, self.ent_embed.weight.transpose(1, 0))
            x += self.bias.expand_as(x)
        else:
            x = torch.mul(x.unsqueeze(1), self.ent_embed(neg_ents)).sum(dim=-1)
            x += self.bias[neg_ents]

        pred = torch.sigmoid(x)

        return pred

Posted on Sat, 16 Oct 2021 04:52:36 -0400 by stueee