コードメモ > ナイーブベイズその1


※上記の広告は60日以上更新のないWIKIに表示されています。更新することで広告が下部へ移動します。

設計もしないままにそのまま書いていったもの。

データベース部分に難があるため、

うまく動作してくれない。

もういちどデータベース作成部分からやりなおし。


おそらく学習用データがまずい。

数が少ないことが。

あとデータベースのラッパとアクセスが地味に辛い。

再度打ち直す必要がある。

learn.py

#coding:utf-8
import MeCab
class Parser:
  def __init__(self):
    option='--node-format=%m,%f[6],%f[0],%h\\n'
    self.mec = MeCab.Tagger(option)
  def parse(self,string):
    parsed = self.mec.parse(string)
    return parsed
  def parse_iter(self,string):
    data = self.parse(string).split('\n')

    for line in data[:-1]:
      if not ',' in line:return
      param = line.split(',')
      if len(param)!=4:continue
      (pure_word,word,word_type,id) = param
      yield (pure_word,word,word_type,int(id))

def test():
  p = Parser()
  print list(p.parse_iter('あいうえおかきく'))
  for data in p.parse_iter('解析テストをしています。'
     '多分うまくいくはず。'
     '素性のIDとは何の事を言っているかイマイチ理解できていない自分。'
     '情けない。単語IDとはまた関係ない様子。'):
    print '%s,%s,%s,%d'%(data[0],data[1],data[2],data[3])

if __name__=='__main__':
  test()

MeCabとパーサ。MeCabについてはもう一度噛んでおきたい。


make_database.py

#coding:utf-8
import os
from optparse import OptionParser

def procfolder():
  def _makefolder(clsname_list):
    if os.access('./text_files',os.F_OK):
      return False
    for name in clsname_list:
      ps = os.path.join('.','text_files',name)
      os.makedirs(ps)
    return True

  import sys
  if len(sys.argv) > 1:
    print '次のクラスのフォルダを作成します'
    for name in sys.argv[1:]:
      print '%s '%name,
    print u'\nよろしいですか[y/n]',
    if 'y' == raw_input():
      _makefolder(sys.argv[1:])
    else:
      print u'フォルダ作成をキャンセルしました'
  else:
    print u'クラス名を入力してください'
    exit()

def create_database():
  import sqlite3
  def create_table(database):
    sql_text = '''
    create table words(
      id integer primary key autoincrement,
      word text,
      type integer
    );'''
    database.execute(sql_text)
  def have_word(word,database):
    sql_text = '''
    select * from words where word='%s';
    '''%(word)
    if database.execute(sql_text).fetchone():
      return True
    else:
      return False
  def addword(word,word_type,database):
    sql_text= '''
    insert into words (word,type) values ('%s',%d);
    '''%(word,word_type)
    database.execute(sql_text)
    print '.',

  def exist_table(database):
    sql_text='''
       select name from sqlite_master
         where type='table' order by name and name='words';
    '''
    if database.execute(sql_text).fetchone():
      return True
    else:
      return False

  print 'creating table:'
  db = sqlite3.connect('worddatabase.db')
  if not exist_table(db):
    create_table(db)
  for word,word_type in enumwords():
    if not have_word(word,db):
      addword(word,word_type,db)
  db.commit() #コミットしないと反映されないようで...
  db.close()
  print 'done.'

def enumwords():
  import os
  from learn import Parser
  root_folder = 'text_files'
  join = os.path.join
  parser = Parser()
  for clsname in os.listdir(join('.',root_folder)):
    for docname in os.listdir(join('.',root_folder,clsname)):
       file_path =join('.',root_folder,clsname,docname)
       try:
         for line in open(file_path):
           for words in parser.parse_iter(line):
             if words[-1]>=10:#品詞
               yield (words[1],words[-1])
       except IOError:
         print 'skip a directory.'

def get_wordid(database,word,type):
  sql_text='''
    select id from words
      where word='%s' and type='%s'
  '''%(word,type)
  return database.execute(sql_text).fetchone()

#未テスト
def get_wordcount():
  sql_text='''
    select count(id) from words;
  '''
  return database.execute(sql_text).fetchone()

def regist_words(database):
  sql_text='''
    select id from words;
  '''
  for data in database.execute(sql_text):
    yield data

'''
def create_doctable():
  import sqlite3
  db = sqlite3.connect('worddatabase.db')
  def create_doctable(database):
    sql_text=create table docment_prob(
         name text
    for word_id in regist_words(db):         
      sql_text += ',id_%d real'%word_id
    sql_text+=');'
    database.execute(sql_text)
  create_doctable(db)
  #バイナリストリームとして保存しても問題ない。
'''

def docprob(filepath):
  from learn import Parser
  import os
  dir,filename = os.path.split(filepath)
  parser = Parser()
  getbin = [0] * get_wordcount()
  for line in open(filepath):
    for word_data in parser.parse_iter(line):
      sql_text = '''
        select id from words
          where word = '%s' and type = '%s';
      '''%(word,type)
  #未テスト2

if __name__ == '__main__':
  #procfolder()
  create_database()
  #create_doctable() 
  #calc_clsprob()

フォルダ作成と学習用データからの単語の集計

clsprob.py

from calcrate import Database
def clsprob():
  import os
  filepath = os.path.join('.','text_files')
  cls_dict={}
  for clsname in os.listdir(filepath):
    filenum=len([file for file in os.listdir(os.path.join(filepath,clsname))
          if file[-4:]=='.mdl'])
    cls_dict[clsname] = filenum
  filenum = sum(cls_dict.values())

  #write
  db = ExtDatabase()
  db.create_clstable()
  for clsname in cls_dict:
    print u'クラス:%s ファイル数:%d'%(clsname,cls_dict[clsname])
    p_c = (cls_dict[clsname]+1)/(float( len(cls_dict) + filenum))
    print u'P(C)= %f'%p_c
    db.add_clsprob(clsname,p_c)
  del db

from calcrate import Database
import sqlite3
class ExtDatabase(Database):
  def __init__(self):
    Database.__init__(self)

  def create_clstable(self):
    sql_text='''
    create table clsprob(
      id integer primary key autoincrement,
      clsname text,
      prob real
    );
    '''
    try:
      self.database.execute(sql_text)
    except sqlite3.OperationalError:
      print 'exitst' 

  def add_clsprob(self,name,value):
    sql_text='''
    insert into clsprob(
      clsname,prob
    ) values ('%s',%f);
    '''%(name,value)
    self.database.execute(sql_text)


if __name__=='__main__':
  clsprob()

クラスの単語データ

P(W | C)P(C)を求める。

calcrate.py

#coding:utf-8
import sqlite3
class Database:
  inst=None
  @classmethod
  def get_instance(cls):
    if not cls.inst:
      cls.inst = Database()
    return cls.inst

  def __init__(self):
    if Database.inst:
      raise Exception('singletonでインスタンスが二回生成された。')
    self.database = sqlite3.connect('worddatabase.db')
    self.__make_wordtable()
  def __make_wordtable(self):
    sql_text='''
      select * from words;
    '''
    self.word_list={}
    for id,word,type in self.database.execute(sql_text):
      self.word_list[(word.encode('utf-8'),type)]=id
  
  def get_wordid(self,word,type):
    if not ( (word,type) in self.word_list):
      return -1
    else:
      return self.word_list[(word,type)]

  def __get_wordid_from(self,word,type):
    sql_text='''
      select id from words
      where word='%s' and type='%s'
    '''%(word,type)
    data = self.database.execute(sql_text).fetchone()
    if data == None:
      return -1
    else:
      return data[0]
  #未テスト
  def __get_wordcount(self):
    sql_text='''
      select count(id) from words;
    '''
    return self.database.execute(sql_text).fetchone()[0]
  def get_wordcount(self):
    return len(self.word_list)

  #ガペコレにまかせていーんでしょうか
  def __del__(self): 
   self.database.commit()
   self.database.close()

def make_model(path):
  import os
  dir,filename = os.path.split(path)
  print 'make %s\'s model'%filename
  db = Database.get_instance()
  datalist = [0] * db.get_wordcount()
  
  #モデル化  
  for word,type in enumword(path):
    pos = db.get_wordid(word,type)
    if pos != -1:
      datalist[pos-1]=1
  
  print 'outfile %s'%filename
  #QQQ:LLだとJSONの
  import struct
  outfile = open(os.path.join(dir,filename+'.mdl'),'w')
  for bin in datalist:
    outfile.write(struct.pack('d',bin))
  outfile.close()

def enumword(file):
  from learn import Parser
  parser=Parser()
  try:
    for line in open(file):
      for data in parser.parse_iter(line):
        #word and type
        yield (data[1],data[-1])
  except IOError:
    print 'directory pass'
def make_probfile():
  import os
  rootdir = 'text_files'
  for cls in os.listdir(os.path.join('.',rootdir)):
    crean_modelfile(os.path.join('.',rootdir,cls))
    for filename in os.listdir(os.path.join('.',rootdir,cls)):
      path = os.path.join('.',rootdir,cls,filename)
      make_model(path)

#未テスト
def crean_modelfile(directry):
  import os
  for file in (f for f in os.listdir(directry) if f[-4:]=='.mdl'):
    filepath = os.path.join(directry,file)
    os.remove(filepath)

if __name__=='__main__':
  make_probfile()

テキストをモデル化して保存する。

これは永続化する必要はないのではないか。

クラスの計算でpythonのデータ形式で計算してデータベースを作ればいいのではないか。

…まあテストを作りやすくするってのはどうすればいいのやら。


一つ思ったこと:

P(C)が結構強く効いてくるので、

訓練文書の数は揃えておいた方がいいのと、

ストップワードの蘇生を増やすべきかと。


分類用

solve.py

#coding:utf-8
def enum_words(io):
  from learn import Parser
  text = ''.join(io.readlines())
  parser = Parser()
  for words in parser.parse_iter(text):
    if words[-1] >= 10:
      yield (words[1],words[-1])

class WordTable:
  def __init__(self):
    import sqlite3
    db = sqlite3.connect('worddatabase.db')
    sql_text='''
      select * from words;
    '''
    self.table={}
    for id,word,type in db.execute(sql_text):
      key = (word,int(type))
      self.table[key] = int(id)
    db.close()

  def get_id(self,word,id):
    key = (word.decode('utf-8'),id)
    if key in self.table:
      return self.table[key]
    else:
      raise Exception('単語が存在しませぬ。')

  def __len__(self):
    return len(self.table)

def test_enum():
  text = raw_input()
  from StringIO import StringIO
  table = WordTable()
  for word,typeid in enum_words(StringIO(text)):
    print word
    print typeid
    print table.get_id(word,typeid)

def makemodel(text):
  table = WordTable()
  wordmodel=[0] * len(table)
  from StringIO import StringIO
  for word,typeid in enum_words(StringIO(text)):
    try:
      id = table.get_id(word,typeid)
    except Exception,inst:
      pass
      #print inst.args
    else:
       wordmodel[id-1] = 1
  
  return wordmodel

def makemodel_show():
  text = raw_input()
  model = makemodel(text)
  for i,prob in enumerate(model):
    print 'id_%d:%d'%(i+1,prob)

class ClsTable:
  def __init__(self,clsname):
    import sqlite3
    db = sqlite3.connect('worddatabase.db')
    sql_text='''
      select * from %s;
    '''%clsname
    self.table = []
    for wordid,prob in db.execute(sql_text):
      self.table.append(prob)

    self.clsprob = db.execute(
      "select prob from clsprob where clsname='%s'"%clsname).fetchone()[0]
    db.close()

  def word_prob(self,id):
    if id == 0 or id > len(self.table):
      raise Exception('out range')

    return self.table[id - 1]

  def cls_prob(self):
    return self.clsprob

def calcurate_prob(textmodel,clsname):
  import math
  clstable = ClsTable(clsname)
  p_dc = 1.0
  for i,word_exist  in enumerate(textmodel):
    if word_exist >= 1:
      p_dc += math.log(clstable.word_prob(i+1))
    else:
      p_dc += math.log((1 - clstable.word_prob(i+1)))

  return p_dc + math.log( clstable.cls_prob())

def getclsname():
  import sqlite3
  db = sqlite3.connect('worddatabase.db')
  columns = db.execute('select clsname from clsprob')
  return set([name[0] for name in columns])
   
if __name__ == '__main__':
  text = raw_input()
  model = makemodel(text)
  clsnames = getclsname()
  for cls in clsnames:
    value = calcurate_prob(model,cls)
    print 'cls %s : value = %d'%(cls,value)