Module circuitgraph.locks

Expand source code
import code
import circuitgraph as cg
from circuitgraph.transform import miter,syn
from circuitgraph.sat import sat
from circuitgraph.logic import *
from random import sample,choice,randint
from pysat.solvers import Cadical
from pysat.formula import IDPool
from pysat.card import *

def xorLock(c,k):
        # create copy to lock
        cl = c.copy()

        # randomly select gates
        gates = sample(cl.nodes()-cl.outputs(),k)

        # insert key gates
        key = {}
        for i,gate in enumerate(gates):
                # select random key value
                key[f'key_{i}'] = choice([True,False])

                # create xor/xnor,input
                gate_type = 'xnor' if key[f'key_{i}'] else 'xor'
                fanout = cl.fanout(gate)
                cl.disconnect(gate,fanout)
                cl.add(f'key_gate_{i}',gate_type,fanin=gate,fanout=fanout)
                cl.add(f'key_{i}','input',fanout=f'key_gate_{i}')

        return cl, key


def muxLock(c,k):
        # create copy to lock
        cl = c.copy()

        # get 2:1 mux
        m = mux(2).stripIO()

        # randomly select gates
        gates = sample(cl.nodes()-cl.outputs(),k)
        decoyGates = sample(cl.nodes()-cl.outputs(),k)

        # insert key gates
        key = {}
        for i,(gate,decoyGate) in enumerate(zip(gates,decoyGates)):
                # select random key value
                key[f'key_{i}'] = choice([True,False])

                # create and connect mux
                fanout = cl.fanout(gate)
                cl.disconnect(gate,fanout)
                cl.extend(m.relabel({n:f'mux_{i}_{n}' for n in m.nodes()}))
                cl.connect(f'mux_{i}_out',fanout)
                cl.add(f'key_{i}','input',fanout=f'mux_{i}_sel_0')
                if key[f'key_{i}']:
                        cl.connect(gate,f'mux_{i}_in_1')
                        cl.connect(decoyGate,f'mux_{i}_in_0')
                else:
                        cl.connect(gate,f'mux_{i}_in_0')
                        cl.connect(decoyGate,f'mux_{i}_in_1')

        return cl, key


def lutLock(c,n,w):
        # create copy to lock
        cl = c.copy()

        # parse mux
        m = mux(2**w).stripIO()

        # randomly select gates
        potentialGates = set(g for g in cl.nodes()-cl.io() if len(c.fanin(g))<=w)
        gates = sample(potentialGates,n)
        potentialGates -= set(gates)

        # insert key gates
        key = {}
        for i,gate in enumerate(gates):

                fanout = cl.fanout(gate)
                fanin = list(cl.fanin(gate))
                padding = sample(potentialGates-cl.fanin(gate),w-len(fanin))

                # create and connect LUT
                cl.extend(m.relabel({n:f'lut_{i}_{n}' for n in m.nodes()}))
                cl.connect(f'lut_{i}_out',fanout)

                # connect sel
                for j,f in enumerate(fanin+padding):
                        cl.connect(f,f'lut_{i}_sel_{j}')

                # connect keys
                for j,vs in enumerate(product([False,True],repeat=len(fanin+padding))):
                        assumptions = {s:v for s,v in zip(fanin+padding,vs[::-1]) if s in fanin}
                        cl.add(f'key_{i*2**w+j}','input',fanout=f'lut_{i}_in_{j}')
                        key[f'key_{i*2**w+j}'] = sat(c,assumptions)[gate]

                # delete gate
                cl.remove(gate)
                cl = cl.relabel({f'lut_{i}_out':gate})

        return cl, key


def sfllHD(c,w,hd):
        # create copy to lock
        cl = c.copy()

        # parse popcount
        p = popcount(w)
        lcomp = len(p.outputs())
        p = p.stripIO()

        # find output with large enough fanin
        potential_outs = [o for o in cl.outputs() if len(cl.startpoints(o))>=w]
        if not potential_outs:
                print('input with too small')
                return None
        out = sample(potential_outs,1)[0]
        out_driver = cl.fanin(out).pop()

        # create key
        key = {f'key_{i}':choice([True,False]) for i in range(w)}

        # instantiate and connect hd circuits
        cl.extend(p.relabel({n:f'flip_pop_{n}' for n in p.nodes()}))
        cl.extend(p.relabel({n:f'restore_pop_{n}' for n in p.nodes()}))

        # connect inputs
        for i,inp in enumerate(sample(cl.startpoints(out),w)):
                cl.add(f'key_{i}','input')
                cl.add(f'hardcoded_key_{i}','1' if key[f'key_{i}'] else '0')
                cl.add(f'restore_xor_{i}','xor',fanin=[f'key_{i}',inp])
                cl.add(f'flip_xor_{i}','xor',fanin=[f'hardcoded_key_{i}',inp])
                cl.connect(f'flip_xor_{i}',f'flip_pop_in_{i}')
                cl.connect(f'restore_xor_{i}',f'restore_pop_in_{i}')

        # connect outputs
        cl.add(f'flip_out','and')
        cl.add(f'restore_out','and')
        for i,v in enumerate(format(hd, f'0{cg.clog2(w)+1}b')[::-1]):
                cl.add(f'hd_{i}',v)
                cl.add(f'restore_out_xor_{i}','xor',fanin=[f'hd_{i}',f'restore_pop_out_{i}'],fanout='restore_out')
                cl.add(f'flip_out_xor_{i}','xor',fanin=[f'hd_{i}',f'flip_pop_out_{i}'],fanout='flip_out')

        # flip output
        cl.disconnect(out_driver,out)
        cl.add('out_xor','xor',fanin=['restore_out','flip_out',out_driver],fanout=out)

        return cl, key


def ttLock(c,w):
        # create copy to lock
        cl = c.copy()

        # find output with large enough fanin
        potential_outs = [o for o in cl.outputs() if len(cl.startpoints(o))>=w]
        if not potential_outs:
                print('input with too small')
                return None
        out = sample(potential_outs,1)[0]
        out_driver = cl.fanin(out).pop()

        # create key
        key = {f'key_{i}':choice([True,False]) for i in range(w)}

        # connect comparators
        cl.add(f'flip_out','and')
        cl.add(f'restore_out','and')
        for i,inp in enumerate(sample(cl.startpoints(out),w)):
                cl.add(f'key_{i}','input')
                cl.add(f'hardcoded_key_{i}','1' if key[f'key_{i}'] else '0')
                cl.add(f'restore_xor_{i}','xor',fanin=[f'key_{i}',inp],fanout='restore_out')
                cl.add(f'flip_xor_{i}','xor',fanin=[f'hardcoded_key_{i}',inp],fanout='flip_out')

        # flip output
        cl.disconnect(out_driver,out)
        cl.add('out_xor','xor',fanin=['restore_out','flip_out',out_driver],fanout=out)

        return cl,key


def sfllFlex(c,w,n):
        # create copy to lock
        cl = c.copy()

        # find output with large enough fanin
        potential_outs = [o for o in cl.outputs() if len(cl.startpoints(o))>=w]
        if not potential_outs:
                print('input with too small')
                return None
        out = sample(potential_outs,1)[0]
        out_driver = cl.fanin(out).pop()

        # create key
        key = {f'key_{i}':choice([True,False]) for i in range(w*n)}

        # connect comparators
        cl.add(f'flip_out','or')
        cl.add(f'restore_out','or')

        for j in range(n):
                cl.add(f'flip_and_{j}','and',fanout=f'flip_out')
                cl.add(f'restore_and_{j}','and',fanout=f'restore_out')

        for i,inp in enumerate(sample(cl.startpoints(out),w)):
                for j in range(n):
                        cl.add(f'key_{i+j*w}','input')
                        cl.add(f'hardcoded_key_{i}_{j}','1' if key[f'key_{i+j*w}'] else '0')
                        cl.add(f'restore_xor_{i}_{j}','xor',fanin=[f'key_{i+j*w}',inp],fanout=f'restore_and_{j}')
                        cl.add(f'flip_xor_{i}_{j}','xor',fanin=[f'hardcoded_key_{i}_{j}',inp],fanout=f'flip_and_{j}')

        # flip output
        cl.disconnect(out_driver,out)
        cl.add('out_xor','xor',fanin=['restore_out','flip_out',out_driver],fanout=out)

        return cl,key

def ttLockSen(params):
        pass


def switch():
        m = mux(2).stripIO()
        s = Circuit(name='switch')
        s.extend(m.relabel({n:f'm0_{n}' for n in m.nodes()}))
        s.extend(m.relabel({n:f'm1_{n}' for n in m.nodes()}))
        s.add('in_0','buf',fanout=['m0_in_0','m1_in_1'])
        s.add('in_1','buf',fanout=['m0_in_1','m1_in_0'])
        s.add('out_0','xor',fanin='m0_out')
        s.add('out_1','xor',fanin='m1_out')
        s.add('key_0','input',fanout=['m0_sel_0','m1_sel_0'])
        s.add('key_1','input',fanout='out_0')
        s.add('key_2','input',fanout='out_1')
        return s

def connectBanyan(cl,swb_ins,swb_outs,bw):
        I = int(2*clog2(bw)-2)
        J = int(bw/2)
        for i in range(clog2(J)):
                r = J/(2**i)
                for j in range(J):
                        t = (j%r)>=(r/2)
                        # straight
                        out_i = int((i*bw)+(2*j)+t)
                        in_i = int((i*bw+bw)+(2*j)+t)
                        cl.connect(swb_outs[out_i],swb_ins[in_i])

                        # cross
                        out_i = int((i*bw)+(2*j)+(1-t)+((r-1)*((1-t)*2-1)))
                        in_i = int((i*bw+bw)+(2*j)+(1-t))
                        cl.connect(swb_outs[out_i],swb_ins[in_i])

                        if r>2:
                                # straight
                                out_i = int(((I*J*2)-((2+i)*bw))+(2*j)+t)
                                in_i = int(((I*J*2)-((1+i)*bw))+(2*j)+t)
                                cl.connect(swb_outs[out_i],swb_ins[in_i])

                                # cross
                                out_i = int(((I*J*2)-((2+i)*bw))+(2*j)+(1-t)+((r-1)*((1-t)*2-1)))
                                in_i = int(((I*J*2)-((1+i)*bw))+(2*j)+(1-t))
                                cl.connect(swb_outs[out_i],swb_ins[in_i])


def fullLock(c,bw,lw):
        # lock with luts
        cl,key = lutLock(c,int(bw/lw),lw)

        # generate switch
        s = switch()

        # generate banyan
        I = int(2*clog2(bw)-2)
        J = int(bw/2)

        # add switches
        for i in range(I*J):
                cl.extend(s.relabel({n:f'swb_{i}_{n}' for n in s}))

        # make connections
        swb_ins = [f'swb_{i//2}_in_{i%2}' for i in range(I*J*2)]
        swb_outs = [f'swb_{i//2}_out_{i%2}' for i in range(I*J*2)]
        connectBanyan(cl,swb_ins,swb_outs,bw)

        # get banyan io
        net_ins = swb_ins[:bw]
        net_outs = swb_outs[-bw:]

        # generate key
        for i in range(I*J):
                for j in range(3):
                        key[f'swb_{i}_key_{j}'] = choice([True,False])

        # get banyan mapping
        mapping = {}
        polarity = {}
        orig_result = sat(cl,{**{n:False for n in net_ins},**key})
        for net_in in net_ins:
                result = sat(cl,{**{n:False if n!=net_in else True for n in net_ins},**key})
                for net_out in net_outs:
                        if result[net_out]!=orig_result[net_out]:
                                mapping[net_in] = net_out
                                polarity[net_in] = result[net_out]
                                break

        # connect banyan io to luts
        for i in range(int(bw/lw)):
                for j in range(lw):
                        driver = cl.fanin(f'lut_{i}_sel_{j}').pop()
                        cl.disconnect(driver,f'lut_{i}_sel_{j}')
                        net_in = net_ins[i*lw+j]
                        cl.connect(mapping[net_in],f'lut_{i}_sel_{j}')
                        if not polarity[net_in]:
                                driver = cl.add(f'not_{net_in}','not',fanin=driver)
                        cl.connect(driver,net_in)

        return cl,key


#def lebl(c,bw,ng):
#       # create copy to lock
#       cl = c.copy()
#
#       # generate switch and mux
#       s = switch()
#       m = mux(4)
#
#       # generate banyan
#       I = int(2*clog2(bw)-2)
#       J = int(bw/2)
#
#       # add switches and muxes
#       for i in range(I*J):
#               cl.extend(s.relabel({n:f'swb_{i}_{n}' for n in s}))
#               cl.extend(m.relabel({n:f'mux_0_{i}_{n}' for n in m}))
#               cl.extend(m.relabel({n:f'mux_1_{i}_{n}' for n in m}))
#
#       # make connections
#       swb_ins = [f'swb_{i//2}_in_{i%2}' for i in range(I*J*2)]
#       swb_outs = [f'mux_{i%2}_{i//2}_out' for i in range(I*J*2)]
#       connectBanyan(cl,swb_ins,swb_outs,bw)
#
#       # get banyan io
#       net_ins = swb_ins[:bw]
#       net_outs = swb_outs[-bw:]
#
#       # generate connections between swb outs
#       swb_out_fin = {o:set() for o in swb_outs}
#       swb_out_fout = {o:set() for o in swb_outs}
#       for o in swb_outs:
#               fo_node = cl.fanout(o).pop()
#               swb_i = fo_node.split('_')[1]
#               swb_out_fin[f'mux_0_{swb_i}_out'].add(o)
#               swb_out_fin[f'mux_1_{swb_i}_out'].add(o)
#               swb_out_fout[o].add(f'mux_0_{swb_i}_out')
#               swb_out_fout[o].add(f'mux_1_{swb_i}_out')
#
#       # find a mapping of circuit onto banyan
#       net_map = IDPool()
#       for bn in swb_outs+net_ins:
#               for cn in c:
#                       net_map.id(f'm_{bn}_{cn}')
#
#       # mapping implications
#       clauses = []
#       for bn in swb_outs:
#               # fanin
#               if swb_out_fin[bn]:
#                       for cn in c:
#                               if c.fanin(cn):
#                                       for fcn in c.fanin(cn):
#                                               clause = [-net_map.id(f'm_{bn}_{cn}')]
#                                               clause += [net_map.id(f'm_{fbn}_{fcn}') for fbn in swb_out_fin[bn]]
#                                               clause += [net_map.id(f'm_{fbn}_{cn}') for fbn in swb_out_fin[bn]]
#                                               clauses.append(clause)
#                               else:
#                                       clause = [-net_map.id(f'm_{bn}_{cn}')]
#                                       clause += [net_map.id(f'm_{fbn}_{cn}') for fbn in swb_out_fin[bn]]
#                                       clauses.append(clause)
#
#               # fanout
#               if swb_out_fout[bn]:
#                       for cn in c:
#                               clause = [-net_map.id(f'm_{bn}_{cn}')]
#                               clause += [net_map.id(f'm_{fbn}_{cn}') for fbn in swb_out_fout[bn]]
#                               for fcn in c.fanout(cn):
#                                       clause += [net_map.id(f'm_{fbn}_{fcn}') for fbn in swb_out_fout[bn]]
#                               clauses.append(clause)
#
#       # input implications
#       for i,bn in enumerate(net_ins):
#               for cn in c:
#                       clause = [-net_map.id(f'm_{bn}_{cn}')]
#                       clause += [net_map.id(f'm_{fbn}_{cn}') for fbn in swb_outs[2*(i//2):2*(i//2)+2]]
#                       for fcn in c.fanout(cn):
#                               clause += [net_map.id(f'm_{fbn}_{fcn}') for fbn in swb_outs[2*(i//2):2*(i//2)+2]]
#                       clauses.append(clause)
#
#       # no feed through
#       for cn in c:
#               net_map.id(f'INPUT_OR_{cn}')
#               net_map.id(f'OUTPUT_OR_{cn}')
#               clauses.append([-net_map.id(f'INPUT_OR_{cn}')]+[net_map.id(f'm_{bn}_{cn}') for bn in net_ins])
#               clauses.append([-net_map.id(f'OUTPUT_OR_{cn}')]+[net_map.id(f'm_{bn}_{cn}') for bn in net_outs])
#               for bn in net_ins:
#                       clauses.append([net_map.id(f'INPUT_OR_{cn}'),-net_map.id(f'm_{bn}_{cn}')])
#               for bn in net_outs:
#                       clauses.append([net_map.id(f'OUTPUT_OR_{cn}'),-net_map.id(f'm_{bn}_{cn}')])
#               clauses.append([-net_map.id(f'OUTPUT_OR_{cn}'),-net_map.id(f'INPUT_OR_{cn}')])
#
#       # at least ngates
#       for bn in swb_outs+net_ins:
#               net_map.id(f'NGATES_OR_{bn}')
#               clauses.append([-net_map.id(f'NGATES_OR_{bn}')]+[net_map.id(f'm_{bn}_{cn}') for cn in c])
#               for cn in c:
#                       clauses.append([net_map.id(f'NGATES_OR_{bn}'),-net_map.id(f'm_{bn}_{cn}')])
#       clauses += CardEnc.atleast(bound=ng,lits=[net_map.id(f'NGATES_OR_{bn}') for bn in swb_outs+net_ins],vpool=net_map).clauses
#
#       # at most one mapping per out
#       for bn in swb_outs+net_ins:
#               clauses += CardEnc.atmost(lits=[net_map.id(f'm_{bn}_{cn}') for cn in c],vpool=net_map).clauses
#
#       # limit output usage
#       for cn in c:
#               lits = [net_map.id(f'm_{bn}_{cn}') for bn in net_outs]
#               bound = len(list(c.fanout(cn)))
#               if len(lits)<bound: continue
#               clauses += CardEnc.atmost(bound=bound,lits=lits,vpool=net_map).clauses
#
#       # remove outputs from middle of net
#       for bn in swb_outs[:-bw]+net_ins:
#               for cn in c:
#                       if c[cn]['output']:
#                               clauses += [[-net_map.id(f'm_{bn}_{cn}')]]
#       # solve
#       solver = Cadical()
#       solver.append_formula(clauses)
#       #found = solver.solve(assumptions=ass)
#       found = solver.solve()
#       if not found:
#               print(f'no config for width: {bw}')
#               core = solver.get_core()
#               print(core)
#               code.interact(local=locals())
#       else:
#               print('found')
#       model = solver.get_model()
#
#       # get each banyan outs gate
#       mapping = {}
#       for bn in swb_outs+net_ins:
#               selected_gates = [cn for cn in c if model[net_map.id(f'm_{bn}_{cn}')-1]>0]
#               if len(selected_gates)>1:
#                       print(f'multiple gates mapped to: {bn}')
#                       code.interact(local=locals())
#               mapping[bn] = selected_gates[0] if selected_gates else None
#
#       # get outputs in fanout
#       net_fanout_outs = set()
#       for bn in net_outs:
#               if mapping[bn]:
#                       net_fanout_outs |= set(n for n in nx.descendants(c,mapping[bn]) if c[n]['output'])
#
#       # get potential fanouts
#       potential_decoy_fanouts = set()
#       for n in c:
#               node_fanout_outs = set(d for d in nx.descendants(c,n) if c[n]['output'])
#               if all(f in net_fanout_outs for f in node_fanout_outs) and c[n]['gate']!='input':
#                       potential_decoy_fanouts.add(n)
#       potential_decoy_fanouts -= set(mapping.values())
#
#       # get inputs in fanin
#       net_fanin_ins = set()
#       for bn in net_ins:
#               if mapping[bn]:
#                       net_fanin_ins |= set(n for n in nx.ancestors(c,mapping[bn]) if c[n]['gate']=='input')
#
#       # get potential fanins
#       potential_decoy_fanins = set()
#       for n in c:
#               node_fanin_ins = set(d for d in nx.ancestors(c,n) if c[n]['gate']=='input')
#               if all(f in net_fanin_ins for f in node_fanin_ins) and not c[n]['output']:
#                       potential_decoy_fanins.add(n)
#       potential_decoy_fanins -= set(mapping.values())
#
#       # connect net inputs
#       c.add_edges_from((mapping[bn],bn) for bn in net_ins if mapping[bn])
#       c.add_edges_from((sample(potential_decoy_fanins,1)[0],bn) for bn in net_ins if not mapping[bn])
#       mapping.update({list(c.fanin(bn))[0]:list(c.fanin(bn))[0] for bn in net_ins})
#
#       # connect switch boxes
#       for i,bn in enumerate(swb_outs):
#               # get keys
#               if key_values[f'mux_0_{i//2}_key_0']==1 and key_values[f'mux_0_{i//2}_key_1']==1:
#                       key = 3
#               elif key_values[f'mux_0_{i//2}_key_0']==-1 and key_values[f'mux_0_{i//2}_key_1']==1:
#                       key = 2
#               elif key_values[f'mux_0_{i//2}_key_0']==1 and key_values[f'mux_0_{i//2}_key_1']==-1:
#                       key = 1
#               elif key_values[f'mux_0_{i//2}_key_0']==-1 and key_values[f'mux_0_{i//2}_key_1']==-1:
#                       key = 0
#               switch_key = 1 if key_values[f'switch_{i//2}_key']==1 else 0
#
#               # connect inner nodes
#               mux_gate_types = set()
#
#               # constant output, hookup to a node that is already in the affected outputs fanin, not in others
#               if not mapping[bn] and bn in net_outs:
#                       decoy_fanout_gate = sample(potential_decoy_fanouts,1)[0]
#                       c.add_edge(bn,decoy_fanout_gate)
#                       if c.nodes[decoy_fanout_gate]['gate'] in ['and','nand']:
#                               c.nodes[f'mux_{i%2}_{i//2}_in_{key}']['gate'] = '1'
#                       elif c.nodes[decoy_fanout_gate]['gate'] in ['or','nor','xor','xnor']:
#                               c.nodes[f'mux_{i%2}_{i//2}_in_{key}']['gate'] = '0'
#                       elif c.nodes[decoy_fanout_gate]['gate'] in ['buf']:
#                               if randint(0,1):
#                                       c.nodes[f'mux_{i%2}_{i//2}_in_{key}']['gate'] = '1'
#                                       c.nodes[decoy_fanout_gate]['gate'] = sample(['and','xnor'],1)[0]
#                               else:
#                                       c.nodes[f'mux_{i%2}_{i//2}_in_{key}']['gate'] = '0'
#                                       c.nodes[decoy_fanout_gate]['gate'] = sample(['or','xor'],1)[0]
#                       elif c.nodes[decoy_fanout_gate]['gate'] in ['not']:
#                               if randint(0,1):
#                                       c.nodes[f'mux_{i%2}_{i//2}_in_{key}']['gate'] = '1'
#                                       c.nodes[decoy_fanout_gate]['gate'] = sample(['nand','xor'],1)[0]
#                               else:
#                                       c.nodes[f'mux_{i%2}_{i//2}_in_{key}']['gate'] = '0'
#                                       c.nodes[decoy_fanout_gate]['gate'] = sample(['nor','xnor'],1)[0]
#                       mux_gate_types.add(c.nodes[f'mux_{i%2}_{i//2}_in_{key}']['gate'])
#
#               # feedthrough
#               elif mapping[bn] in [mapping[fbn] for fbn in swb_out_fin[bn]]:
#                       c.nodes[f'mux_{i%2}_{i//2}_in_{key}']['gate'] = 'buf'
#                       mux_gate_types.add('buf')
#                       if mapping[list(c.fanin(f'switch_{i//2}_in_0'))[0]]==mapping[bn]:
#                               c.add_edge(f'switch_{i//2}_out_{switch_key}',f'mux_{i%2}_{i//2}_in_{key}')
#                       else:
#                               c.add_edge(f'switch_{i//2}_out_{1-switch_key}',f'mux_{i%2}_{i//2}_in_{key}')
#
#               # gate
#               elif mapping[bn]:
#                       c.nodes[f'mux_{i%2}_{i//2}_in_{key}']['gate'] = c.nodes[mapping[bn]]['gate']
#                       mux_gate_types.add(c.nodes[mapping[bn]]['gate'])
#                       if mapping[list(c.fanin(f'switch_{i//2}_in_0'))[0]] in list(c.fanin(mapping[bn])):
#                               c.add_edge(f'switch_{i//2}_out_{switch_key}',f'mux_{i%2}_{i//2}_in_{key}')
#                       if mapping[list(c.fanin(f'switch_{i//2}_in_1'))[0]] in list(c.fanin(mapping[bn])):
#                               c.add_edge(f'switch_{i//2}_out_{1-switch_key}',f'mux_{i%2}_{i//2}_in_{key}')
#
#               # mapped to None, any key works
#               else:
#                       key = None
#
#               # fill out random gates
#               for j in range(4):
#                       if j != key:
#                               t = sample(set(['buf','or','nor','and','nand','not','xor','xnor','0','1'])-mux_gate_types,1)[0]
#                               mux_gate_types.add(t)
#                               c.nodes[f'mux_{i%2}_{i//2}_in_{j}']['gate'] = t
#                               if t=='not' or t=='buf':
#                                       # pick a random fanin
#                                       c.add_edge(f'switch_{i//2}_out_{randint(0,1)}',f'mux_{i%2}_{i//2}_in_{j}')
#                               elif t=='1' or t=='0':
#                                       pass
#                               else:
#                                       c.add_edge(f'switch_{i//2}_out_0',f'mux_{i%2}_{i//2}_in_{j}')
#                                       c.add_edge(f'switch_{i//2}_out_1',f'mux_{i%2}_{i//2}_in_{j}')
#
#       # connect outputs non constant outs
#       rev_mapping = {}
#       for bn in net_outs:
#               if mapping[bn]:
#                       if mapping[bn] not in rev_mapping:
#                               rev_mapping[mapping[bn]] = set()
#                       rev_mapping[mapping[bn]].add(bn)
#
#       for cn in rev_mapping.keys():
#               for fcn in c.fanout(cn):
#                       c.add_edge(sample(rev_mapping[cn],1)[0],fcn)
#
#       # delete mapped gates
#       deleted = True
#       while deleted:
#               deleted = False
#               for n in list(c.nodes):
#                       # node and all fanout are in the net
#                       if n not in mapping and n in mapping.values():
#                               if all(s not in mapping and s in mapping.values() for s in c.fanout(n)):
#                                       c.remove_node(n)
#                                       deleted = True
#                       # node in net fanout
#                       if n in [mapping[o] for o in net_outs] and n in c.nodes:
#                               c.remove_node(n)
#                               deleted = True
#
#       return cl,key

def checkLock(c,cl,key):
        m = miter(c,cl)
        key = {f'c1_{k}':v for k,v in key.items()}

        live = sat(m,assumptions=key)
        if not live:
                return True

        return sat(m,assumptions={'sat':True,**key})

Functions

def checkLock(c, cl, key)
Expand source code
def checkLock(c,cl,key):
        m = miter(c,cl)
        key = {f'c1_{k}':v for k,v in key.items()}

        live = sat(m,assumptions=key)
        if not live:
                return True

        return sat(m,assumptions={'sat':True,**key})
def connectBanyan(cl, swb_ins, swb_outs, bw)
Expand source code
def connectBanyan(cl,swb_ins,swb_outs,bw):
        I = int(2*clog2(bw)-2)
        J = int(bw/2)
        for i in range(clog2(J)):
                r = J/(2**i)
                for j in range(J):
                        t = (j%r)>=(r/2)
                        # straight
                        out_i = int((i*bw)+(2*j)+t)
                        in_i = int((i*bw+bw)+(2*j)+t)
                        cl.connect(swb_outs[out_i],swb_ins[in_i])

                        # cross
                        out_i = int((i*bw)+(2*j)+(1-t)+((r-1)*((1-t)*2-1)))
                        in_i = int((i*bw+bw)+(2*j)+(1-t))
                        cl.connect(swb_outs[out_i],swb_ins[in_i])

                        if r>2:
                                # straight
                                out_i = int(((I*J*2)-((2+i)*bw))+(2*j)+t)
                                in_i = int(((I*J*2)-((1+i)*bw))+(2*j)+t)
                                cl.connect(swb_outs[out_i],swb_ins[in_i])

                                # cross
                                out_i = int(((I*J*2)-((2+i)*bw))+(2*j)+(1-t)+((r-1)*((1-t)*2-1)))
                                in_i = int(((I*J*2)-((1+i)*bw))+(2*j)+(1-t))
                                cl.connect(swb_outs[out_i],swb_ins[in_i])
def fullLock(c, bw, lw)
Expand source code
def fullLock(c,bw,lw):
        # lock with luts
        cl,key = lutLock(c,int(bw/lw),lw)

        # generate switch
        s = switch()

        # generate banyan
        I = int(2*clog2(bw)-2)
        J = int(bw/2)

        # add switches
        for i in range(I*J):
                cl.extend(s.relabel({n:f'swb_{i}_{n}' for n in s}))

        # make connections
        swb_ins = [f'swb_{i//2}_in_{i%2}' for i in range(I*J*2)]
        swb_outs = [f'swb_{i//2}_out_{i%2}' for i in range(I*J*2)]
        connectBanyan(cl,swb_ins,swb_outs,bw)

        # get banyan io
        net_ins = swb_ins[:bw]
        net_outs = swb_outs[-bw:]

        # generate key
        for i in range(I*J):
                for j in range(3):
                        key[f'swb_{i}_key_{j}'] = choice([True,False])

        # get banyan mapping
        mapping = {}
        polarity = {}
        orig_result = sat(cl,{**{n:False for n in net_ins},**key})
        for net_in in net_ins:
                result = sat(cl,{**{n:False if n!=net_in else True for n in net_ins},**key})
                for net_out in net_outs:
                        if result[net_out]!=orig_result[net_out]:
                                mapping[net_in] = net_out
                                polarity[net_in] = result[net_out]
                                break

        # connect banyan io to luts
        for i in range(int(bw/lw)):
                for j in range(lw):
                        driver = cl.fanin(f'lut_{i}_sel_{j}').pop()
                        cl.disconnect(driver,f'lut_{i}_sel_{j}')
                        net_in = net_ins[i*lw+j]
                        cl.connect(mapping[net_in],f'lut_{i}_sel_{j}')
                        if not polarity[net_in]:
                                driver = cl.add(f'not_{net_in}','not',fanin=driver)
                        cl.connect(driver,net_in)

        return cl,key
def lutLock(c, n, w)
Expand source code
def lutLock(c,n,w):
        # create copy to lock
        cl = c.copy()

        # parse mux
        m = mux(2**w).stripIO()

        # randomly select gates
        potentialGates = set(g for g in cl.nodes()-cl.io() if len(c.fanin(g))<=w)
        gates = sample(potentialGates,n)
        potentialGates -= set(gates)

        # insert key gates
        key = {}
        for i,gate in enumerate(gates):

                fanout = cl.fanout(gate)
                fanin = list(cl.fanin(gate))
                padding = sample(potentialGates-cl.fanin(gate),w-len(fanin))

                # create and connect LUT
                cl.extend(m.relabel({n:f'lut_{i}_{n}' for n in m.nodes()}))
                cl.connect(f'lut_{i}_out',fanout)

                # connect sel
                for j,f in enumerate(fanin+padding):
                        cl.connect(f,f'lut_{i}_sel_{j}')

                # connect keys
                for j,vs in enumerate(product([False,True],repeat=len(fanin+padding))):
                        assumptions = {s:v for s,v in zip(fanin+padding,vs[::-1]) if s in fanin}
                        cl.add(f'key_{i*2**w+j}','input',fanout=f'lut_{i}_in_{j}')
                        key[f'key_{i*2**w+j}'] = sat(c,assumptions)[gate]

                # delete gate
                cl.remove(gate)
                cl = cl.relabel({f'lut_{i}_out':gate})

        return cl, key
def muxLock(c, k)
Expand source code
def muxLock(c,k):
        # create copy to lock
        cl = c.copy()

        # get 2:1 mux
        m = mux(2).stripIO()

        # randomly select gates
        gates = sample(cl.nodes()-cl.outputs(),k)
        decoyGates = sample(cl.nodes()-cl.outputs(),k)

        # insert key gates
        key = {}
        for i,(gate,decoyGate) in enumerate(zip(gates,decoyGates)):
                # select random key value
                key[f'key_{i}'] = choice([True,False])

                # create and connect mux
                fanout = cl.fanout(gate)
                cl.disconnect(gate,fanout)
                cl.extend(m.relabel({n:f'mux_{i}_{n}' for n in m.nodes()}))
                cl.connect(f'mux_{i}_out',fanout)
                cl.add(f'key_{i}','input',fanout=f'mux_{i}_sel_0')
                if key[f'key_{i}']:
                        cl.connect(gate,f'mux_{i}_in_1')
                        cl.connect(decoyGate,f'mux_{i}_in_0')
                else:
                        cl.connect(gate,f'mux_{i}_in_0')
                        cl.connect(decoyGate,f'mux_{i}_in_1')

        return cl, key
def sfllFlex(c, w, n)
Expand source code
def sfllFlex(c,w,n):
        # create copy to lock
        cl = c.copy()

        # find output with large enough fanin
        potential_outs = [o for o in cl.outputs() if len(cl.startpoints(o))>=w]
        if not potential_outs:
                print('input with too small')
                return None
        out = sample(potential_outs,1)[0]
        out_driver = cl.fanin(out).pop()

        # create key
        key = {f'key_{i}':choice([True,False]) for i in range(w*n)}

        # connect comparators
        cl.add(f'flip_out','or')
        cl.add(f'restore_out','or')

        for j in range(n):
                cl.add(f'flip_and_{j}','and',fanout=f'flip_out')
                cl.add(f'restore_and_{j}','and',fanout=f'restore_out')

        for i,inp in enumerate(sample(cl.startpoints(out),w)):
                for j in range(n):
                        cl.add(f'key_{i+j*w}','input')
                        cl.add(f'hardcoded_key_{i}_{j}','1' if key[f'key_{i+j*w}'] else '0')
                        cl.add(f'restore_xor_{i}_{j}','xor',fanin=[f'key_{i+j*w}',inp],fanout=f'restore_and_{j}')
                        cl.add(f'flip_xor_{i}_{j}','xor',fanin=[f'hardcoded_key_{i}_{j}',inp],fanout=f'flip_and_{j}')

        # flip output
        cl.disconnect(out_driver,out)
        cl.add('out_xor','xor',fanin=['restore_out','flip_out',out_driver],fanout=out)

        return cl,key
def sfllHD(c, w, hd)
Expand source code
def sfllHD(c,w,hd):
        # create copy to lock
        cl = c.copy()

        # parse popcount
        p = popcount(w)
        lcomp = len(p.outputs())
        p = p.stripIO()

        # find output with large enough fanin
        potential_outs = [o for o in cl.outputs() if len(cl.startpoints(o))>=w]
        if not potential_outs:
                print('input with too small')
                return None
        out = sample(potential_outs,1)[0]
        out_driver = cl.fanin(out).pop()

        # create key
        key = {f'key_{i}':choice([True,False]) for i in range(w)}

        # instantiate and connect hd circuits
        cl.extend(p.relabel({n:f'flip_pop_{n}' for n in p.nodes()}))
        cl.extend(p.relabel({n:f'restore_pop_{n}' for n in p.nodes()}))

        # connect inputs
        for i,inp in enumerate(sample(cl.startpoints(out),w)):
                cl.add(f'key_{i}','input')
                cl.add(f'hardcoded_key_{i}','1' if key[f'key_{i}'] else '0')
                cl.add(f'restore_xor_{i}','xor',fanin=[f'key_{i}',inp])
                cl.add(f'flip_xor_{i}','xor',fanin=[f'hardcoded_key_{i}',inp])
                cl.connect(f'flip_xor_{i}',f'flip_pop_in_{i}')
                cl.connect(f'restore_xor_{i}',f'restore_pop_in_{i}')

        # connect outputs
        cl.add(f'flip_out','and')
        cl.add(f'restore_out','and')
        for i,v in enumerate(format(hd, f'0{cg.clog2(w)+1}b')[::-1]):
                cl.add(f'hd_{i}',v)
                cl.add(f'restore_out_xor_{i}','xor',fanin=[f'hd_{i}',f'restore_pop_out_{i}'],fanout='restore_out')
                cl.add(f'flip_out_xor_{i}','xor',fanin=[f'hd_{i}',f'flip_pop_out_{i}'],fanout='flip_out')

        # flip output
        cl.disconnect(out_driver,out)
        cl.add('out_xor','xor',fanin=['restore_out','flip_out',out_driver],fanout=out)

        return cl, key
def switch()
Expand source code
def switch():
        m = mux(2).stripIO()
        s = Circuit(name='switch')
        s.extend(m.relabel({n:f'm0_{n}' for n in m.nodes()}))
        s.extend(m.relabel({n:f'm1_{n}' for n in m.nodes()}))
        s.add('in_0','buf',fanout=['m0_in_0','m1_in_1'])
        s.add('in_1','buf',fanout=['m0_in_1','m1_in_0'])
        s.add('out_0','xor',fanin='m0_out')
        s.add('out_1','xor',fanin='m1_out')
        s.add('key_0','input',fanout=['m0_sel_0','m1_sel_0'])
        s.add('key_1','input',fanout='out_0')
        s.add('key_2','input',fanout='out_1')
        return s
def ttLock(c, w)
Expand source code
def ttLock(c,w):
        # create copy to lock
        cl = c.copy()

        # find output with large enough fanin
        potential_outs = [o for o in cl.outputs() if len(cl.startpoints(o))>=w]
        if not potential_outs:
                print('input with too small')
                return None
        out = sample(potential_outs,1)[0]
        out_driver = cl.fanin(out).pop()

        # create key
        key = {f'key_{i}':choice([True,False]) for i in range(w)}

        # connect comparators
        cl.add(f'flip_out','and')
        cl.add(f'restore_out','and')
        for i,inp in enumerate(sample(cl.startpoints(out),w)):
                cl.add(f'key_{i}','input')
                cl.add(f'hardcoded_key_{i}','1' if key[f'key_{i}'] else '0')
                cl.add(f'restore_xor_{i}','xor',fanin=[f'key_{i}',inp],fanout='restore_out')
                cl.add(f'flip_xor_{i}','xor',fanin=[f'hardcoded_key_{i}',inp],fanout='flip_out')

        # flip output
        cl.disconnect(out_driver,out)
        cl.add('out_xor','xor',fanin=['restore_out','flip_out',out_driver],fanout=out)

        return cl,key
def ttLockSen(params)
Expand source code
def ttLockSen(params):
        pass
def xorLock(c, k)
Expand source code
def xorLock(c,k):
        # create copy to lock
        cl = c.copy()

        # randomly select gates
        gates = sample(cl.nodes()-cl.outputs(),k)

        # insert key gates
        key = {}
        for i,gate in enumerate(gates):
                # select random key value
                key[f'key_{i}'] = choice([True,False])

                # create xor/xnor,input
                gate_type = 'xnor' if key[f'key_{i}'] else 'xor'
                fanout = cl.fanout(gate)
                cl.disconnect(gate,fanout)
                cl.add(f'key_gate_{i}',gate_type,fanin=gate,fanout=fanout)
                cl.add(f'key_{i}','input',fanout=f'key_gate_{i}')

        return cl, key