"""Tools for reading writing and manipulating GMTK parameter files
@author: Arthur Kantor
@contact: akantorREMOVE_THIS@uiuc.edu
@copyright: 2008
@license: GPL version 3
@date: 11/18/2008
@version: 0.2
History:
This module is derived from a piece of the SVC library
which is Copyright (C) 2006-2008 by Jan Svec, honza.svec@gmail.com
@see: http://code.google.com/p/extended-hidden-vector-state-parser/
"""
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
import os.path
from types import StringTypes
from fnmatch import fnmatchcase
from copy import deepcopy
import subprocess
#TODO
#FIXME
#Arthur 11/18/2008
#1) children should not insert themselves into their parents
#2) replace dicts in Workspace and DTs class with NamedObjectCollection
_LEAF = -1
class DCPTDict(dict):
def flat(self):
for key in sorted(self.keys()):
for i in self[key]:
yield i
#class WordFile(file):
#"""Class for reading file wordwise
#"""
#def __init__(self, *args):
#super(WordFile, self).__init__(*args)
#self._stack = []
#def readWords(self):
#while not self._stack:
#line = self.readline()
#if not line:
#break
#line = line.split('%', 1)[0]
#self._stack.extend(line.split())
#while self._stack:
#yield self._stack.pop(0)
class ProbTable(object):
def __init__(self, scard, pcard):
super(ProbTable, self).__init__()
if not issequence(scard):
scard = [scard]
if not issequence(pcard):
pcard = [pcard]
self._nvars = len(scard)
self._npars = len(pcard)
self._cvars = tuple(scard)
self._cpars = tuple(pcard)
self._table = self._initTable(scard, pcard)
def __eq__(self, other):
for attr in ('_cvars', '_cpars', '_eqTables'):
if not hasattr(other, attr):
return NotImplementedError()
return (self._cvars == other._cvars) and \
(self._cpars == other._cpars) and \
self._eqTables(other)
def _initTable(self, scard, pcard):
raise TypeError('Abstract method called')
def _eqTables(self, other):
raise TypeError('Abstract method called')
def getTable(self):
return self._table
table = property(getTable)
def _convertSlices(self, item):
var = []
cond = []
cur = var
if not issequence(item):
item = [item]
for i in item:
if isinstance(i, slice):
if cur is cond:
raise ValueError("Malformed probability index")
if i.step is not None:
raise ValueError("Malformed probability index")
if i.start is not None:
var.append(i.start)
if i.stop is None:
raise ValueError("Malformed probability index")
else:
cond.append(i.stop)
cur = cond
else:
cur.append(i)
return tuple(var), tuple(cond)
def _getMassFunction(self, parents):
raise TypeError('Abstract method called')
def _createMassFunction(self, parents):
raise TypeError('Abstract method called')
def _delMassFunction(self, parents):
raise TypeError('Abstract method called')
def __getitem__(self, item):
var, cond = self._convertSlices(item)
mf = self._getMassFunction(cond)
if len(var) > 1:
return mf[var]
elif len(var) == 1:
return mf[var[0]]
else:
return mf
def __setitem__(self, item, value):
var, cond = self._convertSlices(item)
mf = self._createMassFunction(cond)
if len(var) > 1:
mf[var] = value
elif len(var) == 1:
mf[var[0]] = value
elif len(mf) == len(value):
#since dcpt[:1] returns the mf, it seems that dcpt[:1] =[.4 .6] should also work...
mf[:]=value[:]
else:
raise ValueError('Cardinality is %d, but attempting to assign a mass function of cardinality %d.' % (len(mf) ,len(value)))
def __delitem__(self, item):
var, cond = self._convertSlices(item)
if len(var) > 0:
raise IndexError('You can delete only whole mass function, eg. [:2, 1, 3]')
self._delMassFunction(cond)
def getParentCards(self):
return self._cpars
parentCards=property(getParentCards)
def getSelfCards(self):
return self._cvars
selfCards=property(getSelfCards)
class DenseTable(ProbTable):
def _initTable(self, scard, pcard):
if len(scard) != 1:
raise ValueError("DenseTable currently supports only 1D variables")
ret = DCPTDict()
for p in self.possibleParents:
ret[p] = [0.] * scard[0]
return ret
def _getMassFunction(self, parents):
if len(parents) != self._npars:
raise ValueError('Invalid count of parents')
return self._table[parents]
def _eqTables(self, other):
return self._table == other._table
_createMassFunction = _getMassFunction
def changeParentCards(self, newParentCards):
"""reshape table to have new parents with newParentCards cardinalities.
If new mass functions are added, they are set to 0 vectors """
if not issequence(newParentCards):
newParentCards = [newParentCards]
if 0 in newParentCards:
raise ValueError("cannot have 0 cardinality in parents")
#check if some PMs need to be deleted
toBeDeleted=[ range(newParentCards[p], self.parentCards[p]) for p in range(min(len(newParentCards), len(self.parentCards)))]
toBeDeleted += [ range(self.parentCards[p]) for p in range(len(newParentCards), len(self.parentCards))]
if len(newParentCards)newParentCards[i]:
parCards[i]=newParentCards[i]
#now check if some PMs need to be added
toBeAdded=[range(self.parentCards[p], newParentCards[p]) for p in range(min(len(newParentCards), len(self.parentCards)))]
toBeAdded += [range(newParentCards[p]) for p in range(len(self.parentCards), len(newParentCards))]
for i,p in enumerate(toBeAdded):
parRanges = [range(j) for j in parCards]
parRanges[i]=p
for pmKey in cartezian(*parRanges):
self.table[pmKey] = [0.] * self.selfCards[0]
if parCards[i]' % (self.__class__.__name__, self.name)
@classmethod
def readFromFile(cls, parent, stream):
"""Read the object from stream.
@param parent: Read object will be added to the container parent
@type stream: WorkspaceIO
@param stream: Where the object will be read from
"""
raise TypeError('Abstract method called')
def writeToFile(self, stream):
"""Write the object to stream.
@type stream: WorkspaceIO
@param stream: Where the object will be written to
"""
raise TypeError('Abstract method called')
def setName(self, name):
self._name = name
def getName(self):
return self._name
name = property(getName,setName)
"""Object name"""
def setParent(self, parent):
self._parent = parent
def getParent(self):
return self._parent
parent = property(getParent, setParent)
class NameCollection(_Object, list):
def __init__(self, parent, name):
_Object.__init__(self,parent,name)
@classmethod
def readFromFile(cls, parent, stream):
name = stream.readWord()
coll = cls(parent, name)
n_obj = stream.readInt()
for i in range(n_obj):
coll.append(stream.readWord())
return coll
def writeToFile(self, stream):
stream.writelnWord(self.name)
# ostr=UnnumberedOutStream(stream)
# for i in self:
# ostr.writeObject(i)
# ostr.finalize()
stream.writelnInt(len(self))
for i in self:
stream.writelnWord(i)
def __eq__(self, other):
return list(self) == list(other)
class TreeLeaf(object):
def __init__(self, expression,):
if isinstance(expression, int):
self.eval = False
self.expression = expression
self.sexpression = str(expression)
elif expression[0] == '{' and expression[-1] == '}':
self.sexpression = expression
expression = expression[1:-1].replace('!', ' not ')
expression = expression.replace('&&', ' and ')
expression = 'int(%s)' % expression.replace('||', ' or ')
self.eval = True
self.expression = compile(expression, '', 'eval')
else:
self.sexpression = expression
self.eval = False
self.expression = int(expression)
def __call__(self, parents):
if self.eval:
ns = {}
for i, value in enumerate(parents):
ns['p%d' % i] = value
return eval(self.expression, ns)
else:
return self.expression
def __eq__(self, other):
if not hasattr(other, 'expression'):
return NotImplemented
return self.expression == other.expression
@classmethod
def readFromFile(cls, stream):
total = stream.readWord()
if total[0] == '{':
while total[-1] != '}':
total += stream.readWord()
leaf= cls(total)
return leaf
def writeToFile(self, stream, indent=''):
stream.writeWord(self.sexpression)
class TreeBranch(object):
def __init__(self, parent_id, default=0, comment=None):
super(TreeBranch, self).__init__()
self.parentId = parent_id
self._questions = {}
self.default = default
self.comment =comment
def __eq__(self, other):
for attr in ('parentId', '_questions', 'default'):
if not hasattr(other, attr):
return NotImplemented
return (self.parentId == other.parentId) and (self._questions == other._questions) and (self.default == other.default)
def isLeaf(self):
return self.parentId == _LEAF
def vanish(self):
while True:
if self.isLeaf():
return
if len(self._questions) == 0:
self.parentId = self.default.parentId
self._questions = self.default._questions
self.default = self.default.default
else:
break
for branch in self._questions.values():
branch.vanish()
self.default.vanish()
def __contains__(self, item):
return item in self._questions
def numQuestions(self):
return len(self._questions)
def __getitem__(self, item):
if isinstance(item, (long, int)):
if item in self._questions:
return self._questions[item]
else:
return self.default
else:
raise TypeError('Bad index')
def __setitem__(self, item, value):
if isinstance(item, (long, int)):
self._questions[item] = value
else:
raise TypeError('Bad index')
def append(self, value):
"""Adds answer value to question numQuestions()
WARNING: No checks are made to make sure that question does not already exist.
"""
self._questions[self.numQuestions()]= value
def __delitem__(self, item):
if isinstance(item, (long, int)):
try:
del self._questions[item]
except KeyError:
raise ValueError("Tree hasn't branch for %r" % item)
else:
raise TypeError('Bad index')
@classmethod
def readFromFile(cls, stream):
parent_id = stream.readInt()
branch = cls(parent_id)
if parent_id == _LEAF:
branch.default = TreeLeaf.readFromFile(stream)
branch.comment = stream.comment
return branch
else:
n_quest = stream.readInt()
questions = []
while len(questions)< n_quest-1:
w = stream.readWord()
if w == '...':
questions.extend(range(questions[-1]+1,stream.readInt()+1))
else:
questions.append(int(w))
#
#for x in questions:
# print str(x)+"\n"
question = stream.readWord()
if question != 'default':
raise ValueError('Expected string "default", not %r' % question)
#
branch.comment = stream.comment
for q in questions:
answer = cls.readFromFile(stream)
branch[q] = answer
branch.default = cls.readFromFile(stream)
return branch
def writeToFile(self, stream, indent=''):
if indent:
stream.writeWord(indent)
stream.writeInt(self.parentId)
if self.isLeaf():
self.default.writeToFile(stream)
if self.comment:
stream.writelnWord(' % '+self.comment)
else:
stream.writeNewLine()
else:
stream.writeInt(len(self._questions)+1)
l = sorted(self._questions)
if l == range(len(l)):
stream.writeWord(self.makeRange(len(l)))
else:
for q in l:
stream.writeWord(q)
stream.writeWord('default')
if self.comment:
stream.writelnWord(' % '+self.comment)
else:
stream.writeNewLine()
for i in l:
self._questions[i].writeToFile(stream, indent+' ')
self.default.writeToFile(stream, indent+' ')
def makeRange(self,N):
if N<0:
raise ValueError("cannot make range into negative numbers")
elif N == 0:
return ''
elif N == 1:
return '0'
else:
return '0 ... '+str(N-1)
class DT(_Object):
NullTree = TreeBranch(_LEAF, TreeLeaf(0))
def __init__(self, parent, name, parentCount, tree):
#self._tree = deepcopy(tree)
self._tree = tree
self._parentCount = parentCount
super(DT, self).__init__(parent, name)
def __eq__(self, other):
for attr in ('_parentCount', '_tree'):
if not hasattr(other, attr):
return NotImplemented
return (self._parentCount == other._parentCount) and (self._tree == other._tree)
def getTree(self):
return self._tree
tree = property(getTree)
def getParentCount(self):
return self._parentCount
parentCount = property(getParentCount)
@classmethod
def readFromFile(cls, parent, stream, readDTS=True):
name = stream.readWord()
w = stream.readWord()
try:
parentCount = int(w)
per_utterance = False
except ValueError:
per_utterance = True
if not per_utterance:
tree = TreeBranch.readFromFile(stream)
return cls(parent, name, parentCount, tree)
else:
return DTs(parent, w, readDTS)
def writeToFile(self, stream):
stream.writelnWord(self.name)
stream.writelnInt(self.parentCount)
self.tree.writeToFile(stream)
def __getitem__(self, item):
return self.answer(item)
def answer(self, values):
if len(values) != self.parentCount:
raise ValueError('You must supply %d values' % self.parentCount)
tree = self.tree
while True:
if tree.isLeaf():
return tree.default(values)
else:
tree = tree[values[tree.parentId]]
def __setitem__(self, item, value):
self.store(item, value)
def store(self, parents, value):
if len(parents) != self.parentCount:
raise ValueError('You must supply %d values' % self.parentCount)
p_indexes = range(len(parents))
if not isinstance(value, TreeLeaf):
value = TreeLeaf(value)
tree = self.tree
while True:
if tree.isLeaf():
if p_indexes:
# Start branching in leaf, create new default branch as
# copy of this leaf
tree.default = deepcopy(tree)
tree.parentId = p_indexes.pop(0)
continue
else:
# Overwrite stored value
tree.default = value
break
else:
p_id = tree.parentId
p_val = parents[p_id]
if p_val in tree:
# Descent in tree
tree = tree[p_val]
if p_id in p_indexes:
p_indexes.remove(p_id)
else:
if p_indexes:
# Insert new subtree
new_p_id = p_indexes.pop(0)
new_default = deepcopy(tree.default)
new_tree = TreeBranch(new_p_id, new_default)
tree[p_val] = new_tree
tree = new_tree
else:
# Make leaf
tree[p_val] = TreeBranch(_LEAF, value)
break
def __delitem__(self, item):
self.delete(item)
def delete(self, values):
if len(values) != self.parentCount:
raise ValueError('You must supply %d values' % self.parentCount)
tree = self.tree
old_tree = None
while True:
if tree.isLeaf():
if old_tree is not None:
val = values[old_tree.parentId]
del old_tree[val]
old_tree.vanish()
break
else:
raise ValueError("Cannot delete value in default branch of tree")
else:
val = values[tree.parentId]
if val in tree:
tree, old_tree = tree[val], tree
else:
tree = tree[val]
old_tree = None
class DTs(_Object):
def __init__(self, parent, name, readDTS=True):
gmtk_name = os.path.basename(name).rsplit('.',1)[0]
super(DTs, self).__init__(parent, gmtk_name)
self._trees = []
self._readDTS=readDTS
self.setDtsFilename(name)
def __eq__(self, other):
if not hasattr(other, '_trees'):
return NotImplemented
return (self._trees == other._trees)
def getDtsFilename(self):
return self._dtsFilename
def setDtsFilename(self, name):
self._dtsFilename = name
if self._readDTS:
self.readTrees()
dtsFilename=property(getDtsFilename, setDtsFilename)
def writeToFile(self, stream):
stream.writeInt(len(self._trees))
stream.writelnWord('%Number of DTs')
for i in range(len(self._trees)):
stream.writelnInt(i)
self._trees[i].writeToFile(stream)
stream.writeNewLine()
def getTrees(self):
return self._trees
trees=property(getTrees)
def discardTrees(self):
trees = self.trees
parent = self.parent
while trees:
t = trees.pop()
del parent[DT, t.name]
def readTrees(self):
self.discardTrees()
trees = self.trees
io = self.parent.preprocessFile(self.dtsFilename)
nobj = io.readInt()
for i in range(nobj):
ri = io.readInt()
if i != ri:
raise ValueError('Invalid object index, read %d, expected %d' % (ri, i))
trees.append(DT.readFromFile(self.parent, io))
io.close()
class _PMF(_Object):
def __init__(self, parent, name, cardinality):
super(_PMF, self).__init__(parent, name)
self._initTable(cardinality)
def _initTable(self, cardinality):
raise TypeError('Abstract method called')
def getCardinality(self):
return len(self)
cardinality = property(getCardinality)
class DPMF(_PMF, list):
def _initTable(self, cardinality):
self[:] = [0] * cardinality
@classmethod
def readFromFile(cls, parent, stream):
name = stream.readWord()
cardinality = stream.readInt()
dpmf = cls(parent, name, cardinality)
for i in range(cardinality):
dpmf[i] = stream.readFloat()
return dpmf
def writeToFile(self, stream):
stream.writeWord(self.name)
stream.writeInt(self.cardinality)
for i in self:
stream.writeFloat(i)
#stream.writeNewLine()
def copy(self,parent,name):
'''return a copy of this object, with new name and parent'''
other=self.__class__(parent,name,self.cardinality)
other[:]=self[:]
return other
#MEAN is exactly the same as DPMF, only cardinality variable is replaced by dimentionality
class MEAN(_Object, list):
def __init__(self, parent, name, dimensionality):
_Object.__init__(self, parent, name)
self._initTable(dimensionality)
def _initTable(self, dimensionality):
self[:] = [0] * dimensionality
def getDimensionality(self):
return len(self)
dimensionality=property(getDimensionality)
@classmethod
def readFromFile(cls, parent, stream):
name = stream.readWord()
dimensionality = stream.readInt()
vec = cls(parent, name, dimensionality)
for i in range(dimensionality):
vec[i] = stream.readFloat()
return vec
def writeToFile(self, stream):
stream.writeWord(self.name)
stream.writeInt(self.dimensionality)
for i in self:
stream.writeFloat(i)
#stream.writeNewLine()
def copy(self,parent,name):
'''return a copy of this object, with new name and parent'''
other=self.__class__(parent,name,self.dimensionality)
other[:]=self[:]
return other
#diagonal convariances are also the same as MEAN
class COVAR(MEAN):
pass
#Gaussian Component
class GC(_Object):
def __init__(self, parent, name, dimensionality, meanName, varName):
super(GC, self).__init__(parent, name)
self.dimensionality=dimensionality
self.meanName=meanName
self.varName=varName
@classmethod
def readFromFile(cls, parent, stream):
dimensionality = stream.readInt()
typ = stream.readInt()
if typ != 0:
raise ValueError('only GC type 0 is supported, but %d is specified' % typ)
name = stream.readWord()
meanName = stream.readWord()
varName = stream.readWord()
gc = cls(parent, name, dimensionality, meanName, varName)
#print gc
return gc
def writeToFile(self, stream):
stream.writeInt(self.dimensionality)
stream.writeInt(0)
stream.writeWord(self.name)
stream.writeWord(self.meanName)
stream.writelnWord(self.varName)
#Mixture of Gaussians
class MG(_Object):
def __init__(self, parent, name, dimensionality, weightsDpmfName, gcNames):
super(MG, self).__init__(parent, name)
self.dimensionality=dimensionality
self.weightsDpmfName=weightsDpmfName
self.gcNames=gcNames
def copy(self, parent, name):
"""returns a tied copy of self with a new parent and name"""
return MG(parent, name, self.dimensionality, self.weightsDpmfName, self.gcNames)
@classmethod
def readFromFile(cls, parent, stream):
dimensionality = stream.readInt()
name = stream.readWord()
numComponents = stream.readInt()
weightsDpmfName = stream.readWord()
gcNames=[]
for i in range(numComponents):
gcNames.append(stream.readWord())
mg = cls(parent, name, dimensionality, weightsDpmfName, gcNames)
#print mg
return mg
def writeToFile(self, stream):
stream.writeInt(self.dimensionality)
stream.writeWord(self.name)
stream.writeInt(len(self.gcNames))
stream.writelnWord(self.weightsDpmfName)
for i in range(len(self.gcNames)):
stream.writeWord(self.gcNames[i])
class DLINK_MAT(_Object):
def __init__(self, parent, name):
raise NotImplementedError()
@classmethod
def readFromFile(cls, parent, stream):
raise NotImplementedError()
class WEIGHT_MAT(_Object):
def __init__(self, parent, name, cardinality):
raise NotImplementedError()
@classmethod
def readFromFile(cls, parent, stream):
raise NotImplementedError()
class GSMG(_Object):
def __init__(self, parent, name, cardinality):
raise NotImplementedError()
@classmethod
def readFromFile(cls, parent, stream):
raise NotImplementedError()
class LSMG(_Object):
def __init__(self, parent, name, cardinality):
raise NotImplementedError()
@classmethod
def readFromFile(cls, parent, stream):
raise NotImplementedError()
class MSMG(_Object):
def __init__(self, parent, name, cardinality):
raise NotImplementedError()
@classmethod
def readFromFile(cls, parent, stream):
raise NotImplementedError()
class SPMF(_PMF):
def __init__(self, parent, name, cardinality, dpmfName):
super(SPMF, self).__init__(parent, name, cardinality)
self._dpmfName = dpmfName
self._ptrs = {}
def __eq__(self, other):
for attr in ('_dpmfName', '_ptrs'):
if not hasattr(other, attr):
return NotImplemented
return (self._dpmfName == other._dpmfName) and (self._ptrs == other._ptrs)
def _initTable(self, cardinality):
self._cardinality = cardinality
def getDpmf(self):
return self.parent[DPMF, self.dpmfName]
dpmf=property(getDpmf)
def getDpmfName(self):
return self._dpmfName
dpmfName=property(getDpmfName)
def getPtrs(self):
return self._ptrs
ptrs=property(getPtrs)
def __len__(self):
return self._cardinality
def __getitem__(self, item):
dpmf = self.dpmf
ptrs = self._ptrs
l = len(self)
if isinstance(item, (int, long)):
if item < 0:
item += l
if not (0 <= item < l):
raise IndexError('Index out of range')
if item in ptrs:
return dpmf[ptrs[item]]
else:
return 0.0
else:
raise TypeError('Bad index')
def __setitem__(self, item, value):
dpmf = self.dpmf
ptrs = self._ptrs
l = len(self)
if isinstance(item, (int, long)):
if item < 0:
item += l
if not (0 <= item < l):
raise IndexError('Index out of range')
if item in ptrs:
dpmf[ptrs[item]] = value
else:
new_index = len(dpmf)
ptrs[item] = new_index
dpmf.append(value)
else:
raise TypeError('Bad index')
def __delitem__(self, item):
dpmf = self.dpmf
ptrs = self._ptrs
l = len(self)
if isinstance(item, (int, long)):
if item < 0:
item += l
if not (0 <= item < l):
raise IndexError('Index out of range')
if item in ptrs:
ref = ptrs[item]
del ptrs[item]
del dpmf[ref]
for key, value in ptrs.items():
if value > ref:
ptrs[key] = value-1
else:
pass
else:
raise TypeError('Bad index')
@classmethod
def readFromFile(cls, parent, stream):
name = stream.readWord()
cardinality = stream.readInt()
ptrs = {}
length = stream.readInt()
for i in range(length):
ptr = stream.readInt()
ptrs[ptr] = i
dpmfName = stream.readWord()
spmf = cls(parent, name, cardinality, dpmfName)
spmf.ptrs.update(ptrs)
return spmf
def writeToFile(self, stream):
stream.writelnWord(self.name)
stream.writelnInt(self.cardinality)
t = [y[0] for y in sorted(self._ptrs.items(), key=lambda x: x[1])]
stream.writelnInt(len(t))
for n in t:
stream.writeInt(n)
stream.writeNewLine()
stream.writelnWord(self.dpmfName)
class _CPT(_Object, ProbTable):
def __init__(self, parent, name, parent_cards, self_card):
super(_CPT, self).__init__(parent, name, [self_card], parent_cards)
def getSelfCard(self):
cards = self.selfCards
assert len(cards) == 1
return cards[0]
selfCard=property(getSelfCard)
class DCPT(_CPT, DenseTable):
@classmethod
def readFromFile(cls, parent, stream):
name = stream.readWord()
n_parents = stream.readInt()
parent_cards = []
total = 1
for i in range(n_parents):
card = stream.readInt()
parent_cards.append(card)
total *= card
self_card = stream.readInt()
total *= self_card
dcpt = cls(parent, name, parent_cards, self_card)
t = dcpt.table
for key in sorted(t.keys()):
for i in range(self_card):
t[key][i] = stream.readFloat()
return dcpt
def writeToFile(self, stream):
stream.writelnWord(self.name)
stream.writeInt(len(self.parentCards))
for c in self.parentCards:
stream.writeInt(c)
stream.writeNewLine()
self_card = self.selfCard
stream.writelnInt(self_card)
for i, val in enumerate(self.table.flat()):
if i > 0 and i % self_card == 0:
stream.writeNewLine()
stream.writeFloat(val)
else:
stream.writeNewLine()
class SCPT(_CPT):
def __init__(self, parent, name, parent_cards, self_card, dt_name, coll_name):
super(SCPT, self).__init__(parent, name, parent_cards, self_card)
self._dtName = dt_name
self._collName = coll_name
def __eq__(self, other):
return (super(SCPT, self).__eq__(other)) and \
(self._dtName == other._dtName) and \
(self._collName == other._collName)
def _initTable(self, scard, pcard):
return None
def _eqTables(self, other):
for attr in ('_dtName', '_collName'):
if not hasattr(other, attr):
return NotImplemented
return (self._dtName == other._dtName) and \
(self._collName == other._collName)
def getDtName(self):
return self._dtName
dtName = property(getDtName)
def getDt(self):
return self.parent[DT, self.dtName]
dt = property(getDt)
def getCollName(self):
return self._collName
collName = property(getCollName)
def getColl(self):
return self.parent[NameCollection, self.collName]
coll = property(getColl)
@classmethod
def readFromFile(cls, parent, stream):
name = stream.readWord()
n_parents = stream.readInt()
parent_cards = []
for i in range(n_parents):
card = stream.readInt()
parent_cards.append(card)
self_card = stream.readInt()
dtName = stream.readWord()
collName = stream.readWord()
scpt = cls(parent, name, parent_cards, self_card, dtName, collName)
return scpt
def writeToFile(self, stream):
stream.writelnWord(self.name)
stream.writeInt(len(self.parentCards))
for c in self.parentCards:
stream.writeInt(c)
stream.writeNewLine()
stream.writelnInt(self.selfCard)
stream.writelnWord(self.dtName)
stream.writelnWord(self.collName)
@classmethod
def create(cls, parent, name, parent_cards, self_card):
collection = NameCollection(parent, name)
collection.append(name+'00000')
null_dpmf = DPMF(parent, name+'00000', self_card)
null_spmf = SPMF(parent, name+'00000', self_card, name+'00000')
dt = DT(parent, name, len(parent_cards), DT.NullTree)
return cls(parent, name, parent_cards, self_card, name, name)
def _getMassFunction(self, parents):
index = self.dt[parents]
spmf = self.parent[SPMF, self.coll[index]]
return spmf
def newMassFunction(self):
"""Create new SPMF (and its DPMF) and register it in collection
@return: Tuple (index, spmf), where `index` is index in NameCollection and spmf
is created function.
"""
index = len(self.coll)
new_name = '%s%05d' % (self.name, index)
dpmf = DPMF(self.parent, new_name, 0)
spmf = SPMF(self.parent, new_name, self.selfCard, dpmf.name)
self.coll.append(new_name)
return index, spmf
def _createMassFunction(self, parents):
tree_value = self.dt[parents]
if tree_value != 0:
return self._getMassFunction(parents)
else:
index, spmf = self.newMassFunction()
self.dt[parents] = index
return spmf
def _delMassFunction(self, parents):
del self.dt[parents]
class DetCPT(_CPT):
def __init__(self, parent, name, parent_cards, self_card, dt_name):
super(DetCPT, self).__init__(parent, name, parent_cards, self_card)
self._dtName = dt_name
def _eqTables(self, other):
if not hasattr(other, '_dtName'):
return NotImplemented
return (self._dtName == other._dtName)
def getDtName(self):
return self._dtName
dtName = property(getDtName)
def getDt(self):
return self.parent[DT, self.dtName]
dt = property(getDt)
def _initTable(self, scard, pcard):
return None
def _getMassFunction(self, parents):
i = self.dt[parents]
ret = [0.] * self.selfCard
if not (0 <= i < self.selfCard):
raise IndexError("DT %r returns value %d, which is out of range [0, %d]" % (self.dtName, i, self.selfCard-1))
ret[i] = 1.0
return ret
@classmethod
def readFromFile(cls, parent, stream):
name = stream.readWord()
n_parents = stream.readInt()
parent_cards = []
for i in range(n_parents):
card = stream.readInt()
parent_cards.append(card)
self_card = stream.readInt()
dtName = stream.readWord()
detcpt = cls(parent, name, parent_cards, self_card, dtName)
return detcpt
def writeToFile(self, stream):
stream.writelnWord(self.name)
stream.writeInt(len(self.parentCards))
for c in self.parentCards:
stream.writeInt(c)
stream.writeNewLine()
stream.writelnInt(self.selfCard)
stream.writelnWord(self.dtName)
class FeatureDefinition(set):
def __init__(self, name, allowedValues):
self.name = name
self.update(allowedValues)
@classmethod
def readFromFile(cls, io):
name = io.readWord()
empty,firstVal= io.readWord().split('(')
if not empty == '':
raise ValueError("Cannot have chars preceding '(' in the same word on line %d"%io.line)
self.add(firstVal)
endSeen = False
while not endSeen:
val = io.readWord()
try:
lastVal, leftovers = val.split(')')
except ValueError:
self.add(val)
else:
endSeen = True
if lastVal: #last char is )
self.add(lastVal)
if leftovers:
raise ValueError("Was expecting a ')' by itself or at the end of a word on line %d"%io.line)
return cls(name, featureValues)
def writeToFile(self, io):
io.writeWord(self.name)
io.writeWord('(')
for v in self:
io.writeWord(v)
io.writeWord(')')
class FeatureValues(object):
def __init__(self, name, featureValues):
self.featureValues=featureValues
self.name = name
@classmethod
def readFromFile(cls, io):
name = io.readWord()
#FIXME actually read in the feature definitions to determine how many words to read
#For now read all words to end of line
featureValues = io.file.readLine().split()
return cls(name, featureValues)
def writeToFile(self, io):
io.writeWord(self.name)
for v in self.featureValues:
io.writeWord(v)
class Question(set):
def __init__(self, name, feature, values=set()):
self.name = name
self.feature =feature
self.update(values)
@classmethod
def readFromFile(cls, io):
name = io.readWord()
feature = io.readWord()
q =Question(name, feature)
n = io.readInt()
for i in range(n):
q.add(io.readWord())
return q
def writeToFile(self, io):
io.writeWord(self.name)
io.writeWord(self.feature)
io.writeInt(len(self))
for v in sorted(self):
io.writeWord(v)
class NamedObjectCollection(dict):
"A generic named object collection"
def __init__(self, obj_type):
self.obj_type = obj_type #The class of the named objects in this collection
def readFromIO(self, io):
nobj = io.readInt()
for i in range(nobj):
ri = io.readInt()
if i != ri:
raise ValueError('Invalid object index, read %d, expected %d' % (ri, i))
obj = self.obj_type.readFromFile(io)
self[obj.name]=obj
def writeToFile(self, io):
items = sorted(self.items())
ostr=NumberedObjectOutStream(io, self.obj_type)
for (name, obj) in items:
ostr.writeObject(obj)
ostr.finalize()
class ObjectOutStream(object):
def __init__(self, io):
self._io=io
self.objCount=0
self._countBuf = self._io.file.tell()
self._io.writeWord(' '*20)
def finalize(self):
curPos=self._io.file.tell()
self._io.file.seek(self._countBuf)
self._io.writeInt(self.objCount)
self._io.file.seek(curPos)
def writeComment(self,s):
self._io.writelnWord(s)
class UnnumberedOutStream(ObjectOutStream):
def __init__(self, io):
super(UnnumberedOutStream,self).__init__(io)
self._io.writelnWord('% number of objects')
self._io.writeNewLine()
def writeObject(self,obj):
obj.writeToFile(self._io)
self.objCount +=1
class NumberedObjectOutStream(ObjectOutStream):
'''An object which writes namedObjects to a stream on the fly,
without storing them in memory - useful for large object collections'''
def __init__(self, io, obj_type):
try:
obj_type = obj_type.__name__
except AttributeError:
pass #objtype is probably some string
super(NumberedObjectOutStream,self).__init__(io)
self._io.writelnWord('% number of '+ "%s objects" % obj_type)
self._io.writeNewLine()
def writelnString(self,s):
self._io.writeInt(self.objCount)
self._io.writelnWord(s)
self.objCount +=1
def writeObject(self,obj):
self._io.writelnInt(self.objCount)
obj.writeToFile(self._io)
self._io.writeNewLine()
self._io.writeNewLine()
self.objCount +=1
class NamedObjectList(list):
"A generic named and ordered object list"
def __init__(self, obj_type):
self.obj_type = obj_type #The class of the named objects in this collection
def readFromIO(self, io):
"""This differs from readFromFile because readFromFile is a @classMethod"""
nobj = io.readInt()
for i in range(nobj):
ri = io.readInt()
if i != ri:
raise ValueError('Invalid object index, read %d, expected %d' % (ri, i))
obj = self.obj_type.readFromFile(io)
self[i]=obj
def writeToFile(self, io):
ostr=NumberedObjectOutStream(io, self.obj_type)
for obj in self:
ostr.writeObject(obj)
ostr.finalize()
class FeatureValuesOutStream(NumberedObjectOutStream):
def __init__(self,io,name, featureDefsName):
io.writelnWord(name+" % name of this feature values collection")
io.writelnWord(featureDefsName+ " % feature definitions name")
super(FeatureValuesOutStream,self).__init__(io,FeatureValues)
class FeatureValuesCollection(NamedObjectCollection,_Object):
"A gmtkTie Feature Values collection"
obj_type = FeatureValues
def __init__(self, parent, name=None, featureDefsName=None, featureValuesList=[]):
self.featureDefsName=featureDefsName
self.update(featureValuesList)
NamedObjectCollection.__init__(self,self.obj_type)
_Object.__init__(self,parent, name)
@classmethod
def readFromFile(cls, parent, io):
coll = cls(parent)
coll.readFromIO(io)
return coll
def readFromIO(self, io):
self.name = io.readWord()
self.featureDefsName = io.readWord()
super(FeatureValuesCollection,self).readFromIO(io)
def getFeatureDefs(self):
return self.parent[FeatureDefinitionList,self.featureDefsName]
featureDefs= property(getFeatureDefs)
def writeToFile(self, io):
ostr=FeatureValuesOutStream(io, self.name,self.featureDefsName)
for obj in self:
ostr.writeObject(obj)
ostr.finalize()
class QuestionCollection(NamedObjectCollection,_Object):
"A gmtkTie questions collection"
obj_type = Question
def __init__(self, parent, name=None, featureDefsName=None, featureValuesName=None, questions={}):
self.featureDefsName=featureDefsName
self.featureValuesName=featureValuesName
self.update(questions)
NamedObjectCollection.__init__(self,self.obj_type)
_Object.__init__(self,parent, name)
@classmethod
def readFromFile(cls, parent, io):
coll = cls(parent)
coll.readFromIO(io)
return coll
def readFromIO(self, io):
self.name = io.readWord()
self.featureDefsName = io.readWord()
self.featureValuesName = io.readWord()
super(FeatureValuesCollection,self).readFromIO(io)
def getFeatureDefs(self):
return self.parent[FeatureDefinitionList,self.featureDefsName]
featureDefs= property(getFeatureDefs)
def writeToFile(self, io):
io.writelnWord(self.name+" % name of this questions collection")
io.writelnWord(self.featureDefsName+ " % feature definitions name")
io.writelnWord(self.featureValuesName+ " % feature values name")
super(QuestionCollection,self).writeToFile(io)
def validate(self):
"make sure all the questions are about existing features, and represent valid feature value sets"
raise NotImplementedError()
class FeatureDefinitionList(NamedObjectList,_Object):
"A gmtkTie Feature Definition collection"
obj_type = FeatureDefinition
def __init__(self, parent, name=None, featureDefList=[]):
NamedObjectList.__init__(self,self.obj_type)
_Object.__init__(self,parent, name)
self.extend(featureDefList)
@classmethod
def readFromFile(cls, parent, io):
coll = cls(parent)
coll.readFromIO(io)
return coll
def readFromIO(self, io):
self.name = io.readWord()
super(FeatureDefinitionList,self).readFromFile(io)
def writeToFile(self, io):
io.writelnWord(self.name+" % name of this feature definitions list")
super(FeatureDefinitionList,self).writeToFile(io)
class Workspace(object):
""" contains collections of all known objects, (including collections of some collections, e.g. DTs,
FeatureValuesCollection and FeatureDefinitionCollection). FIXME No effort is done to fix name collisions. So
if two FeatureValues with the same name belong to different FeatureValuesCollection. self[FeatureValues] will have only one of them"""
knownObjects = {
'NAME_COLLECTION': NameCollection,
'DPMF': DPMF,
'MEAN': MEAN,
'COVAR': COVAR,
'GC': GC,
'MG': MG,
'SPMF': SPMF,
'DENSE_CPT': DCPT,
'SPARSE_CPT': SCPT,
'DETERMINISTIC_CPT': DetCPT,
'DT': DT,
'____DTs': DTs,
'DLINK_MAT' : DLINK_MAT,
'WEIGHT_MAT' : WEIGHT_MAT,
'GSMG' : GSMG,
'LSMG' : LSMG,
'MSMG' : MSMG,
'____FeatureValuesCollection' : FeatureValuesCollection, #gmtkTie feature values file: cannot be used in a Master file
'____FeatureDefinitionList' : FeatureDefinitionList, #gmtkTie feature definitions file: cannot be used in a Master file
'____QuestionCollection' : QuestionCollection, #gmtkTie questions file: cannot be used in a Master file
}
def __init__(self, cppOptions=None, readDTS=True):
super(Workspace, self).__init__()
self._objects = dict([(obj_type,{}) for obj_type in self.knownTypes])
self._cpp = Preprocessor(cppOptions=cppOptions)
self._readDTS = readDTS
def __str__(self):
print "%d object kinds" % len(self.objects)
for obj_type in self.objects:
print "%s obj_type : %d items" % (obj_type, len(self[obj_type]))
for o in sorted(self[obj_type].items()) :
print o
def getKnownTypes(self):
return self.knownObjects.values()
knownTypes = property(getKnownTypes)
def getObjects(self):
return self._objects
objects = property(getObjects)
def preprocessFile(self, fn):
return WorkspaceIO(self._cpp.openFile(fn))
def readMasterFile(self, mstr):
IN_FILE = '_IN_FILE'
INLINE = 'inline'
ASCII = 'ascii'
mstr_io = self.preprocessFile(mstr)
#keep a list of open files, in case more than one parameter type is
#stored in the same file
file_ios={}
while True:
try:
command = mstr_io.readWord()
except IOError:
break
if not command.endswith(IN_FILE):
raise ValueError('Invalid master file command: %s' % command)
type_name = command[:-len(IN_FILE)]
obj_type = self.knownObjects[type_name]
fn = mstr_io.readWord()
if fn == INLINE:
self.readFromIO(obj_type,mstr_io)
else:
format = mstr_io.readWord()
if format != ASCII:
raise ValueError('Format of %r not supported: %s' % (fn, format))
if fn in file_ios:
file_io =file_ios[fn]
else:
file_io = self.preprocessFile(fn)
file_ios[fn]=file_io
self.readFromIO(obj_type,mstr_io)
def writeMasterFile(self, mstr):
OUT_FILE = '_OUT_FILE'
ASCII = 'ascii'
mstr_io = self.preprocessFile(mstr)
file_ios={}
while True:
try:
command = mstr_io.readWord()
except IOError:
break
if not command.endswith(OUT_FILE):
raise ValueError('Invalid master file command: %s' % command)
type_name = command[:-len(OUT_FILE)]
obj_type = self.knownObjects[type_name]
fn = mstr_io.readWord()
format = mstr_io.readWord()
if format != ASCII:
raise ValueError('Format of %r not supported: %s' % (fn, format))
if fn in file_ios:
file_io =file_ios[fn]
else:
fw = file(fn, 'w')
file_io = WorkspaceIO(fw)
file_ios[fn]=file_io
self.writeToIO(obj_type,file_io)
for f in file_ios.values():
f.close()
def readTrainableParamsFile(self, trainableFile):
io = self.preprocessFile(trainableFile)
for typ in [DPMF, SPMF, MEAN, COVAR,DLINK_MAT,WEIGHT_MAT, DCPT,GC, MG, GSMG, LSMG, MSMG, ]:
self.readFromIO(typ,io)
io.close()
def writeTrainableParamsFile(self, trainableFile):
io = WorkspaceIO(file(trainableFile,'w'))
self.writeTrainableParamsIO(io)
io.close()
def writeTrainableParamsIO(self, io):
for typ in [DPMF, SPMF, MEAN, COVAR,DLINK_MAT,WEIGHT_MAT, DCPT,GC, MG, GSMG, LSMG, MSMG, ]:
self.writeToIO(typ,io)
def readFromFile(self, obj_type, filename):
f = self.preprocessFile(filename)
try:
self.readFromIO(obj_type,f)
finally:
f.close()
def writeToFile(self, obj_type, filename):
f = WorkspaceIO.withFile(filename, 'w')
try:
self.writeToIO(obj_type, f)
finally:
f.close()
def readFromIO(self, obj_type, io):
nobj = io.readInt()
for i in range(nobj):
ri = io.readInt()
if i != ri:
raise ValueError('Invalid object index, read %d, expected %d' % (ri, i))
if obj_type == DT:
obj = obj_type.readFromFile(self, io, readDTS=self._readDTS)
else:
obj = obj_type.readFromFile(self, io)
def writeToIO(self, obj_type, io):
items = sorted(self[obj_type].items())
io.writelnWord('% '+ "%s objects" % obj_type.__name__)
io.writeInt(len(items))
io.writeNewLine()
io.writeNewLine()
for i, (name, obj) in enumerate(items):
io.writelnInt(i)
obj.writeToFile(io)
io.writeNewLine()
def __contains__(self, (obj_type, name)):
return name in self._objects[obj_type]
def __getitem__(self, item):
if not issequence(item):
item = [item]
if len(item) == 1:
return self._objects[item[0]]
elif len(item) == 2:
return self._objects[item[0]][item[1]]
else:
raise IndexError('Invalid index: %r' % item)
def __setitem__(self, obj_type, value):
name = value.name
if (obj_type, name) in self:
raise ValueError('There is already %s object %r' % (obj_type.__name__, name))
self._objects[obj_type][name] = value
def __delitem__(self, (obj_type, name)):
obj = self._objects[obj_type][name]
del self._objects[obj_type][name]
obj.parent = None
def getObjLike(self, obj_type, mask):
objs = self.objects[obj_type]
ret = []
for name, obj in objs.iteritems():
if fnmatchcase(name, mask):
ret.append(obj)
return ret
def delObjLike(self, obj_type, mask):
objs = self.objects[obj_type]
to_del = []
for name, obj in objs.iteritems():
if fnmatchcase(name, mask):
to_del.append(name)
for name in to_del:
del self[obj_type, name]
class Preprocessor(object):
def __init__(self, cppOptions=None):
super(Preprocessor, self).__init__()
if cppOptions is None:
cppOptions = []
self.cppOptions = cppOptions
def createProcess(self, fn):
p = subprocess.Popen(['cpp'] + self.cppOptions + ['-P', fn], stdout=subprocess.PIPE)
return p
def openFile(self, fn):
p = self.createProcess(fn)
return p.stdout
class WorkspaceIO(object):
def __init__(self, fobj):
super(WorkspaceIO, self).__init__()
self._file = fobj
self._line = 0 #counts lines read through this WorkspaceIO object
self._stack = []
self._ws = False
self._wordsInFile = self.wordsGen()
'''the comment on the current line in the input file'''
self.comment=''
@classmethod
def withFile(cls, *args, **kwargs):
f = file(*args, **kwargs)
return cls(f)
def getFile(self):
return self._file
file=property(getFile)
@property
def name(self):
return self._file.name
def readWord(self):
return self._wordsInFile.next()
def getLine(self):
return self._line
line=property(getLine)
def wordsGen(self):
for line in self._file:
self._line += 1
l = line.split('%', 2)
line=l[0]
self.comment = (len(l)>1 and l[1] or '').strip()
#print "%s, %s" %(line, self.comment)
for w in line.split():
yield w
raise IOError('End of file')
def readInt(self):
return int(self.readWord())
def readFloat(self):
return float(self.readWord())
def writeWord(self, w):
if self._ws:
self._file.write(' ')
self._file.write('%s' % w)
self._ws = True
def writelnWord(self, w):
self.writeWord(w)
self.writeNewLine()
def writeInt(self, i):
if self._ws:
self._file.write(' ')
self._file.write('%d' % i)
self._ws = True
def writelnInt(self, w):
self.writeInt(w)
self.writeNewLine()
def writeFloat(self, f):
if self._ws:
self._file.write(' ')
self._file.write('%.10g' % f)
self._ws = True
def writelnFloat(self, w):
self.writeFloat(w)
self.writeNewLine()
def writeNewLine(self):
self._file.write('\n')
self._ws = False
def close(self):
self._file.close()
#some utility functions
def cartezian(*vectors):
"""Compute Cartesian product of passed arguments
"""
ret = ret_old = [(v,) for v in vectors[0]]
for vec in vectors[1:]:
ret = []
for v in vec:
for r in ret_old:
ret.append(r+(v,))
ret_old = ret
return ret
def issequence(obj):
"""Return True if `obj` is sequence, but not string
@rtype: bool
"""
if isinstance(obj, StringTypes):
return False
else:
try:
len(obj)
return True
except TypeError:
return False