Create CsrfTokenClient and LoginClient by injection, along with a little cleanup (#5491)

This commit is contained in:
Paul Hawke 2024-01-28 21:36:57 -06:00 committed by GitHub
parent 5661e8c332
commit 84ffffbbe7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 73 additions and 81 deletions

View file

@ -64,10 +64,6 @@ public class LoginActivity extends AccountAuthenticatorActivity {
@Inject @Inject
SessionManager sessionManager; SessionManager sessionManager;
@Inject
@Named(NAMED_COMMONS_WIKI_SITE)
WikiSite commonsWikiSite;
@Inject @Inject
@Named("default_preferences") @Named("default_preferences")
JsonKvStore applicationKvStore; JsonKvStore applicationKvStore;
@ -231,13 +227,13 @@ public class LoginActivity extends AccountAuthenticatorActivity {
private void doLogin(String username, String password, String twoFactorCode) { private void doLogin(String username, String password, String twoFactorCode) {
progressDialog.show(); progressDialog.show();
loginToken = ServiceFactory.get(commonsWikiSite, LoginInterface.class).getLoginToken(); loginToken = loginClient.getLoginToken();
loginToken.enqueue( loginToken.enqueue(
new Callback<MwQueryResponse>() { new Callback<MwQueryResponse>() {
@Override @Override
public void onResponse(Call<MwQueryResponse> call, public void onResponse(Call<MwQueryResponse> call,
Response<MwQueryResponse> response) { Response<MwQueryResponse> response) {
loginClient.login(commonsWikiSite, username, password, null, twoFactorCode, loginClient.login(username, password, null, twoFactorCode,
response.body().query().loginToken(), Locale.getDefault().getLanguage(), new LoginCallback() { response.body().query().loginToken(), Locale.getDefault().getLanguage(), new LoginCallback() {
@Override @Override
public void success(@NonNull LoginResult result) { public void success(@NonNull LoginResult result) {

View file

@ -3,9 +3,7 @@ package fr.free.nrw.commons.auth.csrf
import androidx.annotation.VisibleForTesting import androidx.annotation.VisibleForTesting
import fr.free.nrw.commons.auth.SessionManager import fr.free.nrw.commons.auth.SessionManager
import org.wikipedia.AppAdapter import org.wikipedia.AppAdapter
import org.wikipedia.dataclient.ServiceFactory
import org.wikipedia.dataclient.SharedPreferenceCookieManager import org.wikipedia.dataclient.SharedPreferenceCookieManager
import org.wikipedia.dataclient.WikiSite
import org.wikipedia.dataclient.mwapi.MwQueryResponse import org.wikipedia.dataclient.mwapi.MwQueryResponse
import fr.free.nrw.commons.auth.login.LoginClient import fr.free.nrw.commons.auth.login.LoginClient
import fr.free.nrw.commons.auth.login.LoginCallback import fr.free.nrw.commons.auth.login.LoginCallback
@ -19,17 +17,16 @@ import java.util.concurrent.Callable
import java.util.concurrent.Executors.newSingleThreadExecutor import java.util.concurrent.Executors.newSingleThreadExecutor
class CsrfTokenClient( class CsrfTokenClient(
private val csrfWikiSite: WikiSite, private val sessionManager: SessionManager,
private val sessionManager: SessionManager private val csrfTokenInterface: CsrfTokenInterface,
private val loginClient: LoginClient
) { ) {
private var retries = 0 private var retries = 0
private var csrfTokenCall: Call<MwQueryResponse?>? = null private var csrfTokenCall: Call<MwQueryResponse?>? = null
private val loginClient = LoginClient()
@Throws(Throwable::class) @Throws(Throwable::class)
fun getTokenBlocking(): String { fun getTokenBlocking(): String {
var token = "" var token = ""
val service = ServiceFactory.get(csrfWikiSite, CsrfTokenInterface::class.java)
val userName = AppAdapter.get().getUserName() val userName = AppAdapter.get().getUserName()
val password = AppAdapter.get().getPassword() val password = AppAdapter.get().getPassword()
@ -37,13 +34,12 @@ class CsrfTokenClient(
try { try {
if (retry > 0) { if (retry > 0) {
// Log in explicitly // Log in explicitly
LoginClient() loginClient.loginBlocking(userName, password, "")
.loginBlocking(csrfWikiSite, userName, password, "")
} }
// Get CSRFToken response off the main thread. // Get CSRFToken response off the main thread.
val response = newSingleThreadExecutor().submit(Callable { val response = newSingleThreadExecutor().submit(Callable {
service.getCsrfTokenCall().execute() csrfTokenInterface.getCsrfTokenCall().execute()
}).get() }).get()
if (response.body()?.query()?.csrfToken().isNullOrEmpty()) { if (response.body()?.query()?.csrfToken().isNullOrEmpty()) {
@ -114,7 +110,7 @@ class CsrfTokenClient(
login(userName, password, callback) { login(userName, password, callback) {
Timber.i("retrying...") Timber.i("retrying...")
cancel() cancel()
csrfTokenCall = request(ServiceFactory.get(csrfWikiSite, CsrfTokenInterface::class.java), callback) csrfTokenCall = request(csrfTokenInterface, callback)
} }
} else { } else {
callback.failure(caught()) callback.failure(caught())
@ -126,8 +122,7 @@ class CsrfTokenClient(
password: String, password: String,
callback: Callback, callback: Callback,
retryCallback: () -> Unit retryCallback: () -> Unit
) = LoginClient() ) = loginClient.request(username, password, object : LoginCallback {
.request(csrfWikiSite, username, password, object : LoginCallback {
override fun success(loginResult: LoginResult) { override fun success(loginResult: LoginResult) {
if (loginResult.pass) { if (loginResult.pass) {
sessionManager.updateAccount(loginResult) sessionManager.updateAccount(loginResult)

View file

@ -6,8 +6,6 @@ import fr.free.nrw.commons.auth.login.LoginResult.ResetPasswordResult
import fr.free.nrw.commons.wikidata.WikidataConstants.WIKIPEDIA_URL import fr.free.nrw.commons.wikidata.WikidataConstants.WIKIPEDIA_URL
import io.reactivex.android.schedulers.AndroidSchedulers import io.reactivex.android.schedulers.AndroidSchedulers
import io.reactivex.schedulers.Schedulers import io.reactivex.schedulers.Schedulers
import org.wikipedia.dataclient.ServiceFactory
import org.wikipedia.dataclient.WikiSite
import org.wikipedia.dataclient.mwapi.MwQueryResponse import org.wikipedia.dataclient.mwapi.MwQueryResponse
import retrofit2.Call import retrofit2.Call
import retrofit2.Callback import retrofit2.Callback
@ -18,7 +16,7 @@ import java.io.IOException
/** /**
* Responsible for making login related requests to the server. * Responsible for making login related requests to the server.
*/ */
class LoginClient { class LoginClient(private val loginInterface: LoginInterface) {
private var tokenCall: Call<MwQueryResponse?>? = null private var tokenCall: Call<MwQueryResponse?>? = null
private var loginCall: Call<LoginResponse?>? = null private var loginCall: Call<LoginResponse?>? = null
@ -30,14 +28,18 @@ class LoginClient {
*/ */
private var userLanguage = "" private var userLanguage = ""
fun request(wiki: WikiSite, userName: String, password: String, cb: LoginCallback) { fun getLoginToken() = loginInterface.getLoginToken()
fun request(userName: String, password: String, cb: LoginCallback) {
cancel() cancel()
tokenCall = ServiceFactory.get(wiki, LoginInterface::class.java).getLoginToken() tokenCall = getLoginToken()
tokenCall!!.enqueue(object : Callback<MwQueryResponse?> { tokenCall!!.enqueue(object : Callback<MwQueryResponse?> {
override fun onResponse(call: Call<MwQueryResponse?>, response: Response<MwQueryResponse?>) { override fun onResponse(call: Call<MwQueryResponse?>, response: Response<MwQueryResponse?>) {
login(wiki, userName, password, null, null, login(
response.body()!!.query()!!.loginToken(), userLanguage, cb) userName, password, null, null, response.body()!!.query()!!.loginToken(),
userLanguage, cb
)
} }
override fun onFailure(call: Call<MwQueryResponse?>, caught: Throwable) { override fun onFailure(call: Call<MwQueryResponse?>, caught: Throwable) {
@ -50,16 +52,15 @@ class LoginClient {
} }
fun login( fun login(
wiki: WikiSite, userName: String, password: String, retypedPassword: String?, userName: String, password: String, retypedPassword: String?, twoFactorCode: String?,
twoFactorCode: String?, loginToken: String?, userLanguage: String, cb: LoginCallback loginToken: String?, userLanguage: String, cb: LoginCallback
) { ) {
this.userLanguage = userLanguage this.userLanguage = userLanguage
loginCall = if (twoFactorCode.isNullOrEmpty() && retypedPassword.isNullOrEmpty()) { loginCall = if (twoFactorCode.isNullOrEmpty() && retypedPassword.isNullOrEmpty()) {
ServiceFactory.get(wiki, LoginInterface::class.java) loginInterface.postLogIn(userName, password, loginToken, userLanguage, WIKIPEDIA_URL)
.postLogIn(userName, password, loginToken, userLanguage, WIKIPEDIA_URL)
} else { } else {
ServiceFactory.get(wiki, LoginInterface::class.java).postLogIn( loginInterface.postLogIn(
userName, password, retypedPassword, twoFactorCode, loginToken, userLanguage, true userName, password, retypedPassword, twoFactorCode, loginToken, userLanguage, true
) )
} }
@ -69,12 +70,12 @@ class LoginClient {
call: Call<LoginResponse?>, call: Call<LoginResponse?>,
response: Response<LoginResponse?> response: Response<LoginResponse?>
) { ) {
val loginResult = response.body()?.toLoginResult(wiki, password) val loginResult = response.body()?.toLoginResult(password)
if (loginResult != null) { if (loginResult != null) {
if (loginResult.pass && !loginResult.userName.isNullOrEmpty()) { if (loginResult.pass && !loginResult.userName.isNullOrEmpty()) {
// The server could do some transformations on user names, e.g. on some // The server could do some transformations on user names, e.g. on some
// wikis is uppercases the first letter. // wikis is uppercases the first letter.
getExtendedInfo(wiki, loginResult.userName, loginResult, cb) getExtendedInfo(loginResult.userName, loginResult, cb)
} else if ("UI" == loginResult.status) { } else if ("UI" == loginResult.status) {
when (loginResult) { when (loginResult) {
is OAuthResult -> cb.twoFactorPrompt( is OAuthResult -> cb.twoFactorPrompt(
@ -106,25 +107,24 @@ class LoginClient {
} }
@Throws(Throwable::class) @Throws(Throwable::class)
fun loginBlocking(wiki: WikiSite, userName: String, password: String, twoFactorCode: String?) { fun loginBlocking(userName: String, password: String, twoFactorCode: String?) {
val tokenResponse = ServiceFactory.get(wiki, LoginInterface::class.java).getLoginToken().execute() val tokenResponse = getLoginToken().execute()
if (tokenResponse.body()?.query()?.loginToken().isNullOrEmpty()) { if (tokenResponse.body()?.query()?.loginToken().isNullOrEmpty()) {
throw IOException("Unexpected response when getting login token.") throw IOException("Unexpected response when getting login token.")
} }
val loginToken = tokenResponse.body()?.query()?.loginToken() val loginToken = tokenResponse.body()?.query()?.loginToken()
val tempLoginCall = if (twoFactorCode.isNullOrEmpty()) { val tempLoginCall = if (twoFactorCode.isNullOrEmpty()) {
ServiceFactory.get(wiki, LoginInterface::class.java).postLogIn( loginInterface.postLogIn(userName, password, loginToken, userLanguage, WIKIPEDIA_URL)
userName, password, loginToken, userLanguage, WIKIPEDIA_URL)
} else { } else {
ServiceFactory.get(wiki, LoginInterface::class.java).postLogIn( loginInterface.postLogIn(
userName, password, null, twoFactorCode, loginToken, userLanguage, true userName, password, null, twoFactorCode, loginToken, userLanguage, true
) )
} }
val response = tempLoginCall.execute() val response = tempLoginCall.execute()
val loginResponse = response.body() ?: throw IOException("Unexpected response when logging in.") val loginResponse = response.body() ?: throw IOException("Unexpected response when logging in.")
val loginResult = loginResponse.toLoginResult(wiki, password) ?: throw IOException("Unexpected response when logging in.") val loginResult = loginResponse.toLoginResult(password) ?: throw IOException("Unexpected response when logging in.")
if ("UI" == loginResult.status) { if ("UI" == loginResult.status) {
if (loginResult is OAuthResult) { if (loginResult is OAuthResult) {
@ -139,19 +139,14 @@ class LoginClient {
} }
} }
private fun getExtendedInfo( private fun getExtendedInfo(userName: String, loginResult: LoginResult, cb: LoginCallback) =
wiki: WikiSite, userName: String, loginResult: LoginResult, cb: LoginCallback loginInterface.getUserInfo(userName)
) = ServiceFactory.get(wiki, LoginInterface::class.java).getUserInfo(userName)
.subscribeOn(Schedulers.io()).observeOn(AndroidSchedulers.mainThread()) .subscribeOn(Schedulers.io()).observeOn(AndroidSchedulers.mainThread())
.subscribe({ response: MwQueryResponse? -> .subscribe({ response: MwQueryResponse? ->
loginResult.userId = response?.query()?.userInfo()?.id() ?: 0 loginResult.userId = response?.query()?.userInfo()?.id() ?: 0
loginResult.groups = response?.query()?.getUserResponse(userName)?.groups ?: emptySet() loginResult.groups =
response?.query()?.getUserResponse(userName)?.groups ?: emptySet()
cb.success(loginResult) cb.success(loginResult)
Timber.v(
"Found user ID %s for %s",
response?.query()?.userInfo()?.id(),
wiki.subdomain()
)
}, { caught: Throwable -> }, { caught: Throwable ->
Timber.e(caught, "Login succeeded but getting group information failed. ") Timber.e(caught, "Login succeeded but getting group information failed. ")
cb.error(caught) cb.error(caught)

View file

@ -4,7 +4,6 @@ import com.google.gson.annotations.SerializedName
import fr.free.nrw.commons.auth.login.LoginResult.OAuthResult import fr.free.nrw.commons.auth.login.LoginResult.OAuthResult
import fr.free.nrw.commons.auth.login.LoginResult.ResetPasswordResult import fr.free.nrw.commons.auth.login.LoginResult.ResetPasswordResult
import fr.free.nrw.commons.auth.login.LoginResult.Result import fr.free.nrw.commons.auth.login.LoginResult.Result
import org.wikipedia.dataclient.WikiSite
import org.wikipedia.dataclient.mwapi.MwServiceError import org.wikipedia.dataclient.mwapi.MwServiceError
class LoginResponse { class LoginResponse {
@ -14,8 +13,8 @@ class LoginResponse {
@SerializedName("clientlogin") @SerializedName("clientlogin")
private val clientLogin: ClientLogin? = null private val clientLogin: ClientLogin? = null
fun toLoginResult(site: WikiSite, password: String): LoginResult? { fun toLoginResult(password: String): LoginResult? {
return clientLogin?.toLoginResult(site, password) return clientLogin?.toLoginResult(password)
} }
} }
@ -27,15 +26,15 @@ internal class ClientLogin {
@SerializedName("username") @SerializedName("username")
private val userName: String? = null private val userName: String? = null
fun toLoginResult(site: WikiSite, password: String): LoginResult { fun toLoginResult(password: String): LoginResult {
var userMessage = message var userMessage = message
if ("UI" == status) { if ("UI" == status) {
if (requests != null) { if (requests != null) {
for (req in requests) { for (req in requests) {
if ("MediaWiki\\Extension\\OATHAuth\\Auth\\TOTPAuthenticationRequest" == req.id()) { if ("MediaWiki\\Extension\\OATHAuth\\Auth\\TOTPAuthenticationRequest" == req.id()) {
return OAuthResult(site, status, userName, password, message) return OAuthResult(status, userName, password, message)
} else if ("MediaWiki\\Auth\\PasswordAuthenticationRequest" == req.id()) { } else if ("MediaWiki\\Auth\\PasswordAuthenticationRequest" == req.id()) {
return ResetPasswordResult(site, status, userName, password, message) return ResetPasswordResult(status, userName, password, message)
} }
} }
} }
@ -43,7 +42,7 @@ internal class ClientLogin {
//TODO: String resource -- Looks like needed for others in this class too //TODO: String resource -- Looks like needed for others in this class too
userMessage = "An unknown error occurred." userMessage = "An unknown error occurred."
} }
return Result(site, status ?: "", userName, password, userMessage) return Result(status ?: "", userName, password, userMessage)
} }
} }

View file

@ -1,9 +1,6 @@
package fr.free.nrw.commons.auth.login package fr.free.nrw.commons.auth.login
import org.wikipedia.dataclient.WikiSite
sealed class LoginResult( sealed class LoginResult(
val site: WikiSite,
val status: String, val status: String,
val userName: String?, val userName: String?,
val password: String?, val password: String?,
@ -14,26 +11,23 @@ sealed class LoginResult(
val pass: Boolean get() = "PASS" == status val pass: Boolean get() = "PASS" == status
class Result( class Result(
site: WikiSite,
status: String, status: String,
userName: String?, userName: String?,
password: String?, password: String?,
message: String? message: String?
): LoginResult(site, status, userName, password, message) ): LoginResult(status, userName, password, message)
class OAuthResult( class OAuthResult(
site: WikiSite,
status: String, status: String,
userName: String?, userName: String?,
password: String?, password: String?,
message: String? message: String?
) : LoginResult(site, status, userName, password, message) ) : LoginResult(status, userName, password, message)
class ResetPasswordResult( class ResetPasswordResult(
site: WikiSite,
status: String, status: String,
userName: String?, userName: String?,
password: String?, password: String?,
message: String? message: String?
) : LoginResult(site, status, userName, password, message) ) : LoginResult(status, userName, password, message)
} }

View file

@ -11,6 +11,8 @@ import fr.free.nrw.commons.actions.PageEditClient;
import fr.free.nrw.commons.actions.PageEditInterface; import fr.free.nrw.commons.actions.PageEditInterface;
import fr.free.nrw.commons.actions.ThanksInterface; import fr.free.nrw.commons.actions.ThanksInterface;
import fr.free.nrw.commons.auth.SessionManager; import fr.free.nrw.commons.auth.SessionManager;
import fr.free.nrw.commons.auth.csrf.CsrfTokenInterface;
import fr.free.nrw.commons.auth.login.LoginInterface;
import fr.free.nrw.commons.category.CategoryInterface; import fr.free.nrw.commons.category.CategoryInterface;
import fr.free.nrw.commons.explore.depictions.DepictsClient; import fr.free.nrw.commons.explore.depictions.DepictsClient;
import fr.free.nrw.commons.kvstore.JsonKvStore; import fr.free.nrw.commons.kvstore.JsonKvStore;
@ -37,7 +39,6 @@ import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor; import okhttp3.logging.HttpLoggingInterceptor;
import okhttp3.logging.HttpLoggingInterceptor.Level; import okhttp3.logging.HttpLoggingInterceptor.Level;
import fr.free.nrw.commons.auth.csrf.CsrfTokenClient; import fr.free.nrw.commons.auth.csrf.CsrfTokenClient;
import org.wikipedia.dataclient.Service;
import org.wikipedia.dataclient.ServiceFactory; import org.wikipedia.dataclient.ServiceFactory;
import org.wikipedia.dataclient.WikiSite; import org.wikipedia.dataclient.WikiSite;
import org.wikipedia.json.GsonUtil; import org.wikipedia.json.GsonUtil;
@ -106,15 +107,27 @@ public class NetworkingModule {
@Named(NAMED_COMMONS_CSRF) @Named(NAMED_COMMONS_CSRF)
@Provides @Provides
@Singleton @Singleton
public CsrfTokenClient provideCommonsCsrfTokenClient( public CsrfTokenClient provideCommonsCsrfTokenClient(SessionManager sessionManager,
@Named(NAMED_COMMONS_WIKI_SITE) WikiSite commonsWikiSite, SessionManager sessionManager) { CsrfTokenInterface tokenInterface, LoginClient loginClient) {
return new CsrfTokenClient(commonsWikiSite, sessionManager); return new CsrfTokenClient(sessionManager, tokenInterface, loginClient);
} }
@Provides @Provides
@Singleton @Singleton
public LoginClient provideLoginClient() { public CsrfTokenInterface provideCsrfTokenInterface(@Named(NAMED_COMMONS_WIKI_SITE) WikiSite commonsWikiSite) {
return new LoginClient(); return ServiceFactory.get(commonsWikiSite, BuildConfig.COMMONS_URL, CsrfTokenInterface.class);
}
@Provides
@Singleton
public LoginInterface provideLoginInterface(@Named(NAMED_COMMONS_WIKI_SITE) WikiSite commonsWikiSite) {
return ServiceFactory.get(commonsWikiSite, BuildConfig.COMMONS_URL, LoginInterface.class);
}
@Provides
@Singleton
public LoginClient provideLoginClient(LoginInterface loginInterface) {
return new LoginClient(loginInterface);
} }
@Provides @Provides

View file

@ -3,6 +3,7 @@ package fr.free.nrw.commons.auth.csrf
import com.google.gson.stream.MalformedJsonException import com.google.gson.stream.MalformedJsonException
import fr.free.nrw.commons.MockWebServerTest import fr.free.nrw.commons.MockWebServerTest
import fr.free.nrw.commons.auth.SessionManager import fr.free.nrw.commons.auth.SessionManager
import fr.free.nrw.commons.auth.login.LoginClient
import org.junit.Test import org.junit.Test
import org.mockito.ArgumentMatchers.any import org.mockito.ArgumentMatchers.any
import org.mockito.ArgumentMatchers.eq import org.mockito.ArgumentMatchers.eq
@ -10,16 +11,15 @@ import org.mockito.ArgumentMatchers.isA
import org.mockito.Mockito.mock import org.mockito.Mockito.mock
import org.mockito.Mockito.never import org.mockito.Mockito.never
import org.mockito.Mockito.verify import org.mockito.Mockito.verify
import org.wikipedia.dataclient.Service
import org.wikipedia.dataclient.WikiSite
import org.wikipedia.dataclient.mwapi.MwException import org.wikipedia.dataclient.mwapi.MwException
import org.wikipedia.dataclient.okhttp.HttpStatusException import org.wikipedia.dataclient.okhttp.HttpStatusException
class CsrfTokenClientTest : MockWebServerTest() { class CsrfTokenClientTest : MockWebServerTest() {
private val wikiSite = WikiSite("test.wikipedia.org")
private val cb = mock(CsrfTokenClient.Callback::class.java) private val cb = mock(CsrfTokenClient.Callback::class.java)
private val sessionManager = mock(SessionManager::class.java) private val sessionManager = mock(SessionManager::class.java)
private val subject = CsrfTokenClient(wikiSite, sessionManager) private val tokenInterface = mock(CsrfTokenInterface::class.java)
private val loginClient = mock(LoginClient::class.java)
private val subject = CsrfTokenClient(sessionManager, tokenInterface, loginClient)
@Test @Test
@Throws(Throwable::class) @Throws(Throwable::class)