#! /usr/env/bin python
"""Sample code. Edit it as you like!"""
__authors__ = "Olivier Delalleau"
__copyright__ = "(c) 2010, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "Olivier Delalleau <delallea@iro>"
# Standard library imports are on a single line.
import os, sys, time
# Third-party imports come after standard library imports, and there is
# only one import per line. Imports are sorted lexicographically.
import numpy
import scipy
import theano
# Individual 'from' imports come after packages.
from numpy import argmax
from theano import tensor
# Application-specific imports come last.
# The absolute path should always be used.
from pylearn import datasets, learner
from pylearn.formulas import noise
# All exceptions inherit from Exception.
class PylearnError(Exception):
# TODO Write doc.
pass
# All top-level classes inherit from object.
class StorageExample(object):
# TODO Write doc.
pass
# Two blank lines between definitions of top-level classes and functions.
class AwesomeLearner(learner.Learner):
# TODO Write doc.
def __init__(self, print_fields=None):
# TODO Write doc.
# print_fields is a list of strings whose counts found in the
# training set should be printed at the end of training. If None,
# then nothing is printed.
# Do not forget to call the parent class constructor.
super(AwesomeLearner, self).__init__()
# Use None instead of an empty list as default argument to
# print_fields to avoid issues with mutable default arguments.
self.print_fields = if_none(print_fields, [])
# One blank line between method definitions.
def add_field(self, field):
# TODO Write doc.
# Test if something belongs to a container with `in`, not
# container-specific methods like `index`.
if field in self.print_fields:
# TODO Print a warning and do nothing.
pass
else:
# This is why using [] as default to print_fields in the
# constructor would have been a bad idea.
self.print_fields.append(field)
def train(self, dataset):
# TODO Write doc (store the mean of each field in the training
# set).
self.mean_fields = {}
count = {}
for sample_dict in dataset:
# Whenever it is enough for what you need, use iterative
# instead of list versions of dictionary methods.
for field, value in sample_dict.iteritems():
# Keep line length to max 80 characters, using parentheses
# instead of \ to continue long lines.
self.mean_fields[field] = (self.mean_fields.get(field, 0) +
value)
count[field] = count.get(field, 0) + 1
for field in self.mean_fields:
self.mean_fields[field] /= float(count[field])
for field in self.print_fields:
# Test is done with `in`, not `has_key`.
if field in self.sum_fields:
# TODO Use log module instead.
print '%s: %s' % (field, self.sum_fields[field])
else:
# TODO Print warning.
pass
def test_error(self, dataset):
# TODO Write doc.
if not hasattr(self, 'sum_fields'):
# Exceptions should be raised as follows (in particular, no
# string exceptions!).
raise PylearnError('Cannot test a learner that was not '
'trained.')
error = 0
count = 0
for sample_dict in dataset:
for field, value in sample_dict.iteritems():
try:
# Minimize code into a try statement.
mean = self.mean_fields[field]
# Always specicy which kind of exception you are
# intercepting with except.
except KeyError:
raise PylearnError(
"Found in a test sample a field ('%s') that had "
"never been seen in the training set." % field)
error += (value - self.mean_fields[field])**2
count += 1
# Remember to divide by a floating point number unless you
# explicitly want an integer division (in which case you should
# use //).
mse = error / float(count)
# TODO Use log module instead.
print 'MSE: %s' % mse
return mse
def if_none(val_if_not_none, val_if_none):
# TODO Write doc.
if val_if_not_none is not None:
return val_if_not_none
else:
return val_if_none
def print_subdirs_in(directory):
# TODO Write doc.
# Using list comprehension rather than filter.
sub_dirs = sorted([d for d in os.listdir(directory)
if os.path.isdir(os.path.join(directory, d))])
print '%s: %s' % (directory, ' '.join(sub_dirs))
# A `for` loop is often easier to read than a call to `map`.
for d in sub_dirs:
print_subdirs_in(os.path.join(directory, d))
def main():
if len(sys.argv) != 2:
# Note: conventions on how to display script documentation and
# parse arguments are still to-be-determined. This is just one
# way to do it.
print("""\
Usage: %s <directory>
For the given directory and all sub-directories found inside it, print
the list of the directories they contain."""
% os.path.basename(sys.argv[0]))
return 1
print_subdirs_in(sys.argv[1])
return 0
# Top-level executable code should be minimal.
if __name__ == '__main__':
sys.exit(main())