# Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import sqlalchemy import sqlalchemy.ext.declarative import sqlalchemy.orm import unittest2 import oauth2client import oauth2client.client import oauth2client.contrib.sqlalchemy Base = sqlalchemy.ext.declarative.declarative_base() class DummyModel(Base): __tablename__ = 'dummy' id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) # we will query against this, because of ROWID key = sqlalchemy.Column(sqlalchemy.Integer) credentials = sqlalchemy.Column( oauth2client.contrib.sqlalchemy.CredentialsType) class TestSQLAlchemyStorage(unittest2.TestCase): def setUp(self): engine = sqlalchemy.create_engine('sqlite://') Base.metadata.create_all(engine) self.session = sqlalchemy.orm.sessionmaker(bind=engine) self.credentials = oauth2client.client.OAuth2Credentials( access_token='token', client_id='client_id', client_secret='client_secret', refresh_token='refresh_token', token_expiry=datetime.datetime.utcnow(), token_uri=oauth2client.GOOGLE_TOKEN_URI, user_agent='DummyAgent', ) def tearDown(self): session = self.session() session.query(DummyModel).filter_by(key=1).delete() session.commit() def compare_credentials(self, result): self.assertEqual(result.access_token, self.credentials.access_token) self.assertEqual(result.client_id, self.credentials.client_id) self.assertEqual(result.client_secret, self.credentials.client_secret) self.assertEqual(result.refresh_token, self.credentials.refresh_token) self.assertEqual(result.token_expiry, self.credentials.token_expiry) self.assertEqual(result.token_uri, self.credentials.token_uri) self.assertEqual(result.user_agent, self.credentials.user_agent) def test_get(self): session = self.session() credentials_storage = oauth2client.contrib.sqlalchemy.Storage( session=session, model_class=DummyModel, key_name='key', key_value=1, property_name='credentials', ) self.assertIsNone(credentials_storage.get()) session.add(DummyModel( key=1, credentials=self.credentials, )) session.commit() self.compare_credentials(credentials_storage.get()) def test_put(self): session = self.session() oauth2client.contrib.sqlalchemy.Storage( session=session, model_class=DummyModel, key_name='key', key_value=1, property_name='credentials', ).put(self.credentials) session.commit() entity = session.query(DummyModel).filter_by(key=1).first() self.compare_credentials(entity.credentials) def test_delete(self): session = self.session() session.add(DummyModel( key=1, credentials=self.credentials, )) session.commit() query = session.query(DummyModel).filter_by(key=1) self.assertIsNotNone(query.first()) oauth2client.contrib.sqlalchemy.Storage( session=session, model_class=DummyModel, key_name='key', key_value=1, property_name='credentials', ).delete() session.commit() self.assertIsNone(query.first())